mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
refactor: simplify websocket message handling and add cache load logic
- Modified `Cloudflare` class in `Cloudflare.py` to add logic for loading `_args` from a cache file if it exists and `_args` is `None` - Inserted code in `Cloudflare.py` to check existence of cache file and read JSON content into `_args` - Refactored `Copilot` class in `Copilot.py` by removing `try`/`finally` block around websocket message loop - Moved websocket close logic to the end of the message handling loop in `Copilot.py` - Removed nested `try`/`except` block inside the websocket loop in `Copilot.py` - Preserved original message handling structure while simplifying control flow in `Copilot.py
This commit is contained in:
parent
8243a47f75
commit
5ff7c88428
2 changed files with 41 additions and 38 deletions
|
|
@ -66,6 +66,11 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
|||
cls.model_aliases = {**cls.model_aliases, **model_map}
|
||||
if not cls.models:
|
||||
try:
|
||||
cache_file = cls.get_cache_file()
|
||||
if cls._args is None:
|
||||
if cache_file.exists():
|
||||
with cache_file.open("r") as f:
|
||||
cls._args = json.load(f)
|
||||
if cls._args is None:
|
||||
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
|
||||
read_models()
|
||||
|
|
|
|||
|
|
@ -158,44 +158,42 @@ class Copilot(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
image_prompt: str = None
|
||||
last_msg = None
|
||||
sources = {}
|
||||
try:
|
||||
while not wss.closed:
|
||||
try:
|
||||
msg = await asyncio.wait_for(wss.recv(), 3 if done else timeout)
|
||||
msg = json.loads(msg[0])
|
||||
except:
|
||||
break
|
||||
last_msg = msg
|
||||
if msg.get("event") == "appendText":
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") == "generatingImage":
|
||||
image_prompt = msg.get("prompt")
|
||||
elif msg.get("event") == "imageGenerated":
|
||||
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||
elif msg.get("event") == "done":
|
||||
yield FinishReason("stop")
|
||||
done = True
|
||||
elif msg.get("event") == "suggestedFollowups":
|
||||
yield SuggestedFollowups(msg.get("suggestions"))
|
||||
break
|
||||
elif msg.get("event") == "replaceText":
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") == "titleUpdate":
|
||||
yield TitleGeneration(msg.get("title"))
|
||||
elif msg.get("event") == "citation":
|
||||
sources[msg.get("url")] = msg
|
||||
yield SourceLink(list(sources.keys()).index(msg.get("url")), msg.get("url"))
|
||||
elif msg.get("event") == "error":
|
||||
raise RuntimeError(f"Error: {msg}")
|
||||
elif msg.get("event") not in ["received", "startMessage", "partCompleted"]:
|
||||
debug.log(f"Copilot Message: {msg}")
|
||||
if not done:
|
||||
raise RuntimeError(f"Invalid response: {last_msg}")
|
||||
if sources:
|
||||
yield Sources(sources.values())
|
||||
finally:
|
||||
if not wss.closed:
|
||||
await wss.close()
|
||||
while not wss.closed:
|
||||
try:
|
||||
msg = await asyncio.wait_for(wss.recv(), 3 if done else timeout)
|
||||
msg = json.loads(msg[0])
|
||||
except:
|
||||
break
|
||||
last_msg = msg
|
||||
if msg.get("event") == "appendText":
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") == "generatingImage":
|
||||
image_prompt = msg.get("prompt")
|
||||
elif msg.get("event") == "imageGenerated":
|
||||
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||
elif msg.get("event") == "done":
|
||||
yield FinishReason("stop")
|
||||
done = True
|
||||
elif msg.get("event") == "suggestedFollowups":
|
||||
yield SuggestedFollowups(msg.get("suggestions"))
|
||||
break
|
||||
elif msg.get("event") == "replaceText":
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") == "titleUpdate":
|
||||
yield TitleGeneration(msg.get("title"))
|
||||
elif msg.get("event") == "citation":
|
||||
sources[msg.get("url")] = msg
|
||||
yield SourceLink(list(sources.keys()).index(msg.get("url")), msg.get("url"))
|
||||
elif msg.get("event") == "error":
|
||||
raise RuntimeError(f"Error: {msg}")
|
||||
elif msg.get("event") not in ["received", "startMessage", "partCompleted"]:
|
||||
debug.log(f"Copilot Message: {msg}")
|
||||
if not done:
|
||||
raise RuntimeError(f"Invalid response: {last_msg}")
|
||||
if sources:
|
||||
yield Sources(sources.values())
|
||||
if not wss.closed:
|
||||
await wss.close()
|
||||
|
||||
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
|
||||
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue