mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
Add Path and PathLike support when uploading images (#2514)
* Add Path and PathLike support when uploading images Improve raise_for_status in special cases Move ImageResponse to providers.response module Improve OpenaiChat and OpenaiAccount providers Add Sources for web_search in OpenaiChat Add JsonConversation for import and export conversations to js Add RequestLogin response type Add TitleGeneration support in OpenaiChat and gui * Improve Docker Container Guide in README.md * Add tool calls api support, add search tool support
This commit is contained in:
parent
9918df98b3
commit
86e36efe6b
34 changed files with 935 additions and 1333 deletions
22
README.md
22
README.md
|
|
@ -97,10 +97,19 @@ To access the space, please use the following login credentials:
|
|||
|
||||
1. **Install Docker:** Begin by [downloading and installing Docker](https://docs.docker.com/get-docker/).
|
||||
|
||||
2. **Set Up the Container:**
|
||||
2. **Check Directories:**
|
||||
|
||||
Before running the container, make sure the necessary data directories exist or can be created. For example, you can create and set ownership on these directories by running:
|
||||
|
||||
```bash
|
||||
mkdir -p ${PWD}/har_and_cookies ${PWD}/generated_images
|
||||
chown -R 1000:1000 ${PWD}/har_and_cookies ${PWD}/generated_images
|
||||
```
|
||||
|
||||
3. **Set Up the Container:**
|
||||
Use the following commands to pull the latest image and start the container:
|
||||
|
||||
```sh
|
||||
```bash
|
||||
docker pull hlohaus789/g4f
|
||||
docker run \
|
||||
-p 8080:8080 -p 1337:1337 -p 7900:7900 \
|
||||
|
|
@ -110,7 +119,9 @@ docker run \
|
|||
hlohaus789/g4f:latest
|
||||
```
|
||||
|
||||
To run the slim docker image. Use this command:
|
||||
##### Running the Slim Docker Image
|
||||
|
||||
Use the following command to run the Slim Docker image. This command also updates the `g4f` package at startup and installs any additional dependencies:
|
||||
|
||||
```bash
|
||||
docker run \
|
||||
|
|
@ -122,14 +133,13 @@ docker run \
|
|||
&& pip install -U g4f[slim] \
|
||||
&& python -m g4f --debug
|
||||
```
|
||||
It also updates the `g4f` package at startup and installs any new required dependencies.
|
||||
|
||||
3. **Access the Client:**
|
||||
4. **Access the Client:**
|
||||
|
||||
- To use the included client, navigate to: [http://localhost:8080/chat/](http://localhost:8080/chat/) or [http://localhost:1337/chat/](http://localhost:1337/chat/)
|
||||
- Or set the API base for your client to: [http://localhost:1337/v1](http://localhost:1337/v1)
|
||||
|
||||
4. **(Optional) Provider Login:**
|
||||
5. **(Optional) Provider Login:**
|
||||
If required, you can access the container's desktop here: http://localhost:7900/?autoconnect=1&resize=scale&password=secret for provider login purposes.
|
||||
|
||||
#### Installation Guide for Windows (.exe)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from .client import *
|
|||
from .image_client import *
|
||||
from .include import *
|
||||
from .retry_provider import *
|
||||
from .web_search import *
|
||||
from .models import *
|
||||
|
||||
unittest.main()
|
||||
|
|
@ -5,40 +5,45 @@ from g4f.errors import MissingAuthError
|
|||
class ProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
model, messages, stream, **kwargs
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
||||
|
||||
class AsyncProviderMock(AsyncProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
model, messages, **kwargs
|
||||
cls, model, messages, **kwargs
|
||||
):
|
||||
return "Mock"
|
||||
|
||||
class AsyncGeneratorProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
model, messages, stream, **kwargs
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
||||
|
||||
class ModelProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
model, messages, stream, **kwargs
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
yield model
|
||||
|
||||
class YieldProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
model, messages, stream, **kwargs
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
for message in messages:
|
||||
yield message["content"]
|
||||
|
|
@ -84,8 +89,9 @@ class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider):
|
|||
|
||||
class YieldNoneProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
model, messages, stream, **kwargs
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
yield None
|
||||
89
etc/unittest/web_search.py
Normal file
89
etc/unittest/web_search.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
from duckduckgo_search.exceptions import DuckDuckGoSearchException
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
|
||||
from g4f.client import AsyncClient
|
||||
from .mocks import YieldProviderMock
|
||||
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self) -> None:
|
||||
if not has_requirements:
|
||||
self.skipTest('web search requirements not passed')
|
||||
|
||||
async def test_search(self):
|
||||
client = AsyncClient(provider=YieldProviderMock)
|
||||
tool_calls = [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"query": "search query", # content of last message: messages[-1]["content"]
|
||||
"max_results": 5, # maximum number of search results
|
||||
"max_words": 500, # maximum number of used words from search results for generating the response
|
||||
"backend": "html", # or "lite", "api": change it to pypass rate limits
|
||||
"add_text": True, # do scraping websites
|
||||
"timeout": 5, # in seconds for scraping websites
|
||||
"region": "wt-wt",
|
||||
"instructions": "Using the provided web search results, to write a comprehensive reply to the user request.\n"
|
||||
"Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)",
|
||||
},
|
||||
"name": "search_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
try:
|
||||
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
|
||||
self.assertIn("Using the provided web search results", response.choices[0].message.content)
|
||||
except DuckDuckGoSearchException as e:
|
||||
self.skipTest(f'DuckDuckGoSearchException: {e}')
|
||||
|
||||
async def test_search2(self):
|
||||
client = AsyncClient(provider=YieldProviderMock)
|
||||
tool_calls = [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"query": "search query",
|
||||
},
|
||||
"name": "search_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
try:
|
||||
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
|
||||
self.assertIn("Using the provided web search results", response.choices[0].message.content)
|
||||
except DuckDuckGoSearchException as e:
|
||||
self.skipTest(f'DuckDuckGoSearchException: {e}')
|
||||
|
||||
async def test_search3(self):
|
||||
client = AsyncClient(provider=YieldProviderMock)
|
||||
tool_calls = [
|
||||
{
|
||||
"function": {
|
||||
"arguments": json.dumps({
|
||||
"query": "search query", # content of last message: messages[-1]["content"]
|
||||
"max_results": 5, # maximum number of search results
|
||||
"max_words": 500, # maximum number of used words from search results for generating the response
|
||||
}),
|
||||
"name": "search_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
try:
|
||||
response = await client.chat.completions.create([{"content": "", "role": "user"}], "", tool_calls=tool_calls)
|
||||
self.assertIn("Using the provided web search results", response.choices[0].message.content)
|
||||
except DuckDuckGoSearchException as e:
|
||||
self.skipTest(f'DuckDuckGoSearchException: {e}')
|
||||
|
|
@ -18,18 +18,19 @@ try:
|
|||
except ImportError:
|
||||
has_nodriver = False
|
||||
|
||||
from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation
|
||||
from .base_provider import AbstractProvider, ProviderModelMixin
|
||||
from .helper import format_prompt_max_length
|
||||
from .openai.har_file import get_headers, get_har_files
|
||||
from ..typing import CreateResult, Messages, ImagesType
|
||||
from ..errors import MissingRequirementsError, NoValidHarFileError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..providers.response import JsonConversation, RequestLogin
|
||||
from ..providers.asyncio import get_running_loop
|
||||
from .openai.har_file import get_headers, get_har_files
|
||||
from ..requests import get_nodriver
|
||||
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||
from .. import debug
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
class Conversation(JsonConversation):
|
||||
conversation_id: str
|
||||
|
||||
def __init__(self, conversation_id: str):
|
||||
|
|
@ -80,7 +81,7 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
if has_nodriver:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"[Login to {cls.label}]({login_url})\n\n"
|
||||
yield RequestLogin(cls.label, login_url)
|
||||
get_running_loop(check_nested=True)
|
||||
cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import ClientSession
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_prompt
|
||||
from .needs_auth.OpenaiAPI import OpenaiAPI
|
||||
|
||||
"""
|
||||
Mhystical.cc
|
||||
|
|
@ -19,39 +14,31 @@ from .helper import format_prompt
|
|||
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Mhystical(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
class Mhystical(OpenaiAPI):
|
||||
url = "https://api.mhystical.cc"
|
||||
api_endpoint = "https://api.mhystical.cc/v1/completions"
|
||||
working = True
|
||||
needs_auth = False
|
||||
supports_stream = False # Set to False, as streaming is not specified in ChatifyAI
|
||||
supports_system_message = False
|
||||
supports_message_history = True
|
||||
|
||||
default_model = 'gpt-4'
|
||||
models = [default_model]
|
||||
model_aliases = {}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
if model in cls.models:
|
||||
return model
|
||||
elif model in cls.model_aliases:
|
||||
return cls.model_aliases.get(model, cls.default_model)
|
||||
else:
|
||||
return cls.default_model
|
||||
def get_model(cls, model: str, **kwargs) -> str:
|
||||
cls.last_model = cls.default_model
|
||||
return cls.default_model
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
|
||||
headers = {
|
||||
"x-api-key": "mhystical",
|
||||
"Content-Type": "application/json",
|
||||
|
|
@ -61,24 +48,11 @@ class Mhystical(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"referer": f"{cls.url}/",
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": format_prompt(messages)}]
|
||||
}
|
||||
async with session.post(cls.api_endpoint, json=data, headers=headers, proxy=proxy) as response:
|
||||
await raise_for_status(response)
|
||||
response_text = await response.text()
|
||||
filtered_response = cls.filter_response(response_text)
|
||||
yield filtered_response
|
||||
|
||||
@staticmethod
|
||||
def filter_response(response_text: str) -> str:
|
||||
try:
|
||||
json_response = json.loads(response_text)
|
||||
message_content = json_response["choices"][0]["message"]["content"]
|
||||
return message_content
|
||||
except (KeyError, IndexError, json.JSONDecodeError) as e:
|
||||
logger.error("Error parsing response: %s", e)
|
||||
return "Error: Failed to parse response from API."
|
||||
return super().create_async_generator(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=cls.supports_stream,
|
||||
api_endpoint=cls.api_endpoint,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -72,7 +72,7 @@ class PollinationsAI(OpenaiAPI):
|
|||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
if model in cls.image_models:
|
||||
if cls.get_models() and model in cls.image_models:
|
||||
async for response in cls._generate_image(model, messages, prompt, proxy, seed, width, height):
|
||||
yield response
|
||||
elif model in cls.models:
|
||||
|
|
|
|||
|
|
@ -1,93 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import RateLimitError
|
||||
from ...providers.response import BaseConversation
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
"""
|
||||
Represents a conversation with specific attributes.
|
||||
"""
|
||||
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
|
||||
"""
|
||||
Initialize a new conversation instance.
|
||||
|
||||
Args:
|
||||
conversationId (str): Unique identifier for the conversation.
|
||||
clientId (str): Client identifier.
|
||||
conversationSignature (str): Signature for the conversation.
|
||||
"""
|
||||
self.conversationId = conversationId
|
||||
self.clientId = clientId
|
||||
self.conversationSignature = conversationSignature
|
||||
|
||||
async def create_conversation(session: StreamSession, headers: dict, tone: str) -> Conversation:
|
||||
"""
|
||||
Create a new conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Conversation: An instance representing the created conversation.
|
||||
"""
|
||||
if tone == "Copilot":
|
||||
url = "https://copilot.microsoft.com/turing/conversation/create?bundleVersion=1.1809.0"
|
||||
else:
|
||||
url = "https://www.bing.com/turing/conversation/create?bundleVersion=1.1809.0"
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 404:
|
||||
raise RateLimitError("Response 404: Do less requests and reuse conversations")
|
||||
await raise_for_status(response, "Failed to create conversation")
|
||||
data = await response.json()
|
||||
if not data:
|
||||
raise RuntimeError('Empty response: Failed to create conversation')
|
||||
conversationId = data.get('conversationId')
|
||||
clientId = data.get('clientId')
|
||||
conversationSignature = response.headers.get('X-Sydney-Encryptedconversationsignature')
|
||||
if not conversationId or not clientId or not conversationSignature:
|
||||
raise RuntimeError('Empty fields: Failed to create conversation')
|
||||
return Conversation(conversationId, clientId, conversationSignature)
|
||||
|
||||
async def list_conversations(session: StreamSession) -> list:
|
||||
"""
|
||||
List all conversations asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
|
||||
Returns:
|
||||
list: A list of conversations.
|
||||
"""
|
||||
url = "https://www.bing.com/turing/conversation/chats"
|
||||
async with session.get(url) as response:
|
||||
response = await response.json()
|
||||
return response["chats"]
|
||||
|
||||
async def delete_conversation(session: StreamSession, conversation: Conversation, headers: dict) -> bool:
|
||||
"""
|
||||
Delete a conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
conversation (Conversation): The conversation to delete.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: True if deletion was successful, False otherwise.
|
||||
"""
|
||||
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
|
||||
json = {
|
||||
"conversationId": conversation.conversationId,
|
||||
"conversationSignature": conversation.conversationSignature,
|
||||
"participant": {"id": conversation.clientId},
|
||||
"source": "cib",
|
||||
"optionsSets": ["autosave"]
|
||||
}
|
||||
try:
|
||||
async with session.post(url, json=json, headers=headers) as response:
|
||||
response = await response.json()
|
||||
return response["result"]["value"] == "Success"
|
||||
except:
|
||||
return False
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
"""
|
||||
Module to handle image uploading and processing for Bing AI integrations.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from aiohttp import ClientSession, FormData
|
||||
|
||||
from ...typing import ImageType, Tuple
|
||||
from ...image import to_image, process_image, to_base64_jpg, ImageRequest, Image
|
||||
from ...requests import raise_for_status
|
||||
|
||||
IMAGE_CONFIG = {
|
||||
"maxImagePixels": 360000,
|
||||
"imageCompressionRate": 0.7,
|
||||
"enableFaceBlurDebug": False,
|
||||
}
|
||||
|
||||
async def upload_image(
|
||||
session: ClientSession,
|
||||
image_data: ImageType,
|
||||
tone: str,
|
||||
headers: dict
|
||||
) -> ImageRequest:
|
||||
"""
|
||||
Uploads an image to Bing's AI service and returns the image response.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
image_data (bytes): The image data to be uploaded.
|
||||
tone (str): The tone of the conversation.
|
||||
proxy (str, optional): Proxy if any. Defaults to None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the image upload fails.
|
||||
|
||||
Returns:
|
||||
ImageRequest: The response from the image upload.
|
||||
"""
|
||||
image = to_image(image_data)
|
||||
new_width, new_height = calculate_new_dimensions(image)
|
||||
image = process_image(image, new_width, new_height)
|
||||
img_binary_data = to_base64_jpg(image, IMAGE_CONFIG['imageCompressionRate'])
|
||||
|
||||
data = build_image_upload_payload(img_binary_data, tone)
|
||||
|
||||
async with session.post("https://www.bing.com/images/kblob", data=data, headers=prepare_headers(headers)) as response:
|
||||
await raise_for_status(response, "Failed to upload image")
|
||||
return parse_image_response(await response.json())
|
||||
|
||||
def calculate_new_dimensions(image: Image) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculates the new dimensions for the image based on the maximum allowed pixels.
|
||||
|
||||
Args:
|
||||
image (Image): The PIL Image object.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The new width and height for the image.
|
||||
"""
|
||||
width, height = image.size
|
||||
max_image_pixels = IMAGE_CONFIG['maxImagePixels']
|
||||
if max_image_pixels / (width * height) < 1:
|
||||
scale_factor = math.sqrt(max_image_pixels / (width * height))
|
||||
return int(width * scale_factor), int(height * scale_factor)
|
||||
return width, height
|
||||
|
||||
def build_image_upload_payload(image_bin: str, tone: str) -> FormData:
|
||||
"""
|
||||
Builds the payload for image uploading.
|
||||
|
||||
Args:
|
||||
image_bin (str): Base64 encoded image binary data.
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The data and boundary for the payload.
|
||||
"""
|
||||
data = FormData()
|
||||
knowledge_request = json.dumps(build_knowledge_request(tone), ensure_ascii=False)
|
||||
data.add_field('knowledgeRequest', knowledge_request, content_type="application/json")
|
||||
data.add_field('imageBase64', image_bin)
|
||||
return data
|
||||
|
||||
def build_knowledge_request(tone: str) -> dict:
|
||||
"""
|
||||
Builds the knowledge request payload.
|
||||
|
||||
Args:
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
dict: The knowledge request payload.
|
||||
"""
|
||||
return {
|
||||
"imageInfo": {},
|
||||
"knowledgeRequest": {
|
||||
'invokedSkills': ["ImageById"],
|
||||
'subscriptionId': "Bing.Chat.Multimodal",
|
||||
'invokedSkillsRequestData': {
|
||||
'enableFaceBlur': True
|
||||
},
|
||||
'convoData': {
|
||||
'convoid': "",
|
||||
'convotone': tone
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def prepare_headers(headers: dict) -> dict:
|
||||
"""
|
||||
Prepares the headers for the image upload request.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
boundary (str): The boundary string for the multipart/form-data.
|
||||
|
||||
Returns:
|
||||
dict: The headers for the request.
|
||||
"""
|
||||
headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
|
||||
headers["Origin"] = 'https://www.bing.com'
|
||||
return headers
|
||||
|
||||
def parse_image_response(response: dict) -> ImageRequest:
|
||||
"""
|
||||
Parses the response from the image upload.
|
||||
|
||||
Args:
|
||||
response (dict): The response dictionary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If parsing the image info fails.
|
||||
|
||||
Returns:
|
||||
ImageRequest: The parsed image response.
|
||||
"""
|
||||
if not response.get('blobId'):
|
||||
raise RuntimeError("Failed to parse image info.")
|
||||
|
||||
result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
|
||||
result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
|
||||
|
||||
result['originalImageUrl'] = (
|
||||
f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
|
||||
if IMAGE_CONFIG["enableFaceBlurDebug"] else
|
||||
f"https://www.bing.com/images/blob?bcid={result['bcid']}"
|
||||
)
|
||||
return ImageRequest(result)
|
||||
|
|
@ -1,523 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
from urllib import parse
|
||||
from datetime import datetime, date
|
||||
|
||||
from ...typing import AsyncResult, Messages, ImageType, Cookies
|
||||
from ...image import ImageRequest
|
||||
from ...errors import ResponseError, ResponseStatusError, RateLimitError
|
||||
from ...requests import DEFAULT_HEADERS
|
||||
from ...requests.aiohttp import StreamSession
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_random_hex
|
||||
from ..bing.upload_image import upload_image
|
||||
from ..bing.conversation import Conversation, create_conversation, delete_conversation
|
||||
from ..needs_auth.BingCreateImages import BingCreateImages
|
||||
from ... import debug
|
||||
|
||||
class Tones:
|
||||
"""
|
||||
Defines the different tone options for the Bing provider.
|
||||
"""
|
||||
creative = "Creative"
|
||||
balanced = "Balanced"
|
||||
precise = "Precise"
|
||||
copilot = "Copilot"
|
||||
|
||||
class Bing(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"""
|
||||
Bing provider for generating responses using the Bing API.
|
||||
"""
|
||||
label = "Microsoft Copilot in Bing"
|
||||
url = "https://bing.com/chat"
|
||||
working = False
|
||||
supports_message_history = True
|
||||
default_model = "Balanced"
|
||||
default_vision_model = "gpt-4-vision"
|
||||
models = [getattr(Tones, key) for key in Tones.__dict__ if not key.startswith("__")]
|
||||
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 900,
|
||||
api_key: str = None,
|
||||
cookies: Cookies = None,
|
||||
tone: str = None,
|
||||
image: ImageType = None,
|
||||
web_search: bool = False,
|
||||
context: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
Creates an asynchronous generator for producing responses from Bing.
|
||||
|
||||
:param model: The model to use.
|
||||
:param messages: Messages to process.
|
||||
:param proxy: Proxy to use for requests.
|
||||
:param timeout: Timeout for requests.
|
||||
:param cookies: Cookies for the session.
|
||||
:param tone: The tone of the response.
|
||||
:param image: The image type to be used.
|
||||
:param web_search: Flag to enable or disable web search.
|
||||
:return: An asynchronous result object.
|
||||
"""
|
||||
prompt = messages[-1]["content"]
|
||||
if context is None:
|
||||
context = create_context(messages[:-1]) if len(messages) > 1 else None
|
||||
if tone is None:
|
||||
tone = tone if model.startswith("gpt-4") else model
|
||||
tone = cls.get_model("" if tone is None else tone)
|
||||
gpt4_turbo = True if model.startswith("gpt-4-turbo") else False
|
||||
|
||||
return stream_generate(
|
||||
prompt, tone, image, context, cookies, api_key,
|
||||
proxy, web_search, gpt4_turbo, timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def create_context(messages: Messages) -> str:
|
||||
"""
|
||||
Creates a context string from a list of messages.
|
||||
|
||||
:param messages: A list of message dictionaries.
|
||||
:return: A string representing the context created from the messages.
|
||||
"""
|
||||
return "".join(
|
||||
f"[{message['role']}]" + ("(#message)"
|
||||
if message['role'] != "system"
|
||||
else "(#additional_instructions)") + f"\n{message['content']}"
|
||||
for message in messages
|
||||
) + "\n\n"
|
||||
|
||||
def get_ip_address() -> str:
|
||||
return f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||
|
||||
def get_default_cookies():
|
||||
#muid = get_random_hex().upper()
|
||||
sid = get_random_hex().upper()
|
||||
guid = get_random_hex().upper()
|
||||
isodate = date.today().isoformat()
|
||||
timestamp = int(time.time())
|
||||
zdate = "0001-01-01T00:00:00.0000000"
|
||||
return {
|
||||
"_C_Auth": "",
|
||||
#"MUID": muid,
|
||||
#"MUIDB": muid,
|
||||
"_EDGE_S": f"F=1&SID={sid}",
|
||||
"_EDGE_V": "1",
|
||||
"SRCHD": "AF=hpcodx",
|
||||
"SRCHUID": f"V=2&GUID={guid}&dmnchg=1",
|
||||
"_RwBf": (
|
||||
f"r=0&ilt=1&ihpd=0&ispd=0&rc=3&rb=0&gb=0&rg=200&pc=0&mtu=0&rbb=0&g=0&cid="
|
||||
f"&clo=0&v=1&l={isodate}&lft={zdate}&aof=0&ard={zdate}"
|
||||
f"&rwdbt={zdate}&rwflt={zdate}&o=2&p=&c=&t=0&s={zdate}"
|
||||
f"&ts={isodate}&rwred=0&wls=&wlb="
|
||||
"&wle=&ccp=&cpt=&lka=0&lkt=0&aad=0&TH="
|
||||
),
|
||||
'_Rwho': f'u=d&ts={isodate}',
|
||||
"_SS": f"SID={sid}&R=3&RB=0&GB=0&RG=200&RP=0",
|
||||
"SRCHUSR": f"DOB={date.today().strftime('%Y%m%d')}&T={timestamp}",
|
||||
"SRCHHPGUSR": f"HV={int(time.time())}",
|
||||
"BCP": "AD=1&AL=1&SM=1",
|
||||
"ipv6": f"hit={timestamp}",
|
||||
'_C_ETH' : '1',
|
||||
}
|
||||
|
||||
async def create_headers(cookies: Cookies = None, api_key: str = None) -> dict:
|
||||
if cookies is None:
|
||||
# import nodriver as uc
|
||||
# browser = await uc.start(headless=False)
|
||||
# page = await browser.get(Defaults.home)
|
||||
# await asyncio.sleep(10)
|
||||
# cookies = {}
|
||||
# for c in await page.browser.cookies.get_all():
|
||||
# if c.domain.endswith(".bing.com"):
|
||||
# cookies[c.name] = c.value
|
||||
# user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
# await page.close()
|
||||
cookies = get_default_cookies()
|
||||
if api_key is not None:
|
||||
cookies["_U"] = api_key
|
||||
headers = Defaults.headers.copy()
|
||||
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
return headers
|
||||
|
||||
class Defaults:
|
||||
"""
|
||||
Default settings and configurations for the Bing provider.
|
||||
"""
|
||||
delimiter = "\x1e"
|
||||
|
||||
# List of allowed message types for Bing responses
|
||||
allowedMessageTypes = [
|
||||
"ActionRequest","Chat",
|
||||
"ConfirmationCard", "Context",
|
||||
"InternalSearchQuery", #"InternalSearchResult",
|
||||
#"Disengaged", "InternalLoaderMessage",
|
||||
"Progress", "RenderCardRequest",
|
||||
"RenderContentRequest", "AdsQuery",
|
||||
"SemanticSerp", "GenerateContentQuery",
|
||||
"SearchQuery", "GeneratedCode",
|
||||
"InternalTasksMessage"
|
||||
]
|
||||
|
||||
sliceIds = {
|
||||
"balanced": [
|
||||
"supllmnfe","archnewtf",
|
||||
"stpstream", "stpsig", "vnextvoicecf", "scmcbase", "cmcpupsalltf", "sydtransctrl",
|
||||
"thdnsrch", "220dcl1s0", "0215wcrwips0", "0305hrthrots0", "0130gpt4t",
|
||||
"bingfc", "0225unsticky1", "0228scss0",
|
||||
"defquerycf", "defcontrol", "3022tphpv"
|
||||
],
|
||||
"creative": [
|
||||
"bgstream", "fltltst2c",
|
||||
"stpstream", "stpsig", "vnextvoicecf", "cmcpupsalltf", "sydtransctrl",
|
||||
"0301techgnd", "220dcl1bt15", "0215wcrwip", "0305hrthrot", "0130gpt4t",
|
||||
"bingfccf", "0225unsticky1", "0228scss0",
|
||||
"3022tpvs0"
|
||||
],
|
||||
"precise": [
|
||||
"bgstream", "fltltst2c",
|
||||
"stpstream", "stpsig", "vnextvoicecf", "cmcpupsalltf", "sydtransctrl",
|
||||
"0301techgnd", "220dcl1bt15", "0215wcrwip", "0305hrthrot", "0130gpt4t",
|
||||
"bingfccf", "0225unsticky1", "0228scss0",
|
||||
"defquerycf", "3022tpvs0"
|
||||
],
|
||||
"copilot": []
|
||||
}
|
||||
|
||||
optionsSets = {
|
||||
"balanced": {
|
||||
"default": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg", "autosave",
|
||||
"iyxapbing", "iycapbing",
|
||||
"galileo", "saharagenconv5", "gldcl1p",
|
||||
"gpt4tmncnp"
|
||||
],
|
||||
"nosearch": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg", "autosave",
|
||||
"iyxapbing", "iycapbing",
|
||||
"galileo", "sunoupsell", "base64filter", "uprv4p1upd",
|
||||
"hourthrot", "noctprf", "gndlogcf", "nosearchall"
|
||||
]
|
||||
},
|
||||
"creative": {
|
||||
"default": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg",
|
||||
"iyxapbing", "iycapbing",
|
||||
"h3imaginative", "techinstgnd", "hourthrot", "clgalileo", "gencontentv3",
|
||||
"gpt4tmncnp"
|
||||
],
|
||||
"nosearch": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg", "autosave",
|
||||
"iyxapbing", "iycapbing",
|
||||
"h3imaginative", "sunoupsell", "base64filter", "uprv4p1upd",
|
||||
"hourthrot", "noctprf", "gndlogcf", "nosearchall",
|
||||
"clgalileo", "nocache", "up4rp14bstcst"
|
||||
]
|
||||
},
|
||||
"precise": {
|
||||
"default": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg",
|
||||
"iyxapbing", "iycapbing",
|
||||
"h3precise", "techinstgnd", "hourthrot", "techinstgnd", "hourthrot",
|
||||
"clgalileo", "gencontentv3"
|
||||
],
|
||||
"nosearch": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg", "autosave",
|
||||
"iyxapbing", "iycapbing",
|
||||
"h3precise", "sunoupsell", "base64filter", "uprv4p1upd",
|
||||
"hourthrot", "noctprf", "gndlogcf", "nosearchall",
|
||||
"clgalileo", "nocache", "up4rp14bstcst"
|
||||
]
|
||||
},
|
||||
"copilot": [
|
||||
"nlu_direct_response_filter", "deepleo",
|
||||
"disable_emoji_spoken_text", "responsible_ai_policy_235",
|
||||
"enablemm", "dv3sugg",
|
||||
"iyxapbing", "iycapbing",
|
||||
"h3precise", "clgalileo", "gencontentv3", "prjupy"
|
||||
],
|
||||
}
|
||||
|
||||
# Default location settings
|
||||
location = {
|
||||
"locale": "en-US", "market": "en-US", "region": "US",
|
||||
"location":"lat:34.0536909;long:-118.242766;re=1000m;",
|
||||
"locationHints": [{
|
||||
"country": "United States", "state": "California", "city": "Los Angeles",
|
||||
"timezoneoffset": 8, "countryConfidence": 8,
|
||||
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
|
||||
"RegionType": 2, "SourceType": 1
|
||||
}],
|
||||
}
|
||||
|
||||
# Default headers for requests
|
||||
home = "https://www.bing.com/chat?q=Microsoft+Copilot&FORM=hpcodx"
|
||||
headers = {
|
||||
**DEFAULT_HEADERS,
|
||||
"accept": "application/json",
|
||||
"referer": home,
|
||||
"x-ms-client-request-id": str(uuid.uuid4()),
|
||||
"x-ms-useragent": "azsdk-js-api-client-factory/1.0.0-beta.1 core-rest-pipeline/1.15.1 OS/Windows",
|
||||
}
|
||||
|
||||
def format_message(msg: dict) -> str:
|
||||
"""
|
||||
Formats a message dictionary into a JSON string with a delimiter.
|
||||
|
||||
:param msg: The message dictionary to format.
|
||||
:return: A formatted string representation of the message.
|
||||
"""
|
||||
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
|
||||
|
||||
def create_message(
|
||||
conversation: Conversation,
|
||||
prompt: str,
|
||||
tone: str,
|
||||
context: str = None,
|
||||
image_request: ImageRequest = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
new_conversation: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Creates a message for the Bing API with specified parameters.
|
||||
|
||||
:param conversation: The current conversation object.
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param image_request: The image request with the url.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:return: A formatted string message for the Bing API.
|
||||
"""
|
||||
|
||||
options_sets = Defaults.optionsSets[tone.lower()]
|
||||
if not web_search and "nosearch" in options_sets:
|
||||
options_sets = options_sets["nosearch"]
|
||||
elif "default" in options_sets:
|
||||
options_sets = options_sets["default"]
|
||||
options_sets = options_sets.copy()
|
||||
if gpt4_turbo:
|
||||
options_sets.append("dlgpt4t")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
struct = {
|
||||
"arguments":[{
|
||||
"source": "cib",
|
||||
"optionsSets": options_sets,
|
||||
"allowedMessageTypes": Defaults.allowedMessageTypes,
|
||||
"sliceIds": Defaults.sliceIds[tone.lower()],
|
||||
"verbosity": "verbose",
|
||||
"scenario": "CopilotMicrosoftCom" if tone == Tones.copilot else "SERP",
|
||||
"plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
|
||||
"traceId": get_random_hex(40),
|
||||
"conversationHistoryOptionsSets": ["autosave","savemem","uprofupd","uprofgen"],
|
||||
"gptId": "copilot",
|
||||
"isStartOfSession": new_conversation,
|
||||
"requestId": request_id,
|
||||
"message":{
|
||||
**Defaults.location,
|
||||
"userIpAddress": get_ip_address(),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"author": "user",
|
||||
"inputMethod": "Keyboard",
|
||||
"text": prompt,
|
||||
"messageType": "Chat",
|
||||
"requestId": request_id,
|
||||
"messageId": request_id
|
||||
},
|
||||
"tone": "Balanced" if tone == Tones.copilot else tone,
|
||||
"spokenTextMode": "None",
|
||||
"conversationId": conversation.conversationId,
|
||||
"participant": {"id": conversation.clientId}
|
||||
}],
|
||||
"invocationId": "0",
|
||||
"target": "chat",
|
||||
"type": 4
|
||||
}
|
||||
|
||||
if image_request and image_request.get('imageUrl') and image_request.get('originalImageUrl'):
|
||||
struct['arguments'][0]['message']['originalImageUrl'] = image_request.get('originalImageUrl')
|
||||
struct['arguments'][0]['message']['imageUrl'] = image_request.get('imageUrl')
|
||||
struct['arguments'][0]['experienceType'] = None
|
||||
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
|
||||
|
||||
if context:
|
||||
struct['arguments'][0]['previousMessages'] = [{
|
||||
"author": "user",
|
||||
"description": context,
|
||||
"contextType": "ClientApp",
|
||||
"messageType": "Context",
|
||||
"messageId": "discover-web--page-ping-mriduna-----"
|
||||
}]
|
||||
|
||||
return format_message(struct)
|
||||
|
||||
async def stream_generate(
|
||||
prompt: str,
|
||||
tone: str,
|
||||
image: ImageType = None,
|
||||
context: str = None,
|
||||
cookies: dict = None,
|
||||
api_key: str = None,
|
||||
proxy: str = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
timeout: int = 900,
|
||||
conversation: Conversation = None,
|
||||
return_conversation: bool = False,
|
||||
raise_apology: bool = False,
|
||||
max_retries: int = None,
|
||||
sleep_retry: int = 15,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Asynchronously streams generated responses from the Bing API.
|
||||
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param image: The image type involved in the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param cookies: Cookies for the session.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:param timeout: Timeout for the request.
|
||||
:return: An asynchronous generator yielding responses.
|
||||
"""
|
||||
headers = await create_headers(cookies, api_key)
|
||||
new_conversation = conversation is None
|
||||
max_retries = (5 if new_conversation else 0) if max_retries is None else max_retries
|
||||
first = True
|
||||
while first or conversation is None:
|
||||
async with StreamSession(timeout=timeout, proxy=proxy) as session:
|
||||
first = False
|
||||
do_read = True
|
||||
try:
|
||||
if conversation is None:
|
||||
conversation = await create_conversation(session, headers, tone)
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
except (ResponseStatusError, RateLimitError) as e:
|
||||
max_retries -= 1
|
||||
if max_retries < 1:
|
||||
raise e
|
||||
if debug.logging:
|
||||
print(f"Bing: Retry: {e}")
|
||||
headers = await create_headers()
|
||||
await asyncio.sleep(sleep_retry)
|
||||
continue
|
||||
|
||||
image_request = await upload_image(
|
||||
session,
|
||||
image,
|
||||
"Balanced" if tone == Tones.copilot else tone,
|
||||
headers
|
||||
) if image else None
|
||||
async with session.ws_connect(
|
||||
'wss://s.copilot.microsoft.com/sydney/ChatHub'
|
||||
if tone == "Copilot" else
|
||||
'wss://sydney.bing.com/sydney/ChatHub',
|
||||
autoping=False,
|
||||
params={'sec_access_token': conversation.conversationSignature},
|
||||
headers=headers
|
||||
) as wss:
|
||||
await wss.send_str(format_message({'protocol': 'json', 'version': 1}))
|
||||
await wss.send_str(format_message({"type": 6}))
|
||||
await wss.receive_str()
|
||||
await wss.send_str(create_message(
|
||||
conversation, prompt, tone,
|
||||
context if new_conversation else None,
|
||||
image_request, web_search, gpt4_turbo,
|
||||
new_conversation
|
||||
))
|
||||
response_txt = ''
|
||||
returned_text = ''
|
||||
message_id = None
|
||||
while do_read:
|
||||
try:
|
||||
msg = await wss.receive_str()
|
||||
except TypeError:
|
||||
continue
|
||||
objects = msg.split(Defaults.delimiter)
|
||||
for obj in objects:
|
||||
if not obj:
|
||||
continue
|
||||
try:
|
||||
response = json.loads(obj)
|
||||
except ValueError:
|
||||
continue
|
||||
if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
|
||||
message = response['arguments'][0]['messages'][0]
|
||||
if message_id is not None and message_id != message["messageId"]:
|
||||
returned_text = ''
|
||||
message_id = message["messageId"]
|
||||
image_response = None
|
||||
if (raise_apology and message['contentOrigin'] == 'Apology'):
|
||||
raise ResponseError("Apology Response Error")
|
||||
if 'adaptiveCards' in message:
|
||||
card = message['adaptiveCards'][0]['body'][0]
|
||||
if "text" in card:
|
||||
response_txt = card.get('text')
|
||||
if message.get('messageType') and "inlines" in card:
|
||||
inline_txt = card['inlines'][0].get('text')
|
||||
response_txt += f"{inline_txt}\n"
|
||||
elif message.get('contentType') == "IMAGE":
|
||||
prompt = message.get('text')
|
||||
try:
|
||||
image_client = BingCreateImages(cookies, proxy, api_key)
|
||||
image_response = await image_client.create_async(prompt)
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Bing: Failed to create images: {e}")
|
||||
image_response = f"\nhttps://www.bing.com/images/create?q={parse.quote(prompt)}"
|
||||
if response_txt.startswith(returned_text):
|
||||
new = response_txt[len(returned_text):]
|
||||
if new not in ("", "\n"):
|
||||
yield new
|
||||
returned_text = response_txt
|
||||
if image_response is not None:
|
||||
yield image_response
|
||||
elif response.get('type') == 2:
|
||||
result = response['item']['result']
|
||||
do_read = False
|
||||
if result.get('error'):
|
||||
max_retries -= 1
|
||||
if max_retries < 1:
|
||||
if result["value"] == "CaptchaChallenge":
|
||||
raise RateLimitError(f"{result['value']}: Use other cookies or/and ip address")
|
||||
else:
|
||||
raise RuntimeError(f"{result['value']}: {result['message']}")
|
||||
if debug.logging:
|
||||
print(f"Bing: Retry: {result['value']}: {result['message']}")
|
||||
headers = await create_headers()
|
||||
conversation = None
|
||||
await asyncio.sleep(sleep_retry)
|
||||
break
|
||||
elif response.get('type') == 3:
|
||||
do_read = False
|
||||
break
|
||||
if conversation is not None:
|
||||
await delete_conversation(session, conversation, headers)
|
||||
|
|
@ -31,5 +31,4 @@ from .GeekGpt import GeekGpt
|
|||
from .GPTalk import GPTalk
|
||||
from .Hashnode import Hashnode
|
||||
from .Ylokh import Ylokh
|
||||
from .OpenAssistant import OpenAssistant
|
||||
from .Bing import Bing
|
||||
from .OpenAssistant import OpenAssistant
|
||||
|
|
@ -17,8 +17,9 @@ except ImportError:
|
|||
|
||||
from ... import debug
|
||||
from ...typing import Messages, Cookies, ImagesType, AsyncResult, AsyncIterator
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation, SynthesizeData
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, get_cookies
|
||||
from ...providers.response import JsonConversation, SynthesizeData, RequestLogin
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
from ...requests import get_nodriver
|
||||
|
|
@ -81,7 +82,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"[Login to {cls.label}]({login_url})\n\n"
|
||||
yield RequestLogin(cls.label, login_url)
|
||||
page = await browser.get(f"{cls.url}/app")
|
||||
await page.select("div.ql-editor.textarea", 240)
|
||||
cookies = {}
|
||||
|
|
@ -305,37 +306,37 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if sid_match:
|
||||
cls._sid = sid_match.group(1)
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
class Conversation(JsonConversation):
|
||||
def __init__(self,
|
||||
conversation_id: str = "",
|
||||
response_id: str = "",
|
||||
choice_id: str = ""
|
||||
conversation_id: str,
|
||||
response_id: str,
|
||||
choice_id: str
|
||||
) -> None:
|
||||
self.conversation_id = conversation_id
|
||||
self.response_id = response_id
|
||||
self.choice_id = choice_id
|
||||
|
||||
async def iter_filter_base64(response_iter: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
|
||||
async def iter_filter_base64(chunks: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
|
||||
search_for = b'[["wrb.fr","XqA3Ic","[\\"'
|
||||
end_with = b'\\'
|
||||
is_started = False
|
||||
async for chunk in response_iter:
|
||||
async for chunk in chunks:
|
||||
if is_started:
|
||||
if end_with in chunk:
|
||||
yield chunk.split(end_with, 1).pop(0)
|
||||
yield chunk.split(end_with, 1, maxsplit=1).pop(0)
|
||||
break
|
||||
else:
|
||||
yield chunk
|
||||
elif search_for in chunk:
|
||||
is_started = True
|
||||
yield chunk.split(search_for, 1).pop()
|
||||
yield chunk.split(search_for, 1, maxsplit=1).pop()
|
||||
else:
|
||||
raise RuntimeError(f"Response: {chunk}")
|
||||
raise ValueError(f"Response: {chunk}")
|
||||
|
||||
async def iter_base64_decode(response_iter: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
|
||||
async def iter_base64_decode(chunks: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
|
||||
buffer = b""
|
||||
rest = 0
|
||||
async for chunk in response_iter:
|
||||
async for chunk in chunks:
|
||||
chunk = buffer + chunk
|
||||
rest = len(chunk) % 4
|
||||
buffer = chunk[-rest:]
|
||||
|
|
|
|||
|
|
@ -8,19 +8,19 @@ try:
|
|||
except ImportError:
|
||||
has_curl_cffi = False
|
||||
|
||||
from ..base_provider import ProviderModelMixin, AbstractProvider
|
||||
from ..helper import format_prompt
|
||||
from ...typing import CreateResult, Messages, Cookies
|
||||
from ...errors import MissingRequirementsError
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import JsonConversation, ImageResponse, Sources
|
||||
from ...cookies import get_cookies
|
||||
from ...image import ImageResponse
|
||||
from ..base_provider import ProviderModelMixin, AbstractProvider, BaseConversation
|
||||
from ..helper import format_prompt
|
||||
from ... import debug
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
class Conversation(JsonConversation):
|
||||
def __init__(self, conversation_id: str, message_id: str):
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self.conversation_id: str = conversation_id
|
||||
self.message_id: str = message_id
|
||||
|
||||
class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
url = "https://huggingface.co/chat"
|
||||
|
|
@ -152,33 +152,35 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
|||
raise_for_status(response)
|
||||
|
||||
full_response = ""
|
||||
sources = None
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
line = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to decode JSON: {line}, error: {e}")
|
||||
debug.log(f"Failed to decode JSON: {line}, error: {e}")
|
||||
continue
|
||||
|
||||
if "type" not in line:
|
||||
raise RuntimeError(f"Response: {line}")
|
||||
|
||||
elif line["type"] == "stream":
|
||||
token = line["token"].replace('\u0000', '')
|
||||
full_response += token
|
||||
if stream:
|
||||
yield token
|
||||
|
||||
elif line["type"] == "finalAnswer":
|
||||
break
|
||||
elif line["type"] == "file":
|
||||
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
|
||||
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
|
||||
elif line["type"] == "webSearch" and "sources" in line:
|
||||
sources = Sources(line["sources"])
|
||||
|
||||
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
|
||||
full_response = full_response.replace('<|im_end|', '').strip()
|
||||
if not stream:
|
||||
yield full_response
|
||||
if sources is not None:
|
||||
yield sources
|
||||
|
||||
@classmethod
|
||||
def create_conversation(cls, session: Session, model: str):
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
max_new_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
prompt: str = None,
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
try:
|
||||
|
|
@ -73,16 +74,16 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
payload = None
|
||||
if model in cls.image_models:
|
||||
if cls.get_models() and model in cls.image_models:
|
||||
stream = False
|
||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
||||
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32)}}
|
||||
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
|
||||
else:
|
||||
params = {
|
||||
"return_full_text": False,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
**kwargs
|
||||
**extra_data
|
||||
}
|
||||
async with StreamSession(
|
||||
headers=headers,
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ import json
|
|||
import requests
|
||||
|
||||
from ..helper import filter_none
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...image import to_data_uri
|
||||
from ... import debug
|
||||
|
|
@ -50,6 +51,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
timeout: int = 120,
|
||||
images: ImagesType = None,
|
||||
api_key: str = None,
|
||||
api_endpoint: str = None,
|
||||
api_base: str = None,
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
|
|
@ -58,6 +60,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
tools: Optional[list] = None,
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
|
@ -92,16 +95,23 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
top_p=top_p,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**extra_data
|
||||
)
|
||||
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
|
||||
if api_endpoint is None:
|
||||
api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
|
||||
async with session.post(api_endpoint, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
if not stream:
|
||||
data = await response.json()
|
||||
cls.raise_error(data)
|
||||
choice = data["choices"][0]
|
||||
if "content" in choice["message"]:
|
||||
if "content" in choice["message"] and choice["message"]["content"]:
|
||||
yield choice["message"]["content"].strip()
|
||||
elif "tool_calls" in choice["message"]:
|
||||
yield ToolCalls(choice["message"]["tool_calls"])
|
||||
if "usage" in data:
|
||||
yield Usage(**data["usage"])
|
||||
finish = cls.read_finish_reason(choice)
|
||||
if finish is not None:
|
||||
yield finish
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ class OpenaiAccount(OpenaiChat):
|
|||
image_models = ["dall-e-3", "gpt-4", "gpt-4o"]
|
||||
default_vision_model = "gpt-4o"
|
||||
default_image_model = "dall-e-3"
|
||||
models = [*OpenaiChat.fallback_models, default_image_model]
|
||||
fallback_models = [*OpenaiChat.fallback_models, default_image_model]
|
||||
model_aliases = {default_image_model: default_vision_model}
|
||||
|
|
@ -9,24 +9,23 @@ import base64
|
|||
import time
|
||||
import requests
|
||||
import random
|
||||
from typing import AsyncIterator
|
||||
from typing import AsyncIterator, Iterator, Optional, Generator, Dict, List
|
||||
from copy import copy
|
||||
|
||||
try:
|
||||
import nodriver
|
||||
from nodriver.cdp.network import get_response_body
|
||||
has_nodriver = True
|
||||
except ImportError:
|
||||
has_nodriver = False
|
||||
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType, AsyncIterator
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImagesType
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests import StreamSession
|
||||
from ...requests import StreamSession, Session
|
||||
from ...requests import get_nodriver
|
||||
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, NoValidHarFileError
|
||||
from ...providers.response import BaseConversation, FinishReason, SynthesizeData
|
||||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, Sources, TitleGeneration, RequestLogin, quote_url
|
||||
from ..helper import format_cookies
|
||||
from ..openai.har_file import get_request_config
|
||||
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
|
||||
|
|
@ -106,15 +105,30 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
_expires: int = None
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
def get_models(cls, proxy: str = None, timeout: int = 180) -> List[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
response = requests.get(f"{cls.url}/backend-anon/models")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
cls.models = [model.get("slug") for model in data.get("models")]
|
||||
except Exception:
|
||||
cls.models = cls.fallback_models
|
||||
# try:
|
||||
# headers = {
|
||||
# **(cls.get_default_headers() if cls._headers is None else cls._headers),
|
||||
# "accept": "application/json",
|
||||
# }
|
||||
# with Session(
|
||||
# proxy=proxy,
|
||||
# impersonate="chrome",
|
||||
# timeout=timeout,
|
||||
# headers=headers
|
||||
# ) as session:
|
||||
# response = session.get(
|
||||
# f"{cls.url}/backend-anon/models"
|
||||
# if cls._api_key is None else
|
||||
# f"{cls.url}/backend-api/models"
|
||||
# )
|
||||
# raise_for_status(response)
|
||||
# data = response.json()
|
||||
# cls.models = [model.get("slug") for model in data.get("models")]
|
||||
# except Exception as e:
|
||||
# debug.log(f"OpenaiChat: Failed to get models: {type(e).__name__}: {e}")
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
|
|
@ -199,13 +213,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"""
|
||||
# Create a message object with the user role and the content
|
||||
messages = [{
|
||||
"id": str(uuid.uuid4()),
|
||||
"author": {"role": message["role"]},
|
||||
"content": {"content_type": "text", "parts": [message["content"]]},
|
||||
"id": str(uuid.uuid4()),
|
||||
"create_time": int(time.time()),
|
||||
"metadata": {"serialization_metadata": {"custom_symbol_offsets": []}, "system_hints": system_hints},
|
||||
"metadata": {"serialization_metadata": {"custom_symbol_offsets": []}, **({"system_hints": system_hints} if system_hints else {})},
|
||||
"create_time": time.time(),
|
||||
} for message in messages]
|
||||
|
||||
# Check if there is an image response
|
||||
if image_requests:
|
||||
# Change content in last user message
|
||||
|
|
@ -236,24 +249,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
|
||||
@classmethod
|
||||
async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse:
|
||||
"""
|
||||
Retrieves the image response based on the message content.
|
||||
|
||||
This method processes the message content to extract image information and retrieves the
|
||||
corresponding image from the backend API. It then returns an ImageResponse object containing
|
||||
the image URL and the prompt used to generate the image.
|
||||
|
||||
Args:
|
||||
session (StreamSession): The StreamSession object used for making HTTP requests.
|
||||
headers (dict): HTTP headers to be used for the request.
|
||||
line (dict): A dictionary representing the line of response that contains image information.
|
||||
|
||||
Returns:
|
||||
ImageResponse: An object containing the image URL and the prompt, or None if no image is found.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there'san error in downloading the image, including issues with the HTTP request or response.
|
||||
"""
|
||||
try:
|
||||
prompt = element["metadata"]["dalle"]["prompt"]
|
||||
file_id = element["asset_pointer"].split("file-service://", 1)[1]
|
||||
|
|
@ -347,6 +342,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if cls._api_key is None:
|
||||
auto_continue = False
|
||||
conversation.finish_reason = None
|
||||
sources = Sources([])
|
||||
while conversation.finish_reason is None:
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-anon/sentinel/chat-requirements"
|
||||
|
|
@ -387,11 +383,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
)]
|
||||
data = {
|
||||
"action": action,
|
||||
"messages": None,
|
||||
"parent_message_id": conversation.message_id,
|
||||
"model": model,
|
||||
"timezone_offset_min":-60,
|
||||
"timezone":"Europe/Berlin",
|
||||
"suggestions":[],
|
||||
"history_and_training_disabled": history_disabled and not auto_continue and not return_conversation or not cls.needs_auth,
|
||||
"conversation_mode":{"kind":"primary_assistant","plugin_ids":None},
|
||||
"force_paragen":False,
|
||||
|
|
@ -433,17 +429,40 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
headers=headers
|
||||
) as response:
|
||||
cls._update_request_args(session)
|
||||
if response.status == 403 and max_retries > 0:
|
||||
if response.status in (403, 404) and max_retries > 0:
|
||||
max_retries -= 1
|
||||
debug.log(f"Retry: Error {response.status}: {await response.text()}")
|
||||
conversation.conversation_id = None
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
await raise_for_status(response)
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
buffer = u""
|
||||
async for line in response.iter_lines():
|
||||
async for chunk in cls.iter_messages_line(session, line, conversation):
|
||||
yield chunk
|
||||
async for chunk in cls.iter_messages_line(session, line, conversation, sources):
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.replace("\ue203", "").replace("\ue204", "").replace("\ue206", "")
|
||||
buffer += chunk
|
||||
if buffer.find(u"\ue200") != -1:
|
||||
if buffer.find(u"\ue201") != -1:
|
||||
buffer = buffer.replace("\ue200", "").replace("\ue202", "\n").replace("\ue201", "")
|
||||
buffer = buffer.replace("navlist\n", "#### ")
|
||||
def replacer(match):
|
||||
link = None
|
||||
if len(sources.list) > int(match.group(1)):
|
||||
link = sources.list[int(match.group(1))]["url"]
|
||||
return f"[[{int(match.group(1))+1}]]({link})"
|
||||
return f" [{int(match.group(1))+1}]"
|
||||
buffer = re.sub(r'(?:cite\nturn0search|cite\nturn0news|turn0news)(\d+)', replacer, buffer)
|
||||
else:
|
||||
continue
|
||||
yield buffer
|
||||
buffer = ""
|
||||
else:
|
||||
yield chunk
|
||||
if sources.list:
|
||||
yield sources
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
if not history_disabled and cls._api_key is not None:
|
||||
yield SynthesizeData(cls.__name__, {
|
||||
"conversation_id": conversation.conversation_id,
|
||||
|
|
@ -459,7 +478,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
yield FinishReason(conversation.finish_reason)
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator:
|
||||
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation, sources: Sources) -> AsyncIterator:
|
||||
if not line.startswith(b"data: "):
|
||||
return
|
||||
elif line.startswith(b"data: [DONE]"):
|
||||
|
|
@ -470,15 +489,26 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
line = json.loads(line[6:])
|
||||
except:
|
||||
return
|
||||
if isinstance(line, dict) and "v" in line:
|
||||
if not isinstance(line, dict):
|
||||
return
|
||||
if "type" in line:
|
||||
if line["type"] == "title_generation":
|
||||
yield TitleGeneration(line["title"])
|
||||
if "v" in line:
|
||||
v = line.get("v")
|
||||
if isinstance(v, str) and fields.is_recipient:
|
||||
if "p" not in line or line.get("p") == "/message/content/parts/0":
|
||||
yield v
|
||||
elif isinstance(v, list) and fields.is_recipient:
|
||||
elif isinstance(v, list):
|
||||
for m in v:
|
||||
if m.get("p") == "/message/content/parts/0":
|
||||
if m.get("p") == "/message/content/parts/0" and fields.is_recipient:
|
||||
yield m.get("v")
|
||||
elif m.get("p") == "/message/metadata/search_result_groups":
|
||||
for entry in [p.get("entries") for p in m.get("v")]:
|
||||
for link in entry:
|
||||
sources.add_source(link)
|
||||
elif re.match(r"^/message/metadata/content_references/\d+$", m.get("p")):
|
||||
sources.add_source(m.get("v"))
|
||||
elif m.get("p") == "/message/metadata":
|
||||
fields.finish_reason = m.get("v", {}).get("finish_details", {}).get("type")
|
||||
break
|
||||
|
|
@ -529,14 +559,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
try:
|
||||
await get_request_config(proxy)
|
||||
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
|
||||
if RequestConfig.access_token is not None:
|
||||
cls._set_api_key(RequestConfig.access_token)
|
||||
if RequestConfig.access_token is not None or cls.needs_auth:
|
||||
if not cls._set_api_key(RequestConfig.access_token):
|
||||
raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}")
|
||||
except NoValidHarFileError:
|
||||
if has_nodriver:
|
||||
if cls._api_key is None:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"[Login to {cls.label}]({login_url})\n\n"
|
||||
yield RequestLogin(cls.label, login_url)
|
||||
await cls.nodriver_auth(proxy)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -563,7 +594,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
arkBx=None,
|
||||
arkHeader=event.request.headers,
|
||||
arkBody=event.request.post_data,
|
||||
userAgent=event.request.headers.get("user-agent")
|
||||
userAgent=event.request.headers.get("User-Agent")
|
||||
)
|
||||
await page.send(nodriver.cdp.network.enable())
|
||||
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
|
||||
|
|
@ -585,14 +616,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
break
|
||||
await asyncio.sleep(1)
|
||||
RequestConfig.data_build = await page.evaluate("document.documentElement.getAttribute('data-build')")
|
||||
for c in await page.send(get_cookies([cls.url])):
|
||||
RequestConfig.cookies[c["name"]] = c["value"]
|
||||
RequestConfig.cookies = await page.send(get_cookies([cls.url]))
|
||||
await page.close()
|
||||
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
|
||||
cls._set_api_key(cls._api_key)
|
||||
|
||||
@staticmethod
|
||||
def get_default_headers() -> dict:
|
||||
def get_default_headers() -> Dict[str, str]:
|
||||
return {
|
||||
**DEFAULT_HEADERS,
|
||||
"content-type": "application/json",
|
||||
|
|
@ -609,22 +639,30 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
@classmethod
|
||||
def _update_request_args(cls, session: StreamSession):
|
||||
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
|
||||
cls._cookies[c.key if hasattr(c, "key") else c.name] = c.value
|
||||
cls._cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
|
||||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
def _set_api_key(cls, api_key: str):
|
||||
cls._api_key = api_key
|
||||
cls._expires = int(time.time()) + 60 * 60 * 4
|
||||
if api_key:
|
||||
cls._headers["authorization"] = f"Bearer {api_key}"
|
||||
exp = api_key.split(".")[1]
|
||||
exp = (exp + "=" * (4 - len(exp) % 4)).encode()
|
||||
cls._expires = json.loads(base64.b64decode(exp)).get("exp")
|
||||
debug.log(f"OpenaiChat: API key expires at\n {cls._expires} we have:\n {time.time()}")
|
||||
if time.time() > cls._expires:
|
||||
debug.log(f"OpenaiChat: API key is expired")
|
||||
else:
|
||||
cls._api_key = api_key
|
||||
cls._headers["authorization"] = f"Bearer {api_key}"
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _update_cookie_header(cls):
|
||||
if cls._cookies:
|
||||
cls._headers["cookie"] = format_cookies(cls._cookies)
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
class Conversation(JsonConversation):
|
||||
"""
|
||||
Class to encapsulate response fields.
|
||||
"""
|
||||
|
|
@ -633,10 +671,10 @@ class Conversation(BaseConversation):
|
|||
self.message_id = message_id
|
||||
self.finish_reason = finish_reason
|
||||
self.is_recipient = False
|
||||
|
||||
|
||||
def get_cookies(
|
||||
urls: list[str] = None
|
||||
):
|
||||
urls: Optional[Iterator[str]] = None
|
||||
) -> Generator[Dict, Dict, Dict[str, str]]:
|
||||
params = {}
|
||||
if urls is not None:
|
||||
params['urls'] = [i for i in urls]
|
||||
|
|
@ -645,4 +683,4 @@ def get_cookies(
|
|||
'params': params,
|
||||
}
|
||||
json = yield cmd_dict
|
||||
return json['cookies']
|
||||
return {c["name"]: c["value"] for c in json['cookies']} if 'cookies' in json else {}
|
||||
|
|
@ -11,6 +11,7 @@ from .typing import Messages, CreateResult, AsyncResult, ImageType
|
|||
from .errors import StreamNotSupportedError
|
||||
from .cookies import get_cookies, set_cookies
|
||||
from .providers.types import ProviderType
|
||||
from .providers.helper import concat_chunks
|
||||
from .client.service import get_model_and_provider, get_last_provider
|
||||
|
||||
#Configure "g4f" logger
|
||||
|
|
@ -48,7 +49,7 @@ class ChatCompletion:
|
|||
|
||||
result = provider.create_completion(model, messages, stream=stream, **kwargs)
|
||||
|
||||
return result if stream else ''.join([str(chunk) for chunk in result if chunk])
|
||||
return result if stream else concat_chunks(result)
|
||||
|
||||
@staticmethod
|
||||
def create_async(model : Union[Model, str],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import uvicorn
|
|||
import secrets
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from email.utils import formatdate
|
||||
import os.path
|
||||
from fastapi import FastAPI, Response, Request, UploadFile, Depends
|
||||
from fastapi.middleware.wsgi import WSGIMiddleware
|
||||
|
|
@ -22,23 +22,21 @@ from starlette.status import (
|
|||
HTTP_403_FORBIDDEN,
|
||||
HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
from starlette.staticfiles import NotModifiedResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, HTTPBasic
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from starlette._compat import md5_hexdigest
|
||||
from types import SimpleNamespace
|
||||
from typing import Union, Optional, List
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
class Annotated:
|
||||
pass
|
||||
|
||||
import g4f
|
||||
import g4f.debug
|
||||
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
|
||||
from g4f.providers.response import BaseConversation
|
||||
from g4f.client.helper import filter_none
|
||||
from g4f.image import is_accepted_format, is_data_uri_an_image, images_dir
|
||||
from g4f.image import is_data_uri_an_image, images_dir
|
||||
from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError
|
||||
from g4f.cookies import read_cookie_files, get_cookies_dir
|
||||
from g4f.Provider import ProviderType, ProviderUtils, __providers__
|
||||
|
|
@ -47,7 +45,7 @@ from .stubs import (
|
|||
ChatCompletionsConfig, ImageGenerationConfig,
|
||||
ProviderResponseModel, ModelResponseModel,
|
||||
ErrorResponseModel, ProviderResponseDetailModel,
|
||||
FileResponseModel
|
||||
FileResponseModel, Annotated
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -445,16 +443,33 @@ class Api:
|
|||
HTTP_200_OK: {"content": {"image/*": {}}},
|
||||
HTTP_404_NOT_FOUND: {}
|
||||
})
|
||||
async def get_image(filename):
|
||||
async def get_image(filename, request: Request):
|
||||
target = os.path.join(images_dir, filename)
|
||||
|
||||
ext = os.path.splitext(filename).pop()
|
||||
stat_result = SimpleNamespace()
|
||||
stat_result.st_size = 0
|
||||
if os.path.isfile(target):
|
||||
stat_result.st_size = os.stat(target).st_size
|
||||
stat_result.st_mtime = int(f"{filename.split('_')[0]}")
|
||||
response = FileResponse(
|
||||
target,
|
||||
media_type=f"image/{ext.replace('jpg', 'jepg')}",
|
||||
headers={
|
||||
"content-length": str(stat_result.st_size),
|
||||
"last-modified": formatdate(stat_result.st_mtime, usegmt=True),
|
||||
"etag": f'"{md5_hexdigest(filename.encode(), usedforsecurity=False)}"'
|
||||
},
|
||||
)
|
||||
try:
|
||||
if_none_match = request.headers["if-none-match"]
|
||||
etag = response.headers["etag"]
|
||||
if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
|
||||
return NotModifiedResponse(response.headers)
|
||||
except KeyError:
|
||||
pass
|
||||
if not os.path.isfile(target):
|
||||
return Response(status_code=404)
|
||||
|
||||
with open(target, "rb") as f:
|
||||
content_type = is_accepted_format(f.read(12))
|
||||
|
||||
return FileResponse(target, media_type=content_type)
|
||||
return Response(status_code=HTTP_404_NOT_FOUND)
|
||||
return response
|
||||
|
||||
def format_exception(e: Union[Exception, str], config: Union[ChatCompletionsConfig, ImageGenerationConfig] = None, image: bool = False) -> str:
|
||||
last_provider = {} if not image else g4f.get_last_provider(True)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ from __future__ import annotations
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Union, Optional
|
||||
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
class Annotated:
|
||||
pass
|
||||
from g4f.typing import Messages
|
||||
|
||||
class ChatCompletionsConfig(BaseModel):
|
||||
|
|
@ -23,6 +27,16 @@ class ChatCompletionsConfig(BaseModel):
|
|||
history_disabled: Optional[bool] = None
|
||||
auto_continue: Optional[bool] = None
|
||||
timeout: Optional[int] = None
|
||||
tool_calls: list = Field(default=[], examples=[[
|
||||
{
|
||||
"function": {
|
||||
"arguments": {"query":"search query", "max_results":5, "max_words": 2500, "backend": "api", "add_text": True, "timeout": 5},
|
||||
"name": "search_tool"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]])
|
||||
tools: list = None
|
||||
|
||||
class ImageGenerationConfig(BaseModel):
|
||||
prompt: str
|
||||
|
|
|
|||
|
|
@ -6,20 +6,22 @@ import random
|
|||
import string
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
|
||||
|
||||
from ..image import ImageResponse, copy_images, images_dir
|
||||
from ..typing import Messages, ImageType
|
||||
from ..providers.types import ProviderType
|
||||
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
|
||||
from ..providers.types import ProviderType, BaseRetryProvider
|
||||
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
|
||||
from ..errors import NoImageResponseError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import to_sync_generator, async_generator_to_list
|
||||
from ..web_search import get_search_message, do_search
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, get_last_provider, convert_to_provider
|
||||
from .service import get_model_and_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
|
||||
from .. import debug
|
||||
|
||||
|
|
@ -35,6 +37,47 @@ except NameError:
|
|||
except StopAsyncIteration:
|
||||
raise StopIteration
|
||||
|
||||
def validate_arguments(data: dict):
|
||||
if "arguments" in data:
|
||||
if isinstance(data["arguments"], str):
|
||||
data["arguments"] = json.loads(data["arguments"])
|
||||
if not isinstance(data["arguments"], dict):
|
||||
raise ValueError("Tool function arguments must be a dictionary or a json string")
|
||||
else:
|
||||
return filter_none(**data["arguments"])
|
||||
else:
|
||||
return {}
|
||||
|
||||
async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages = messages.copy()
|
||||
messages[-1]["content"] = await do_search(
|
||||
messages[-1]["content"],
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
response = async_iter_callback(model=model, messages=messages, **kwargs)
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
|
||||
def iter_run_tools(iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
|
||||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if tool.get("type") == "function":
|
||||
if tool.get("function", {}).get("name") == "search_tool":
|
||||
tool["function"]["arguments"] = validate_arguments(tool["function"])
|
||||
messages[-1]["content"] = get_search_message(
|
||||
messages[-1]["content"],
|
||||
raise_search_exceptions=True,
|
||||
**tool["function"]["arguments"]
|
||||
)
|
||||
return iter_callback(model=model, messages=messages, **kwargs)
|
||||
|
||||
# Synchronous iter_response function
|
||||
def iter_response(
|
||||
response: Union[Iterator[Union[str, ResponseType]]],
|
||||
|
|
@ -45,6 +88,8 @@ def iter_response(
|
|||
) -> ChatCompletionResponseType:
|
||||
content = ""
|
||||
finish_reason = None
|
||||
tool_calls = None
|
||||
usage = None
|
||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||
idx = 0
|
||||
|
||||
|
|
@ -55,6 +100,12 @@ def iter_response(
|
|||
if isinstance(chunk, FinishReason):
|
||||
finish_reason = chunk.reason
|
||||
break
|
||||
elif isinstance(chunk, ToolCalls):
|
||||
tool_calls = chunk.get_list()
|
||||
continue
|
||||
elif isinstance(chunk, Usage):
|
||||
usage = chunk.get_dict()
|
||||
continue
|
||||
elif isinstance(chunk, BaseConversation):
|
||||
yield chunk
|
||||
continue
|
||||
|
|
@ -88,18 +139,21 @@ def iter_response(
|
|||
if response_format is not None and "type" in response_format:
|
||||
if response_format["type"] == "json_object":
|
||||
content = filter_json(content)
|
||||
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()))
|
||||
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()), **filter_none(
|
||||
tool_calls=tool_calls,
|
||||
usage=usage
|
||||
))
|
||||
|
||||
# Synchronous iter_append_model_and_provider function
|
||||
def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
|
||||
last_provider = None
|
||||
|
||||
def iter_append_model_and_provider(response: ChatCompletionResponseType, last_model: str, last_provider: ProviderType) -> ChatCompletionResponseType:
|
||||
if isinstance(last_provider, BaseRetryProvider):
|
||||
last_provider = last_provider.last_provider
|
||||
for chunk in response:
|
||||
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
yield chunk
|
||||
if last_provider is not None:
|
||||
chunk.model = getattr(last_provider, "last_model", last_model)
|
||||
chunk.provider = last_provider.__name__
|
||||
yield chunk
|
||||
|
||||
async def async_iter_response(
|
||||
response: AsyncIterator[Union[str, ResponseType]],
|
||||
|
|
@ -155,15 +209,20 @@ async def async_iter_response(
|
|||
await safe_aclose(response)
|
||||
|
||||
async def async_iter_append_model_and_provider(
|
||||
response: AsyncChatCompletionResponseType
|
||||
response: AsyncChatCompletionResponseType,
|
||||
last_model: str,
|
||||
last_provider: ProviderType
|
||||
) -> AsyncChatCompletionResponseType:
|
||||
last_provider = None
|
||||
try:
|
||||
if isinstance(last_provider, BaseRetryProvider):
|
||||
if last_provider is not None:
|
||||
last_provider = last_provider.last_provider
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
if last_provider is not None:
|
||||
chunk.model = getattr(last_provider, "last_model", last_model)
|
||||
chunk.provider = last_provider.__name__
|
||||
yield chunk
|
||||
finally:
|
||||
await safe_aclose(response)
|
||||
|
|
@ -215,7 +274,9 @@ class Completions:
|
|||
kwargs["images"] = [(image, image_name)]
|
||||
if ignore_stream:
|
||||
kwargs["ignore_stream"] = True
|
||||
response = provider.create_completion(
|
||||
|
||||
response = iter_run_tools(
|
||||
provider.create_completion,
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
@ -237,7 +298,7 @@ class Completions:
|
|||
# If response is an async generator, collect it into a list
|
||||
response = asyncio.run(async_generator_to_list(response))
|
||||
response = iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = iter_append_model_and_provider(response)
|
||||
response = iter_append_model_and_provider(response, model, provider)
|
||||
if stream:
|
||||
return response
|
||||
else:
|
||||
|
|
@ -296,7 +357,7 @@ class Images:
|
|||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
|
||||
e = None
|
||||
error = None
|
||||
response = None
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
for provider in provider_handler.providers:
|
||||
|
|
@ -306,6 +367,7 @@ class Images:
|
|||
provider_name = provider.__name__
|
||||
break
|
||||
except Exception as e:
|
||||
error = e
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
|
@ -313,14 +375,14 @@ class Images:
|
|||
if isinstance(response, ImageResponse):
|
||||
return await self._process_image_response(
|
||||
response,
|
||||
response_format,
|
||||
proxy,
|
||||
model,
|
||||
provider_name
|
||||
provider_name,
|
||||
response_format,
|
||||
proxy
|
||||
)
|
||||
if response is None:
|
||||
if e is not None:
|
||||
raise e
|
||||
if error is not None:
|
||||
raise error
|
||||
raise NoImageResponseError(f"No image response from {provider_name}")
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
||||
|
|
@ -390,7 +452,7 @@ class Images:
|
|||
if image is not None:
|
||||
kwargs["images"] = [(image, None)]
|
||||
|
||||
e = None
|
||||
error = None
|
||||
response = None
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
for provider in provider_handler.providers:
|
||||
|
|
@ -400,27 +462,27 @@ class Images:
|
|||
provider_name = provider.__name__
|
||||
break
|
||||
except Exception as e:
|
||||
error = e
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
else:
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
||||
if isinstance(response, ImageResponse):
|
||||
return await self._process_image_response(response, response_format, proxy, model, provider_name)
|
||||
return await self._process_image_response(response, model, provider_name, response_format, proxy)
|
||||
if response is None:
|
||||
if e is not None:
|
||||
raise e
|
||||
if error is not None:
|
||||
raise error
|
||||
raise NoImageResponseError(f"No image response from {provider_name}")
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
||||
async def _process_image_response(
|
||||
self,
|
||||
response: ImageResponse,
|
||||
model: str,
|
||||
provider: str,
|
||||
response_format: Optional[str] = None,
|
||||
proxy: str = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None
|
||||
proxy: str = None
|
||||
) -> ImagesResponse:
|
||||
last_provider = get_last_provider(True)
|
||||
if response_format == "url":
|
||||
# Return original URLs without saving locally
|
||||
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
|
||||
|
|
@ -438,8 +500,8 @@ class Images:
|
|||
return ImagesResponse.model_construct(
|
||||
created=int(time.time()),
|
||||
data=images,
|
||||
model=last_provider.get("model") if model is None else model,
|
||||
provider=last_provider.get("name") if provider is None else provider
|
||||
model=model,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -500,7 +562,8 @@ class AsyncCompletions:
|
|||
create_handler = provider.create_async_generator
|
||||
else:
|
||||
create_handler = provider.create_completion
|
||||
response = create_handler(
|
||||
response = async_iter_run_tools(
|
||||
create_handler,
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
|
|
@ -512,11 +575,8 @@ class AsyncCompletions:
|
|||
),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = async_iter_append_model_and_provider(response)
|
||||
response = async_iter_append_model_and_provider(response, model, provider)
|
||||
return response if stream else anext(response)
|
||||
|
||||
class AsyncImages(Images):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Any
|
||||
from time import time
|
||||
|
||||
from .helper import filter_none
|
||||
|
||||
ToolCalls = Optional[List[Dict[str, Any]]]
|
||||
Usage = Optional[Dict[str, int]]
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel, Field
|
||||
except ImportError:
|
||||
|
|
@ -57,10 +60,11 @@ class ChatCompletionChunk(BaseModel):
|
|||
class ChatCompletionMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: ToolCalls
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, content: str):
|
||||
return super().model_construct(role="assistant", content=content)
|
||||
def model_construct(cls, content: str, tool_calls: ToolCalls = None):
|
||||
return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
index: int
|
||||
|
|
@ -78,11 +82,11 @@ class ChatCompletion(BaseModel):
|
|||
model: str
|
||||
provider: Optional[str]
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: Dict[str, int] = Field(examples=[{
|
||||
usage: Usage = Field(default={
|
||||
"prompt_tokens": 0, #prompt_tokens,
|
||||
"completion_tokens": 0, #completion_tokens,
|
||||
"total_tokens": 0, #prompt_tokens + completion_tokens,
|
||||
}])
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def model_construct(
|
||||
|
|
@ -90,7 +94,9 @@ class ChatCompletion(BaseModel):
|
|||
content: str,
|
||||
finish_reason: str,
|
||||
completion_id: str = None,
|
||||
created: int = None
|
||||
created: int = None,
|
||||
tool_calls: ToolCalls = None,
|
||||
usage: Usage = None
|
||||
):
|
||||
return super().model_construct(
|
||||
id=f"chatcmpl-{completion_id}" if completion_id else None,
|
||||
|
|
@ -99,14 +105,10 @@ class ChatCompletion(BaseModel):
|
|||
model=None,
|
||||
provider=None,
|
||||
choices=[ChatCompletionChoice.model_construct(
|
||||
ChatCompletionMessage.model_construct(content),
|
||||
finish_reason
|
||||
ChatCompletionMessage.model_construct(content, tool_calls),
|
||||
finish_reason,
|
||||
)],
|
||||
usage={
|
||||
"prompt_tokens": 0, #prompt_tokens,
|
||||
"completion_tokens": 0, #completion_tokens,
|
||||
"total_tokens": 0, #prompt_tokens + completion_tokens,
|
||||
}
|
||||
**filter_none(usage=usage)
|
||||
)
|
||||
|
||||
class ChatCompletionDelta(BaseModel):
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ class MissingAuthError(Exception):
|
|||
class NoImageResponseError(Exception):
|
||||
...
|
||||
|
||||
class RateLimitError(Exception):
|
||||
...
|
||||
|
||||
class ResponseError(Exception):
|
||||
...
|
||||
|
||||
class ResponseStatusError(Exception):
|
||||
...
|
||||
|
||||
class RateLimitError(ResponseStatusError):
|
||||
...
|
||||
|
||||
class NoValidHarFileError(Exception):
|
||||
...
|
||||
|
|
@ -31,6 +31,7 @@ let controller_storage = {};
|
|||
let content_storage = {};
|
||||
let error_storage = {};
|
||||
let synthesize_storage = {};
|
||||
let title_storage = {};
|
||||
|
||||
messageInput.addEventListener("blur", () => {
|
||||
window.scrollTo(0, 0);
|
||||
|
|
@ -423,13 +424,13 @@ stop_generating.addEventListener("click", async () => {
|
|||
let key;
|
||||
for (key in controller_storage) {
|
||||
if (!controller_storage[key].signal.aborted) {
|
||||
controller_storage[key].abort();
|
||||
let message = message_storage[key];
|
||||
if (message) {
|
||||
content_storage[key].inner.innerHTML += " [aborted]";
|
||||
message_storage[key] += " [aborted]";
|
||||
console.log(`aborted ${window.conversation_id} #${key}`);
|
||||
}
|
||||
controller_storage[key].abort();
|
||||
}
|
||||
}
|
||||
await load_conversation(window.conversation_id, false);
|
||||
|
|
@ -491,7 +492,14 @@ const prepare_messages = (messages, message_index = -1) => {
|
|||
async function add_message_chunk(message, message_id) {
|
||||
content_map = content_storage[message_id];
|
||||
if (message.type == "conversation") {
|
||||
console.info("Conversation used:", message.conversation)
|
||||
const conversation = await get_conversation(window.conversation_id);
|
||||
if (!conversation.data) {
|
||||
conversation.data = {};
|
||||
}
|
||||
for (const [key, value] of Object.entries(message.conversation)) {
|
||||
conversation.data[key] = value;
|
||||
}
|
||||
await save_conversation(conversation_id, conversation);
|
||||
} else if (message.type == "provider") {
|
||||
provider_storage[message_id] = message.provider;
|
||||
content_map.content.querySelector('.provider').innerHTML = `
|
||||
|
|
@ -503,6 +511,7 @@ async function add_message_chunk(message, message_id) {
|
|||
} else if (message.type == "message") {
|
||||
console.error(message.message)
|
||||
} else if (message.type == "error") {
|
||||
if (content_map.inner.dataset.timeout) clearTimeout(content_map.inner.dataset.timeout);
|
||||
error_storage[message_id] = message.error
|
||||
console.error(message.error);
|
||||
content_map.inner.innerHTML += markdown_render(`**An error occured:** ${message.error}`);
|
||||
|
|
@ -512,6 +521,9 @@ async function add_message_chunk(message, message_id) {
|
|||
} else if (message.type == "preview") {
|
||||
if (content_map.inner.clientHeight > 200)
|
||||
content_map.inner.style.height = content_map.inner.clientHeight + "px";
|
||||
if (img = content_map.inner.querySelector("img"))
|
||||
if (!img.complete)
|
||||
return;
|
||||
content_map.inner.innerHTML = markdown_render(message.preview);
|
||||
} else if (message.type == "content") {
|
||||
message_storage[message_id] += message.content;
|
||||
|
|
@ -523,6 +535,10 @@ async function add_message_chunk(message, message_id) {
|
|||
log_storage.appendChild(p);
|
||||
} else if (message.type == "synthesize") {
|
||||
synthesize_storage[message_id] = message.synthesize;
|
||||
} else if (message.type == "title") {
|
||||
title_storage[message_id] = message.title;
|
||||
} else if (message.type == "login") {
|
||||
update_message(content_map, message_id, message.login);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -531,12 +547,12 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
model = get_selected_model()?.value || null;
|
||||
provider = providerSelect.options[providerSelect.selectedIndex].value;
|
||||
}
|
||||
let messages = await get_messages(window.conversation_id);
|
||||
messages = prepare_messages(messages, message_index);
|
||||
let conversation = await get_conversation(window.conversation_id);
|
||||
messages = prepare_messages(conversation.items, message_index);
|
||||
message_storage[message_id] = "";
|
||||
stop_generating.classList.remove("stop_generating-hidden");
|
||||
|
||||
if (message_index == -1) {
|
||||
if (message_index == -1 && !regenerate) {
|
||||
await scroll_to_bottom();
|
||||
}
|
||||
|
||||
|
|
@ -579,7 +595,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
inner: content_el.querySelector('.content_inner'),
|
||||
count: content_el.querySelector('.count'),
|
||||
}
|
||||
if (message_index == -1) {
|
||||
if (message_index == -1 && !regenerate) {
|
||||
await scroll_to_bottom();
|
||||
}
|
||||
try {
|
||||
|
|
@ -592,6 +608,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
await api("conversation", {
|
||||
id: message_id,
|
||||
conversation_id: window.conversation_id,
|
||||
conversation: conversation.data && provider in conversation.data ? conversation.data[provider] : null,
|
||||
model: model,
|
||||
web_search: switchInput.checked,
|
||||
provider: provider,
|
||||
|
|
@ -618,18 +635,21 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
|
|||
}
|
||||
}
|
||||
delete controller_storage[message_id];
|
||||
if (!error_storage[message_id] && message_storage[message_id]) {
|
||||
if (message_storage[message_id]) {
|
||||
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
|
||||
await add_message(
|
||||
window.conversation_id,
|
||||
"assistant",
|
||||
message_storage[message_id],
|
||||
message_storage[message_id] + (error_storage[message_id] ? " [error]" : ""),
|
||||
message_provider,
|
||||
message_index,
|
||||
synthesize_storage[message_id],
|
||||
regenerate
|
||||
regenerate,
|
||||
title_storage[message_id]
|
||||
);
|
||||
await safe_load_conversation(window.conversation_id, message_index == -1);
|
||||
if (!error_storage[message_id]) {
|
||||
await safe_load_conversation(window.conversation_id, message_index == -1);
|
||||
}
|
||||
}
|
||||
let cursorDiv = message_el.querySelector(".cursor");
|
||||
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
|
||||
|
|
@ -696,6 +716,7 @@ const show_option = async (conversation_id) => {
|
|||
const input_el = document.createElement("input");
|
||||
input_el.value = title_el.innerText;
|
||||
input_el.classList.add("convo-title");
|
||||
input_el.onclick = (e) => e.stopPropagation()
|
||||
input_el.onfocus = () => trash_el.style.display = "none";
|
||||
input_el.onchange = () => set_conversation_title(conversation_id, input_el.value);
|
||||
left_el.removeChild(title_el);
|
||||
|
|
@ -718,7 +739,6 @@ const hide_option = async (conversation_id) => {
|
|||
const span_el = document.createElement("span");
|
||||
span_el.innerText = input_el.value;
|
||||
span_el.classList.add("convo-title");
|
||||
span_el.onclick = () => set_conversation(conversation_id);
|
||||
left_el.removeChild(input_el);
|
||||
left_el.appendChild(span_el);
|
||||
}
|
||||
|
|
@ -772,8 +792,11 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
|||
if (!conversation) {
|
||||
return;
|
||||
}
|
||||
|
||||
document.title = conversation.new_title ? `g4f - ${conversation.new_title}` : document.title;
|
||||
let title = conversation.title || conversation.new_title;
|
||||
title = title ? `${title} - g4f` : window.title;
|
||||
if (title) {
|
||||
document.title = title;
|
||||
}
|
||||
|
||||
if (systemPrompt) {
|
||||
systemPrompt.value = conversation.system || "";
|
||||
|
|
@ -956,10 +979,21 @@ const add_message = async (
|
|||
provider = null,
|
||||
message_index = -1,
|
||||
synthesize_data = null,
|
||||
regenerate = false
|
||||
regenerate = false,
|
||||
title = null
|
||||
) => {
|
||||
const conversation = await get_conversation(conversation_id);
|
||||
if (!conversation) return;
|
||||
if (!conversation) {
|
||||
return;
|
||||
}
|
||||
if (title) {
|
||||
conversation.title = title;
|
||||
} else if (!conversation.title) {
|
||||
let new_value = content.trim();
|
||||
let new_lenght = new_value.indexOf("\n");
|
||||
new_lenght = new_lenght > 200 || new_lenght < 0 ? 200 : new_lenght;
|
||||
conversation.title = new_value.substring(0, new_lenght);
|
||||
}
|
||||
const new_message = {
|
||||
role: role,
|
||||
content: content,
|
||||
|
|
@ -988,6 +1022,15 @@ const add_message = async (
|
|||
return conversation.items.length - 1;
|
||||
};
|
||||
|
||||
const escapeHtml = (unsafe) => {
|
||||
return unsafe.replaceAll('&', '&').replaceAll('<', '<').replaceAll('>', '>').replaceAll('"', '"').replaceAll("'", ''');
|
||||
}
|
||||
|
||||
const toLocaleDateString = (date) => {
|
||||
date = new Date(date);
|
||||
return date.toLocaleString('en-GB', {dateStyle: 'short', timeStyle: 'short', monthStyle: 'short'}).replace("/" + date.getFullYear(), "");
|
||||
}
|
||||
|
||||
const load_conversations = async () => {
|
||||
let conversations = [];
|
||||
for (let i = 0; i < appStorage.length; i++) {
|
||||
|
|
@ -998,32 +1041,14 @@ const load_conversations = async () => {
|
|||
}
|
||||
conversations.sort((a, b) => (b.updated||0)-(a.updated||0));
|
||||
|
||||
await clear_conversations();
|
||||
|
||||
let html = "";
|
||||
let html = [];
|
||||
conversations.forEach((conversation) => {
|
||||
if (conversation?.items.length > 0 && !conversation.new_title) {
|
||||
let new_value = (conversation.items[0]["content"]).trim();
|
||||
let new_lenght = new_value.indexOf("\n");
|
||||
new_lenght = new_lenght > 200 || new_lenght < 0 ? 200 : new_lenght;
|
||||
conversation.new_title = new_value.substring(0, new_lenght);
|
||||
appStorage.setItem(
|
||||
`conversation:${conversation.id}`,
|
||||
JSON.stringify(conversation)
|
||||
);
|
||||
}
|
||||
let updated = "";
|
||||
if (conversation.updated) {
|
||||
const date = new Date(conversation.updated);
|
||||
updated = date.toLocaleString('en-GB', {dateStyle: 'short', timeStyle: 'short', monthStyle: 'short'});
|
||||
updated = updated.replace("/" + date.getFullYear(), "")
|
||||
}
|
||||
html += `
|
||||
html.push(`
|
||||
<div class="convo" id="convo-${conversation.id}">
|
||||
<div class="left">
|
||||
<div class="left" onclick="set_conversation('${conversation.id}')">
|
||||
<i class="fa-regular fa-comments"></i>
|
||||
<span class="datetime" onclick="set_conversation('${conversation.id}')">${updated}</span>
|
||||
<span class="convo-title" onclick="set_conversation('${conversation.id}')">${conversation.new_title}</span>
|
||||
<span class="datetime">${conversation.updated ? toLocaleDateString(conversation.updated) : ""}</span>
|
||||
<span class="convo-title">${escapeHtml(conversation.new_title ? conversation.new_title : conversation.title)}</span>
|
||||
</div>
|
||||
<i onclick="show_option('${conversation.id}')" class="fa-solid fa-ellipsis-vertical" id="conv-${conversation.id}"></i>
|
||||
<div id="cho-${conversation.id}" class="choise" style="display:none;">
|
||||
|
|
@ -1031,9 +1056,10 @@ const load_conversations = async () => {
|
|||
<i onclick="hide_option('${conversation.id}')" class="fa-regular fa-x"></i>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
`);
|
||||
});
|
||||
box_conversations.innerHTML += html;
|
||||
await clear_conversations();
|
||||
box_conversations.innerHTML += html.join("");
|
||||
};
|
||||
|
||||
const hide_input = document.querySelector(".toolbar .hide-input");
|
||||
|
|
@ -1211,9 +1237,15 @@ function count_words_and_tokens(text, model) {
|
|||
return `(${count_words(text)} words, ${count_chars(text)} chars, ${count_tokens(model, text)} tokens)`;
|
||||
}
|
||||
|
||||
function update_message(content_map, message_id) {
|
||||
function update_message(content_map, message_id, content = null) {
|
||||
content_map.inner.dataset.timeout = setTimeout(() => {
|
||||
html = markdown_render(message_storage[message_id]);
|
||||
let cleared = false;
|
||||
if (content_map.inner.dataset.timeout) {
|
||||
content_map.inner.dataset.timeout = clearTimeout(content_map.inner.dataset.timeout);
|
||||
cleared = true;
|
||||
}
|
||||
if (!content) content = message_storage[message_id];
|
||||
html = markdown_render(content);
|
||||
let lastElement, lastIndex = null;
|
||||
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
|
||||
const index = html.lastIndexOf(element)
|
||||
|
|
@ -1237,7 +1269,9 @@ function update_message(content_map, message_id) {
|
|||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
}
|
||||
}
|
||||
if (content_map.inner.dataset.timeout) clearTimeout(content_map.inner.dataset.timeout);
|
||||
if (content_map.inner.dataset.timeout && !cleared){
|
||||
content_map.inner.dataset.timeout = clearTimeout(content_map.inner.dataset.timeout);
|
||||
}
|
||||
}, 100);
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_di
|
|||
from g4f.Provider import ProviderType, __providers__, __map__
|
||||
from g4f.providers.base_provider import ProviderModelMixin
|
||||
from g4f.providers.retry_provider import IterListProvider
|
||||
from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
|
||||
from g4f.providers.response import BaseConversation, JsonConversation, FinishReason, SynthesizeData, TitleGeneration, RequestLogin
|
||||
from g4f.client.service import convert_to_provider
|
||||
from g4f import debug
|
||||
|
||||
|
|
@ -97,15 +97,18 @@ class Api:
|
|||
kwargs['web_search'] = True
|
||||
do_web_search = False
|
||||
if do_web_search:
|
||||
from .internet import get_search_message
|
||||
from ...web_search import get_search_message
|
||||
messages[-1]["content"] = get_search_message(messages[-1]["content"])
|
||||
if json_data.get("auto_continue"):
|
||||
kwargs['auto_continue'] = True
|
||||
|
||||
conversation_id = json_data.get("conversation_id")
|
||||
if conversation_id and provider:
|
||||
if provider in conversations and conversation_id in conversations[provider]:
|
||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
||||
conversation = json_data.get("conversation")
|
||||
if conversation is not None:
|
||||
kwargs["conversation"] = JsonConversation(**conversation)
|
||||
else:
|
||||
conversation_id = json_data.get("conversation_id")
|
||||
if conversation_id and provider:
|
||||
if provider in conversations and conversation_id in conversations[provider]:
|
||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
||||
|
||||
if json_data.get("ignored"):
|
||||
kwargs["ignored"] = json_data["ignored"]
|
||||
|
|
@ -146,7 +149,12 @@ class Api:
|
|||
if provider not in conversations:
|
||||
conversations[provider] = {}
|
||||
conversations[provider][conversation_id] = chunk
|
||||
yield self._format_json("conversation", conversation_id)
|
||||
if isinstance(chunk, JsonConversation):
|
||||
yield self._format_json("conversation", {
|
||||
provider: chunk.to_dict()
|
||||
})
|
||||
else:
|
||||
yield self._format_json("conversation_id", conversation_id)
|
||||
elif isinstance(chunk, Exception):
|
||||
logger.exception(chunk)
|
||||
yield self._format_json("message", get_error_message(chunk))
|
||||
|
|
@ -160,6 +168,10 @@ class Api:
|
|||
yield self._format_json("content", str(images))
|
||||
elif isinstance(chunk, SynthesizeData):
|
||||
yield self._format_json("synthesize", chunk.to_json())
|
||||
elif isinstance(chunk, TitleGeneration):
|
||||
yield self._format_json("title", chunk.title)
|
||||
elif isinstance(chunk, RequestLogin):
|
||||
yield self._format_json("login", str(chunk))
|
||||
elif not isinstance(chunk, FinishReason):
|
||||
yield self._format_json("content", str(chunk))
|
||||
if debug.logs:
|
||||
|
|
|
|||
|
|
@ -1,160 +1,3 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
from ...errors import MissingRequirementsError
|
||||
from ... import debug
|
||||
|
||||
import asyncio
|
||||
|
||||
class SearchResults():
|
||||
def __init__(self, results: list, used_words: int):
|
||||
self.results = results
|
||||
self.used_words = used_words
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.results
|
||||
|
||||
def __str__(self):
|
||||
search = ""
|
||||
for idx, result in enumerate(self.results):
|
||||
if search:
|
||||
search += "\n\n\n"
|
||||
search += f"Title: {result.title}\n\n"
|
||||
if result.text:
|
||||
search += result.text
|
||||
else:
|
||||
search += result.snippet
|
||||
search += f"\n\nSource: [[{idx}]]({result.url})"
|
||||
return search
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
class SearchResultEntry():
|
||||
def __init__(self, title: str, url: str, snippet: str, text: str = None):
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.snippet = snippet
|
||||
self.text = text
|
||||
|
||||
def set_text(self, text: str):
|
||||
self.text = text
|
||||
|
||||
def scrape_text(html: str, max_words: int = None) -> str:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main",
|
||||
".main-content-wrapper",
|
||||
".main-content",
|
||||
".emt-container-inner",
|
||||
".content-wrapper",
|
||||
"#content",
|
||||
"#mainContent",
|
||||
]:
|
||||
select = soup.select_one(selector)
|
||||
if select:
|
||||
soup = select
|
||||
break
|
||||
# Zdnet
|
||||
for remove in [".c-globalDisclosure"]:
|
||||
select = soup.select_one(remove)
|
||||
if select:
|
||||
select.extract()
|
||||
clean_text = ""
|
||||
for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"):
|
||||
text = paragraph.get_text()
|
||||
for line in text.splitlines():
|
||||
words = []
|
||||
for word in line.replace("\t", " ").split(" "):
|
||||
if word:
|
||||
words.append(word)
|
||||
count = len(words)
|
||||
if not count:
|
||||
continue
|
||||
if max_words:
|
||||
max_words -= count
|
||||
if max_words <= 0:
|
||||
break
|
||||
if clean_text:
|
||||
clean_text += "\n"
|
||||
clean_text += " ".join(words)
|
||||
|
||||
return clean_text
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text()
|
||||
return scrape_text(html, max_words)
|
||||
except:
|
||||
return
|
||||
|
||||
async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text: bool = True) -> SearchResults:
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]')
|
||||
with DDGS() as ddgs:
|
||||
results = []
|
||||
for result in ddgs.text(
|
||||
query,
|
||||
region="wt-wt",
|
||||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=n_results,
|
||||
):
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
result["href"],
|
||||
result["body"]
|
||||
))
|
||||
|
||||
if add_text:
|
||||
requests = []
|
||||
async with ClientSession(timeout=ClientTimeout(5)) as session:
|
||||
for entry in results:
|
||||
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (n_results - 1))))
|
||||
texts = await asyncio.gather(*requests)
|
||||
|
||||
formatted_results = []
|
||||
used_words = 0
|
||||
left_words = max_words
|
||||
for i, entry in enumerate(results):
|
||||
if add_text:
|
||||
entry.text = texts[i]
|
||||
if left_words:
|
||||
left_words -= entry.title.count(" ") + 5
|
||||
if entry.text:
|
||||
left_words -= entry.text.count(" ")
|
||||
else:
|
||||
left_words -= entry.snippet.count(" ")
|
||||
if 0 > left_words:
|
||||
break
|
||||
used_words = max_words - left_words
|
||||
formatted_results.append(entry)
|
||||
|
||||
return SearchResults(formatted_results, used_words)
|
||||
|
||||
def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str:
|
||||
try:
|
||||
search_results = asyncio.run(search(prompt, n_results, max_words))
|
||||
message = f"""
|
||||
{search_results}
|
||||
|
||||
|
||||
Instruction: Using the provided web search results, to write a comprehensive reply to the user request.
|
||||
Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)
|
||||
|
||||
User request:
|
||||
{prompt}
|
||||
"""
|
||||
debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words")
|
||||
return message
|
||||
except Exception as e:
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
return prompt
|
||||
from ...web_search import SearchResults, search, get_search_message
|
||||
96
g4f/image.py
96
g4f/image.py
|
|
@ -2,11 +2,13 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from aiohttp import ClientSession, ClientError
|
||||
try:
|
||||
from PIL.Image import open as open_image, new as new_image
|
||||
|
|
@ -17,7 +19,7 @@ except ImportError:
|
|||
|
||||
from .typing import ImageType, Union, Image, Optional, Cookies
|
||||
from .errors import MissingRequirementsError
|
||||
from .providers.response import ResponseType
|
||||
from .providers.response import ImageResponse, ImagePreview
|
||||
from .requests.aiohttp import get_connector
|
||||
from . import debug
|
||||
|
||||
|
|
@ -33,15 +35,6 @@ EXTENSIONS_MAP: dict[str, str] = {
|
|||
# Define the directory for generated images
|
||||
images_dir = "./generated_images"
|
||||
|
||||
def fix_url(url: str) -> str:
|
||||
""" replace ' ' by '+' (to be markdown compliant)"""
|
||||
return url.replace(" ","+")
|
||||
|
||||
def fix_title(title: str) -> str:
|
||||
if title:
|
||||
return title.replace("\n", "").replace('"', '')
|
||||
return ""
|
||||
|
||||
def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
"""
|
||||
Converts the input image to a PIL Image object.
|
||||
|
|
@ -55,7 +48,7 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
|||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "pillow" package for images')
|
||||
|
||||
if isinstance(image, str):
|
||||
if isinstance(image, str) and image.startswith("data:"):
|
||||
is_data_uri_an_image(image)
|
||||
image = extract_data_uri(image)
|
||||
|
||||
|
|
@ -203,47 +196,6 @@ def process_image(image: Image, new_width: int, new_height: int) -> Image:
|
|||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def to_base64_jpg(image: Image, compression_rate: float) -> str:
|
||||
"""
|
||||
Converts the given image to a base64-encoded string.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to convert.
|
||||
compression_rate (float): The compression rate (0.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
str: The base64-encoded image.
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
|
||||
return base64.b64encode(output_buffer.getvalue()).decode()
|
||||
|
||||
def format_images_markdown(images: Union[str, list], alt: str, preview: Union[str, list] = None) -> str:
|
||||
"""
|
||||
Formats the given images as a markdown string.
|
||||
|
||||
Args:
|
||||
images: The images to format.
|
||||
alt (str): The alt for the images.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
if isinstance(images, list) and len(images) == 1:
|
||||
images = images[0]
|
||||
if isinstance(images, str):
|
||||
result = f"[ if preview else images)})]({fix_url(images)})"
|
||||
else:
|
||||
if not isinstance(preview, list):
|
||||
preview = [preview.replace('{image}', image) if preview else image for image in images]
|
||||
result = "\n".join(
|
||||
f"[})]({fix_url(image)})"
|
||||
for idx, image in enumerate(images)
|
||||
)
|
||||
start_flag = "<!-- generated images start -->\n"
|
||||
end_flag = "<!-- generated images end -->\n"
|
||||
return f"\n{start_flag}{result}\n{end_flag}\n"
|
||||
|
||||
def to_bytes(image: ImageType) -> bytes:
|
||||
"""
|
||||
|
|
@ -257,7 +209,7 @@ def to_bytes(image: ImageType) -> bytes:
|
|||
"""
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
elif isinstance(image, str) and image.startswith("data:"):
|
||||
is_data_uri_an_image(image)
|
||||
return extract_data_uri(image)
|
||||
elif isinstance(image, Image):
|
||||
|
|
@ -265,8 +217,15 @@ def to_bytes(image: ImageType) -> bytes:
|
|||
image.save(bytes_io, image.format)
|
||||
image.seek(0)
|
||||
return bytes_io.getvalue()
|
||||
elif isinstance(image, (str, os.PathLike)):
|
||||
return Path(image).read_bytes()
|
||||
elif isinstance(image, Path):
|
||||
return image.read_bytes()
|
||||
else:
|
||||
image.seek(0)
|
||||
try:
|
||||
image.seek(0)
|
||||
except (AttributeError, io.UnsupportedOperation):
|
||||
pass
|
||||
return image.read()
|
||||
|
||||
def to_data_uri(image: ImageType) -> str:
|
||||
|
|
@ -314,33 +273,6 @@ async def copy_images(
|
|||
|
||||
return await asyncio.gather(*[copy_image(image) for image in images])
|
||||
|
||||
class ImageResponse(ResponseType):
|
||||
def __init__(
|
||||
self,
|
||||
images: Union[str, list],
|
||||
alt: str,
|
||||
options: dict = {}
|
||||
):
|
||||
self.images = images
|
||||
self.alt = alt
|
||||
self.options = options
|
||||
|
||||
def __str__(self) -> str:
|
||||
return format_images_markdown(self.images, self.alt, self.get("preview"))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.options.get(key)
|
||||
|
||||
def get_list(self) -> list[str]:
|
||||
return [self.images] if isinstance(self.images, str) else self.images
|
||||
|
||||
class ImagePreview(ImageResponse):
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def to_string(self):
|
||||
return super().__str__()
|
||||
|
||||
class ImageDataResponse():
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -120,6 +120,12 @@ gpt_4o_mini = Model(
|
|||
)
|
||||
|
||||
# o1
|
||||
o1 = Model(
|
||||
name = 'o1',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = OpenaiAccount
|
||||
)
|
||||
|
||||
o1_preview = Model(
|
||||
name = 'o1-preview',
|
||||
base_provider = 'OpenAI',
|
||||
|
|
@ -655,6 +661,7 @@ class ModelUtils:
|
|||
gpt_4o_mini.name: gpt_4o_mini,
|
||||
|
||||
# o1
|
||||
o1.name: o1,
|
||||
o1_preview.name: o1_preview,
|
||||
o1_mini.name: o1_mini,
|
||||
|
||||
|
|
@ -828,11 +835,12 @@ class ModelUtils:
|
|||
__models__ = {
|
||||
model.name: (model, providers)
|
||||
for model, providers in [
|
||||
(model, model.best_provider.providers
|
||||
(model, [provider for provider in model.best_provider.providers if provider.working]
|
||||
if isinstance(model.best_provider, IterListProvider)
|
||||
else [model.best_provider]
|
||||
if model.best_provider is not None
|
||||
if model.best_provider is not None and model.best_provider.working
|
||||
else [])
|
||||
for model in ModelUtils.convert.values()]
|
||||
if providers
|
||||
}
|
||||
_all_models = list(__models__.keys())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
from asyncio import AbstractEventLoop
|
||||
|
|
@ -11,20 +10,11 @@ from inspect import signature, Parameter
|
|||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider
|
||||
from .asyncio import get_running_loop, to_sync_generator
|
||||
from .response import FinishReason, BaseConversation, SynthesizeData
|
||||
from .response import BaseConversation
|
||||
from .helper import concat_chunks, async_concat_chunks
|
||||
from ..errors import ModelNotSupportedError
|
||||
from .. import debug
|
||||
|
||||
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
from curl_cffi import aio
|
||||
if not hasattr(aio, "_get_selector"):
|
||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
|
|
@ -36,6 +26,7 @@ class AbstractProvider(BaseProvider):
|
|||
model: str,
|
||||
messages: Messages,
|
||||
*,
|
||||
timeout: int = None,
|
||||
loop: AbstractEventLoop = None,
|
||||
executor: ThreadPoolExecutor = None,
|
||||
**kwargs
|
||||
|
|
@ -57,13 +48,11 @@ class AbstractProvider(BaseProvider):
|
|||
loop = loop or asyncio.get_running_loop()
|
||||
|
||||
def create_func() -> str:
|
||||
chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk]
|
||||
if chunks:
|
||||
return "".join(chunks)
|
||||
return concat_chunks(cls.create_completion(model, messages, False, **kwargs))
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=kwargs.get("timeout")
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -205,10 +194,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
return "".join([
|
||||
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
|
||||
])
|
||||
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import random
|
||||
import string
|
||||
|
||||
from ..typing import Messages, Cookies
|
||||
from ..typing import Messages, Cookies, AsyncIterator, Iterator
|
||||
from .. import debug
|
||||
|
||||
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
|
|
@ -73,5 +73,14 @@ def filter_none(**kwargs) -> dict:
|
|||
if value is not None
|
||||
}
|
||||
|
||||
async def async_concat_chunks(chunks: AsyncIterator) -> str:
|
||||
return concat_chunks([chunk async for chunk in chunks])
|
||||
|
||||
def concat_chunks(chunks: Iterator) -> str:
|
||||
return "".join([
|
||||
str(chunk) for chunk in chunks
|
||||
if chunk and not isinstance(chunk, Exception)
|
||||
])
|
||||
|
||||
def format_cookies(cookies: Cookies) -> str:
|
||||
return "; ".join([f"{k}={v}" for k, v in cookies.items()])
|
||||
|
|
@ -1,12 +1,88 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Union
|
||||
from abc import abstractmethod
|
||||
from urllib.parse import quote_plus, unquote_plus
|
||||
|
||||
def quote_url(url: str) -> str:
|
||||
url = unquote_plus(url)
|
||||
url = url.split("//", maxsplit=1)
|
||||
# If there is no "//" in the URL, then it is a relative URL
|
||||
if len(url) == 1:
|
||||
return quote_plus(url[0], '/?&=#')
|
||||
url[1] = url[1].split("/", maxsplit=1)
|
||||
# If there is no "/" after the domain, then it is a domain URL
|
||||
if len(url[1]) == 1:
|
||||
return url[0] + "//" + url[1][0]
|
||||
return url[0] + "//" + url[1][0] + "/" + quote_plus(url[1][1], '/?&=#')
|
||||
|
||||
def quote_title(title: str) -> str:
|
||||
if title:
|
||||
return title.replace("\n", "").replace('"', '')
|
||||
return ""
|
||||
|
||||
def format_link(url: str, title: str = None) -> str:
|
||||
if title is None:
|
||||
title = unquote_plus(url.split("//", maxsplit=1)[1].split("?")[0].replace("www.", ""))
|
||||
return f"[{quote_title(title)}]({quote_url(url)})"
|
||||
|
||||
def format_image(image: str, alt: str, preview: str = None) -> str:
|
||||
"""
|
||||
Formats the given image as a markdown string.
|
||||
|
||||
Args:
|
||||
image: The image to format.
|
||||
alt (str): The alt for the image.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
return f"[ if preview else image)})]({quote_url(image)})"
|
||||
|
||||
def format_images_markdown(images: Union[str, list], alt: str, preview: Union[str, list] = None) -> str:
|
||||
"""
|
||||
Formats the given images as a markdown string.
|
||||
|
||||
Args:
|
||||
images: The images to format.
|
||||
alt (str): The alt for the images.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
if isinstance(images, list) and len(images) == 1:
|
||||
images = images[0]
|
||||
if isinstance(images, str):
|
||||
result = format_image(images, alt, preview)
|
||||
else:
|
||||
result = "\n".join(
|
||||
format_image(image, f"#{idx+1} {alt}", preview[idx] if isinstance(preview, list) else preview)
|
||||
for idx, image in enumerate(images)
|
||||
)
|
||||
start_flag = "<!-- generated images start -->\n"
|
||||
end_flag = "<!-- generated images end -->\n"
|
||||
return f"\n{start_flag}{result}\n{end_flag}\n"
|
||||
|
||||
class ResponseType:
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
pass
|
||||
|
||||
class JsonMixin:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def get_dict(self):
|
||||
return {
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
if not key.startswith("__")
|
||||
}
|
||||
|
||||
class FinishReason():
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
|
|
@ -14,26 +90,92 @@ class FinishReason():
|
|||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class Sources(ResponseType):
|
||||
def __init__(self, sources: list[dict[str, str]]) -> None:
|
||||
self.list = sources
|
||||
class ToolCalls(ResponseType):
|
||||
def __init__(self, list: list):
|
||||
self.list = list
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)]))
|
||||
return ""
|
||||
|
||||
def get_list(self) -> list:
|
||||
return self.list
|
||||
|
||||
class Usage(ResponseType, JsonMixin):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class TitleGeneration(ResponseType):
|
||||
def __init__(self, title: str) -> None:
|
||||
self.title = title
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class Sources(ResponseType):
|
||||
def __init__(self, sources: list[dict[str, str]]) -> None:
|
||||
self.list = []
|
||||
for source in sources:
|
||||
self.add_source(source)
|
||||
|
||||
def add_source(self, source: dict[str, str]):
|
||||
url = source.get("url", source.get("link", None))
|
||||
if url is not None:
|
||||
url = re.sub(r"[&?]utm_source=.+", "", url)
|
||||
source["url"] = url
|
||||
self.list.append(source)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n\n" + ("\n".join([
|
||||
f"{idx+1}. {format_link(link['url'], link.get('title', None))}"
|
||||
for idx, link in enumerate(self.list)
|
||||
]))
|
||||
|
||||
class BaseConversation(ResponseType):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class SynthesizeData(ResponseType):
|
||||
class JsonConversation(BaseConversation, JsonMixin):
|
||||
pass
|
||||
|
||||
class SynthesizeData(ResponseType, JsonMixin):
|
||||
def __init__(self, provider: str, data: dict):
|
||||
self.provider = provider
|
||||
self.data = data
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
**self.__dict__
|
||||
}
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class RequestLogin(ResponseType):
|
||||
def __init__(self, label: str, login_url: str) -> None:
|
||||
self.label = label
|
||||
self.login_url = login_url
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
return format_link(self.login_url, f"[Login to {self.label}]") + "\n\n"
|
||||
|
||||
class ImageResponse(ResponseType):
|
||||
def __init__(
|
||||
self,
|
||||
images: Union[str, list],
|
||||
alt: str,
|
||||
options: dict = {}
|
||||
):
|
||||
self.images = images
|
||||
self.alt = alt
|
||||
self.options = options
|
||||
|
||||
def __str__(self) -> str:
|
||||
return format_images_markdown(self.images, self.alt, self.get("preview"))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.options.get(key)
|
||||
|
||||
def get_list(self) -> list[str]:
|
||||
return [self.images] if isinstance(self.images, str) else self.images
|
||||
|
||||
class ImagePreview(ImageResponse):
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def to_string(self):
|
||||
return super().__str__()
|
||||
|
|
@ -21,21 +21,23 @@ def is_openai(text: str) -> bool:
|
|||
return "<p>Unable to load site</p>" in text or 'id="challenge-error-text"' in text
|
||||
|
||||
async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None):
|
||||
if response.status in (429, 402):
|
||||
raise RateLimitError(f"Response {response.status}: Rate limit reached")
|
||||
if response.ok:
|
||||
return
|
||||
text = await response.text()
|
||||
if message is None:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
|
||||
if message == "HTML content":
|
||||
if response.status == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
elif response.status in (429, 402):
|
||||
message = "Rate limit"
|
||||
if response.status == 403 and is_cloudflare(text):
|
||||
raise CloudflareError(f"Response {response.status}: Cloudflare detected")
|
||||
elif response.status == 403 and is_openai(text):
|
||||
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
||||
elif response.status == 502:
|
||||
raise ResponseStatusError(f"Response {response.status}: Bad gateway")
|
||||
elif message is not None:
|
||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
||||
else:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else text
|
||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
||||
|
||||
def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None):
|
||||
|
|
@ -43,16 +45,19 @@ def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, R
|
|||
return raise_for_status_async(response, message)
|
||||
if response.ok:
|
||||
return
|
||||
if response.status_code in (429, 402):
|
||||
raise RateLimitError(f"Response {response.status_code}: Rate limit reached")
|
||||
elif response.status_code == 403 and is_cloudflare(response.text):
|
||||
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
|
||||
elif response.status == 403 and is_openai(response.text):
|
||||
raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected")
|
||||
elif message is not None:
|
||||
raise ResponseStatusError(f"Response {response.status}: {message}")
|
||||
elif response.status_code == 502:
|
||||
raise ResponseStatusError(f"Response {response.status}: Bad gateway")
|
||||
else:
|
||||
if message is None:
|
||||
message = "HTML content" if response.headers.get("content-type").startswith("text/html") else response.text
|
||||
if message == "HTML content":
|
||||
if response.status_code == 520:
|
||||
message = "Unknown error (Cloudflare)"
|
||||
elif response.status_code in (429, 402):
|
||||
message = "Rate limit"
|
||||
raise RateLimitError(f"Response {response.status_code}: {message}")
|
||||
if response.status_code == 403 and is_cloudflare(response.text):
|
||||
raise CloudflareError(f"Response {response.status_code}: Cloudflare detected")
|
||||
elif response.status_code == 403 and is_openai(response.text):
|
||||
raise ResponseStatusError(f"Response {response.status_code}: OpenAI Bot detected")
|
||||
elif response.status_code == 502:
|
||||
raise ResponseStatusError(f"Response {response.status_code}: Bad gateway")
|
||||
else:
|
||||
raise ResponseStatusError(f"Response {response.status_code}: {message}")
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional
|
||||
|
||||
try:
|
||||
|
|
@ -19,7 +20,7 @@ CreateResult = Iterator[Union[str, ResponseType]]
|
|||
AsyncResult = AsyncIterator[Union[str, ResponseType]]
|
||||
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
|
||||
Cookies = Dict[str, str]
|
||||
ImageType = Union[str, bytes, IO, Image]
|
||||
ImageType = Union[str, bytes, IO, Image, os.PathLike]
|
||||
ImagesType = List[Tuple[ImageType, Optional[str]]]
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
172
g4f/web_search.py
Normal file
172
g4f/web_search.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientError
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
from duckduckgo_search.exceptions import DuckDuckGoSearchException
|
||||
from bs4 import BeautifulSoup
|
||||
has_requirements = True
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
from .errors import MissingRequirementsError
|
||||
from . import debug
|
||||
|
||||
import asyncio
|
||||
|
||||
DEFAULT_INSTRUCTIONS = """
|
||||
Using the provided web search results, to write a comprehensive reply to the user request.
|
||||
Make sure to add the sources of cites using [[Number]](Url) notation after the reference. Example: [[0]](http://google.com)
|
||||
"""
|
||||
|
||||
class SearchResults():
|
||||
def __init__(self, results: list, used_words: int):
|
||||
self.results = results
|
||||
self.used_words = used_words
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.results
|
||||
|
||||
def __str__(self):
|
||||
search = ""
|
||||
for idx, result in enumerate(self.results):
|
||||
if search:
|
||||
search += "\n\n\n"
|
||||
search += f"Title: {result.title}\n\n"
|
||||
if result.text:
|
||||
search += result.text
|
||||
else:
|
||||
search += result.snippet
|
||||
search += f"\n\nSource: [[{idx}]]({result.url})"
|
||||
return search
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.results)
|
||||
|
||||
class SearchResultEntry():
|
||||
def __init__(self, title: str, url: str, snippet: str, text: str = None):
|
||||
self.title = title
|
||||
self.url = url
|
||||
self.snippet = snippet
|
||||
self.text = text
|
||||
|
||||
def set_text(self, text: str):
|
||||
self.text = text
|
||||
|
||||
def scrape_text(html: str, max_words: int = None) -> str:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
for selector in [
|
||||
"main",
|
||||
".main-content-wrapper",
|
||||
".main-content",
|
||||
".emt-container-inner",
|
||||
".content-wrapper",
|
||||
"#content",
|
||||
"#mainContent",
|
||||
]:
|
||||
select = soup.select_one(selector)
|
||||
if select:
|
||||
soup = select
|
||||
break
|
||||
# Zdnet
|
||||
for remove in [".c-globalDisclosure"]:
|
||||
select = soup.select_one(remove)
|
||||
if select:
|
||||
select.extract()
|
||||
clean_text = ""
|
||||
for paragraph in soup.select("p, h1, h2, h3, h4, h5, h6"):
|
||||
text = paragraph.get_text()
|
||||
for line in text.splitlines():
|
||||
words = []
|
||||
for word in line.replace("\t", " ").split(" "):
|
||||
if word:
|
||||
words.append(word)
|
||||
count = len(words)
|
||||
if not count:
|
||||
continue
|
||||
if max_words:
|
||||
max_words -= count
|
||||
if max_words <= 0:
|
||||
break
|
||||
if clean_text:
|
||||
clean_text += "\n"
|
||||
clean_text += " ".join(words)
|
||||
|
||||
return clean_text
|
||||
|
||||
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
html = await response.text()
|
||||
return scrape_text(html, max_words)
|
||||
except ClientError:
|
||||
return
|
||||
|
||||
async def search(query: str, max_results: int = 5, max_words: int = 2500, backend: str = "api", add_text: bool = True, timeout: int = 5, region: str = "wt-wt") -> SearchResults:
|
||||
if not has_requirements:
|
||||
raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]')
|
||||
with DDGS() as ddgs:
|
||||
results = []
|
||||
for result in ddgs.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=max_results,
|
||||
backend=backend,
|
||||
):
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
result["href"],
|
||||
result["body"]
|
||||
))
|
||||
|
||||
if add_text:
|
||||
requests = []
|
||||
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
|
||||
for entry in results:
|
||||
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1))))
|
||||
texts = await asyncio.gather(*requests)
|
||||
|
||||
formatted_results = []
|
||||
used_words = 0
|
||||
left_words = max_words
|
||||
for i, entry in enumerate(results):
|
||||
if add_text:
|
||||
entry.text = texts[i]
|
||||
if left_words:
|
||||
left_words -= entry.title.count(" ") + 5
|
||||
if entry.text:
|
||||
left_words -= entry.text.count(" ")
|
||||
else:
|
||||
left_words -= entry.snippet.count(" ")
|
||||
if 0 > left_words:
|
||||
break
|
||||
used_words = max_words - left_words
|
||||
formatted_results.append(entry)
|
||||
|
||||
return SearchResults(formatted_results, used_words)
|
||||
|
||||
async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_INSTRUCTIONS, **kwargs) -> str:
|
||||
if query is None:
|
||||
query = prompt
|
||||
search_results = await search(query, **kwargs)
|
||||
new_prompt = f"""
|
||||
{search_results}
|
||||
|
||||
Instruction: {instructions}
|
||||
|
||||
User request:
|
||||
{prompt}
|
||||
"""
|
||||
debug.log(f"Web search: '{query.strip()[:50]}...' {len(search_results.results)} Results {search_results.used_words} Words")
|
||||
return new_prompt
|
||||
|
||||
def get_search_message(prompt: str, raise_search_exceptions=False, **kwargs) -> str:
|
||||
try:
|
||||
return asyncio.run(do_search(prompt, **kwargs))
|
||||
except (DuckDuckGoSearchException, MissingRequirementsError) as e:
|
||||
if raise_search_exceptions:
|
||||
raise e
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
return prompt
|
||||
Loading…
Add table
Add a link
Reference in a new issue