Refactor Copilot and PollinationsAI classes for improved error handling and timeout adjustments; add rate limiting in API class based on user IP.

This commit is contained in:
hlohaus 2025-10-31 14:17:41 +01:00
parent 3b4ad875cc
commit 23218c4aa3
3 changed files with 30 additions and 75 deletions

View file

@ -284,7 +284,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
sources = {} sources = {}
while not wss.closed: while not wss.closed:
try: try:
msg_txt, _ = await asyncio.wait_for(wss.recv(), 3 if done else timeout) msg_txt, _ = await asyncio.wait_for(wss.recv(), 1 if done else timeout)
msg = json.loads(msg_txt) msg = json.loads(msg_txt)
except: except:
break break
@ -369,7 +369,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None, needs_auth:
button = await page.select("[data-testid=\"submit-button\"]") button = await page.select("[data-testid=\"submit-button\"]")
if button: if button:
await button.click() await button.click()
turnstile = await page.select('#cf-turnstile', 300) turnstile = await page.select('#cf-turnstile')
if turnstile: if turnstile:
debug.log("Found Element: 'cf-turnstile'") debug.log("Found Element: 'cf-turnstile'")
await asyncio.sleep(3) await asyncio.sleep(3)

View file

@ -31,37 +31,6 @@ DEFAULT_HEADERS = {
"origin": "https://pollinations.ai", "origin": "https://pollinations.ai",
} }
FOLLOWUPS_TOOLS = [{
"type": "function",
"function": {
"name": "options",
"description": "Provides options for the conversation",
"parameters": {
"properties": {
"title": {
"title": "Conversation title. Prefixed with one or more emojies",
"type": "string"
},
"followups": {
"items": {
"type": "string"
},
"title": "Suggested 4 Followups (only user messages)",
"type": "array"
}
},
"title": "Conversation",
"type": "object"
}
}
}]
FOLLOWUPS_DEVELOPER_MESSAGE = [{
"role": "developer",
"content": "Provide conversation options.",
}]
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Pollinations AI 🌸" label = "Pollinations AI 🌸"
url = "https://pollinations.ai" url = "https://pollinations.ai"
@ -375,12 +344,12 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response) await raise_for_status(response)
except Exception as e: except Exception as e:
responses.add(e) responses.add(e)
debug.error(f"Error fetching image: {e}") debug.error(f"Error fetching image:", e)
if response.headers['content-type'].startswith("image/"): if response.headers['content-type'].startswith("image/"):
responses.add(ImageResponse(str(response.url), prompt, {"headers": headers})) responses.add(ImageResponse(str(response.url), prompt, {"headers": headers}))
else: else:
t_ = await response.text() t_ = await response.text()
debug.error(f"UnHandel Error fetching image: {t_}") debug.error(f"UnHandel Error fetching image:", t_)
responses.add(t_) responses.add(t_)
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
@ -465,43 +434,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response: async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
if response.status in (400, 500): if response.status in (400, 500):
debug.error(f"Error: {response.status} - Bad Request: {data}") debug.error(f"Error: {response.status} - Bad Request: {data}")
full_resposne = []
async for chunk in read_response(response, stream, format_media_prompt(messages), cls.get_dict(), async for chunk in read_response(response, stream, format_media_prompt(messages), cls.get_dict(),
kwargs.get("download_media", True)): kwargs.get("download_media", True)):
if isinstance(chunk, str):
full_resposne.append(chunk)
yield chunk yield chunk
if full_resposne:
full_content = "".join(full_resposne)
if kwargs.get("action") == "next" and model != "evil":
tool_messages = []
for message in messages:
if message.get("role") == "user":
if isinstance(message.get("content"), str):
tool_messages.append({"role": "user", "content": message.get("content")})
elif isinstance(message.get("content"), list):
next_value = message.get("content").pop()
if isinstance(next_value, dict):
next_value = next_value.get("text")
if next_value:
tool_messages.append({"role": "user", "content": next_value})
tool_messages.append({"role": "assistant", "content": full_content})
data = {
"model": "openai",
"messages": tool_messages + FOLLOWUPS_DEVELOPER_MESSAGE,
"tool_choice": "required",
"tools": FOLLOWUPS_TOOLS
}
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
try:
await raise_for_status(response)
tool_calls = (await response.json()).get("choices", [{}])[0].get("message", {}).get(
"tool_calls", [])
if tool_calls:
arguments = json.loads(tool_calls.pop().get("function", {}).get("arguments"))
if arguments.get("title"):
yield TitleGeneration(arguments.get("title"))
if arguments.get("followups"):
yield SuggestedFollowups(arguments.get("followups"))
except Exception as e:
debug.error("Error generating title and followups:", e)

View file

@ -65,7 +65,7 @@ from g4f.client.helper import filter_none
from g4f.config import DEFAULT_PORT, DEFAULT_TIMEOUT, DEFAULT_STREAM_TIMEOUT from g4f.config import DEFAULT_PORT, DEFAULT_TIMEOUT, DEFAULT_STREAM_TIMEOUT
from g4f.image import EXTENSIONS_MAP, is_data_an_media, process_image from g4f.image import EXTENSIONS_MAP, is_data_an_media, process_image
from g4f.image.copy_images import get_media_dir, copy_media, get_source_url from g4f.image.copy_images import get_media_dir, copy_media, get_source_url
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError, MissingRequirementsError from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError, MissingRequirementsError, RateLimitError
from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.providers.types import ProviderType from g4f.providers.types import ProviderType
from g4f.providers.response import AudioResponse from g4f.providers.response import AudioResponse
@ -417,6 +417,8 @@ class Api:
}) })
return ErrorResponse.from_message("The model does not exist.", HTTP_404_NOT_FOUND) return ErrorResponse.from_message("The model does not exist.", HTTP_404_NOT_FOUND)
most_wanted = {}
failure_counts = {}
responses = { responses = {
HTTP_200_OK: {"model": ChatCompletion}, HTTP_200_OK: {"model": ChatCompletion},
HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel}, HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel},
@ -433,8 +435,29 @@ class Api:
provider: str = None, provider: str = None,
conversation_id: str = None, conversation_id: str = None,
x_user: Annotated[str | None, Header()] = None, x_user: Annotated[str | None, Header()] = None,
cf_ipcountry: Annotated[str | None, Header()] = None cf_ipcountry: Annotated[str | None, Header()] = None,
x_forwarded_for: Annotated[str | None, Header()] = None
): ):
if AppConfig.demo and x_forwarded_for is not None:
current_most_wanted = next(iter(most_wanted.values()), 0)
is_most_wanted = False
if x_forwarded_for in most_wanted:
if failure_counts.get(x_forwarded_for, 0) > 0:
failure_counts[x_forwarded_for] -= 1
most_wanted[x_forwarded_for] += 1
elif most_wanted[x_forwarded_for] >= current_most_wanted:
if x_forwarded_for not in failure_counts:
failure_counts[x_forwarded_for] = 0
failure_counts[x_forwarded_for] += 1
is_most_wanted = True
else:
most_wanted[x_forwarded_for] += 1
else:
most_wanted[x_forwarded_for] = 1
sorted_most_wanted = dict(sorted(most_wanted.items(), key=lambda item: item[1], reverse=True))
debug.log(f"Most wanted IPs: {sorted_most_wanted}")
if is_most_wanted:
raise RateLimitError("You are most wanted! Please wait before making another request.")
if provider is not None and provider not in Provider.__map__: if provider is not None and provider not in Provider.__map__:
if provider in model_map: if provider in model_map:
config.model = provider config.model = provider