stable-diffusion-webui/basic_auth_middleware.py
2024-09-30 11:58:26 -05:00

39 lines
1.5 KiB
Python

import base64
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import FastAPI
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.username = username
self.password = password
async def dispatch(self, request, call_next):
# Extract the Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "):
return self._unauthorized_response()
try:
# Decode and split the credentials
encoded_credentials = auth_header.split(" ")[1]
# We should add a step to santize the input here
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
provided_username, provided_password = decoded_credentials.split(":")
# Check credentials
if provided_username == self.username and provided_password == self.password:
response = await call_next(request)
return response
else:
return self._unauthorized_response()
except Exception:
return self._unauthorized_response()
def _unauthorized_response(self):
return Response(
content="Unauthorized",
status_code=401,
headers={"WWW-Authenticate": "Basic realm='api'"}
)