mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
IterListProvider support for generating images (#2441)
* IterListProvider support for generating images * Add missing get_har_files import in Copilot * Fix typo in dall-e-3 model name * Add image client unittests * Add MicrosoftDesigner provider * Import MicrosoftDesigner and add it to the model list
This commit is contained in:
parent
8d5d522c4e
commit
79c407b939
16 changed files with 391 additions and 135 deletions
|
|
@ -5,6 +5,7 @@ from .backend import *
|
|||
from .main import *
|
||||
from .model import *
|
||||
from .client import *
|
||||
from .image_client import *
|
||||
from .include import *
|
||||
from .retry_provider import *
|
||||
|
||||
|
|
|
|||
44
etc/unittest/image_client.py
Normal file
44
etc/unittest/image_client.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from g4f.client import AsyncClient, ImagesResponse
|
||||
from g4f.providers.retry_provider import IterListProvider
|
||||
from .mocks import (
|
||||
YieldImageResponseProviderMock,
|
||||
MissingAuthProviderMock,
|
||||
AsyncRaiseExceptionProviderMock,
|
||||
YieldNoneProviderMock
|
||||
)
|
||||
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_skip_provider(self):
|
||||
client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False))
|
||||
response = await client.images.generate("Hello", "", response_format="orginal")
|
||||
self.assertIsInstance(response, ImagesResponse)
|
||||
self.assertEqual("Hello", response.data[0].url)
|
||||
|
||||
async def test_only_one_result(self):
|
||||
client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False))
|
||||
response = await client.images.generate("Hello", "", response_format="orginal")
|
||||
self.assertIsInstance(response, ImagesResponse)
|
||||
self.assertEqual("Hello", response.data[0].url)
|
||||
|
||||
async def test_skip_none(self):
|
||||
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False))
|
||||
response = await client.images.generate("Hello", "", response_format="orginal")
|
||||
self.assertIsInstance(response, ImagesResponse)
|
||||
self.assertEqual("Hello", response.data[0].url)
|
||||
|
||||
def test_raise_exception(self):
|
||||
async def run_exception():
|
||||
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False))
|
||||
await client.images.generate("Hello", "")
|
||||
self.assertRaises(RuntimeError, asyncio.run, run_exception())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider
|
||||
from g4f.image import ImageResponse
|
||||
from g4f.errors import MissingAuthError
|
||||
|
||||
class ProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
|
@ -41,6 +43,25 @@ class YieldProviderMock(AsyncGeneratorProvider):
|
|||
for message in messages:
|
||||
yield message["content"]
|
||||
|
||||
class YieldImageResponseProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model, messages, stream, prompt: str, **kwargs
|
||||
):
|
||||
yield ImageResponse(prompt, "")
|
||||
|
||||
class MissingAuthProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
raise MissingAuthError(cls.__name__)
|
||||
yield cls.__name__
|
||||
|
||||
class RaiseExceptionProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ MODELS = {
|
|||
'flux-pro/v1.1-ultra-raw': {'persona_id': "flux-pro-v1.1-ultra-raw"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan
|
||||
'flux/dev': {'persona_id': "flux-dev"},
|
||||
|
||||
'dalle-e-3': {'persona_id': "dalle-three"},
|
||||
'dall-e-3': {'persona_id': "dalle-three"},
|
||||
|
||||
'recraft-v3': {'persona_id': "recraft"}
|
||||
}
|
||||
|
|
@ -130,7 +130,7 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
"flux-realism": "flux-realism",
|
||||
"flux-dev": "flux/dev",
|
||||
|
||||
"dalle-3": "dalle-e-3",
|
||||
"dalle-3": "dall-e-3",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from http.cookiejar import CookieJar
|
||||
|
|
@ -20,10 +19,10 @@ except ImportError:
|
|||
from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation
|
||||
from .helper import format_prompt
|
||||
from ..typing import CreateResult, Messages, ImageType
|
||||
from ..errors import MissingRequirementsError
|
||||
from ..errors import MissingRequirementsError, NoValidHarFileError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..providers.asyncio import get_running_loop
|
||||
from .openai.har_file import NoValidHarFileError, get_headers
|
||||
from .openai.har_file import get_headers, get_har_files
|
||||
from ..requests import get_nodriver
|
||||
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||
from .. import debug
|
||||
|
|
@ -76,12 +75,12 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
if cls.needs_auth or image is not None:
|
||||
if conversation is None or conversation.access_token is None:
|
||||
try:
|
||||
access_token, cookies = readHAR()
|
||||
access_token, cookies = readHAR(cls.url)
|
||||
except NoValidHarFileError as h:
|
||||
debug.log(f"Copilot: {h}")
|
||||
try:
|
||||
get_running_loop(check_nested=True)
|
||||
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
|
||||
access_token, cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
|
||||
except MissingRequirementsError:
|
||||
raise h
|
||||
else:
|
||||
|
|
@ -162,10 +161,9 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
if not is_started:
|
||||
raise RuntimeError(f"Invalid response: {last_msg}")
|
||||
|
||||
@classmethod
|
||||
async def get_access_token_and_cookies(cls, proxy: str = None):
|
||||
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
|
||||
browser = await get_nodriver(proxy=proxy)
|
||||
page = await browser.get(cls.url)
|
||||
page = await browser.get(url)
|
||||
access_token = None
|
||||
while access_token is None:
|
||||
access_token = await page.evaluate("""
|
||||
|
|
@ -175,22 +173,22 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
|||
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
||||
if (item.credentialType == "AccessToken"
|
||||
&& item.expiresOn > Math.floor(Date.now() / 1000)
|
||||
&& item.target.includes("ChatAI")) {
|
||||
&& item.target.includes("target")) {
|
||||
return item.secret;
|
||||
}
|
||||
} catch(e) {}
|
||||
}
|
||||
})()
|
||||
""")
|
||||
""".replace('"target"', json.dumps(target)))
|
||||
if access_token is None:
|
||||
await asyncio.sleep(1)
|
||||
cookies = {}
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||
cookies[c.name] = c.value
|
||||
await page.close()
|
||||
return access_token, cookies
|
||||
|
||||
def readHAR():
|
||||
def readHAR(url: str):
|
||||
api_key = None
|
||||
cookies = None
|
||||
for path in get_har_files():
|
||||
|
|
@ -201,13 +199,10 @@ def readHAR():
|
|||
# Error: not a HAR file!
|
||||
continue
|
||||
for v in harFile['log']['entries']:
|
||||
if v['request']['url'].startswith(url):
|
||||
v_headers = get_headers(v)
|
||||
if v['request']['url'].startswith(Copilot.url):
|
||||
try:
|
||||
if "authorization" in v_headers:
|
||||
api_key = v_headers["authorization"].split(maxsplit=1).pop()
|
||||
except Exception as e:
|
||||
debug.log(f"Error on read headers: {e}")
|
||||
if v['request']['cookies']:
|
||||
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
|
||||
if api_key is None:
|
||||
|
|
|
|||
167
g4f/Provider/needs_auth/MicrosoftDesigner.py
Normal file
167
g4f/Provider/needs_auth/MicrosoftDesigner.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import aiohttp
|
||||
import random
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from ...image import ImageResponse
|
||||
from ...errors import MissingRequirementsError, NoValidHarFileError
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
from ...requests import get_nodriver
|
||||
from ..Copilot import get_headers, get_har_files
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_random_hex
|
||||
from ... import debug
|
||||
|
||||
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Microsoft Designer"
|
||||
url = "https://designer.microsoft.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
default_image_model = "dall-e-3"
|
||||
image_models = [default_image_model, "1024x1024", "1024x1792", "1792x1024"]
|
||||
models = image_models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
image_size = "1024x1024"
|
||||
if model != cls.default_image_model and model in cls.image_models:
|
||||
image_size = model
|
||||
yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy)
|
||||
|
||||
@classmethod
|
||||
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
|
||||
try:
|
||||
access_token, user_agent = readHAR("https://designerapp.officeapps.live.com")
|
||||
except NoValidHarFileError as h:
|
||||
debug.log(f"{cls.__name__}: {h}")
|
||||
try:
|
||||
access_token, user_agent = await get_access_token_and_user_agent(cls.url, proxy)
|
||||
except MissingRequirementsError:
|
||||
raise h
|
||||
images = await create_images(prompt, access_token, user_agent, image_size, proxy)
|
||||
return ImageResponse(images, prompt)
|
||||
|
||||
async def create_images(prompt: str, access_token: str, user_agent: str, image_size: str, proxy: str = None, seed: int = None):
|
||||
url = 'https://designerapp.officeapps.live.com/designerapp/DallE.ashx?action=GetDallEImagesCogSci'
|
||||
if seed is None:
|
||||
seed = random.randint(0, 10000)
|
||||
|
||||
headers = {
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "en-US",
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
"AudienceGroup": "Production",
|
||||
"Caller": "DesignerApp",
|
||||
"ClientId": "b5c2664a-7e9b-4a7a-8c9a-cd2c52dcf621",
|
||||
"SessionId": str(uuid.uuid4()),
|
||||
"UserId": get_random_hex(16),
|
||||
"ContainerId": "1e2843a7-2a98-4a6c-93f2-42002de5c478",
|
||||
"FileToken": "9f1a4cb7-37e7-4c90-b44d-cb61cfda4bb8",
|
||||
"x-upload-to-storage-das": "1",
|
||||
"traceparent": "",
|
||||
"X-DC-Hint": "FranceCentral",
|
||||
"Platform": "Web",
|
||||
"HostApp": "DesignerApp",
|
||||
"ReleaseChannel": "",
|
||||
"IsSignedInUser": "true",
|
||||
"Locale": "de-DE",
|
||||
"UserType": "MSA",
|
||||
"x-req-start": "2615401",
|
||||
"ClientBuild": "1.0.20241120.9",
|
||||
"ClientName": "DesignerApp",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
"Pragma": "no-cache",
|
||||
"Cache-Control": "no-cache",
|
||||
"Referer": "https://designer.microsoft.com/"
|
||||
}
|
||||
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field('dalle-caption', prompt)
|
||||
form_data.add_field('dalle-scenario-name', 'TextToImage')
|
||||
form_data.add_field('dalle-batch-size', '4')
|
||||
form_data.add_field('dalle-image-response-format', 'UrlWithBase64Thumbnail')
|
||||
form_data.add_field('dalle-seed', seed)
|
||||
form_data.add_field('ClientFlights', 'EnableBICForDALLEFlight')
|
||||
form_data.add_field('dalle-hear-back-in-ms', 1000)
|
||||
form_data.add_field('dalle-include-b64-thumbnails', 'true')
|
||||
form_data.add_field('dalle-aspect-ratio-scaling-factor-b64-thumbnails', 0.3)
|
||||
form_data.add_field('dalle-image-size', image_size)
|
||||
|
||||
async with aiohttp.ClientSession(connector=get_connector(proxy=proxy)) as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
await raise_for_status(response)
|
||||
response_data = await response.json()
|
||||
form_data.add_field('dalle-boost-count', response_data.get('dalle-boost-count', 0))
|
||||
polling_meta_data = response_data.get('polling_response', {}).get('polling_meta_data', {})
|
||||
form_data.add_field('dalle-poll-url', polling_meta_data.get('poll_url', ''))
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(polling_meta_data.get('poll_interval', 1000) / 1000)
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
await raise_for_status(response)
|
||||
response_data = await response.json()
|
||||
images = [image["ImageUrl"] for image in response_data.get('image_urls_thumbnail', [])]
|
||||
if images:
|
||||
return images
|
||||
|
||||
def readHAR(url: str) -> tuple[str, str]:
|
||||
api_key = None
|
||||
user_agent = None
|
||||
for path in get_har_files():
|
||||
with open(path, 'rb') as file:
|
||||
try:
|
||||
harFile = json.loads(file.read())
|
||||
except json.JSONDecodeError:
|
||||
# Error: not a HAR file!
|
||||
continue
|
||||
for v in harFile['log']['entries']:
|
||||
if v['request']['url'].startswith(url):
|
||||
v_headers = get_headers(v)
|
||||
if "authorization" in v_headers:
|
||||
api_key = v_headers["authorization"].split(maxsplit=1).pop()
|
||||
if "user-agent" in v_headers:
|
||||
user_agent = v_headers["user-agent"]
|
||||
if api_key is None:
|
||||
raise NoValidHarFileError("No access token found in .har files")
|
||||
|
||||
return api_key, user_agent
|
||||
|
||||
async def get_access_token_and_user_agent(url: str, proxy: str = None):
|
||||
browser = await get_nodriver(proxy=proxy)
|
||||
page = await browser.get(url)
|
||||
user_agent = await page.evaluate("navigator.userAgent")
|
||||
access_token = None
|
||||
while access_token is None:
|
||||
access_token = await page.evaluate("""
|
||||
(() => {
|
||||
for (var i = 0; i < localStorage.length; i++) {
|
||||
try {
|
||||
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
||||
if (item.credentialType == "AccessToken"
|
||||
&& item.expiresOn > Math.floor(Date.now() / 1000)
|
||||
&& item.target.includes("designerappservice")) {
|
||||
return item.secret;
|
||||
}
|
||||
} catch(e) {}
|
||||
}
|
||||
})()
|
||||
""")
|
||||
if access_token is None:
|
||||
await asyncio.sleep(1)
|
||||
await page.close()
|
||||
return access_token, user_agent
|
||||
|
|
@ -22,10 +22,10 @@ from ...requests.raise_for_status import raise_for_status
|
|||
from ...requests import StreamSession
|
||||
from ...requests import get_nodriver
|
||||
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError
|
||||
from ...errors import MissingAuthError, NoValidHarFileError
|
||||
from ...providers.response import BaseConversation, FinishReason, SynthesizeData
|
||||
from ..helper import format_cookies
|
||||
from ..openai.har_file import get_request_config, NoValidHarFileError
|
||||
from ..openai.har_file import get_request_config
|
||||
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
|
||||
from ..openai.proofofwork import generate_proof_token
|
||||
from ..openai.new import get_requirements_token
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from .HuggingFace import HuggingFace
|
|||
from .HuggingFace2 import HuggingFace2
|
||||
from .MetaAI import MetaAI
|
||||
from .MetaAIAccount import MetaAIAccount
|
||||
from .MicrosoftDesigner import MicrosoftDesigner
|
||||
from .OpenaiAccount import OpenaiAccount
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
from .OpenaiChat import OpenaiChat
|
||||
from .PerplexityApi import PerplexityApi
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from copy import deepcopy
|
|||
from .crypt import decrypt, encrypt
|
||||
from ...requests import StreamSession
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...errors import NoValidHarFileError
|
||||
from ... import debug
|
||||
|
||||
arkose_url = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
|
||||
|
|
@ -21,9 +22,6 @@ backend_anon_url = "https://chatgpt.com/backend-anon/conversation"
|
|||
start_url = "https://chatgpt.com/"
|
||||
conversation_url = "https://chatgpt.com/c/"
|
||||
|
||||
class NoValidHarFileError(Exception):
|
||||
pass
|
||||
|
||||
class RequestConfig:
|
||||
cookies: dict = None
|
||||
headers: dict = None
|
||||
|
|
|
|||
|
|
@ -8,14 +8,11 @@ import logging
|
|||
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...cookies import get_cookies_dir
|
||||
from ...errors import MissingRequirementsError
|
||||
from ...errors import MissingRequirementsError, NoValidHarFileError
|
||||
from ... import debug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NoValidHarFileError(Exception):
|
||||
...
|
||||
|
||||
class arkReq:
|
||||
def __init__(self, arkURL, arkHeaders, arkBody, arkCookies, userAgent):
|
||||
self.arkURL = arkURL
|
||||
|
|
|
|||
|
|
@ -12,15 +12,17 @@ from ..image import ImageResponse, copy_images, images_dir
|
|||
from ..typing import Messages, ImageType
|
||||
from ..providers.types import ProviderType
|
||||
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
|
||||
from ..errors import NoImageResponseError, ModelNotFoundError
|
||||
from ..errors import NoImageResponseError, MissingAuthError, NoValidHarFileError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list
|
||||
from ..providers.asyncio import to_sync_generator, async_generator_to_list
|
||||
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
|
||||
from ..image import to_bytes
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, get_last_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
|
||||
from .. import debug
|
||||
|
||||
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
|
|
@ -274,11 +276,6 @@ class Images:
|
|||
provider_handler = provider
|
||||
if provider_handler is None:
|
||||
return default
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
if provider_handler.providers:
|
||||
provider_handler = provider_handler.providers[0]
|
||||
else:
|
||||
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
|
||||
return provider_handler
|
||||
|
||||
async def async_generate(
|
||||
|
|
@ -291,33 +288,23 @@ class Images:
|
|||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
|
||||
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__
|
||||
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
|
||||
response = None
|
||||
if hasattr(provider_handler, "create_async_generator"):
|
||||
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
|
||||
async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
for provider in provider_handler.providers:
|
||||
try:
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
|
||||
if response is not None:
|
||||
provider_name = provider.__name__
|
||||
break
|
||||
elif hasattr(provider_handler, 'create'):
|
||||
if asyncio.iscoroutinefunction(provider_handler.create):
|
||||
response = await provider_handler.create(prompt)
|
||||
except (MissingAuthError, NoValidHarFileError) as e:
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
else:
|
||||
response = provider_handler.create(prompt)
|
||||
if isinstance(response, str):
|
||||
response = ImageResponse([response], prompt)
|
||||
elif hasattr(provider_handler, "create_completion"):
|
||||
get_running_loop(check_nested=True)
|
||||
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
|
||||
for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} does not support image generation")
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
|
||||
|
||||
if isinstance(response, ImageResponse):
|
||||
return await self._process_image_response(
|
||||
response,
|
||||
|
|
@ -330,6 +317,46 @@ class Images:
|
|||
raise NoImageResponseError(f"No image response from {provider_name}")
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
||||
async def _generate_image_response(
|
||||
self,
|
||||
provider_handler,
|
||||
provider_name,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt_prefix: str = "Generate a image: ",
|
||||
image: ImageType = None,
|
||||
**kwargs
|
||||
) -> ImageResponse:
|
||||
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
|
||||
response = None
|
||||
if hasattr(provider_handler, "create_async_generator"):
|
||||
async for item in provider_handler.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
break
|
||||
elif hasattr(provider_handler, "create_completion"):
|
||||
for item in provider_handler.create_completion(
|
||||
model,
|
||||
messages,
|
||||
True,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider {provider_name} does not support image generation")
|
||||
return response
|
||||
|
||||
def create_variation(
|
||||
self,
|
||||
image: ImageType,
|
||||
|
|
@ -352,33 +379,28 @@ class Images:
|
|||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
|
||||
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__
|
||||
provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
prompt = "create a variation of this image"
|
||||
|
||||
if hasattr(provider_handler, "create_async_generator"):
|
||||
messages = [{"role": "user", "content": "create a variation of this image"}]
|
||||
generator = None
|
||||
response = None
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
# File pointer can be read only once, so we need to convert it to bytes
|
||||
image = to_bytes(image)
|
||||
for provider in provider_handler.providers:
|
||||
try:
|
||||
generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
|
||||
async for chunk in generator:
|
||||
if isinstance(chunk, ImageResponse):
|
||||
response = chunk
|
||||
response = await self._generate_image_response(provider, provider.__name__, model, prompt, image=image, **kwargs)
|
||||
if response is not None:
|
||||
provider_name = provider.__name__
|
||||
break
|
||||
finally:
|
||||
await safe_aclose(generator)
|
||||
elif hasattr(provider_handler, 'create_variation'):
|
||||
if asyncio.iscoroutinefunction(provider.provider_handler):
|
||||
response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||
except (MissingAuthError, NoValidHarFileError) as e:
|
||||
debug.log(f"Image provider {provider.__name__}: {e}")
|
||||
else:
|
||||
response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||
else:
|
||||
raise NoImageResponseError(f"Provider {provider_name} does not support image variation")
|
||||
response = await self._generate_image_response(provider_handler, provider_name, model, prompt, image=image, **kwargs)
|
||||
|
||||
if isinstance(response, str):
|
||||
response = ImageResponse([response])
|
||||
if isinstance(response, ImageResponse):
|
||||
return self._process_image_response(response, response_format, proxy, model, provider_name)
|
||||
return await self._process_image_response(response, response_format, proxy, model, provider_name)
|
||||
if response is None:
|
||||
raise NoImageResponseError(f"No image response from {provider_name}")
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
|
|
|||
|
|
@ -45,3 +45,6 @@ class ResponseError(Exception):
|
|||
|
||||
class ResponseStatusError(Exception):
|
||||
...
|
||||
|
||||
class NoValidHarFileError(Exception):
|
||||
...
|
||||
|
|
@ -7,10 +7,12 @@ from .Provider import (
|
|||
AIChatFree,
|
||||
AmigoChat,
|
||||
Blackbox,
|
||||
BingCreateImages,
|
||||
ChatGpt,
|
||||
ChatGptEs,
|
||||
Cloudflare,
|
||||
Copilot,
|
||||
CopilotAccount,
|
||||
DarkAI,
|
||||
DDG,
|
||||
DeepInfraChat,
|
||||
|
|
@ -25,7 +27,9 @@ from .Provider import (
|
|||
MagickPen,
|
||||
Mhystical,
|
||||
MetaAI,
|
||||
MicrosoftDesigner,
|
||||
OpenaiChat,
|
||||
OpenaiAccount,
|
||||
PerplexityLabs,
|
||||
Pi,
|
||||
Pizzagpt,
|
||||
|
|
@ -629,9 +633,9 @@ flux_4o = Model(
|
|||
|
||||
### OpenAI ###
|
||||
dalle_3 = Model(
|
||||
name = 'dalle-3',
|
||||
name = 'dall-e-3',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = AmigoChat
|
||||
best_provider = IterListProvider([CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages])
|
||||
)
|
||||
|
||||
### Recraft ###
|
||||
|
|
@ -828,6 +832,7 @@ class ModelUtils:
|
|||
|
||||
### OpenAI ###
|
||||
'dalle-3': dalle_3,
|
||||
'dall-e-3': dalle_3,
|
||||
|
||||
### Recraft ###
|
||||
'recraft-v3': recraft_v3,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop, runners
|
||||
from typing import Union, Callable, AsyncGenerator, Generator
|
||||
from typing import Optional, Callable, AsyncGenerator, Generator
|
||||
|
||||
from ..errors import NestAsyncioError
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ try:
|
|||
except ImportError:
|
||||
has_uvloop = False
|
||||
|
||||
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
|
||||
def get_running_loop(check_nested: bool) -> Optional[AbstractEventLoop]:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# Do not patch uvloop loop because its incompatible.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
|
||||
from ..typing import Type, List, CreateResult, Messages, AsyncResult
|
||||
from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewT
|
|||
try:
|
||||
from PIL.Image import Image
|
||||
except ImportError:
|
||||
from typing import Type as Image
|
||||
class Image:
|
||||
pass
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import TypedDict
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue