Merge pull request #4 from Moonlite-Media/user/aasherkataria/Basic-Auth

Secure stable diffusion with basic authentication
This commit is contained in:
Aasher Kataria 2024-10-02 21:37:26 -05:00 committed by GitHub
commit e4e30c679f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 63 additions and 1 deletions

2
.env.sample Normal file
View file

@ -0,0 +1,2 @@
SDAPI_USERNAME=admin
SDAPI_PASSWORD=test123

View file

@ -21,6 +21,16 @@ jobs:
with:
ssh-private-key: ${{ secrets.MOONLITE_AWS_EC2_SSH_KEY }}
- name: Create .env file
working-directory: ./
run: |
echo "SDAPI_USERNAME=${{ secrets.MOONLITE_SDAPI_USERNAME }}" >> .env
echo "SDAPI_PASSWORD=${{ secrets.MOONLITE_SDAPI_PASSWORD }}" >> .env
- name: Copy .env file to EC2 instance
run: |
scp -o StrictHostKeyChecking=no -i ${{ secrets.MOONLITE_AWS_EC2_SSH_KEY }} .env ${{ secrets.MOONLITE_AWS_EC2_SSH_USER }}@${{ secrets.MOONLITE_AWS_EC2_SSH_HOST }}:/home/ec2-user/apps/stable-diffusion/
- name: SSH into EC2 and deploy the app
run: |
ssh -o ServerAliveInterval=60 -o ServerAliveCountMax=60 -o StrictHostKeyChecking=no ${{ secrets.MOONLITE_AWS_EC2_SSH_USER }}@${{ secrets.MOONLITE_AWS_EC2_SSH_HOST }} << 'EOF'

1
.gitignore vendored
View file

@ -8,6 +8,7 @@ __pycache__
/venv
/tmp
/model.ckpt
.env
# /models/**/*
/GFPGANv1.3.pth
/gfpgan/weights/*.pth

39
basic_auth_middleware.py Normal file
View file

@ -0,0 +1,39 @@
import base64
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from fastapi import FastAPI
class BasicAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, username: str, password: str):
super().__init__(app)
self.username = username
self.password = password
async def dispatch(self, request, call_next):
# Extract the Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "):
return self._unauthorized_response()
try:
# Decode and split the credentials
encoded_credentials = auth_header.split(" ")[1]
# We should add a step to santize the input here
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
provided_username, provided_password = decoded_credentials.split(":")
# Check credentials
if provided_username == self.username and provided_password == self.password:
response = await call_next(request)
return response
else:
return self._unauthorized_response()
except Exception:
return self._unauthorized_response()
def _unauthorized_response(self):
return Response(
content="Unauthorized",
status_code=401,
headers={"WWW-Authenticate": "Basic realm='api'"}
)

@ -1 +1 @@
Subproject commit 8bbbd0e55ef6e5d71b09c2de2727b36e7bc825b0
Subproject commit 56cec5b2958edf3b1807b7e7b2b1b5186dbd2f81

View file

@ -20,6 +20,7 @@ open-clip-torch
piexif
psutil
pytorch_lightning
python-dotenv
requests
resize-right

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from dotenv import load_dotenv
import os
import time
@ -9,6 +10,7 @@ from modules import initialize
startup_timer = timer.startup_timer
startup_timer.record("launcher")
load_dotenv()
initialize.imports()
@ -30,6 +32,13 @@ def api_only():
initialize.initialize()
app = FastAPI()
# Initialize the Basic Authentication middleware
from basic_auth_middleware import BasicAuthMiddleware
USERNAME = os.getenv('SDAPI_USERNAME')
PASSWORD = os.getenv('SDAPI_PASSWORD')
app.add_middleware(BasicAuthMiddleware, username=USERNAME, password=PASSWORD)
initialize_util.setup_middleware(app)
api = create_api(app)