mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 10:40:43 -08:00
Refactor Copilot and PollinationsAI classes for improved error handling and timeout adjustments; add rate limiting in API class based on user IP.
This commit is contained in:
parent
3b4ad875cc
commit
23218c4aa3
3 changed files with 30 additions and 75 deletions
|
|
@ -284,7 +284,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
sources = {}
|
||||
while not wss.closed:
|
||||
try:
|
||||
msg_txt, _ = await asyncio.wait_for(wss.recv(), 3 if done else timeout)
|
||||
msg_txt, _ = await asyncio.wait_for(wss.recv(), 1 if done else timeout)
|
||||
msg = json.loads(msg_txt)
|
||||
except:
|
||||
break
|
||||
|
|
@ -369,7 +369,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None, needs_auth:
|
|||
button = await page.select("[data-testid=\"submit-button\"]")
|
||||
if button:
|
||||
await button.click()
|
||||
turnstile = await page.select('#cf-turnstile', 300)
|
||||
turnstile = await page.select('#cf-turnstile')
|
||||
if turnstile:
|
||||
debug.log("Found Element: 'cf-turnstile'")
|
||||
await asyncio.sleep(3)
|
||||
|
|
|
|||
|
|
@ -31,37 +31,6 @@ DEFAULT_HEADERS = {
|
|||
"origin": "https://pollinations.ai",
|
||||
}
|
||||
|
||||
FOLLOWUPS_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "options",
|
||||
"description": "Provides options for the conversation",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"title": {
|
||||
"title": "Conversation title. Prefixed with one or more emojies",
|
||||
"type": "string"
|
||||
},
|
||||
"followups": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"title": "Suggested 4 Followups (only user messages)",
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"title": "Conversation",
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
FOLLOWUPS_DEVELOPER_MESSAGE = [{
|
||||
"role": "developer",
|
||||
"content": "Provide conversation options.",
|
||||
}]
|
||||
|
||||
|
||||
class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Pollinations AI 🌸"
|
||||
url = "https://pollinations.ai"
|
||||
|
|
@ -375,12 +344,12 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
await raise_for_status(response)
|
||||
except Exception as e:
|
||||
responses.add(e)
|
||||
debug.error(f"Error fetching image: {e}")
|
||||
debug.error(f"Error fetching image:", e)
|
||||
if response.headers['content-type'].startswith("image/"):
|
||||
responses.add(ImageResponse(str(response.url), prompt, {"headers": headers}))
|
||||
else:
|
||||
t_ = await response.text()
|
||||
debug.error(f"UnHandel Error fetching image: {t_}")
|
||||
debug.error(f"UnHandel Error fetching image:", t_)
|
||||
responses.add(t_)
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
|
@ -465,43 +434,6 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
||||
if response.status in (400, 500):
|
||||
debug.error(f"Error: {response.status} - Bad Request: {data}")
|
||||
full_resposne = []
|
||||
async for chunk in read_response(response, stream, format_media_prompt(messages), cls.get_dict(),
|
||||
kwargs.get("download_media", True)):
|
||||
if isinstance(chunk, str):
|
||||
full_resposne.append(chunk)
|
||||
yield chunk
|
||||
if full_resposne:
|
||||
full_content = "".join(full_resposne)
|
||||
if kwargs.get("action") == "next" and model != "evil":
|
||||
tool_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
if isinstance(message.get("content"), str):
|
||||
tool_messages.append({"role": "user", "content": message.get("content")})
|
||||
elif isinstance(message.get("content"), list):
|
||||
next_value = message.get("content").pop()
|
||||
if isinstance(next_value, dict):
|
||||
next_value = next_value.get("text")
|
||||
if next_value:
|
||||
tool_messages.append({"role": "user", "content": next_value})
|
||||
tool_messages.append({"role": "assistant", "content": full_content})
|
||||
data = {
|
||||
"model": "openai",
|
||||
"messages": tool_messages + FOLLOWUPS_DEVELOPER_MESSAGE,
|
||||
"tool_choice": "required",
|
||||
"tools": FOLLOWUPS_TOOLS
|
||||
}
|
||||
async with session.post(cls.openai_endpoint, json=data, headers=headers) as response:
|
||||
try:
|
||||
await raise_for_status(response)
|
||||
tool_calls = (await response.json()).get("choices", [{}])[0].get("message", {}).get(
|
||||
"tool_calls", [])
|
||||
if tool_calls:
|
||||
arguments = json.loads(tool_calls.pop().get("function", {}).get("arguments"))
|
||||
if arguments.get("title"):
|
||||
yield TitleGeneration(arguments.get("title"))
|
||||
if arguments.get("followups"):
|
||||
yield SuggestedFollowups(arguments.get("followups"))
|
||||
except Exception as e:
|
||||
debug.error("Error generating title and followups:", e)
|
||||
yield chunk
|
||||
|
|
@ -65,7 +65,7 @@ from g4f.client.helper import filter_none
|
|||
from g4f.config import DEFAULT_PORT, DEFAULT_TIMEOUT, DEFAULT_STREAM_TIMEOUT
|
||||
from g4f.image import EXTENSIONS_MAP, is_data_an_media, process_image
|
||||
from g4f.image.copy_images import get_media_dir, copy_media, get_source_url
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError, MissingRequirementsError
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError, MissingRequirementsError, RateLimitError
|
||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
from g4f.providers.types import ProviderType
|
||||
from g4f.providers.response import AudioResponse
|
||||
|
|
@ -417,6 +417,8 @@ class Api:
|
|||
})
|
||||
return ErrorResponse.from_message("The model does not exist.", HTTP_404_NOT_FOUND)
|
||||
|
||||
most_wanted = {}
|
||||
failure_counts = {}
|
||||
responses = {
|
||||
HTTP_200_OK: {"model": ChatCompletion},
|
||||
HTTP_401_UNAUTHORIZED: {"model": ErrorResponseModel},
|
||||
|
|
@ -433,8 +435,29 @@ class Api:
|
|||
provider: str = None,
|
||||
conversation_id: str = None,
|
||||
x_user: Annotated[str | None, Header()] = None,
|
||||
cf_ipcountry: Annotated[str | None, Header()] = None
|
||||
cf_ipcountry: Annotated[str | None, Header()] = None,
|
||||
x_forwarded_for: Annotated[str | None, Header()] = None
|
||||
):
|
||||
if AppConfig.demo and x_forwarded_for is not None:
|
||||
current_most_wanted = next(iter(most_wanted.values()), 0)
|
||||
is_most_wanted = False
|
||||
if x_forwarded_for in most_wanted:
|
||||
if failure_counts.get(x_forwarded_for, 0) > 0:
|
||||
failure_counts[x_forwarded_for] -= 1
|
||||
most_wanted[x_forwarded_for] += 1
|
||||
elif most_wanted[x_forwarded_for] >= current_most_wanted:
|
||||
if x_forwarded_for not in failure_counts:
|
||||
failure_counts[x_forwarded_for] = 0
|
||||
failure_counts[x_forwarded_for] += 1
|
||||
is_most_wanted = True
|
||||
else:
|
||||
most_wanted[x_forwarded_for] += 1
|
||||
else:
|
||||
most_wanted[x_forwarded_for] = 1
|
||||
sorted_most_wanted = dict(sorted(most_wanted.items(), key=lambda item: item[1], reverse=True))
|
||||
debug.log(f"Most wanted IPs: {sorted_most_wanted}")
|
||||
if is_most_wanted:
|
||||
raise RateLimitError("You are most wanted! Please wait before making another request.")
|
||||
if provider is not None and provider not in Provider.__map__:
|
||||
if provider in model_map:
|
||||
config.model = provider
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue