mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Fix api streaming, fix AsyncClient (#2357)
* Fix api streaming, fix AsyncClient, Improve Client class, Some providers fixes, Update models list, Fix some tests, Update model list in Airforce provid er, Add OpenAi image generation url to api, Fix reload and debug in api arguments, Fix websearch in gui * Fix Cloadflare and Pi and AmigoChat provider * Fix conversation support in DDG provider, Add cloudflare bypass with nodriver * Fix unittests without curl_cffi
This commit is contained in:
parent
bc79969e5c
commit
6ce493d4df
34 changed files with 1161 additions and 1132 deletions
|
|
@ -6,14 +6,19 @@ body = {
|
||||||
"provider": "",
|
"provider": "",
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "assistant", "content": "What can you do? Who are you?"}
|
{"role": "user", "content": "What can you do? Who are you?"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
lines = requests.post(url, json=body, stream=True).iter_lines()
|
response = requests.post(url, json=body, stream=True)
|
||||||
for line in lines:
|
response.raise_for_status()
|
||||||
|
for line in response.iter_lines():
|
||||||
if line.startswith(b"data: "):
|
if line.startswith(b"data: "):
|
||||||
try:
|
try:
|
||||||
print(json.loads(line[6:]).get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
|
json_data = json.loads(line[6:])
|
||||||
|
if json_data.get("error"):
|
||||||
|
print(json_data)
|
||||||
|
break
|
||||||
|
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
print()
|
print()
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import requests
|
import requests
|
||||||
url = "http://localhost:1337/v1/images/generations"
|
url = "http://localhost:1337/v1/images/generations"
|
||||||
body = {
|
body = {
|
||||||
"prompt": "heaven for dogs",
|
"model": "dall-e",
|
||||||
"provider": "OpenaiAccount",
|
"prompt": "hello world user",
|
||||||
"response_format": "b64_json",
|
#"response_format": "b64_json",
|
||||||
}
|
}
|
||||||
data = requests.post(url, json=body, stream=True).json()
|
data = requests.post(url, json=body, stream=True).json()
|
||||||
print(data)
|
print(data)
|
||||||
|
|
@ -219,9 +219,6 @@ def main():
|
||||||
if not pull:
|
if not pull:
|
||||||
print(f"No PR number found")
|
print(f"No PR number found")
|
||||||
exit()
|
exit()
|
||||||
if pull.get_reviews().totalCount > 0 or pull.get_issue_comments().totalCount > 0:
|
|
||||||
print(f"Has already a review")
|
|
||||||
exit()
|
|
||||||
diff = get_diff(pull.diff_url)
|
diff = get_diff(pull.diff_url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error get details: {e.__class__.__name__}: {e}")
|
print(f"Error get details: {e.__class__.__name__}: {e}")
|
||||||
|
|
@ -231,6 +228,9 @@ def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error create review: {e}")
|
print(f"Error create review: {e}")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
if pull.get_reviews().totalCount > 0 or pull.get_issue_comments().totalCount > 0:
|
||||||
|
pull.create_issue_comment(body=review)
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
comments = analyze_code(pull, diff)
|
comments = analyze_code(pull, diff)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from .asyncio import *
|
from .asyncio import *
|
||||||
from .backend import *
|
from .backend import *
|
||||||
from .main import *
|
from .main import *
|
||||||
from .model import *
|
from .model import *
|
||||||
from .client import *
|
from .client import *
|
||||||
from .client import *
|
|
||||||
from .include import *
|
from .include import *
|
||||||
from .integration import *
|
from .integration import *
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,19 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from .mocks import ProviderMock
|
|
||||||
import g4f
|
|
||||||
from g4f.errors import MissingRequirementsError
|
from g4f.errors import MissingRequirementsError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from g4f.gui.server.backend import Backend_Api, get_error_message
|
from g4f.gui.server.backend import Backend_Api
|
||||||
has_requirements = True
|
has_requirements = True
|
||||||
except:
|
except:
|
||||||
has_requirements = False
|
has_requirements = False
|
||||||
|
try:
|
||||||
|
from duckduckgo_search.exceptions import DuckDuckGoSearchException
|
||||||
|
except ImportError:
|
||||||
|
class DuckDuckGoSearchException:
|
||||||
|
pass
|
||||||
|
|
||||||
class TestBackendApi(unittest.TestCase):
|
class TestBackendApi(unittest.TestCase):
|
||||||
|
|
||||||
|
|
@ -31,28 +35,15 @@ class TestBackendApi(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_providers(self):
|
def test_get_providers(self):
|
||||||
response = self.api.get_providers()
|
response = self.api.get_providers()
|
||||||
self.assertIsInstance(response, list)
|
self.assertIsInstance(response, dict)
|
||||||
self.assertTrue(len(response) > 0)
|
self.assertTrue(len(response) > 0)
|
||||||
|
|
||||||
def test_search(self):
|
def test_search(self):
|
||||||
from g4f.gui.server.internet import search
|
from g4f.gui.server.internet import search
|
||||||
try:
|
try:
|
||||||
result = asyncio.run(search("Hello"))
|
result = asyncio.run(search("Hello"))
|
||||||
|
except DuckDuckGoSearchException as e:
|
||||||
|
self.skipTest(e)
|
||||||
except MissingRequirementsError:
|
except MissingRequirementsError:
|
||||||
self.skipTest("search is not installed")
|
self.skipTest("search is not installed")
|
||||||
self.assertEqual(5, len(result))
|
self.assertEqual(5, len(result))
|
||||||
|
|
||||||
class TestUtilityFunctions(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
if not has_requirements:
|
|
||||||
self.skipTest("gui is not installed")
|
|
||||||
|
|
||||||
def test_get_error_message(self):
|
|
||||||
g4f.debug.last_provider = ProviderMock
|
|
||||||
exception = Exception("Message")
|
|
||||||
result = get_error_message(exception)
|
|
||||||
self.assertEqual("ProviderMock: Exception: Message", result)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from g4f.client import Client, ChatCompletion, ChatCompletionChunk
|
from g4f.client import Client, AsyncClient, ChatCompletion, ChatCompletionChunk
|
||||||
from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock
|
from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock
|
||||||
|
|
||||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||||
|
|
@ -8,37 +10,38 @@ DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||||
class AsyncTestPassModel(unittest.IsolatedAsyncioTestCase):
|
class AsyncTestPassModel(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_response(self):
|
async def test_response(self):
|
||||||
client = Client(provider=AsyncGeneratorProviderMock)
|
client = AsyncClient(provider=AsyncGeneratorProviderMock)
|
||||||
response = await client.chat.completions.async_create(DEFAULT_MESSAGES, "")
|
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
|
||||||
self.assertIsInstance(response, ChatCompletion)
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
self.assertEqual("Mock", response.choices[0].message.content)
|
self.assertEqual("Mock", response.choices[0].message.content)
|
||||||
|
|
||||||
async def test_pass_model(self):
|
async def test_pass_model(self):
|
||||||
client = Client(provider=ModelProviderMock)
|
client = AsyncClient(provider=ModelProviderMock)
|
||||||
response = await client.chat.completions.async_create(DEFAULT_MESSAGES, "Hello")
|
response = await client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
|
||||||
self.assertIsInstance(response, ChatCompletion)
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
self.assertEqual("Hello", response.choices[0].message.content)
|
self.assertEqual("Hello", response.choices[0].message.content)
|
||||||
|
|
||||||
async def test_max_tokens(self):
|
async def test_max_tokens(self):
|
||||||
client = Client(provider=YieldProviderMock)
|
client = AsyncClient(provider=YieldProviderMock)
|
||||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
response = await client.chat.completions.async_create(messages, "Hello", max_tokens=1)
|
response = await client.chat.completions.create(messages, "Hello", max_tokens=1)
|
||||||
self.assertIsInstance(response, ChatCompletion)
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
self.assertEqual("How ", response.choices[0].message.content)
|
self.assertEqual("How ", response.choices[0].message.content)
|
||||||
response = await client.chat.completions.async_create(messages, "Hello", max_tokens=2)
|
response = await client.chat.completions.create(messages, "Hello", max_tokens=2)
|
||||||
self.assertIsInstance(response, ChatCompletion)
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
self.assertEqual("How are ", response.choices[0].message.content)
|
self.assertEqual("How are ", response.choices[0].message.content)
|
||||||
|
|
||||||
async def test_max_stream(self):
|
async def test_max_stream(self):
|
||||||
client = Client(provider=YieldProviderMock)
|
client = AsyncClient(provider=YieldProviderMock)
|
||||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
response = await client.chat.completions.async_create(messages, "Hello", stream=True)
|
response = client.chat.completions.create(messages, "Hello", stream=True)
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
chunk: ChatCompletionChunk = chunk
|
||||||
self.assertIsInstance(chunk, ChatCompletionChunk)
|
self.assertIsInstance(chunk, ChatCompletionChunk)
|
||||||
if chunk.choices[0].delta.content is not None:
|
if chunk.choices[0].delta.content is not None:
|
||||||
self.assertIsInstance(chunk.choices[0].delta.content, str)
|
self.assertIsInstance(chunk.choices[0].delta.content, str)
|
||||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
|
||||||
response = await client.chat.completions.async_create(messages, "Hello", stream=True, max_tokens=2)
|
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
|
||||||
response_list = []
|
response_list = []
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
response_list.append(chunk)
|
response_list.append(chunk)
|
||||||
|
|
@ -48,9 +51,56 @@ class AsyncTestPassModel(unittest.IsolatedAsyncioTestCase):
|
||||||
self.assertEqual(chunk.choices[0].delta.content, "You ")
|
self.assertEqual(chunk.choices[0].delta.content, "You ")
|
||||||
|
|
||||||
async def test_stop(self):
|
async def test_stop(self):
|
||||||
|
client = AsyncClient(provider=YieldProviderMock)
|
||||||
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
|
response = await client.chat.completions.create(messages, "Hello", stop=["and"])
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertEqual("How are you?", response.choices[0].message.content)
|
||||||
|
|
||||||
|
class TestPassModel(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_response(self):
|
||||||
|
client = Client(provider=AsyncGeneratorProviderMock)
|
||||||
|
response = client.chat.completions.create(DEFAULT_MESSAGES, "")
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertEqual("Mock", response.choices[0].message.content)
|
||||||
|
|
||||||
|
def test_pass_model(self):
|
||||||
|
client = Client(provider=ModelProviderMock)
|
||||||
|
response = client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertEqual("Hello", response.choices[0].message.content)
|
||||||
|
|
||||||
|
def test_max_tokens(self):
|
||||||
client = Client(provider=YieldProviderMock)
|
client = Client(provider=YieldProviderMock)
|
||||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
response = await client.chat.completions.async_create(messages, "Hello", stop=["and"])
|
response = client.chat.completions.create(messages, "Hello", max_tokens=1)
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertEqual("How ", response.choices[0].message.content)
|
||||||
|
response = client.chat.completions.create(messages, "Hello", max_tokens=2)
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertEqual("How are ", response.choices[0].message.content)
|
||||||
|
|
||||||
|
def test_max_stream(self):
|
||||||
|
client = Client(provider=YieldProviderMock)
|
||||||
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
|
response = client.chat.completions.create(messages, "Hello", stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
self.assertIsInstance(chunk, ChatCompletionChunk)
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
self.assertIsInstance(chunk.choices[0].delta.content, str)
|
||||||
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
|
||||||
|
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
|
||||||
|
response_list = list(response)
|
||||||
|
self.assertEqual(len(response_list), 3)
|
||||||
|
for chunk in response_list:
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
self.assertEqual(chunk.choices[0].delta.content, "You ")
|
||||||
|
|
||||||
|
def test_stop(self):
|
||||||
|
client = Client(provider=YieldProviderMock)
|
||||||
|
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||||
|
response = client.chat.completions.create(messages, "Hello", stop=["and"])
|
||||||
self.assertIsInstance(response, ChatCompletion)
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
self.assertEqual("How are you?", response.choices[0].message.content)
|
self.assertEqual("How are you?", response.choices[0].message.content)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,59 +1,171 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, Dict
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
import random
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .helper import format_prompt
|
from ..image import ImageResponse
|
||||||
|
from ..requests import StreamSession, raise_for_status
|
||||||
from .airforce.AirforceChat import AirforceChat
|
from .airforce.AirforceChat import AirforceChat
|
||||||
from .airforce.AirforceImage import AirforceImage
|
from .airforce.AirforceImage import AirforceImage
|
||||||
|
|
||||||
class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://api.airforce"
|
url = "https://api.airforce"
|
||||||
api_endpoint_completions = AirforceChat.api_endpoint
|
api_endpoint_completions = AirforceChat.api_endpoint
|
||||||
api_endpoint_imagine2 = AirforceImage.api_endpoint
|
api_endpoint_imagine = AirforceImage.api_endpoint
|
||||||
working = True
|
working = True
|
||||||
supports_stream = AirforceChat.supports_stream
|
default_model = "gpt-4o-mini"
|
||||||
supports_system_message = AirforceChat.supports_system_message
|
supports_system_message = True
|
||||||
supports_message_history = AirforceChat.supports_message_history
|
supports_message_history = True
|
||||||
|
text_models = [
|
||||||
default_model = AirforceChat.default_model
|
'gpt-4-turbo',
|
||||||
models = [*AirforceChat.models, *AirforceImage.models]
|
default_model,
|
||||||
|
'llama-3.1-70b-turbo',
|
||||||
|
'llama-3.1-8b-turbo',
|
||||||
|
]
|
||||||
|
image_models = [
|
||||||
|
'flux',
|
||||||
|
'flux-realism',
|
||||||
|
'flux-anime',
|
||||||
|
'flux-3d',
|
||||||
|
'flux-disney',
|
||||||
|
'flux-pixel',
|
||||||
|
'flux-4o',
|
||||||
|
'any-dark',
|
||||||
|
]
|
||||||
|
models = [
|
||||||
|
*text_models,
|
||||||
|
*image_models,
|
||||||
|
]
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
**AirforceChat.model_aliases,
|
"gpt-4o": "chatgpt-4o-latest",
|
||||||
**AirforceImage.model_aliases
|
"llama-3.1-70b": "llama-3.1-70b-turbo",
|
||||||
|
"llama-3.1-8b": "llama-3.1-8b-turbo",
|
||||||
|
"gpt-4": "gpt-4-turbo",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str) -> str:
|
def create_async_generator(
|
||||||
if model in cls.models:
|
cls,
|
||||||
return model
|
model: str,
|
||||||
elif model in cls.model_aliases:
|
messages: Messages,
|
||||||
return cls.model_aliases[model]
|
proxy: str = None,
|
||||||
else:
|
seed: int = None,
|
||||||
return cls.default_model
|
size: str = "1:1",
|
||||||
|
stream: bool = False,
|
||||||
@classmethod
|
**kwargs
|
||||||
async def create_async_generator(cls, model: str, messages: Messages, **kwargs) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
|
||||||
provider = AirforceChat if model in AirforceChat.text_models else AirforceImage
|
if model in cls.image_models:
|
||||||
|
return cls._generate_image(model, messages, proxy, seed, size)
|
||||||
|
else:
|
||||||
|
return cls._generate_text(model, messages, proxy, stream, **kwargs)
|
||||||
|
|
||||||
if model not in provider.models:
|
@classmethod
|
||||||
raise ValueError(f"Unsupported model: {model}")
|
async def _generate_image(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
proxy: str = None,
|
||||||
|
seed: int = None,
|
||||||
|
size: str = "1:1",
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
headers = {
|
||||||
|
"accept": "*/*",
|
||||||
|
"accept-language": "en-US,en;q=0.9",
|
||||||
|
"cache-control": "no-cache",
|
||||||
|
"origin": "https://llmplayground.net",
|
||||||
|
"user-agent": "Mozilla/5.0"
|
||||||
|
}
|
||||||
|
if seed is None:
|
||||||
|
seed = random.randint(0, 100000)
|
||||||
|
prompt = messages[-1]['content']
|
||||||
|
|
||||||
# Get the signature of the provider's create_async_generator method
|
async with StreamSession(headers=headers, proxy=proxy) as session:
|
||||||
sig = inspect.signature(provider.create_async_generator)
|
params = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"size": size,
|
||||||
|
"seed": seed
|
||||||
|
}
|
||||||
|
async with session.get(f"{cls.api_endpoint_imagine}", params=params) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
|
content_type = response.headers.get('Content-Type', '').lower()
|
||||||
|
|
||||||
# Filter kwargs to only include parameters that the provider's method accepts
|
if 'application/json' in content_type:
|
||||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
raise RuntimeError(await response.json().get("error", {}).get("message"))
|
||||||
|
elif 'image' in content_type:
|
||||||
|
image_data = b""
|
||||||
|
async for chunk in response.iter_content():
|
||||||
|
if chunk:
|
||||||
|
image_data += chunk
|
||||||
|
image_url = f"{cls.api_endpoint_imagine}?model={model}&prompt={prompt}&size={size}&seed={seed}"
|
||||||
|
yield ImageResponse(images=image_url, alt=prompt)
|
||||||
|
|
||||||
# Add model and messages to filtered_kwargs
|
@classmethod
|
||||||
filtered_kwargs['model'] = model
|
async def _generate_text(
|
||||||
filtered_kwargs['messages'] = messages
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
proxy: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 1,
|
||||||
|
top_p: float = 1,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
headers = {
|
||||||
|
"accept": "*/*",
|
||||||
|
"accept-language": "en-US,en;q=0.9",
|
||||||
|
"authorization": "Bearer missing api key",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"user-agent": "Mozilla/5.0"
|
||||||
|
}
|
||||||
|
async with StreamSession(headers=headers, proxy=proxy) as session:
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
async with session.post(cls.api_endpoint_completions, json=data) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
|
content_type = response.headers.get('Content-Type', '').lower()
|
||||||
|
if 'application/json' in content_type:
|
||||||
|
json_data = await response.json()
|
||||||
|
if json_data.get("model") == "error":
|
||||||
|
raise RuntimeError(json_data['choices'][0]['message'].get('content', ''))
|
||||||
|
if stream:
|
||||||
|
async for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8').strip()
|
||||||
|
if line.startswith("data: ") and line != "data: [DONE]":
|
||||||
|
json_data = json.loads(line[6:])
|
||||||
|
content = json_data['choices'][0]['delta'].get('content', '')
|
||||||
|
if content:
|
||||||
|
yield cls._filter_content(content)
|
||||||
|
else:
|
||||||
|
json_data = await response.json()
|
||||||
|
content = json_data['choices'][0]['message']['content']
|
||||||
|
yield cls._filter_content(content)
|
||||||
|
|
||||||
async for result in provider.create_async_generator(**filtered_kwargs):
|
@classmethod
|
||||||
yield result
|
def _filter_content(cls, part_response: str) -> str:
|
||||||
|
part_response = re.sub(
|
||||||
|
r"One message exceeds the \d+chars per message limit\..+https:\/\/discord\.com\/invite\/\S+",
|
||||||
|
'',
|
||||||
|
part_response
|
||||||
|
)
|
||||||
|
|
||||||
|
part_response = re.sub(
|
||||||
|
r"Rate limit \(\d+\/minute\) exceeded\. Join our discord for more: .+https:\/\/discord\.com\/invite\/\S+",
|
||||||
|
'',
|
||||||
|
part_response
|
||||||
|
)
|
||||||
|
return part_response
|
||||||
|
|
@ -2,18 +2,18 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from aiohttp import ClientSession, ClientTimeout, ClientResponseError
|
|
||||||
|
|
||||||
from ...typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from ..helper import format_prompt
|
from ..image import ImageResponse
|
||||||
from ...image import ImageResponse
|
from ..requests import StreamSession, raise_for_status
|
||||||
|
from ..errors import ResponseStatusError
|
||||||
|
|
||||||
class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://amigochat.io/chat/"
|
url = "https://amigochat.io/chat/"
|
||||||
chat_api_endpoint = "https://api.amigochat.io/v1/chat/completions"
|
chat_api_endpoint = "https://api.amigochat.io/v1/chat/completions"
|
||||||
image_api_endpoint = "https://api.amigochat.io/v1/images/generations"
|
image_api_endpoint = "https://api.amigochat.io/v1/images/generations"
|
||||||
working = False
|
working = True
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
|
@ -66,15 +66,6 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
'dalle-e-3': "dalle-three",
|
'dalle-e-3': "dalle-three",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model(cls, model: str) -> str:
|
|
||||||
if model in cls.models:
|
|
||||||
return model
|
|
||||||
elif model in cls.model_aliases:
|
|
||||||
return cls.model_aliases[model]
|
|
||||||
else:
|
|
||||||
return cls.default_model
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_personaId(cls, model: str) -> str:
|
def get_personaId(cls, model: str) -> str:
|
||||||
return cls.persona_ids[model]
|
return cls.persona_ids[model]
|
||||||
|
|
@ -86,6 +77,12 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
timeout: int = 300,
|
||||||
|
frequency_penalty: float = 0,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
presence_penalty: float = 0,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
top_p: float = 0.95,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
|
@ -113,31 +110,25 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"x-device-language": "en-US",
|
"x-device-language": "en-US",
|
||||||
"x-device-platform": "web",
|
"x-device-platform": "web",
|
||||||
"x-device-uuid": device_uuid,
|
"x-device-uuid": device_uuid,
|
||||||
"x-device-version": "1.0.32"
|
"x-device-version": "1.0.41"
|
||||||
}
|
}
|
||||||
|
|
||||||
async with ClientSession(headers=headers) as session:
|
async with StreamSession(headers=headers, proxy=proxy) as session:
|
||||||
if model in cls.chat_models:
|
if model not in cls.image_models:
|
||||||
# Chat completion
|
|
||||||
data = {
|
data = {
|
||||||
"messages": [{"role": m["role"], "content": m["content"]} for m in messages],
|
"messages": messages,
|
||||||
"model": model,
|
"model": model,
|
||||||
"personaId": cls.get_personaId(model),
|
"personaId": cls.get_personaId(model),
|
||||||
"frequency_penalty": 0,
|
"frequency_penalty": frequency_penalty,
|
||||||
"max_tokens": 4000,
|
"max_tokens": max_tokens,
|
||||||
"presence_penalty": 0,
|
"presence_penalty": presence_penalty,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"temperature": 0.5,
|
"temperature": temperature,
|
||||||
"top_p": 0.95
|
"top_p": top_p
|
||||||
}
|
}
|
||||||
|
async with session.post(cls.chat_api_endpoint, json=data, timeout=timeout) as response:
|
||||||
timeout = ClientTimeout(total=300) # 5 minutes timeout
|
await raise_for_status(response)
|
||||||
async with session.post(cls.chat_api_endpoint, json=data, proxy=proxy, timeout=timeout) as response:
|
async for line in response.iter_lines():
|
||||||
if response.status not in (200, 201):
|
|
||||||
error_text = await response.text()
|
|
||||||
raise Exception(f"Error {response.status}: {error_text}")
|
|
||||||
|
|
||||||
async for line in response.content:
|
|
||||||
line = line.decode('utf-8').strip()
|
line = line.decode('utf-8').strip()
|
||||||
if line.startswith('data: '):
|
if line.startswith('data: '):
|
||||||
if line == 'data: [DONE]':
|
if line == 'data: [DONE]':
|
||||||
|
|
@ -164,11 +155,9 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"model": model,
|
"model": model,
|
||||||
"personaId": cls.get_personaId(model)
|
"personaId": cls.get_personaId(model)
|
||||||
}
|
}
|
||||||
async with session.post(cls.image_api_endpoint, json=data, proxy=proxy) as response:
|
async with session.post(cls.image_api_endpoint, json=data) as response:
|
||||||
response.raise_for_status()
|
await raise_for_status(response)
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
|
|
||||||
if "data" in response_data:
|
if "data" in response_data:
|
||||||
image_urls = []
|
image_urls = []
|
||||||
for item in response_data["data"]:
|
for item in response_data["data"]:
|
||||||
|
|
@ -179,10 +168,8 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
yield ImageResponse(image_urls, prompt)
|
yield ImageResponse(image_urls, prompt)
|
||||||
else:
|
else:
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
break
|
break
|
||||||
|
except (ResponseStatusError, Exception) as e:
|
||||||
except (ClientResponseError, Exception) as e:
|
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count >= max_retries:
|
if retry_count >= max_retries:
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -1,72 +1,52 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import cloudscraper
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages, Cookies
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
|
||||||
from .helper import format_prompt
|
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||||
|
|
||||||
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
label = "Cloudflare AI"
|
label = "Cloudflare AI"
|
||||||
url = "https://playground.ai.cloudflare.com"
|
url = "https://playground.ai.cloudflare.com"
|
||||||
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
|
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
|
||||||
|
models_url = "https://playground.ai.cloudflare.com/api/models"
|
||||||
working = True
|
working = True
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
default_model = "@cf/meta/llama-3.1-8b-instruct"
|
||||||
default_model = '@cf/meta/llama-3.1-8b-instruct-awq'
|
|
||||||
models = [
|
|
||||||
'@cf/meta/llama-2-7b-chat-fp16',
|
|
||||||
'@cf/meta/llama-2-7b-chat-int8',
|
|
||||||
|
|
||||||
'@cf/meta/llama-3-8b-instruct',
|
|
||||||
'@cf/meta/llama-3-8b-instruct-awq',
|
|
||||||
'@hf/meta-llama/meta-llama-3-8b-instruct',
|
|
||||||
|
|
||||||
default_model,
|
|
||||||
'@cf/meta/llama-3.1-8b-instruct-fp8',
|
|
||||||
|
|
||||||
'@cf/meta/llama-3.2-1b-instruct',
|
|
||||||
|
|
||||||
'@hf/mistral/mistral-7b-instruct-v0.2',
|
|
||||||
|
|
||||||
'@cf/qwen/qwen1.5-7b-chat-awq',
|
|
||||||
|
|
||||||
'@cf/defog/sqlcoder-7b-2',
|
|
||||||
]
|
|
||||||
|
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
|
"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
|
||||||
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
|
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
|
||||||
|
|
||||||
"llama-3-8b": "@cf/meta/llama-3-8b-instruct",
|
"llama-3-8b": "@cf/meta/llama-3-8b-instruct",
|
||||||
"llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
|
"llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
|
||||||
"llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
|
"llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
|
||||||
|
|
||||||
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
|
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
|
||||||
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
|
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
|
||||||
|
|
||||||
"llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
|
"llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
|
||||||
|
|
||||||
"qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
|
"qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
|
||||||
|
|
||||||
#"sqlcoder-7b": "@cf/defog/sqlcoder-7b-2",
|
|
||||||
}
|
}
|
||||||
|
_args: dict = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str) -> str:
|
def get_models(cls) -> str:
|
||||||
if model in cls.models:
|
if not cls.models:
|
||||||
return model
|
if cls._args is None:
|
||||||
elif model in cls.model_aliases:
|
get_running_loop(check_nested=True)
|
||||||
return cls.model_aliases[model]
|
args = get_args_from_nodriver(cls.url, cookies={
|
||||||
else:
|
'__cf_bm': uuid.uuid4().hex,
|
||||||
return cls.default_model
|
})
|
||||||
|
cls._args = asyncio.run(args)
|
||||||
|
with Session(**cls._args) as session:
|
||||||
|
response = session.get(cls.models_url)
|
||||||
|
raise_for_status(response)
|
||||||
|
json_data = response.json()
|
||||||
|
cls.models = [model.get("name") for model in json_data.get("models")]
|
||||||
|
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
|
||||||
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
|
|
@ -75,76 +55,34 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
|
cookies: Cookies = None,
|
||||||
|
timeout: int = 300,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
if cls._args is None:
|
||||||
headers = {
|
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
|
||||||
'Accept': 'text/event-stream',
|
|
||||||
'Accept-Language': 'en-US,en;q=0.9',
|
|
||||||
'Cache-Control': 'no-cache',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Origin': cls.url,
|
|
||||||
'Pragma': 'no-cache',
|
|
||||||
'Referer': f'{cls.url}/',
|
|
||||||
'Sec-Ch-Ua': '"Chromium";v="129", "Not=A?Brand";v="8"',
|
|
||||||
'Sec-Ch-Ua-Mobile': '?0',
|
|
||||||
'Sec-Ch-Ua-Platform': '"Linux"',
|
|
||||||
'Sec-Fetch-Dest': 'empty',
|
|
||||||
'Sec-Fetch-Mode': 'cors',
|
|
||||||
'Sec-Fetch-Site': 'same-origin',
|
|
||||||
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36',
|
|
||||||
}
|
|
||||||
|
|
||||||
cookies = {
|
|
||||||
'__cf_bm': uuid.uuid4().hex,
|
|
||||||
}
|
|
||||||
|
|
||||||
scraper = cloudscraper.create_scraper()
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": messages,
|
||||||
{"role": "user", "content": format_prompt(messages)}
|
|
||||||
],
|
|
||||||
"lora": None,
|
"lora": None,
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stream": True
|
"stream": True
|
||||||
}
|
}
|
||||||
|
async with StreamSession(**cls._args) as session:
|
||||||
max_retries = 3
|
async with session.post(
|
||||||
full_response = ""
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = scraper.post(
|
|
||||||
cls.api_endpoint,
|
cls.api_endpoint,
|
||||||
headers=headers,
|
|
||||||
cookies=cookies,
|
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
) as response:
|
||||||
proxies={'http': proxy, 'https': proxy} if proxy else None
|
await raise_for_status(response)
|
||||||
)
|
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
|
||||||
|
async for line in response.iter_lines():
|
||||||
if response.status_code == 403:
|
|
||||||
await asyncio.sleep(2 ** attempt)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
for line in response.iter_lines():
|
|
||||||
if line.startswith(b'data: '):
|
if line.startswith(b'data: '):
|
||||||
if line == b'data: [DONE]':
|
if line == b'data: [DONE]':
|
||||||
if full_response:
|
|
||||||
yield full_response
|
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
content = json.loads(line[6:].decode('utf-8'))
|
content = json.loads(line[6:].decode())
|
||||||
if 'response' in content and content['response'] != '</s>':
|
if content.get("response") and content.get("response") != '</s>':
|
||||||
yield content['response']
|
yield content['response']
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
raise
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,31 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession, BaseConnector
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation
|
||||||
from .helper import format_prompt
|
from .helper import format_prompt
|
||||||
|
from ..requests.aiohttp import get_connector
|
||||||
|
from ..requests.raise_for_status import raise_for_status
|
||||||
|
from .. import debug
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
{"model":"gpt-4o","modelName":"GPT-4o","modelVariant":None,"modelStyleId":"gpt-4o-mini","createdBy":"OpenAI","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"4"},
|
||||||
|
{"model":"gpt-4o-mini","modelName":"GPT-4o","modelVariant":"mini","modelStyleId":"gpt-4o-mini","createdBy":"OpenAI","moderationLevel":"HIGH","isAvailable":0,"inputCharLimit":16e3,"settingId":"3"},
|
||||||
|
{"model":"claude-3-5-sonnet-20240620","modelName":"Claude 3.5","modelVariant":"Sonnet","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"7"},
|
||||||
|
{"model":"claude-3-opus-20240229","modelName":"Claude 3","modelVariant":"Opus","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"2"},
|
||||||
|
{"model":"claude-3-haiku-20240307","modelName":"Claude 3","modelVariant":"Haiku","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":0,"inputCharLimit":16e3,"settingId":"1"},
|
||||||
|
{"model":"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo","modelName":"Llama 3.1","modelVariant":"70B","modelStyleId":"llama-3","createdBy":"Meta","moderationLevel":"MEDIUM","isAvailable":0,"isOpenSource":0,"inputCharLimit":16e3,"settingId":"5"},
|
||||||
|
{"model":"mistralai/Mixtral-8x7B-Instruct-v0.1","modelName":"Mixtral","modelVariant":"8x7B","modelStyleId":"mixtral","createdBy":"Mistral AI","moderationLevel":"LOW","isAvailable":0,"isOpenSource":0,"inputCharLimit":16e3,"settingId":"6"}
|
||||||
|
]
|
||||||
|
|
||||||
|
class Conversation(BaseConversation):
|
||||||
|
vqd: str = None
|
||||||
|
message_history: Messages = []
|
||||||
|
|
||||||
|
def __init__(self, model: str):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
url = "https://duckduckgo.com"
|
url = "https://duckduckgo.com"
|
||||||
|
|
@ -18,81 +37,74 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
|
||||||
default_model = "gpt-4o-mini"
|
default_model = "gpt-4o-mini"
|
||||||
models = [
|
models = [model.get("model") for model in MODELS]
|
||||||
"gpt-4o-mini",
|
|
||||||
"claude-3-haiku-20240307",
|
|
||||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
||||||
]
|
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"claude-3-haiku": "claude-3-haiku-20240307",
|
"claude-3-haiku": "claude-3-haiku-20240307",
|
||||||
"llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
"llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
"gpt-4": "gpt-4o-mini"
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str) -> str:
|
async def get_vqd(cls, proxy: str, connector: BaseConnector = None):
|
||||||
return cls.model_aliases.get(model, model) if model in cls.model_aliases else cls.default_model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_vqd(cls):
|
|
||||||
status_url = "https://duckduckgo.com/duckchat/v1/status"
|
status_url = "https://duckduckgo.com/duckchat/v1/status"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
|
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
|
||||||
'Accept': 'text/event-stream',
|
'Accept': 'text/event-stream',
|
||||||
'x-vqd-accept': '1'
|
'x-vqd-accept': '1'
|
||||||
}
|
}
|
||||||
|
async with aiohttp.ClientSession(connector=get_connector(connector, proxy)) as session:
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
async with session.get(status_url, headers=headers) as response:
|
async with session.get(status_url, headers=headers) as response:
|
||||||
if response.status == 200:
|
await raise_for_status(response)
|
||||||
return response.headers.get("x-vqd-4")
|
return response.headers.get("x-vqd-4")
|
||||||
else:
|
|
||||||
print(f"Error: Status code {response.status}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error getting VQD: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async_generator(
|
async def create_async_generator(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
conversation: dict = None,
|
conversation: Conversation = None,
|
||||||
|
return_conversation: bool = False,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
|
connector: BaseConnector = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
|
||||||
|
is_new_conversation = False
|
||||||
|
if conversation is None:
|
||||||
|
conversation = Conversation(model)
|
||||||
|
is_new_conversation = True
|
||||||
|
debug.last_model = model
|
||||||
|
if conversation.vqd is None:
|
||||||
|
conversation.vqd = await cls.get_vqd(proxy, connector)
|
||||||
|
if not conversation.vqd:
|
||||||
|
raise Exception("Failed to obtain VQD token")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'accept': 'text/event-stream',
|
'accept': 'text/event-stream',
|
||||||
'content-type': 'application/json',
|
'content-type': 'application/json',
|
||||||
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
|
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
|
||||||
|
'x-vqd-4': conversation.vqd,
|
||||||
}
|
}
|
||||||
|
async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
|
||||||
vqd = conversation.get('vqd') if conversation else await cls.get_vqd()
|
if is_new_conversation:
|
||||||
if not vqd:
|
conversation.message_history = [{"role": "user", "content": format_prompt(messages)}]
|
||||||
raise Exception("Failed to obtain VQD token")
|
|
||||||
|
|
||||||
headers['x-vqd-4'] = vqd
|
|
||||||
|
|
||||||
if conversation:
|
|
||||||
message_history = conversation.get('messages', [])
|
|
||||||
message_history.append({"role": "user", "content": format_prompt(messages)})
|
|
||||||
else:
|
else:
|
||||||
message_history = [{"role": "user", "content": format_prompt(messages)}]
|
conversation.message_history = [
|
||||||
|
*conversation.message_history,
|
||||||
async with ClientSession(headers=headers) as session:
|
messages[-2],
|
||||||
|
messages[-1]
|
||||||
|
]
|
||||||
|
if return_conversation:
|
||||||
|
yield conversation
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": conversation.model,
|
||||||
"messages": message_history
|
"messages": conversation.message_history
|
||||||
}
|
}
|
||||||
|
async with session.post(cls.api_endpoint, json=data) as response:
|
||||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
conversation.vqd = response.headers.get("x-vqd-4")
|
||||||
response.raise_for_status()
|
await raise_for_status(response)
|
||||||
async for line in response.content:
|
async for line in response.content:
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,13 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
from curl_cffi import requests as cf_reqs
|
from curl_cffi import requests as cf_reqs
|
||||||
|
has_curl_cffi = True
|
||||||
|
except ImportError:
|
||||||
|
has_curl_cffi = False
|
||||||
from ..typing import CreateResult, Messages
|
from ..typing import CreateResult, Messages
|
||||||
|
from ..errors import MissingRequirementsError
|
||||||
from .base_provider import ProviderModelMixin, AbstractProvider
|
from .base_provider import ProviderModelMixin, AbstractProvider
|
||||||
from .helper import format_prompt
|
from .helper import format_prompt
|
||||||
|
|
||||||
|
|
@ -55,6 +60,8 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||||
stream: bool,
|
stream: bool,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> CreateResult:
|
) -> CreateResult:
|
||||||
|
if not has_curl_cffi:
|
||||||
|
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
|
|
||||||
if model in cls.models:
|
if model in cls.models:
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
"gpt-4o": "gpt-4o-2024-08-06",
|
"gpt-4o": "gpt-4o-2024-08-06",
|
||||||
|
|
||||||
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-4": "gpt-4-0613",
|
"gpt-4": "gpt-4o-mini-free",
|
||||||
|
|
||||||
"claude-3-opus": "claude-3-opus-20240229",
|
"claude-3-opus": "claude-3-opus-20240229",
|
||||||
"claude-3-opus": "claude-3-opus-20240229-aws",
|
"claude-3-opus": "claude-3-opus-20240229-aws",
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,21 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ..typing import CreateResult, Messages
|
from ..typing import AsyncResult, Messages, Cookies
|
||||||
from .base_provider import AbstractProvider, format_prompt
|
from .base_provider import AsyncGeneratorProvider, format_prompt
|
||||||
from ..requests import Session, get_session_from_browser, raise_for_status
|
from ..requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
|
||||||
|
|
||||||
class Pi(AbstractProvider):
|
class Pi(AsyncGeneratorProvider):
|
||||||
url = "https://pi.ai/talk"
|
url = "https://pi.ai/talk"
|
||||||
working = True
|
working = True
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
_session = None
|
|
||||||
default_model = "pi"
|
default_model = "pi"
|
||||||
models = [default_model]
|
models = [default_model]
|
||||||
|
_headers: dict = None
|
||||||
|
_cookies: Cookies = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_completion(
|
async def create_async_generator(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
|
|
@ -23,49 +24,52 @@ class Pi(AbstractProvider):
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
timeout: int = 180,
|
timeout: int = 180,
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
webdriver: WebDriver = None,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> CreateResult:
|
) -> AsyncResult:
|
||||||
if cls._session is None:
|
if cls._headers is None:
|
||||||
cls._session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout)
|
args = await get_args_from_nodriver(cls.url, proxy=proxy, timeout=timeout)
|
||||||
|
cls._cookies = args.get("cookies", {})
|
||||||
|
cls._headers = args.get("headers")
|
||||||
|
async with StreamSession(headers=cls._headers, cookies=cls._cookies, proxy=proxy) as session:
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
conversation_id = cls.start_conversation(cls._session)
|
conversation_id = await cls.start_conversation(session)
|
||||||
prompt = format_prompt(messages)
|
prompt = format_prompt(messages)
|
||||||
else:
|
else:
|
||||||
prompt = messages[-1]["content"]
|
prompt = messages[-1]["content"]
|
||||||
answer = cls.ask(cls._session, prompt, conversation_id)
|
answer = cls.ask(session, prompt, conversation_id)
|
||||||
for line in answer:
|
async for line in answer:
|
||||||
if "text" in line:
|
if "text" in line:
|
||||||
yield line["text"]
|
yield line["text"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def start_conversation(cls, session: Session) -> str:
|
async def start_conversation(cls, session: StreamSession) -> str:
|
||||||
response = session.post('https://pi.ai/api/chat/start', data="{}", headers={
|
async with session.post('https://pi.ai/api/chat/start', data="{}", headers={
|
||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'x-api-version': '3'
|
'x-api-version': '3'
|
||||||
})
|
}) as response:
|
||||||
raise_for_status(response)
|
await raise_for_status(response)
|
||||||
return response.json()['conversations'][0]['sid']
|
return (await response.json())['conversations'][0]['sid']
|
||||||
|
|
||||||
def get_chat_history(session: Session, conversation_id: str):
|
async def get_chat_history(session: StreamSession, conversation_id: str):
|
||||||
params = {
|
params = {
|
||||||
'conversation': conversation_id,
|
'conversation': conversation_id,
|
||||||
}
|
}
|
||||||
response = session.get('https://pi.ai/api/chat/history', params=params)
|
async with session.get('https://pi.ai/api/chat/history', params=params) as response:
|
||||||
raise_for_status(response)
|
await raise_for_status(response)
|
||||||
return response.json()
|
return await response.json()
|
||||||
|
|
||||||
def ask(session: Session, prompt: str, conversation_id: str):
|
@classmethod
|
||||||
|
async def ask(cls, session: StreamSession, prompt: str, conversation_id: str):
|
||||||
json_data = {
|
json_data = {
|
||||||
'text': prompt,
|
'text': prompt,
|
||||||
'conversation': conversation_id,
|
'conversation': conversation_id,
|
||||||
'mode': 'BASE',
|
'mode': 'BASE',
|
||||||
}
|
}
|
||||||
response = session.post('https://pi.ai/api/chat', json=json_data, stream=True)
|
async with session.post('https://pi.ai/api/chat', json=json_data) as response:
|
||||||
raise_for_status(response)
|
await raise_for_status(response)
|
||||||
for line in response.iter_lines():
|
cls._cookies = merge_cookies(cls._cookies, response)
|
||||||
|
async for line in response.iter_lines():
|
||||||
if line.startswith(b'data: {"text":'):
|
if line.startswith(b'data: {"text":'):
|
||||||
yield json.loads(line.split(b'data: ')[1])
|
yield json.loads(line.split(b'data: ')[1])
|
||||||
elif line.startswith(b'data: {"title":'):
|
elif line.startswith(b'data: {"title":'):
|
||||||
yield json.loads(line.split(b'data: ')[1])
|
yield json.loads(line.split(b'data: ')[1])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from .local import *
|
||||||
|
|
||||||
from .AIUncensored import AIUncensored
|
from .AIUncensored import AIUncensored
|
||||||
from .Airforce import Airforce
|
from .Airforce import Airforce
|
||||||
|
from .AmigoChat import AmigoChat
|
||||||
from .Bing import Bing
|
from .Bing import Bing
|
||||||
from .Blackbox import Blackbox
|
from .Blackbox import Blackbox
|
||||||
from .ChatGpt import ChatGpt
|
from .ChatGpt import ChatGpt
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ async def create_images(session: ClientSession, prompt: str, timeout: int = TIME
|
||||||
|
|
||||||
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||||
redirect_url = f"{BING_URL}{redirect_url}"
|
redirect_url = f"{BING_URL}{redirect_url}"
|
||||||
request_id = redirect_url.split("id=")[1]
|
request_id = redirect_url.split("id=")[-1]
|
||||||
async with session.get(redirect_url) as response:
|
async with session.get(redirect_url) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
async with session.get(url, headers=headers) as response:
|
async with session.get(url, headers=headers) as response:
|
||||||
cls._update_request_args(session)
|
cls._update_request_args(session)
|
||||||
if response.status == 401:
|
if response.status == 401:
|
||||||
raise MissingAuthError('Add a "api_key" or a .har file' if cls._api_key is None else "Invalid api key")
|
raise MissingAuthError('Add a .har file for OpenaiChat' if cls._api_key is None else "Invalid api key")
|
||||||
await raise_for_status(response)
|
await raise_for_status(response)
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if "categories" in data:
|
if "categories" in data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ from .AI365VIP import AI365VIP
|
||||||
from .AIChatFree import AIChatFree
|
from .AIChatFree import AIChatFree
|
||||||
from .AiChatOnline import AiChatOnline
|
from .AiChatOnline import AiChatOnline
|
||||||
from .AiChats import AiChats
|
from .AiChats import AiChats
|
||||||
from .AmigoChat import AmigoChat
|
|
||||||
from .Aura import Aura
|
from .Aura import Aura
|
||||||
from .Chatgpt4o import Chatgpt4o
|
from .Chatgpt4o import Chatgpt4o
|
||||||
from .ChatgptFree import ChatgptFree
|
from .ChatgptFree import ChatgptFree
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import json
|
||||||
import base64
|
import base64
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
def generate_proof_token(required: bool, seed: str = "", difficulty: str = "", user_agent: str = None, proofTokens: list = None):
|
def generate_proof_token(required: bool, seed: str = "", difficulty: str = "", user_agent: str = None, proofTokens: list = None):
|
||||||
if not required:
|
if not required:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,12 @@ from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZE
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Union, Optional, Iterator
|
from typing import Union, Optional
|
||||||
|
|
||||||
import g4f
|
import g4f
|
||||||
import g4f.debug
|
import g4f.debug
|
||||||
from g4f.client import Client, ChatCompletion, ChatCompletionChunk, ImagesResponse
|
from g4f.client import AsyncClient, ChatCompletion
|
||||||
|
from g4f.client.helper import filter_none
|
||||||
from g4f.typing import Messages
|
from g4f.typing import Messages
|
||||||
from g4f.cookies import read_cookie_files
|
from g4f.cookies import read_cookie_files
|
||||||
|
|
||||||
|
|
@ -47,6 +48,10 @@ def create_app(g4f_api_key: str = None):
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
def create_app_debug(g4f_api_key: str = None):
|
||||||
|
g4f.debug.logging = True
|
||||||
|
return create_app(g4f_api_key)
|
||||||
|
|
||||||
class ChatCompletionsConfig(BaseModel):
|
class ChatCompletionsConfig(BaseModel):
|
||||||
messages: Messages
|
messages: Messages
|
||||||
model: str
|
model: str
|
||||||
|
|
@ -62,13 +67,19 @@ class ChatCompletionsConfig(BaseModel):
|
||||||
class ImageGenerationConfig(BaseModel):
|
class ImageGenerationConfig(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
|
provider: Optional[str] = None
|
||||||
response_format: str = "url"
|
response_format: str = "url"
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
proxy: Optional[str] = None
|
||||||
|
|
||||||
class AppConfig:
|
class AppConfig:
|
||||||
ignored_providers: Optional[list[str]] = None
|
ignored_providers: Optional[list[str]] = None
|
||||||
g4f_api_key: Optional[str] = None
|
g4f_api_key: Optional[str] = None
|
||||||
ignore_cookie_files: bool = False
|
ignore_cookie_files: bool = False
|
||||||
defaults: dict = {}
|
model: str = None,
|
||||||
|
provider: str = None
|
||||||
|
image_provider: str = None
|
||||||
|
proxy: str = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_config(cls, **data):
|
def set_config(cls, **data):
|
||||||
|
|
@ -84,7 +95,7 @@ def set_list_ignored_providers(ignored: list[str]):
|
||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
|
def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
|
||||||
self.app = app
|
self.app = app
|
||||||
self.client = Client()
|
self.client = AsyncClient()
|
||||||
self.g4f_api_key = g4f_api_key
|
self.g4f_api_key = g4f_api_key
|
||||||
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
|
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
|
||||||
|
|
||||||
|
|
@ -133,8 +144,8 @@ class Api:
|
||||||
@self.app.get("/v1")
|
@self.app.get("/v1")
|
||||||
async def read_root_v1():
|
async def read_root_v1():
|
||||||
return HTMLResponse('g4f API: Go to '
|
return HTMLResponse('g4f API: Go to '
|
||||||
'<a href="/v1/chat/completions">chat/completions</a>, '
|
'<a href="/v1/models">models</a>, '
|
||||||
'<a href="/v1/models">models</a>, or '
|
'<a href="/v1/chat/completions">chat/completions</a>, or '
|
||||||
'<a href="/v1/images/generate">images/generate</a>.')
|
'<a href="/v1/images/generate">images/generate</a>.')
|
||||||
|
|
||||||
@self.app.get("/v1/models")
|
@self.app.get("/v1/models")
|
||||||
|
|
@ -177,31 +188,24 @@ class Api:
|
||||||
|
|
||||||
# Create the completion response
|
# Create the completion response
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
|
**filter_none(
|
||||||
**{
|
**{
|
||||||
**AppConfig.defaults,
|
"model": AppConfig.model,
|
||||||
|
"provider": AppConfig.provider,
|
||||||
|
"proxy": AppConfig.proxy,
|
||||||
**config.dict(exclude_none=True),
|
**config.dict(exclude_none=True),
|
||||||
},
|
},
|
||||||
ignored=AppConfig.ignored_providers
|
ignored=AppConfig.ignored_providers
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the response is synchronous or asynchronous
|
|
||||||
if isinstance(response, ChatCompletion):
|
|
||||||
# Synchronous response
|
|
||||||
return JSONResponse(response.to_json())
|
|
||||||
|
|
||||||
if not config.stream:
|
if not config.stream:
|
||||||
# If the response is an iterator but not streaming, collect the result
|
response: ChatCompletion = await response
|
||||||
response_list = list(response) if isinstance(response, Iterator) else [response]
|
return JSONResponse(response.to_json())
|
||||||
return JSONResponse(response_list[0].to_json())
|
|
||||||
|
|
||||||
# Streaming response
|
|
||||||
async def async_generator(sync_gen):
|
|
||||||
for item in sync_gen:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
async def streaming():
|
async def streaming():
|
||||||
try:
|
try:
|
||||||
async for chunk in async_generator(response):
|
async for chunk in response:
|
||||||
yield f"data: {json.dumps(chunk.to_json())}\n\n"
|
yield f"data: {json.dumps(chunk.to_json())}\n\n"
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
pass
|
pass
|
||||||
|
|
@ -217,30 +221,38 @@ class Api:
|
||||||
return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
|
return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
|
||||||
|
|
||||||
@self.app.post("/v1/images/generate")
|
@self.app.post("/v1/images/generate")
|
||||||
|
@self.app.post("/v1/images/generations")
|
||||||
async def generate_image(config: ImageGenerationConfig):
|
async def generate_image(config: ImageGenerationConfig):
|
||||||
try:
|
try:
|
||||||
response: ImagesResponse = await self.client.images.async_generate(
|
response = await self.client.images.generate(
|
||||||
prompt=config.prompt,
|
prompt=config.prompt,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
response_format=config.response_format
|
provider=AppConfig.image_provider if config.provider is None else config.provider,
|
||||||
|
**filter_none(
|
||||||
|
response_format = config.response_format,
|
||||||
|
api_key = config.api_key,
|
||||||
|
proxy = config.proxy
|
||||||
)
|
)
|
||||||
# Convert Image objects to dictionaries
|
)
|
||||||
response_data = [{"url": image.url, "b64_json": image.b64_json} for image in response.data]
|
return JSONResponse(response.to_json())
|
||||||
return JSONResponse({"data": response_data})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
|
return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
|
||||||
|
|
||||||
@self.app.post("/v1/completions")
|
@self.app.post("/v1/completions")
|
||||||
async def completions():
|
async def completions():
|
||||||
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
|
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
|
||||||
|
|
||||||
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig]) -> str:
|
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
|
||||||
last_provider = g4f.get_last_provider(True)
|
last_provider = {} if not image else g4f.get_last_provider(True)
|
||||||
|
provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider
|
||||||
|
model = AppConfig.model if config.model is None else config.model
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
"error": {"message": f"{e.__class__.__name__}: {e}"},
|
"error": {"message": f"{e.__class__.__name__}: {e}"},
|
||||||
"model": last_provider.get("model") if last_provider else getattr(config, 'model', None),
|
"model": last_provider.get("model") if model is None else model,
|
||||||
"provider": last_provider.get("name") if last_provider else getattr(config, 'provider', None)
|
**filter_none(
|
||||||
|
provider=last_provider.get("name") if provider is None else provider
|
||||||
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
def run_api(
|
def run_api(
|
||||||
|
|
@ -250,21 +262,19 @@ def run_api(
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
workers: int = None,
|
workers: int = None,
|
||||||
use_colors: bool = None,
|
use_colors: bool = None,
|
||||||
g4f_api_key: str = None
|
reload: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
|
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
|
||||||
if use_colors is None:
|
if use_colors is None:
|
||||||
use_colors = debug
|
use_colors = debug
|
||||||
if bind is not None:
|
if bind is not None:
|
||||||
host, port = bind.split(":")
|
host, port = bind.split(":")
|
||||||
if debug:
|
|
||||||
g4f.debug.logging = True
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"g4f.api:create_app",
|
f"g4f.api:create_app{'_debug' if debug else ''}",
|
||||||
host=host,
|
host=host,
|
||||||
port=int(port),
|
port=int(port),
|
||||||
workers=workers,
|
workers=workers,
|
||||||
use_colors=use_colors,
|
use_colors=use_colors,
|
||||||
factory=True,
|
factory=True,
|
||||||
reload=debug
|
reload=reload
|
||||||
)
|
)
|
||||||
27
g4f/cli.py
27
g4f/cli.py
|
|
@ -11,16 +11,19 @@ def main():
|
||||||
api_parser = subparsers.add_parser("api")
|
api_parser = subparsers.add_parser("api")
|
||||||
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
|
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
|
||||||
api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
|
api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
|
||||||
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --debug and --workers)")
|
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --reload and --workers)")
|
||||||
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
|
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
|
||||||
default=None, help="Default provider for chat completion. (incompatible with --debug and --workers)")
|
default=None, help="Default provider for chat completion. (incompatible with --reload and --workers)")
|
||||||
api_parser.add_argument("--proxy", default=None, help="Default used proxy.")
|
api_parser.add_argument("--image-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and hasattr(provider, "image_models")],
|
||||||
|
default=None, help="Default provider for image generation. (incompatible with --reload and --workers)"),
|
||||||
|
api_parser.add_argument("--proxy", default=None, help="Default used proxy. (incompatible with --reload and --workers)")
|
||||||
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
|
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
|
||||||
api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.")
|
api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.")
|
||||||
api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files.")
|
api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files. (incompatible with --reload and --workers)")
|
||||||
api_parser.add_argument("--g4f-api-key", type=str, default=None, help="Sets an authentication key for your API. (incompatible with --debug and --workers)")
|
api_parser.add_argument("--g4f-api-key", type=str, default=None, help="Sets an authentication key for your API. (incompatible with --reload and --workers)")
|
||||||
api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
|
api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
|
||||||
default=[], help="List of providers to ignore when processing request. (incompatible with --debug and --workers)")
|
default=[], help="List of providers to ignore when processing request. (incompatible with --reload and --workers)")
|
||||||
|
api_parser.add_argument("--reload", action="store_true", help="Enable reloading.")
|
||||||
subparsers.add_parser("gui", parents=[gui_parser()], add_help=False)
|
subparsers.add_parser("gui", parents=[gui_parser()], add_help=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -39,17 +42,17 @@ def run_api_args(args):
|
||||||
ignore_cookie_files=args.ignore_cookie_files,
|
ignore_cookie_files=args.ignore_cookie_files,
|
||||||
ignored_providers=args.ignored_providers,
|
ignored_providers=args.ignored_providers,
|
||||||
g4f_api_key=args.g4f_api_key,
|
g4f_api_key=args.g4f_api_key,
|
||||||
defaults={
|
provider=args.provider,
|
||||||
"model": args.model,
|
image_provider=args.image_provider,
|
||||||
"provider": args.provider,
|
proxy=args.proxy,
|
||||||
"proxy": args.proxy
|
model=args.model
|
||||||
}
|
|
||||||
)
|
)
|
||||||
run_api(
|
run_api(
|
||||||
bind=args.bind,
|
bind=args.bind,
|
||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
workers=args.workers,
|
workers=args.workers,
|
||||||
use_colors=not args.disable_colors
|
use_colors=not args.disable_colors,
|
||||||
|
reload=args.reload
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,468 @@
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, ImagesResponse
|
from __future__ import annotations
|
||||||
from .client import Client, AsyncClient
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
from typing import Union, AsyncIterator, Iterator, Coroutine
|
||||||
|
|
||||||
|
from ..providers.base_provider import AsyncGeneratorProvider
|
||||||
|
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
|
||||||
|
from ..typing import Messages, Cookies, Image
|
||||||
|
from ..providers.types import ProviderType, FinishReason, BaseConversation
|
||||||
|
from ..errors import NoImageResponseError
|
||||||
|
from ..providers.retry_provider import IterListProvider
|
||||||
|
from ..Provider.needs_auth.BingCreateImages import BingCreateImages
|
||||||
|
from ..requests.aiohttp import get_connector
|
||||||
|
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||||
|
from .image_models import ImageModels
|
||||||
|
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||||
|
from .service import get_model_and_provider, get_last_provider, convert_to_provider
|
||||||
|
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator
|
||||||
|
|
||||||
|
try:
|
||||||
|
anext # Python 3.8+
|
||||||
|
except NameError:
|
||||||
|
async def anext(aiter):
|
||||||
|
try:
|
||||||
|
return await aiter.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
# Synchronous iter_response function
|
||||||
|
def iter_response(
|
||||||
|
response: Union[Iterator[str], AsyncIterator[str]],
|
||||||
|
stream: bool,
|
||||||
|
response_format: dict = None,
|
||||||
|
max_tokens: int = None,
|
||||||
|
stop: list = None
|
||||||
|
) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
||||||
|
content = ""
|
||||||
|
finish_reason = None
|
||||||
|
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
if hasattr(response, '__aiter__'):
|
||||||
|
# It's an async iterator, wrap it into a sync iterator
|
||||||
|
response = to_sync_iter(response)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if isinstance(chunk, FinishReason):
|
||||||
|
finish_reason = chunk.reason
|
||||||
|
break
|
||||||
|
elif isinstance(chunk, BaseConversation):
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = str(chunk)
|
||||||
|
content += chunk
|
||||||
|
|
||||||
|
if max_tokens is not None and idx + 1 >= max_tokens:
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
|
first, content, chunk = find_stop(stop, content, chunk if stream else None)
|
||||||
|
|
||||||
|
if first != -1:
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
|
||||||
|
|
||||||
|
if finish_reason is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
|
||||||
|
else:
|
||||||
|
if response_format is not None and "type" in response_format:
|
||||||
|
if response_format["type"] == "json_object":
|
||||||
|
content = filter_json(content)
|
||||||
|
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
|
||||||
|
|
||||||
|
# Synchronous iter_append_model_and_provider function
|
||||||
|
def iter_append_model_and_provider(response: Iterator[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
|
||||||
|
last_provider = None
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||||
|
chunk.model = last_provider.get("model")
|
||||||
|
chunk.provider = last_provider.get("name")
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def async_iter_response(
|
||||||
|
response: AsyncIterator[str],
|
||||||
|
stream: bool,
|
||||||
|
response_format: dict = None,
|
||||||
|
max_tokens: int = None,
|
||||||
|
stop: list = None
|
||||||
|
) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
||||||
|
content = ""
|
||||||
|
finish_reason = None
|
||||||
|
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk, FinishReason):
|
||||||
|
finish_reason = chunk.reason
|
||||||
|
break
|
||||||
|
elif isinstance(chunk, BaseConversation):
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = str(chunk)
|
||||||
|
content += chunk
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
if max_tokens is not None and idx >= max_tokens:
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
|
first, content, chunk = find_stop(stop, content, chunk if stream else None)
|
||||||
|
|
||||||
|
if first != -1:
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
|
||||||
|
|
||||||
|
if finish_reason is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
finish_reason = "stop" if finish_reason is None else finish_reason
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
|
||||||
|
else:
|
||||||
|
if response_format is not None and "type" in response_format:
|
||||||
|
if response_format["type"] == "json_object":
|
||||||
|
content = filter_json(content)
|
||||||
|
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
|
||||||
|
finally:
|
||||||
|
if hasattr(response, 'aclose'):
|
||||||
|
await safe_aclose(response)
|
||||||
|
|
||||||
|
async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompletionChunk]) -> AsyncIterator:
|
||||||
|
last_provider = None
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||||
|
chunk.model = last_provider.get("model")
|
||||||
|
chunk.provider = last_provider.get("name")
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
if hasattr(response, 'aclose'):
|
||||||
|
await safe_aclose(response)
|
||||||
|
|
||||||
|
class Client(BaseClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: ProviderType = None,
|
||||||
|
image_provider: ImageProvider = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.chat: Chat = Chat(self, provider)
|
||||||
|
self.images: Images = Images(self, image_provider)
|
||||||
|
|
||||||
|
class Completions:
|
||||||
|
def __init__(self, client: Client, provider: ProviderType = None):
|
||||||
|
self.client: Client = client
|
||||||
|
self.provider: ProviderType = provider
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
messages: Messages,
|
||||||
|
model: str,
|
||||||
|
provider: ProviderType = None,
|
||||||
|
stream: bool = False,
|
||||||
|
proxy: str = None,
|
||||||
|
response_format: dict = None,
|
||||||
|
max_tokens: int = None,
|
||||||
|
stop: Union[list[str], str] = None,
|
||||||
|
api_key: str = None,
|
||||||
|
ignored: list[str] = None,
|
||||||
|
ignore_working: bool = False,
|
||||||
|
ignore_stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> IterResponse:
|
||||||
|
model, provider = get_model_and_provider(
|
||||||
|
model,
|
||||||
|
self.provider if provider is None else provider,
|
||||||
|
stream,
|
||||||
|
ignored,
|
||||||
|
ignore_working,
|
||||||
|
ignore_stream,
|
||||||
|
)
|
||||||
|
stop = [stop] if isinstance(stop, str) else stop
|
||||||
|
|
||||||
|
response = provider.create_completion(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream=stream,
|
||||||
|
**filter_none(
|
||||||
|
proxy=self.client.proxy if proxy is None else proxy,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop=stop,
|
||||||
|
api_key=self.client.api_key if api_key is None else api_key
|
||||||
|
),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if asyncio.iscoroutinefunction(provider.create_completion):
|
||||||
|
# Run the asynchronous function in an event loop
|
||||||
|
response = asyncio.run(response)
|
||||||
|
if stream and hasattr(response, '__aiter__'):
|
||||||
|
# It's an async generator, wrap it into a sync iterator
|
||||||
|
response = to_sync_iter(response)
|
||||||
|
elif hasattr(response, '__aiter__'):
|
||||||
|
# If response is an async generator, collect it into a list
|
||||||
|
response = list(to_sync_iter(response))
|
||||||
|
response = iter_response(response, stream, response_format, max_tokens, stop)
|
||||||
|
response = iter_append_model_and_provider(response)
|
||||||
|
if stream:
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return next(response)
|
||||||
|
|
||||||
|
class Chat:
|
||||||
|
completions: Completions
|
||||||
|
|
||||||
|
def __init__(self, client: Client, provider: ProviderType = None):
|
||||||
|
self.completions = Completions(client, provider)
|
||||||
|
|
||||||
|
class Images:
|
||||||
|
def __init__(self, client: Client, provider: ProviderType = None):
|
||||||
|
self.client: Client = client
|
||||||
|
self.provider: ProviderType = provider
|
||||||
|
self.models: ImageModels = ImageModels(client)
|
||||||
|
|
||||||
|
def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||||
|
"""
|
||||||
|
Synchronous generate method that runs the async_generate method in an event loop.
|
||||||
|
"""
|
||||||
|
return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy **kwargs))
|
||||||
|
|
||||||
|
async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||||
|
if provider is None:
|
||||||
|
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
|
||||||
|
elif isinstance(provider, str):
|
||||||
|
provider_handler = convert_to_provider(provider)
|
||||||
|
if provider_handler is None:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
if proxy is None:
|
||||||
|
proxy = self.client.proxy
|
||||||
|
|
||||||
|
if isinstance(provider_handler, IterListProvider):
|
||||||
|
if provider_handler.providers:
|
||||||
|
provider_handler = provider.providers[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"IterListProvider for model {model} has no providers")
|
||||||
|
|
||||||
|
response = None
|
||||||
|
if hasattr(provider_handler, "create_async_generator"):
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
async for item in provider_handler.create_async_generator(model, messages, **kwargs):
|
||||||
|
if isinstance(item, ImageResponse):
|
||||||
|
response = item
|
||||||
|
break
|
||||||
|
elif hasattr(provider, 'create'):
|
||||||
|
if asyncio.iscoroutinefunction(provider_handler.create):
|
||||||
|
response = await provider_handler.create(prompt)
|
||||||
|
else:
|
||||||
|
response = provider_handler.create(prompt)
|
||||||
|
if isinstance(response, str):
|
||||||
|
response = ImageResponse([response], prompt)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Provider {provider} does not support image generation")
|
||||||
|
if isinstance(response, ImageResponse):
|
||||||
|
return await self._process_image_response(response, response_format, proxy, model=model, provider=provider)
|
||||||
|
|
||||||
|
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||||
|
|
||||||
|
async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
|
||||||
|
async def process_image_item(session: aiohttp.ClientSession, image_data: str):
|
||||||
|
if image_data.startswith('http://') or image_data.startswith('https://'):
|
||||||
|
if response_format == "url":
|
||||||
|
return Image(url=image_data, revised_prompt=response.alt)
|
||||||
|
elif response_format == "b64_json":
|
||||||
|
# Fetch the image data and convert it to base64
|
||||||
|
image_content = await self._fetch_image(session, image_data)
|
||||||
|
file_name = self._save_image(image_data_bytes)
|
||||||
|
b64_json = base64.b64encode(image_content).decode('utf-8')
|
||||||
|
return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
|
||||||
|
else:
|
||||||
|
# Assume image_data is base64 data or binary
|
||||||
|
if response_format == "url":
|
||||||
|
if image_data.startswith('data:image'):
|
||||||
|
# Remove the data URL scheme and get the base64 data
|
||||||
|
base64_data = image_data.split(',', 1)[-1]
|
||||||
|
else:
|
||||||
|
base64_data = image_data
|
||||||
|
# Decode the base64 data
|
||||||
|
image_data_bytes = base64.b64decode(base64_data)
|
||||||
|
# Convert bytes to an image
|
||||||
|
file_name = self._save_image(image_data_bytes)
|
||||||
|
return Image(url=file_name, revised_prompt=response.alt)
|
||||||
|
elif response_format == "b64_json":
|
||||||
|
if isinstance(image_data, bytes):
|
||||||
|
file_name = self._save_image(image_data_bytes)
|
||||||
|
b64_json = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
else:
|
||||||
|
b64_json = image_data # If already base64-encoded string
|
||||||
|
return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
|
||||||
|
|
||||||
|
last_provider = get_last_provider(True)
|
||||||
|
async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
|
||||||
|
return ImagesResponse(
|
||||||
|
await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
|
||||||
|
model=last_provider.get("model") if model is None else model,
|
||||||
|
provider=last_provider.get("name") if provider is None else provider
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
|
||||||
|
# Asynchronously fetch image data from the URL
|
||||||
|
async with session.get(url) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return await resp.read()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to fetch image from {url}, status code {resp.status}")
|
||||||
|
|
||||||
|
def _save_image(self, image_data_bytes: bytes) -> str:
|
||||||
|
os.makedirs('generated_images', exist_ok=True)
|
||||||
|
image = to_image(image_data_bytes)
|
||||||
|
file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.{EXTENSIONS_MAP[is_accepted_format(image_data_bytes)]}"
|
||||||
|
image.save(file_name)
|
||||||
|
return file_name
|
||||||
|
|
||||||
|
def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||||
|
return asyncio.run(self.async_create_variation(
|
||||||
|
image, model, provider, response_format
|
||||||
|
**kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
async def async_create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||||
|
if provider is None:
|
||||||
|
provider = self.models.get(model, provider or self.provider or BingCreateImages)
|
||||||
|
if provider is None:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
if isinstance(provider, str):
|
||||||
|
provider = convert_to_provider(provider)
|
||||||
|
if proxy is None:
|
||||||
|
proxy = self.client.proxy
|
||||||
|
|
||||||
|
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
||||||
|
messages = [{"role": "user", "content": "create a variation of this image"}]
|
||||||
|
image_data = to_data_uri(image)
|
||||||
|
generator = None
|
||||||
|
try:
|
||||||
|
generator = provider.create_async_generator(model, messages, image=image_data, response_format=response_format, proxy=proxy, **kwargs)
|
||||||
|
async for response in generator:
|
||||||
|
if isinstance(response, ImageResponse):
|
||||||
|
return self._process_image_response(response)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "async generator ignored GeneratorExit" in str(e):
|
||||||
|
logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if generator and hasattr(generator, 'aclose'):
|
||||||
|
await safe_aclose(generator)
|
||||||
|
logging.info("AsyncGeneratorProvider processing completed in create_variation")
|
||||||
|
elif hasattr(provider, 'create_variation'):
|
||||||
|
if asyncio.iscoroutinefunction(provider.create_variation):
|
||||||
|
response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||||
|
else:
|
||||||
|
response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||||
|
if isinstance(response, str):
|
||||||
|
response = ImageResponse([response])
|
||||||
|
return self._process_image_response(response)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Provider {provider} does not support image variation")
|
||||||
|
|
||||||
|
class AsyncClient(BaseClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: ProviderType = None,
|
||||||
|
image_provider: ImageProvider = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.chat: AsyncChat = AsyncChat(self, provider)
|
||||||
|
self.images: AsyncImages = AsyncImages(self, image_provider)
|
||||||
|
|
||||||
|
class AsyncChat:
|
||||||
|
completions: AsyncCompletions
|
||||||
|
|
||||||
|
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
||||||
|
self.completions = AsyncCompletions(client, provider)
|
||||||
|
|
||||||
|
class AsyncCompletions:
|
||||||
|
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
||||||
|
self.client: AsyncClient = client
|
||||||
|
self.provider: ProviderType = provider
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
messages: Messages,
|
||||||
|
model: str,
|
||||||
|
provider: ProviderType = None,
|
||||||
|
stream: bool = False,
|
||||||
|
proxy: str = None,
|
||||||
|
response_format: dict = None,
|
||||||
|
max_tokens: int = None,
|
||||||
|
stop: Union[list[str], str] = None,
|
||||||
|
api_key: str = None,
|
||||||
|
ignored: list[str] = None,
|
||||||
|
ignore_working: bool = False,
|
||||||
|
ignore_stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk]]:
|
||||||
|
model, provider = get_model_and_provider(
|
||||||
|
model,
|
||||||
|
self.provider if provider is None else provider,
|
||||||
|
stream,
|
||||||
|
ignored,
|
||||||
|
ignore_working,
|
||||||
|
ignore_stream,
|
||||||
|
)
|
||||||
|
stop = [stop] if isinstance(stop, str) else stop
|
||||||
|
|
||||||
|
response = provider.create_completion(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream=stream,
|
||||||
|
**filter_none(
|
||||||
|
proxy=self.client.proxy if proxy is None else proxy,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop=stop,
|
||||||
|
api_key=self.client.api_key if api_key is None else api_key
|
||||||
|
),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, AsyncIterator):
|
||||||
|
response = to_async_iterator(response)
|
||||||
|
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||||
|
response = async_iter_append_model_and_provider(response)
|
||||||
|
return response if stream else anext(response)
|
||||||
|
|
||||||
|
class AsyncImages(Images):
|
||||||
|
def __init__(self, client: AsyncClient, provider: ImageProvider = None):
|
||||||
|
self.client: AsyncClient = client
|
||||||
|
self.provider: ImageProvider = provider
|
||||||
|
self.models: ImageModels = ImageModels(client)
|
||||||
|
|
||||||
|
async def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||||
|
return await self.async_generate(prompt, model, provider, response_format, **kwargs)
|
||||||
|
|
||||||
|
async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||||
|
return await self.async_create_variation(
|
||||||
|
image, model, provider, response_format, **kwargs
|
||||||
|
)
|
||||||
|
|
@ -1,541 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
import threading
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import aiohttp
|
|
||||||
import queue
|
|
||||||
from typing import Union, AsyncIterator, Iterator
|
|
||||||
|
|
||||||
from ..providers.base_provider import AsyncGeneratorProvider
|
|
||||||
from ..image import ImageResponse, to_image, to_data_uri
|
|
||||||
from ..typing import Messages, ImageType
|
|
||||||
from ..providers.types import BaseProvider, ProviderType, FinishReason
|
|
||||||
from ..providers.conversation import BaseConversation
|
|
||||||
from ..image import ImageResponse as ImageProviderResponse
|
|
||||||
from ..errors import NoImageResponseError
|
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
|
||||||
from .image_models import ImageModels
|
|
||||||
from .types import IterResponse, ImageProvider
|
|
||||||
from .types import Client as BaseClient
|
|
||||||
from .service import get_model_and_provider, get_last_provider
|
|
||||||
from .helper import find_stop, filter_json, filter_none
|
|
||||||
from ..models import ModelUtils
|
|
||||||
from ..Provider import IterListProvider
|
|
||||||
|
|
||||||
# Helper function to convert an async generator to a synchronous iterator
|
|
||||||
def to_sync_iter(async_gen: AsyncIterator) -> Iterator:
|
|
||||||
q = queue.Queue()
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
done = object()
|
|
||||||
|
|
||||||
def _run():
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
async def iterate():
|
|
||||||
try:
|
|
||||||
async for item in async_gen:
|
|
||||||
q.put(item)
|
|
||||||
finally:
|
|
||||||
q.put(done)
|
|
||||||
|
|
||||||
loop.run_until_complete(iterate())
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
threading.Thread(target=_run).start()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
item = q.get()
|
|
||||||
if item is done:
|
|
||||||
break
|
|
||||||
yield item
|
|
||||||
|
|
||||||
# Helper function to convert a synchronous iterator to an async iterator
|
|
||||||
async def to_async_iterator(iterator):
|
|
||||||
for item in iterator:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
# Synchronous iter_response function
|
|
||||||
def iter_response(
|
|
||||||
response: Union[Iterator[str], AsyncIterator[str]],
|
|
||||||
stream: bool,
|
|
||||||
response_format: dict = None,
|
|
||||||
max_tokens: int = None,
|
|
||||||
stop: list = None
|
|
||||||
) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
|
||||||
content = ""
|
|
||||||
finish_reason = None
|
|
||||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
if hasattr(response, '__aiter__'):
|
|
||||||
# It's an async iterator, wrap it into a sync iterator
|
|
||||||
response = to_sync_iter(response)
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
if isinstance(chunk, FinishReason):
|
|
||||||
finish_reason = chunk.reason
|
|
||||||
break
|
|
||||||
elif isinstance(chunk, BaseConversation):
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
content += str(chunk)
|
|
||||||
|
|
||||||
if max_tokens is not None and idx + 1 >= max_tokens:
|
|
||||||
finish_reason = "length"
|
|
||||||
|
|
||||||
first, content, chunk = find_stop(stop, content, chunk if stream else None)
|
|
||||||
|
|
||||||
if first != -1:
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
|
|
||||||
|
|
||||||
if finish_reason is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
|
|
||||||
else:
|
|
||||||
if response_format is not None and "type" in response_format:
|
|
||||||
if response_format["type"] == "json_object":
|
|
||||||
content = filter_json(content)
|
|
||||||
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
|
|
||||||
|
|
||||||
# Synchronous iter_append_model_and_provider function
|
|
||||||
def iter_append_model_and_provider(response: Iterator) -> Iterator:
|
|
||||||
last_provider = None
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
|
||||||
chunk.model = last_provider.get("model")
|
|
||||||
chunk.provider = last_provider.get("name")
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
class Client(BaseClient):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
provider: ProviderType = None,
|
|
||||||
image_provider: ImageProvider = None,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.chat: Chat = Chat(self, provider)
|
|
||||||
self._images: Images = Images(self, image_provider)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def images(self) -> Images:
|
|
||||||
return self._images
|
|
||||||
|
|
||||||
async def async_images(self) -> Images:
|
|
||||||
return self._images
|
|
||||||
|
|
||||||
# For backwards compatibility and legacy purposes, use Client instead
|
|
||||||
class AsyncClient(Client):
|
|
||||||
"""Legacy AsyncClient that redirects to the main Client class.
|
|
||||||
This class exists for backwards compatibility."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
import warnings
|
|
||||||
warnings.warn(
|
|
||||||
"AsyncClient is deprecated and will be removed in future versions."
|
|
||||||
"Use Client instead, which now supports both sync and async operations.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2
|
|
||||||
)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
async def async_create(self, *args, **kwargs):
|
|
||||||
"""Asynchronous create method that calls the synchronous method."""
|
|
||||||
return await super().async_create(*args, **kwargs)
|
|
||||||
|
|
||||||
async def async_generate(self, *args, **kwargs):
|
|
||||||
"""Asynchronous image generation method."""
|
|
||||||
return await super().async_generate(*args, **kwargs)
|
|
||||||
|
|
||||||
async def async_images(self) -> Images:
|
|
||||||
"""Asynchronous access to images."""
|
|
||||||
return await super().async_images()
|
|
||||||
|
|
||||||
async def async_fetch_image(self, url: str) -> bytes:
|
|
||||||
"""Asynchronous fetching of an image by URL."""
|
|
||||||
return await self._fetch_image(url)
|
|
||||||
|
|
||||||
class Completions:
|
|
||||||
def __init__(self, client: Client, provider: ProviderType = None):
|
|
||||||
self.client: Client = client
|
|
||||||
self.provider: ProviderType = provider
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
messages: Messages,
|
|
||||||
model: str,
|
|
||||||
provider: ProviderType = None,
|
|
||||||
stream: bool = False,
|
|
||||||
proxy: str = None,
|
|
||||||
response_format: dict = None,
|
|
||||||
max_tokens: int = None,
|
|
||||||
stop: Union[list[str], str] = None,
|
|
||||||
api_key: str = None,
|
|
||||||
ignored: list[str] = None,
|
|
||||||
ignore_working: bool = False,
|
|
||||||
ignore_stream: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
||||||
model, provider = get_model_and_provider(
|
|
||||||
model,
|
|
||||||
self.provider if provider is None else provider,
|
|
||||||
stream,
|
|
||||||
ignored,
|
|
||||||
ignore_working,
|
|
||||||
ignore_stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
stop = [stop] if isinstance(stop, str) else stop
|
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(provider.create_completion):
|
|
||||||
# Run the asynchronous function in an event loop
|
|
||||||
response = asyncio.run(provider.create_completion(
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream=stream,
|
|
||||||
**filter_none(
|
|
||||||
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
api_key=self.client.api_key if api_key is None else api_key
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
response = provider.create_completion(
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream=stream,
|
|
||||||
**filter_none(
|
|
||||||
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
api_key=self.client.api_key if api_key is None else api_key
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
if hasattr(response, '__aiter__'):
|
|
||||||
# It's an async generator, wrap it into a sync iterator
|
|
||||||
response = to_sync_iter(response)
|
|
||||||
|
|
||||||
# Now 'response' is an iterator
|
|
||||||
response = iter_response(response, stream, response_format, max_tokens, stop)
|
|
||||||
response = iter_append_model_and_provider(response)
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
if hasattr(response, '__aiter__'):
|
|
||||||
# If response is an async generator, collect it into a list
|
|
||||||
response = list(to_sync_iter(response))
|
|
||||||
response = iter_response(response, stream, response_format, max_tokens, stop)
|
|
||||||
response = iter_append_model_and_provider(response)
|
|
||||||
return next(response)
|
|
||||||
|
|
||||||
async def async_create(
|
|
||||||
self,
|
|
||||||
messages: Messages,
|
|
||||||
model: str,
|
|
||||||
provider: ProviderType = None,
|
|
||||||
stream: bool = False,
|
|
||||||
proxy: str = None,
|
|
||||||
response_format: dict = None,
|
|
||||||
max_tokens: int = None,
|
|
||||||
stop: Union[list[str], str] = None,
|
|
||||||
api_key: str = None,
|
|
||||||
ignored: list[str] = None,
|
|
||||||
ignore_working: bool = False,
|
|
||||||
ignore_stream: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
|
|
||||||
model, provider = get_model_and_provider(
|
|
||||||
model,
|
|
||||||
self.provider if provider is None else provider,
|
|
||||||
stream,
|
|
||||||
ignored,
|
|
||||||
ignore_working,
|
|
||||||
ignore_stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
stop = [stop] if isinstance(stop, str) else stop
|
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(provider.create_completion):
|
|
||||||
response = await provider.create_completion(
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream=stream,
|
|
||||||
**filter_none(
|
|
||||||
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
api_key=self.client.api_key if api_key is None else api_key
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = provider.create_completion(
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
stream=stream,
|
|
||||||
**filter_none(
|
|
||||||
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
api_key=self.client.api_key if api_key is None else api_key
|
|
||||||
),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Removed 'await' here since 'async_iter_response' returns an async generator
|
|
||||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
|
||||||
response = async_iter_append_model_and_provider(response)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
async for result in response:
|
|
||||||
return result
|
|
||||||
|
|
||||||
class Chat:
|
|
||||||
completions: Completions
|
|
||||||
|
|
||||||
def __init__(self, client: Client, provider: ProviderType = None):
|
|
||||||
self.completions = Completions(client, provider)
|
|
||||||
|
|
||||||
# Asynchronous versions of the helper functions
|
|
||||||
async def async_iter_response(
|
|
||||||
response: Union[AsyncIterator[str], Iterator[str]],
|
|
||||||
stream: bool,
|
|
||||||
response_format: dict = None,
|
|
||||||
max_tokens: int = None,
|
|
||||||
stop: list = None
|
|
||||||
) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
|
||||||
content = ""
|
|
||||||
finish_reason = None
|
|
||||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
|
||||||
idx = 0
|
|
||||||
|
|
||||||
if not hasattr(response, '__aiter__'):
|
|
||||||
response = to_async_iterator(response)
|
|
||||||
|
|
||||||
async for chunk in response:
|
|
||||||
if isinstance(chunk, FinishReason):
|
|
||||||
finish_reason = chunk.reason
|
|
||||||
break
|
|
||||||
elif isinstance(chunk, BaseConversation):
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
content += str(chunk)
|
|
||||||
|
|
||||||
if max_tokens is not None and idx + 1 >= max_tokens:
|
|
||||||
finish_reason = "length"
|
|
||||||
|
|
||||||
first, content, chunk = find_stop(stop, content, chunk if stream else None)
|
|
||||||
|
|
||||||
if first != -1:
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
|
|
||||||
|
|
||||||
if finish_reason is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
finish_reason = "stop" if finish_reason is None else finish_reason
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
|
|
||||||
else:
|
|
||||||
if response_format is not None and "type" in response_format:
|
|
||||||
if response_format["type"] == "json_object":
|
|
||||||
content = filter_json(content)
|
|
||||||
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
|
|
||||||
|
|
||||||
async def async_iter_append_model_and_provider(response: AsyncIterator) -> AsyncIterator:
|
|
||||||
last_provider = None
|
|
||||||
|
|
||||||
if not hasattr(response, '__aiter__'):
|
|
||||||
response = to_async_iterator(response)
|
|
||||||
|
|
||||||
async for chunk in response:
|
|
||||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
|
||||||
chunk.model = last_provider.get("model")
|
|
||||||
chunk.provider = last_provider.get("name")
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
|
|
||||||
response_list = []
|
|
||||||
async for chunk in response:
|
|
||||||
if isinstance(chunk, ImageProviderResponse):
|
|
||||||
response_list.extend(chunk.get_list())
|
|
||||||
elif isinstance(chunk, str):
|
|
||||||
response_list.append(chunk)
|
|
||||||
|
|
||||||
if response_list:
|
|
||||||
return ImagesResponse([Image(image) for image in response_list])
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def create_image(client: Client, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
|
|
||||||
if isinstance(provider, type) and provider.__name__ == "You":
|
|
||||||
kwargs["chat_mode"] = "create"
|
|
||||||
else:
|
|
||||||
prompt = f"create an image with: {prompt}"
|
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(provider.create_completion):
|
|
||||||
response = await provider.create_completion(
|
|
||||||
model,
|
|
||||||
[{"role": "user", "content": prompt}],
|
|
||||||
stream=True,
|
|
||||||
proxy=client.get_proxy(),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = provider.create_completion(
|
|
||||||
model,
|
|
||||||
[{"role": "user", "content": prompt}],
|
|
||||||
stream=True,
|
|
||||||
proxy=client.get_proxy(),
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap synchronous iterator into async iterator if necessary
|
|
||||||
if not hasattr(response, '__aiter__'):
|
|
||||||
response = to_async_iterator(response)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
class Image:
|
|
||||||
def __init__(self, url: str = None, b64_json: str = None):
|
|
||||||
self.url = url
|
|
||||||
self.b64_json = b64_json
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Image(url={self.url}, b64_json={'<base64 data>' if self.b64_json else None})"
|
|
||||||
|
|
||||||
class ImagesResponse:
|
|
||||||
def __init__(self, data: list[Image]):
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ImagesResponse(data={self.data})"
|
|
||||||
|
|
||||||
class Images:
|
|
||||||
def __init__(self, client: 'Client', provider: 'ImageProvider' = None):
|
|
||||||
self.client: 'Client' = client
|
|
||||||
self.provider: 'ImageProvider' = provider
|
|
||||||
self.models: ImageModels = ImageModels(client)
|
|
||||||
|
|
||||||
def generate(self, prompt: str, model: str = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
|
||||||
"""
|
|
||||||
Synchronous generate method that runs the async_generate method in an event loop.
|
|
||||||
"""
|
|
||||||
return asyncio.run(self.async_generate(prompt, model, response_format=response_format, **kwargs))
|
|
||||||
|
|
||||||
async def async_generate(self, prompt: str, model: str = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
|
||||||
provider = self.models.get(model, self.provider)
|
|
||||||
if provider is None:
|
|
||||||
raise ValueError(f"Unknown model: {model}")
|
|
||||||
|
|
||||||
if isinstance(provider, IterListProvider):
|
|
||||||
if provider.providers:
|
|
||||||
provider = provider.providers[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"IterListProvider for model {model} has no providers")
|
|
||||||
|
|
||||||
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
async for response in provider.create_async_generator(model, messages, **kwargs):
|
|
||||||
if isinstance(response, ImageResponse):
|
|
||||||
return await self._process_image_response(response, response_format)
|
|
||||||
elif isinstance(response, str):
|
|
||||||
image_response = ImageResponse([response], prompt)
|
|
||||||
return await self._process_image_response(image_response, response_format)
|
|
||||||
elif hasattr(provider, 'create'):
|
|
||||||
if asyncio.iscoroutinefunction(provider.create):
|
|
||||||
response = await provider.create(prompt)
|
|
||||||
else:
|
|
||||||
response = provider.create(prompt)
|
|
||||||
|
|
||||||
if isinstance(response, ImageResponse):
|
|
||||||
return await self._process_image_response(response, response_format)
|
|
||||||
elif isinstance(response, str):
|
|
||||||
image_response = ImageResponse([response], prompt)
|
|
||||||
return await self._process_image_response(image_response, response_format)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Provider {provider} does not support image generation")
|
|
||||||
|
|
||||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
|
||||||
|
|
||||||
async def _process_image_response(self, response: ImageResponse, response_format: str) -> ImagesResponse:
|
|
||||||
processed_images = []
|
|
||||||
|
|
||||||
for image_data in response.get_list():
|
|
||||||
if image_data.startswith('http://') or image_data.startswith('https://'):
|
|
||||||
if response_format == "url":
|
|
||||||
processed_images.append(Image(url=image_data))
|
|
||||||
elif response_format == "b64_json":
|
|
||||||
# Fetch the image data and convert it to base64
|
|
||||||
image_content = await self._fetch_image(image_data)
|
|
||||||
b64_json = base64.b64encode(image_content).decode('utf-8')
|
|
||||||
processed_images.append(Image(b64_json=b64_json))
|
|
||||||
else:
|
|
||||||
# Assume image_data is base64 data or binary
|
|
||||||
if response_format == "url":
|
|
||||||
if image_data.startswith('data:image'):
|
|
||||||
# Remove the data URL scheme and get the base64 data
|
|
||||||
header, base64_data = image_data.split(',', 1)
|
|
||||||
else:
|
|
||||||
base64_data = image_data
|
|
||||||
# Decode the base64 data
|
|
||||||
image_data_bytes = base64.b64decode(base64_data)
|
|
||||||
# Convert bytes to an image
|
|
||||||
image = to_image(image_data_bytes)
|
|
||||||
file_name = self._save_image(image)
|
|
||||||
processed_images.append(Image(url=file_name))
|
|
||||||
elif response_format == "b64_json":
|
|
||||||
if isinstance(image_data, bytes):
|
|
||||||
b64_json = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
else:
|
|
||||||
b64_json = image_data # If already base64-encoded string
|
|
||||||
processed_images.append(Image(b64_json=b64_json))
|
|
||||||
|
|
||||||
return ImagesResponse(processed_images)
|
|
||||||
|
|
||||||
async def _fetch_image(self, url: str) -> bytes:
|
|
||||||
# Asynchronously fetch image data from the URL
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
return await resp.read()
|
|
||||||
else:
|
|
||||||
raise Exception(f"Failed to fetch image from {url}, status code {resp.status}")
|
|
||||||
|
|
||||||
def _save_image(self, image: 'PILImage') -> str:
|
|
||||||
os.makedirs('generated_images', exist_ok=True)
|
|
||||||
file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.png"
|
|
||||||
image.save(file_name)
|
|
||||||
return file_name
|
|
||||||
|
|
||||||
async def create_variation(self, image: Union[str, bytes], model: str = None, response_format: str = "url", **kwargs):
|
|
||||||
# Existing implementation, adjust if you want to support b64_json here as well
|
|
||||||
pass
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Iterable, AsyncIterator
|
import queue
|
||||||
|
import threading
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Iterator, AsyncGenerator
|
||||||
|
|
||||||
def filter_json(text: str) -> str:
|
def filter_json(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -42,6 +47,40 @@ def filter_none(**kwargs) -> dict:
|
||||||
if value is not None
|
if value is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
async def cast_iter_async(iter: Iterable) -> AsyncIterator:
|
async def safe_aclose(generator: AsyncGenerator) -> None:
|
||||||
for chunk in iter:
|
try:
|
||||||
yield chunk
|
await generator.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Error while closing generator: {e}")
|
||||||
|
|
||||||
|
# Helper function to convert an async generator to a synchronous iterator
|
||||||
|
def to_sync_iter(async_gen: AsyncIterator) -> Iterator:
|
||||||
|
q = queue.Queue()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
done = object()
|
||||||
|
|
||||||
|
def _run():
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def iterate():
|
||||||
|
try:
|
||||||
|
async for item in async_gen:
|
||||||
|
q.put(item)
|
||||||
|
finally:
|
||||||
|
q.put(done)
|
||||||
|
|
||||||
|
loop.run_until_complete(iterate())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=_run).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = q.get()
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
# Helper function to convert a synchronous iterator to an async iterator
|
||||||
|
async def to_async_iterator(iterator: Iterator) -> AsyncIterator:
|
||||||
|
for item in iterator:
|
||||||
|
yield item
|
||||||
|
|
@ -55,7 +55,6 @@ def get_model_and_provider(model : Union[Model, str],
|
||||||
provider = convert_to_provider(provider)
|
provider = convert_to_provider(provider)
|
||||||
|
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
|
|
||||||
if model in ModelUtils.convert:
|
if model in ModelUtils.convert:
|
||||||
model = ModelUtils.convert[model]
|
model = ModelUtils.convert[model]
|
||||||
|
|
||||||
|
|
@ -75,10 +74,10 @@ def get_model_and_provider(model : Union[Model, str],
|
||||||
if not ignore_working and not provider.working:
|
if not ignore_working and not provider.working:
|
||||||
raise ProviderNotWorkingError(f'{provider.__name__} is not working')
|
raise ProviderNotWorkingError(f'{provider.__name__} is not working')
|
||||||
|
|
||||||
if not ignore_working and isinstance(provider, BaseRetryProvider):
|
if isinstance(provider, BaseRetryProvider):
|
||||||
|
if not ignore_working:
|
||||||
provider.providers = [p for p in provider.providers if p.working]
|
provider.providers = [p for p in provider.providers if p.working]
|
||||||
|
if ignored:
|
||||||
if ignored and isinstance(provider, BaseRetryProvider):
|
|
||||||
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
|
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
|
||||||
|
|
||||||
if not ignore_stream and not provider.supports_stream and stream:
|
if not ignore_stream and not provider.supports_stream and stream:
|
||||||
|
|
@ -95,7 +94,7 @@ def get_model_and_provider(model : Union[Model, str],
|
||||||
|
|
||||||
return model, provider
|
return model, provider
|
||||||
|
|
||||||
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
|
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str], None]:
|
||||||
"""
|
"""
|
||||||
Retrieves the last used provider.
|
Retrieves the last used provider.
|
||||||
|
|
||||||
|
|
@ -108,11 +107,14 @@ def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, st
|
||||||
last = debug.last_provider
|
last = debug.last_provider
|
||||||
if isinstance(last, BaseRetryProvider):
|
if isinstance(last, BaseRetryProvider):
|
||||||
last = last.last_provider
|
last = last.last_provider
|
||||||
if last and as_dict:
|
if as_dict:
|
||||||
|
if last:
|
||||||
return {
|
return {
|
||||||
"name": last.__name__,
|
"name": last.__name__,
|
||||||
"url": last.url,
|
"url": last.url,
|
||||||
"model": debug.last_model,
|
"model": debug.last_model,
|
||||||
"label": last.label if hasattr(last, "label") else None
|
"label": getattr(last, "label", None) if hasattr(last, "label") else None
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
return last
|
return last
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from time import time
|
||||||
|
|
||||||
class Model():
|
class Model():
|
||||||
...
|
...
|
||||||
|
|
@ -108,8 +109,18 @@ class Image(Model):
|
||||||
return self.__dict__
|
return self.__dict__
|
||||||
|
|
||||||
class ImagesResponse(Model):
|
class ImagesResponse(Model):
|
||||||
def __init__(self, data: list[Image], created: int = 0) -> None:
|
data: list[Image]
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
created: int
|
||||||
|
|
||||||
|
def __init__(self, data: list[Image], created: int = None, model: str = None, provider: str = None) -> None:
|
||||||
self.data = data
|
self.data = data
|
||||||
|
if created is None:
|
||||||
|
created = int(time())
|
||||||
|
self.model = model
|
||||||
|
if provider is not None:
|
||||||
|
self.provider = provider
|
||||||
self.created = created
|
self.created = created
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,17 @@ Proxies = Union[dict, str]
|
||||||
IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]]
|
IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]]
|
||||||
AsyncIterResponse = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]
|
AsyncIterResponse = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]
|
||||||
|
|
||||||
class ClientProxyMixin():
|
class Client():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = None,
|
||||||
|
proxies: Proxies = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
self.api_key: str = api_key
|
||||||
|
self.proxies= proxies
|
||||||
|
self.proxy: str = self.get_proxy()
|
||||||
|
|
||||||
def get_proxy(self) -> Union[str, None]:
|
def get_proxy(self) -> Union[str, None]:
|
||||||
if isinstance(self.proxies, str):
|
if isinstance(self.proxies, str):
|
||||||
return self.proxies
|
return self.proxies
|
||||||
|
|
@ -21,13 +31,3 @@ class ClientProxyMixin():
|
||||||
return self.proxies["all"]
|
return self.proxies["all"]
|
||||||
elif "https" in self.proxies:
|
elif "https" in self.proxies:
|
||||||
return self.proxies["https"]
|
return self.proxies["https"]
|
||||||
|
|
||||||
class Client(ClientProxyMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str = None,
|
|
||||||
proxies: Proxies = None,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
self.api_key: str = api_key
|
|
||||||
self.proxies: Proxies = proxies
|
|
||||||
|
|
@ -96,7 +96,7 @@ async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = No
|
||||||
|
|
||||||
async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text: bool = True) -> SearchResults:
|
async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text: bool = True) -> SearchResults:
|
||||||
if not has_requirements:
|
if not has_requirements:
|
||||||
raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package')
|
raise MissingRequirementsError('Install "duckduckgo-search" and "beautifulsoup4" package | pip install -U g4f[search]')
|
||||||
with DDGS() as ddgs:
|
with DDGS() as ddgs:
|
||||||
results = []
|
results = []
|
||||||
for result in ddgs.text(
|
for result in ddgs.text(
|
||||||
|
|
|
||||||
153
g4f/models.py
153
g4f/models.py
|
|
@ -2,8 +2,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from .Provider.not_working import Ai4Chat
|
|
||||||
|
|
||||||
from .Provider import IterListProvider, ProviderType
|
from .Provider import IterListProvider, ProviderType
|
||||||
from .Provider import (
|
from .Provider import (
|
||||||
AIChatFree,
|
AIChatFree,
|
||||||
|
|
@ -19,12 +17,10 @@ from .Provider import (
|
||||||
DDG,
|
DDG,
|
||||||
DeepInfraChat,
|
DeepInfraChat,
|
||||||
Free2GPT,
|
Free2GPT,
|
||||||
FreeGpt,
|
|
||||||
FreeNetfly,
|
FreeNetfly,
|
||||||
|
GigaChat,
|
||||||
Gemini,
|
Gemini,
|
||||||
GeminiPro,
|
GeminiPro,
|
||||||
GizAI,
|
|
||||||
GigaChat,
|
|
||||||
HuggingChat,
|
HuggingChat,
|
||||||
HuggingFace,
|
HuggingFace,
|
||||||
Liaobots,
|
Liaobots,
|
||||||
|
|
@ -42,7 +38,6 @@ from .Provider import (
|
||||||
Upstage,
|
Upstage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(unsafe_hash=True)
|
@dataclass(unsafe_hash=True)
|
||||||
class Model:
|
class Model:
|
||||||
"""
|
"""
|
||||||
|
|
@ -62,7 +57,6 @@ class Model:
|
||||||
"""Returns a list of all model names."""
|
"""Returns a list of all model names."""
|
||||||
return _all_models
|
return _all_models
|
||||||
|
|
||||||
|
|
||||||
### Default ###
|
### Default ###
|
||||||
default = Model(
|
default = Model(
|
||||||
name = "",
|
name = "",
|
||||||
|
|
@ -85,8 +79,6 @@ default = Model(
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
############
|
############
|
||||||
### Text ###
|
### Text ###
|
||||||
############
|
############
|
||||||
|
|
@ -115,29 +107,15 @@ gpt_4o_mini = Model(
|
||||||
gpt_4_turbo = Model(
|
gpt_4_turbo = Model(
|
||||||
name = 'gpt-4-turbo',
|
name = 'gpt-4-turbo',
|
||||||
base_provider = 'OpenAI',
|
base_provider = 'OpenAI',
|
||||||
best_provider = IterListProvider([ChatGpt, Airforce, Liaobots, Bing])
|
best_provider = IterListProvider([Liaobots, Bing])
|
||||||
)
|
)
|
||||||
|
|
||||||
gpt_4 = Model(
|
gpt_4 = Model(
|
||||||
name = 'gpt-4',
|
name = 'gpt-4',
|
||||||
base_provider = 'OpenAI',
|
base_provider = 'OpenAI',
|
||||||
best_provider = IterListProvider([Mhystical, Chatgpt4Online, ChatGpt, Bing, OpenaiChat, gpt_4_turbo.best_provider, gpt_4o.best_provider, gpt_4o_mini.best_provider])
|
best_provider = IterListProvider([Chatgpt4Online, Bing, OpenaiChat, DDG, Liaobots, Airforce])
|
||||||
)
|
)
|
||||||
|
|
||||||
# o1
|
|
||||||
o1 = Model(
|
|
||||||
name = 'o1',
|
|
||||||
base_provider = 'OpenAI',
|
|
||||||
best_provider = None
|
|
||||||
)
|
|
||||||
|
|
||||||
o1_mini = Model(
|
|
||||||
name = 'o1-mini',
|
|
||||||
base_provider = 'OpenAI',
|
|
||||||
best_provider = None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
### GigaChat ###
|
### GigaChat ###
|
||||||
gigachat = Model(
|
gigachat = Model(
|
||||||
name = 'GigaChat:latest',
|
name = 'GigaChat:latest',
|
||||||
|
|
@ -145,7 +123,6 @@ gigachat = Model(
|
||||||
best_provider = GigaChat
|
best_provider = GigaChat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Meta ###
|
### Meta ###
|
||||||
meta = Model(
|
meta = Model(
|
||||||
name = "meta-ai",
|
name = "meta-ai",
|
||||||
|
|
@ -157,13 +134,13 @@ meta = Model(
|
||||||
llama_2_7b = Model(
|
llama_2_7b = Model(
|
||||||
name = "llama-2-7b",
|
name = "llama-2-7b",
|
||||||
base_provider = "Meta Llama",
|
base_provider = "Meta Llama",
|
||||||
best_provider = IterListProvider([Cloudflare, Airforce])
|
best_provider = Cloudflare
|
||||||
)
|
)
|
||||||
# llama 3
|
# llama 3
|
||||||
llama_3_8b = Model(
|
llama_3_8b = Model(
|
||||||
name = "llama-3-8b",
|
name = "llama-3-8b",
|
||||||
base_provider = "Meta Llama",
|
base_provider = "Meta Llama",
|
||||||
best_provider = IterListProvider([Cloudflare])
|
best_provider = Cloudflare
|
||||||
)
|
)
|
||||||
|
|
||||||
# llama 3.1
|
# llama 3.1
|
||||||
|
|
@ -198,13 +175,6 @@ llama_3_2_11b = Model(
|
||||||
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
||||||
)
|
)
|
||||||
|
|
||||||
### Mistral ###
|
|
||||||
mistral_7b = Model(
|
|
||||||
name = "mistral-7b",
|
|
||||||
base_provider = "Mistral",
|
|
||||||
best_provider = IterListProvider([Free2GPT])
|
|
||||||
)
|
|
||||||
|
|
||||||
mixtral_8x7b = Model(
|
mixtral_8x7b = Model(
|
||||||
name = "mixtral-8x7b",
|
name = "mixtral-8x7b",
|
||||||
base_provider = "Mistral",
|
base_provider = "Mistral",
|
||||||
|
|
@ -217,27 +187,12 @@ mistral_nemo = Model(
|
||||||
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### NousResearch ###
|
|
||||||
hermes_2_pro = Model(
|
|
||||||
name = "hermes-2-pro",
|
|
||||||
base_provider = "NousResearch",
|
|
||||||
best_provider = Airforce
|
|
||||||
)
|
|
||||||
|
|
||||||
hermes_2_dpo = Model(
|
|
||||||
name = "hermes-2-dpo",
|
|
||||||
base_provider = "NousResearch",
|
|
||||||
best_provider = Airforce
|
|
||||||
)
|
|
||||||
|
|
||||||
hermes_3 = Model(
|
hermes_3 = Model(
|
||||||
name = "hermes-3",
|
name = "hermes-3",
|
||||||
base_provider = "NousResearch",
|
base_provider = "NousResearch",
|
||||||
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
best_provider = IterListProvider([HuggingChat, HuggingFace])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Microsoft ###
|
### Microsoft ###
|
||||||
phi_2 = Model(
|
phi_2 = Model(
|
||||||
name = "phi-2",
|
name = "phi-2",
|
||||||
|
|
@ -256,13 +211,13 @@ phi_3_5_mini = Model(
|
||||||
gemini_pro = Model(
|
gemini_pro = Model(
|
||||||
name = 'gemini-pro',
|
name = 'gemini-pro',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = IterListProvider([GeminiPro, Blackbox, AIChatFree, FreeGpt, Liaobots])
|
best_provider = IterListProvider([GeminiPro, Blackbox, AIChatFree, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini_flash = Model(
|
gemini_flash = Model(
|
||||||
name = 'gemini-flash',
|
name = 'gemini-flash',
|
||||||
base_provider = 'Google DeepMind',
|
base_provider = 'Google DeepMind',
|
||||||
best_provider = IterListProvider([Blackbox, GizAI, Liaobots])
|
best_provider = IterListProvider([Blackbox, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini = Model(
|
gemini = Model(
|
||||||
|
|
@ -278,7 +233,6 @@ gemma_2b = Model(
|
||||||
best_provider = ReplicateHome
|
best_provider = ReplicateHome
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Anthropic ###
|
### Anthropic ###
|
||||||
claude_2_1 = Model(
|
claude_2_1 = Model(
|
||||||
name = 'claude-2.1',
|
name = 'claude-2.1',
|
||||||
|
|
@ -290,13 +244,13 @@ claude_2_1 = Model(
|
||||||
claude_3_opus = Model(
|
claude_3_opus = Model(
|
||||||
name = 'claude-3-opus',
|
name = 'claude-3-opus',
|
||||||
base_provider = 'Anthropic',
|
base_provider = 'Anthropic',
|
||||||
best_provider = IterListProvider([Liaobots])
|
best_provider = Liaobots
|
||||||
)
|
)
|
||||||
|
|
||||||
claude_3_sonnet = Model(
|
claude_3_sonnet = Model(
|
||||||
name = 'claude-3-sonnet',
|
name = 'claude-3-sonnet',
|
||||||
base_provider = 'Anthropic',
|
base_provider = 'Anthropic',
|
||||||
best_provider = IterListProvider([Liaobots])
|
best_provider = Liaobots
|
||||||
)
|
)
|
||||||
|
|
||||||
claude_3_haiku = Model(
|
claude_3_haiku = Model(
|
||||||
|
|
@ -312,7 +266,6 @@ claude_3_5_sonnet = Model(
|
||||||
best_provider = IterListProvider([Blackbox, Liaobots])
|
best_provider = IterListProvider([Blackbox, Liaobots])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Reka AI ###
|
### Reka AI ###
|
||||||
reka_core = Model(
|
reka_core = Model(
|
||||||
name = 'reka-core',
|
name = 'reka-core',
|
||||||
|
|
@ -320,7 +273,6 @@ reka_core = Model(
|
||||||
best_provider = Reka
|
best_provider = Reka
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Blackbox AI ###
|
### Blackbox AI ###
|
||||||
blackboxai = Model(
|
blackboxai = Model(
|
||||||
name = 'blackboxai',
|
name = 'blackboxai',
|
||||||
|
|
@ -341,7 +293,6 @@ command_r_plus = Model(
|
||||||
best_provider = HuggingChat
|
best_provider = HuggingChat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### Qwen ###
|
### Qwen ###
|
||||||
# qwen 1_5
|
# qwen 1_5
|
||||||
qwen_1_5_7b = Model(
|
qwen_1_5_7b = Model(
|
||||||
|
|
@ -477,7 +428,6 @@ german_7b = Model(
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
### HuggingFaceH4 ###
|
### HuggingFaceH4 ###
|
||||||
zephyr_7b = Model(
|
zephyr_7b = Model(
|
||||||
name = 'zephyr-7b',
|
name = 'zephyr-7b',
|
||||||
|
|
@ -492,8 +442,6 @@ neural_7b = Model(
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#############
|
#############
|
||||||
### Image ###
|
### Image ###
|
||||||
#############
|
#############
|
||||||
|
|
@ -527,66 +475,55 @@ flux = Model(
|
||||||
name = 'flux',
|
name = 'flux',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = IterListProvider([Blackbox, AIUncensored, Airforce])
|
best_provider = IterListProvider([Blackbox, AIUncensored, Airforce])
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_pro = Model(
|
flux_pro = Model(
|
||||||
name = 'flux-pro',
|
name = 'flux-pro',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = IterListProvider([Airforce])
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_realism = Model(
|
flux_realism = Model(
|
||||||
name = 'flux-realism',
|
name = 'flux-realism',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = IterListProvider([Airforce])
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_anime = Model(
|
flux_anime = Model(
|
||||||
name = 'flux-anime',
|
name = 'flux-anime',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_3d = Model(
|
flux_3d = Model(
|
||||||
name = 'flux-3d',
|
name = 'flux-3d',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_disney = Model(
|
flux_disney = Model(
|
||||||
name = 'flux-disney',
|
name = 'flux-disney',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_pixel = Model(
|
flux_pixel = Model(
|
||||||
name = 'flux-pixel',
|
name = 'flux-pixel',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
flux_4o = Model(
|
flux_4o = Model(
|
||||||
name = 'flux-4o',
|
name = 'flux-4o',
|
||||||
base_provider = 'Flux AI',
|
base_provider = 'Flux AI',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Other ###
|
### Other ###
|
||||||
any_dark = Model(
|
any_dark = Model(
|
||||||
name = 'any-dark',
|
name = 'any-dark',
|
||||||
base_provider = '',
|
base_provider = '',
|
||||||
best_provider = Airforce
|
best_provider = Airforce
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class ModelUtils:
|
class ModelUtils:
|
||||||
|
|
@ -597,12 +534,14 @@ class ModelUtils:
|
||||||
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
|
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
|
||||||
"""
|
"""
|
||||||
convert: dict[str, Model] = {
|
convert: dict[str, Model] = {
|
||||||
|
|
||||||
############
|
############
|
||||||
### Text ###
|
### Text ###
|
||||||
############
|
############
|
||||||
|
|
||||||
### OpenAI ###
|
### OpenAI ###
|
||||||
|
# gpt-3
|
||||||
|
'gpt-3': gpt_35_turbo,
|
||||||
|
|
||||||
# gpt-3.5
|
# gpt-3.5
|
||||||
'gpt-3.5-turbo': gpt_35_turbo,
|
'gpt-3.5-turbo': gpt_35_turbo,
|
||||||
|
|
||||||
|
|
@ -612,11 +551,6 @@ class ModelUtils:
|
||||||
'gpt-4': gpt_4,
|
'gpt-4': gpt_4,
|
||||||
'gpt-4-turbo': gpt_4_turbo,
|
'gpt-4-turbo': gpt_4_turbo,
|
||||||
|
|
||||||
# o1
|
|
||||||
'o1': o1,
|
|
||||||
'o1-mini': o1_mini,
|
|
||||||
|
|
||||||
|
|
||||||
### Meta ###
|
### Meta ###
|
||||||
"meta-ai": meta,
|
"meta-ai": meta,
|
||||||
|
|
||||||
|
|
@ -636,22 +570,16 @@ class ModelUtils:
|
||||||
'llama-3.2-11b': llama_3_2_11b,
|
'llama-3.2-11b': llama_3_2_11b,
|
||||||
|
|
||||||
### Mistral ###
|
### Mistral ###
|
||||||
'mistral-7b': mistral_7b,
|
|
||||||
'mixtral-8x7b': mixtral_8x7b,
|
'mixtral-8x7b': mixtral_8x7b,
|
||||||
'mistral-nemo': mistral_nemo,
|
'mistral-nemo': mistral_nemo,
|
||||||
|
|
||||||
|
|
||||||
### NousResearch ###
|
### NousResearch ###
|
||||||
'hermes-2-pro': hermes_2_pro,
|
|
||||||
'hermes-2-dpo': hermes_2_dpo,
|
|
||||||
'hermes-3': hermes_3,
|
'hermes-3': hermes_3,
|
||||||
|
|
||||||
|
|
||||||
### Microsoft ###
|
### Microsoft ###
|
||||||
'phi-2': phi_2,
|
'phi-2': phi_2,
|
||||||
'phi-3.5-mini': phi_3_5_mini,
|
'phi-3.5-mini': phi_3_5_mini,
|
||||||
|
|
||||||
|
|
||||||
### Google ###
|
### Google ###
|
||||||
# gemini
|
# gemini
|
||||||
'gemini': gemini,
|
'gemini': gemini,
|
||||||
|
|
@ -661,7 +589,6 @@ class ModelUtils:
|
||||||
# gemma
|
# gemma
|
||||||
'gemma-2b': gemma_2b,
|
'gemma-2b': gemma_2b,
|
||||||
|
|
||||||
|
|
||||||
### Anthropic ###
|
### Anthropic ###
|
||||||
'claude-2.1': claude_2_1,
|
'claude-2.1': claude_2_1,
|
||||||
|
|
||||||
|
|
@ -673,101 +600,52 @@ class ModelUtils:
|
||||||
# claude 3.5
|
# claude 3.5
|
||||||
'claude-3.5-sonnet': claude_3_5_sonnet,
|
'claude-3.5-sonnet': claude_3_5_sonnet,
|
||||||
|
|
||||||
|
|
||||||
### Reka AI ###
|
### Reka AI ###
|
||||||
'reka-core': reka_core,
|
'reka-core': reka_core,
|
||||||
|
|
||||||
|
|
||||||
### Blackbox AI ###
|
### Blackbox AI ###
|
||||||
'blackboxai': blackboxai,
|
'blackboxai': blackboxai,
|
||||||
'blackboxai-pro': blackboxai_pro,
|
'blackboxai-pro': blackboxai_pro,
|
||||||
|
|
||||||
|
|
||||||
### CohereForAI ###
|
### CohereForAI ###
|
||||||
'command-r+': command_r_plus,
|
'command-r+': command_r_plus,
|
||||||
|
|
||||||
|
|
||||||
### GigaChat ###
|
### GigaChat ###
|
||||||
'gigachat': gigachat,
|
'gigachat': gigachat,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Qwen ###
|
|
||||||
# qwen 1.5
|
|
||||||
'qwen-1.5-7b': qwen_1_5_7b,
|
'qwen-1.5-7b': qwen_1_5_7b,
|
||||||
|
|
||||||
# qwen 2
|
|
||||||
'qwen-2-72b': qwen_2_72b,
|
'qwen-2-72b': qwen_2_72b,
|
||||||
|
|
||||||
# qwen 2.5
|
|
||||||
'qwen-2.5-coder-32b': qwen_2_5_coder_32b,
|
|
||||||
|
|
||||||
|
|
||||||
### Upstage ###
|
### Upstage ###
|
||||||
'solar-mini': solar_mini,
|
|
||||||
'solar-pro': solar_pro,
|
'solar-pro': solar_pro,
|
||||||
|
|
||||||
|
|
||||||
### Inflection ###
|
### Inflection ###
|
||||||
'pi': pi,
|
'pi': pi,
|
||||||
|
|
||||||
|
|
||||||
### DeepSeek ###
|
|
||||||
'deepseek-coder': deepseek_coder,
|
|
||||||
|
|
||||||
|
|
||||||
### Yorickvp ###
|
### Yorickvp ###
|
||||||
'llava-13b': llava_13b,
|
'llava-13b': llava_13b,
|
||||||
|
|
||||||
|
|
||||||
### WizardLM ###
|
### WizardLM ###
|
||||||
'wizardlm-2-8x22b': wizardlm_2_8x22b,
|
'wizardlm-2-8x22b': wizardlm_2_8x22b,
|
||||||
|
|
||||||
|
|
||||||
### OpenChat ###
|
### OpenChat ###
|
||||||
'openchat-3.5': openchat_3_5,
|
'openchat-3.5': openchat_3_5,
|
||||||
|
|
||||||
|
|
||||||
### x.ai ###
|
### x.ai ###
|
||||||
'grok-2': grok_2,
|
'grok-2': grok_2,
|
||||||
'grok-2-mini': grok_2_mini,
|
'grok-2-mini': grok_2_mini,
|
||||||
'grok-beta': grok_beta,
|
'grok-beta': grok_beta,
|
||||||
|
|
||||||
|
|
||||||
### Perplexity AI ###
|
### Perplexity AI ###
|
||||||
'sonar-online': sonar_online,
|
'sonar-online': sonar_online,
|
||||||
'sonar-chat': sonar_chat,
|
'sonar-chat': sonar_chat,
|
||||||
|
|
||||||
|
|
||||||
### TheBloke ###
|
### TheBloke ###
|
||||||
'german-7b': german_7b,
|
'german-7b': german_7b,
|
||||||
|
|
||||||
|
|
||||||
### Nvidia ###
|
### Nvidia ###
|
||||||
'nemotron-70b': nemotron_70b,
|
'nemotron-70b': nemotron_70b,
|
||||||
|
|
||||||
|
|
||||||
### Teknium ###
|
|
||||||
'openhermes-2.5': openhermes_2_5,
|
|
||||||
|
|
||||||
|
|
||||||
### Liquid ###
|
|
||||||
'lfm-40b': lfm_40b,
|
|
||||||
|
|
||||||
|
|
||||||
### DiscoResearch ###
|
|
||||||
'german-7b': german_7b,
|
|
||||||
|
|
||||||
|
|
||||||
### HuggingFaceH4 ###
|
|
||||||
'zephyr-7b': zephyr_7b,
|
|
||||||
|
|
||||||
|
|
||||||
### Inferless ###
|
|
||||||
'neural-7b': neural_7b,
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#############
|
#############
|
||||||
### Image ###
|
### Image ###
|
||||||
#############
|
#############
|
||||||
|
|
@ -776,11 +654,9 @@ class ModelUtils:
|
||||||
'sdxl': sdxl,
|
'sdxl': sdxl,
|
||||||
'sd-3': sd_3,
|
'sd-3': sd_3,
|
||||||
|
|
||||||
|
|
||||||
### Playground ###
|
### Playground ###
|
||||||
'playground-v2.5': playground_v2_5,
|
'playground-v2.5': playground_v2_5,
|
||||||
|
|
||||||
|
|
||||||
### Flux AI ###
|
### Flux AI ###
|
||||||
'flux': flux,
|
'flux': flux,
|
||||||
'flux-pro': flux_pro,
|
'flux-pro': flux_pro,
|
||||||
|
|
@ -791,7 +667,6 @@ class ModelUtils:
|
||||||
'flux-pixel': flux_pixel,
|
'flux-pixel': flux_pixel,
|
||||||
'flux-4o': flux_4o,
|
'flux-4o': flux_4o,
|
||||||
|
|
||||||
|
|
||||||
### Other ###
|
### Other ###
|
||||||
'any-dark': any_dark,
|
'any-dark': any_dark,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,13 @@ from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from inspect import signature, Parameter
|
from inspect import signature, Parameter
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
|
||||||
from ..typing import CreateResult, AsyncResult, Messages
|
from ..typing import CreateResult, AsyncResult, Messages
|
||||||
from .types import BaseProvider, FinishReason
|
from .types import BaseProvider, FinishReason
|
||||||
from ..errors import NestAsyncioError, ModelNotSupportedError
|
from ..errors import NestAsyncioError, ModelNotSupportedError
|
||||||
|
|
@ -17,6 +19,17 @@ if sys.version_info < (3, 10):
|
||||||
else:
|
else:
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
has_nest_asyncio = True
|
||||||
|
except ImportError:
|
||||||
|
has_nest_asyncio = False
|
||||||
|
try:
|
||||||
|
import uvloop
|
||||||
|
has_uvloop = True
|
||||||
|
except ImportError:
|
||||||
|
has_uvloop = False
|
||||||
|
|
||||||
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
|
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
try:
|
try:
|
||||||
|
|
@ -31,18 +44,14 @@ def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
# Do not patch uvloop loop because its incompatible.
|
# Do not patch uvloop loop because its incompatible.
|
||||||
try:
|
if has_uvloop:
|
||||||
import uvloop
|
|
||||||
if isinstance(loop, uvloop.Loop):
|
if isinstance(loop, uvloop.Loop):
|
||||||
return loop
|
return loop
|
||||||
except (ImportError, ModuleNotFoundError):
|
if not hasattr(loop.__class__, "_nest_patched"):
|
||||||
pass
|
if has_nest_asyncio:
|
||||||
if check_nested and not hasattr(loop.__class__, "_nest_patched"):
|
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
nest_asyncio.apply(loop)
|
nest_asyncio.apply(loop)
|
||||||
except ImportError:
|
elif check_nested:
|
||||||
raise NestAsyncioError('Install "nest_asyncio" package')
|
raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio')
|
||||||
return loop
|
return loop
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
@ -154,7 +163,7 @@ class AsyncProvider(AbstractProvider):
|
||||||
Returns:
|
Returns:
|
||||||
CreateResult: The result of the completion creation.
|
CreateResult: The result of the completion creation.
|
||||||
"""
|
"""
|
||||||
get_running_loop(check_nested=True)
|
get_running_loop(check_nested=False)
|
||||||
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -208,7 +217,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||||
Returns:
|
Returns:
|
||||||
CreateResult: The result of the streaming completion creation.
|
CreateResult: The result of the streaming completion creation.
|
||||||
"""
|
"""
|
||||||
loop = get_running_loop(check_nested=True)
|
loop = get_running_loop(check_nested=False)
|
||||||
new_loop = False
|
new_loop = False
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|
@ -222,7 +231,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||||
while True:
|
while True:
|
||||||
yield loop.run_until_complete(await_callback(gen.__anext__))
|
yield loop.run_until_complete(await_callback(gen.__anext__))
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
...
|
pass
|
||||||
finally:
|
finally:
|
||||||
if new_loop:
|
if new_loop:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union, Dict, Type
|
from typing import Union, Dict, Type
|
||||||
from ..typing import Messages, CreateResult
|
from ..typing import Messages, CreateResult
|
||||||
|
from .conversation import BaseConversation
|
||||||
|
|
||||||
class BaseProvider(ABC):
|
class BaseProvider(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from typing import Iterator
|
||||||
|
from http.cookies import Morsel
|
||||||
try:
|
try:
|
||||||
from curl_cffi.requests import Session, Response
|
from curl_cffi.requests import Session, Response
|
||||||
from .curl_cffi import StreamResponse, StreamSession, FormData
|
from .curl_cffi import StreamResponse, StreamSession, FormData
|
||||||
|
|
@ -14,11 +17,19 @@ try:
|
||||||
has_webview = True
|
has_webview = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_webview = False
|
has_webview = False
|
||||||
|
try:
|
||||||
|
import nodriver
|
||||||
|
from nodriver.cdp.network import CookieParam
|
||||||
|
has_nodriver = True
|
||||||
|
except ImportError:
|
||||||
|
has_nodriver = False
|
||||||
|
|
||||||
|
from .. import debug
|
||||||
from .raise_for_status import raise_for_status
|
from .raise_for_status import raise_for_status
|
||||||
from ..webdriver import WebDriver, WebDriverSession
|
from ..webdriver import WebDriver, WebDriverSession
|
||||||
from ..webdriver import bypass_cloudflare, get_driver_cookies
|
from ..webdriver import bypass_cloudflare, get_driver_cookies
|
||||||
from ..errors import MissingRequirementsError
|
from ..errors import MissingRequirementsError
|
||||||
|
from ..typing import Cookies
|
||||||
from .defaults import DEFAULT_HEADERS, WEBVIEW_HAEDERS
|
from .defaults import DEFAULT_HEADERS, WEBVIEW_HAEDERS
|
||||||
|
|
||||||
async def get_args_from_webview(url: str) -> dict:
|
async def get_args_from_webview(url: str) -> dict:
|
||||||
|
|
@ -106,3 +117,52 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
impersonate="chrome"
|
impersonate="chrome"
|
||||||
)
|
)
|
||||||
|
def get_cookie_params_from_dict(cookies: Cookies, url: str = None, domain: str = None) -> list[CookieParam]:
|
||||||
|
[CookieParam.from_json({
|
||||||
|
"name": key,
|
||||||
|
"value": value,
|
||||||
|
"url": url,
|
||||||
|
"domain": domain
|
||||||
|
}) for key, value in cookies.items()]
|
||||||
|
|
||||||
|
async def get_args_from_nodriver(
|
||||||
|
url: str,
|
||||||
|
proxy: str = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
cookies: Cookies = None
|
||||||
|
) -> dict:
|
||||||
|
if not has_nodriver:
|
||||||
|
raise MissingRequirementsError('Install "nodriver" package | pip install -U nodriver')
|
||||||
|
if debug.logging:
|
||||||
|
print(f"Open nodriver with url: {url}")
|
||||||
|
browser = await nodriver.start(
|
||||||
|
browser_args=None if proxy is None else [f"--proxy-server={proxy}"],
|
||||||
|
)
|
||||||
|
domain = urlparse(url).netloc
|
||||||
|
if cookies is None:
|
||||||
|
cookies = {}
|
||||||
|
else:
|
||||||
|
await browser.cookies.set_all(get_cookie_params_from_dict(cookies, url=url, domain=domain))
|
||||||
|
page = await browser.get(url)
|
||||||
|
for c in await browser.cookies.get_all():
|
||||||
|
if c.domain.endswith(domain):
|
||||||
|
cookies[c.name] = c.value
|
||||||
|
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||||
|
await page.wait_for("body:not(.no-js)", timeout=timeout)
|
||||||
|
await page.close()
|
||||||
|
browser.stop()
|
||||||
|
return {
|
||||||
|
"cookies": cookies,
|
||||||
|
"headers": {
|
||||||
|
**DEFAULT_HEADERS,
|
||||||
|
"user-agent": user_agent,
|
||||||
|
"referer": url,
|
||||||
|
},
|
||||||
|
"proxy": proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
def merge_cookies(cookies: Iterator[Morsel], response: Response) -> Cookies:
|
||||||
|
if cookies is None:
|
||||||
|
cookies = {}
|
||||||
|
for cookie in response.cookies.jar:
|
||||||
|
cookies[cookie.name] = cookie.value
|
||||||
|
|
@ -11,6 +11,8 @@ class CloudflareError(ResponseStatusError):
|
||||||
...
|
...
|
||||||
|
|
||||||
def is_cloudflare(text: str) -> bool:
|
def is_cloudflare(text: str) -> bool:
|
||||||
|
if "<title>Attention Required! | Cloudflare</title>" in text:
|
||||||
|
return True
|
||||||
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text
|
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text
|
||||||
|
|
||||||
def is_openai(text: str) -> bool:
|
def is_openai(text: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,4 @@ requests
|
||||||
aiohttp
|
aiohttp
|
||||||
brotli
|
brotli
|
||||||
pycryptodome
|
pycryptodome
|
||||||
curl_cffi>=0.6.2
|
|
||||||
nest_asyncio
|
nest_asyncio
|
||||||
cloudscraper
|
|
||||||
|
|
|
||||||
19
setup.py
19
setup.py
|
|
@ -13,8 +13,7 @@ INSTALL_REQUIRE = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"brotli",
|
"brotli",
|
||||||
"pycryptodome",
|
"pycryptodome",
|
||||||
"curl_cffi>=0.6.2",
|
"nest_asyncio",
|
||||||
"cloudscraper" # Cloudflare
|
|
||||||
]
|
]
|
||||||
|
|
||||||
EXTRA_REQUIRE = {
|
EXTRA_REQUIRE = {
|
||||||
|
|
@ -22,18 +21,10 @@ EXTRA_REQUIRE = {
|
||||||
"curl_cffi>=0.6.2",
|
"curl_cffi>=0.6.2",
|
||||||
"certifi",
|
"certifi",
|
||||||
"browser_cookie3", # get_cookies
|
"browser_cookie3", # get_cookies
|
||||||
"PyExecJS", # GptForLove, Vercel
|
|
||||||
"duckduckgo-search>=5.0" ,# internet.search
|
"duckduckgo-search>=5.0" ,# internet.search
|
||||||
"beautifulsoup4", # internet.search and bing.create_images
|
"beautifulsoup4", # internet.search and bing.create_images
|
||||||
"brotli", # openai, bing
|
"brotli", # openai, bing
|
||||||
# webdriver
|
|
||||||
#"undetected-chromedriver>=3.5.5",
|
|
||||||
#"setuptools",
|
|
||||||
#"selenium-wire"
|
|
||||||
# webview
|
|
||||||
"pywebview",
|
|
||||||
"platformdirs",
|
"platformdirs",
|
||||||
"plyer",
|
|
||||||
"cryptography",
|
"cryptography",
|
||||||
"aiohttp_socks", # proxy
|
"aiohttp_socks", # proxy
|
||||||
"pillow", # image
|
"pillow", # image
|
||||||
|
|
@ -41,7 +32,8 @@ EXTRA_REQUIRE = {
|
||||||
"werkzeug", "flask", # gui
|
"werkzeug", "flask", # gui
|
||||||
"fastapi", # api
|
"fastapi", # api
|
||||||
"uvicorn", "nest_asyncio", # api
|
"uvicorn", "nest_asyncio", # api
|
||||||
"pycryptodome" # openai
|
"pycryptodome", # openai
|
||||||
|
"nodriver",
|
||||||
],
|
],
|
||||||
"image": [
|
"image": [
|
||||||
"pillow",
|
"pillow",
|
||||||
|
|
@ -60,12 +52,9 @@ EXTRA_REQUIRE = {
|
||||||
"plyer",
|
"plyer",
|
||||||
"cryptography"
|
"cryptography"
|
||||||
],
|
],
|
||||||
"openai": [
|
|
||||||
"pycryptodome"
|
|
||||||
],
|
|
||||||
"api": [
|
"api": [
|
||||||
"loguru", "fastapi",
|
"loguru", "fastapi",
|
||||||
"uvicorn", "nest_asyncio"
|
"uvicorn",
|
||||||
],
|
],
|
||||||
"gui": [
|
"gui": [
|
||||||
"werkzeug", "flask",
|
"werkzeug", "flask",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue