mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-01-31 04:42:09 -08:00
46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
import base64
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import Response
|
|
from fastapi import FastAPI
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
class BasicAuthMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, username: str, password: str):
|
|
super().__init__(app)
|
|
self.username = username
|
|
self.password = password
|
|
|
|
async def dispatch(self, request, call_next):
|
|
|
|
# Allow unauthenticated access to the `/sdapi/v1/status` endpoint
|
|
if request.url.path == "/sdapi/v1/status":
|
|
return await call_next(request)
|
|
|
|
# Extract the Authorization header
|
|
auth_header = request.headers.get("Authorization") or request.headers.get("authorization")
|
|
if not auth_header or not auth_header.startswith("Basic "):
|
|
return self._unauthorized_response()
|
|
|
|
try:
|
|
# Decode and split the credentials
|
|
encoded_credentials = auth_header.split(" ")[1]
|
|
# We should add a step to santize the input here
|
|
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
|
|
provided_username, provided_password = decoded_credentials.split(":")
|
|
|
|
# Check credentials
|
|
if provided_username == self.username and provided_password == self.password:
|
|
response = await call_next(request)
|
|
return response
|
|
else:
|
|
return self._unauthorized_response()
|
|
except Exception:
|
|
return self._unauthorized_response()
|
|
|
|
def _unauthorized_response(self):
|
|
return Response(
|
|
content="Unauthorized",
|
|
status_code=401,
|
|
headers={"WWW-Authenticate": "Basic realm='api'"}
|
|
)
|