gpt4free/g4f/Provider/github/githubOAuth2.py

159 lines
5.9 KiB
Python

import time
from typing import Dict, Optional, Union
import aiohttp
from .stubs import IGithubOAuth2Client, GithubCredentials, ErrorDataDict
from .sharedTokenManager import SharedTokenManager
# GitHub OAuth endpoints
GITHUB_DEVICE_CODE_ENDPOINT = "https://github.com/login/device/code"
GITHUB_TOKEN_ENDPOINT = "https://github.com/login/oauth/access_token"
# GitHub Copilot OAuth Client ID (VS Code Extension)
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
# Scopes needed for Copilot
GITHUB_COPILOT_SCOPE = "read:user"
TOKEN_REFRESH_BUFFER_MS = 30 * 1000 # 30 seconds
def object_to_urlencoded(data: Dict[str, str]) -> str:
return "&".join([f"{k}={v}" for k, v in data.items()])
def isDeviceAuthorizationSuccess(response: Union[Dict, ErrorDataDict]) -> bool:
return "device_code" in response
def isDeviceTokenSuccess(response: Union[Dict, ErrorDataDict]) -> bool:
return (
"access_token" in response
and response["access_token"]
and isinstance(response["access_token"], str)
and len(response["access_token"]) > 0
)
def isDeviceTokenPending(response: Union[Dict, ErrorDataDict]) -> bool:
return response.get("error") == "authorization_pending"
def isSlowDown(response: Union[Dict, ErrorDataDict]) -> bool:
return response.get("error") == "slow_down"
def isErrorResponse(response: Union[Dict, ErrorDataDict]) -> bool:
return "error" in response and response.get("error") not in ["authorization_pending", "slow_down"]
class GithubOAuth2Client(IGithubOAuth2Client):
def __init__(self, client_id: str = GITHUB_COPILOT_CLIENT_ID):
self.client_id = client_id
self.credentials: GithubCredentials = GithubCredentials()
self.sharedManager = SharedTokenManager.getInstance()
def setCredentials(self, credentials: GithubCredentials):
self.credentials = credentials
def getCredentials(self) -> GithubCredentials:
return self.credentials
async def getAccessToken(self) -> Dict[str, Optional[str]]:
try:
credentials = await self.sharedManager.getValidCredentials(self)
return {"token": credentials.get("access_token")}
except Exception:
# fallback to internal credentials if valid
if (
self.credentials.get("access_token")
and self.isTokenValid(self.credentials)
):
return {"token": self.credentials["access_token"]}
return {"token": None}
async def requestDeviceAuthorization(self, options: dict) -> Union[Dict, ErrorDataDict]:
"""
Request device authorization from GitHub.
Returns:
dict with device_code, user_code, verification_uri, expires_in, interval
"""
body_data = {
"client_id": self.client_id,
"scope": options.get("scope", GITHUB_COPILOT_SCOPE),
}
async with aiohttp.ClientSession() as session:
async with session.post(
GITHUB_DEVICE_CODE_ENDPOINT,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
data=object_to_urlencoded(body_data)
) as resp:
resp_json = await resp.json()
if resp.status != 200:
raise Exception(f"Device authorization failed {resp.status}: {resp_json}")
if not isDeviceAuthorizationSuccess(resp_json):
raise Exception(
f"Device authorization error: {resp_json.get('error')} - {resp_json.get('error_description')}"
)
return resp_json
async def pollDeviceToken(self, options: dict) -> Union[Dict, ErrorDataDict]:
"""
Poll for device token from GitHub.
Args:
options: dict with device_code
Returns:
dict with access_token, token_type, scope or status=pending
"""
body_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"client_id": self.client_id,
"device_code": options["device_code"],
}
async with aiohttp.ClientSession() as session:
async with session.post(
GITHUB_TOKEN_ENDPOINT,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
},
data=object_to_urlencoded(body_data)
) as resp:
resp_json = await resp.json()
# Check for OAuth RFC 8628 responses
if "error" in resp_json:
if resp_json["error"] == "authorization_pending":
return {"status": "pending"}
if resp_json["error"] == "slow_down":
return {"status": "pending", "slowDown": True}
if resp_json["error"] == "expired_token":
raise Exception("Device code expired. Please try again.")
if resp_json["error"] == "access_denied":
raise Exception("Authorization was denied by the user.")
raise Exception(f"Token poll failed: {resp_json.get('error')} - {resp_json.get('error_description')}")
return resp_json
def isTokenValid(self, credentials: GithubCredentials) -> bool:
"""GitHub tokens don't expire by default, but we track expiry_date if set"""
if not credentials.get("access_token"):
return False
expiry_date = credentials.get("expiry_date")
if expiry_date is None:
# GitHub tokens don't expire unless explicitly set
return True
return time.time() * 1000 < expiry_date - TOKEN_REFRESH_BUFFER_MS