stable-diffusion-webui/basic_auth_middleware.py

43 lines
1.8 KiB
Python

import base64
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import FastAPI
import logging
logger = logging.getLogger(__name__)
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.username = username
self.password = password
async def dispatch(self, request, call_next):
# Extract the Authorization header
logger.warning(f"{request.headers}")
auth_header = request.headers.get("Authorization") or request.headers.get("authorization")
if not auth_header or not auth_header.startswith("Basic "):
return self._unauthorized_response()
try:
# Decode and split the credentials
encoded_credentials = auth_header.split(" ")[1]
# We should add a step to santize the input here
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
provided_username, provided_password = decoded_credentials.split(":")
logger.warning(f"provided_username: {provided_username}: provided_password: {provided_password}")
# Check credentials
if provided_username == self.username and provided_password == self.password:
response = await call_next(request)
return response
else:
return self._unauthorized_response()
except Exception:
return self._unauthorized_response()
def _unauthorized_response(self):
return Response(
content="Unauthorized",
status_code=401,
headers={"WWW-Authenticate": "Basic realm='api'"}
)