IterListProvider support for generating images (#2441)

* IterListProvider support for generating images
* Add missing get_har_files import in Copilot
* Fix typo in dall-e-3 model name
* Add image client unittests
* Add MicrosoftDesigner provider
* Import MicrosoftDesigner and add it to the model list
This commit is contained in:
H Lohaus 2024-11-29 13:56:11 +01:00 committed by GitHub
parent 8d5d522c4e
commit 79c407b939
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 391 additions and 135 deletions

View file

@ -5,6 +5,7 @@ from .backend import *
from .main import * from .main import *
from .model import * from .model import *
from .client import * from .client import *
from .image_client import *
from .include import * from .include import *
from .retry_provider import * from .retry_provider import *

View file

@ -0,0 +1,44 @@
from __future__ import annotations
import asyncio
import unittest
from g4f.client import AsyncClient, ImagesResponse
from g4f.providers.retry_provider import IterListProvider
from .mocks import (
YieldImageResponseProviderMock,
MissingAuthProviderMock,
AsyncRaiseExceptionProviderMock,
YieldNoneProviderMock
)
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
async def test_skip_provider(self):
client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)
async def test_only_one_result(self):
client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)
async def test_skip_none(self):
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)
def test_raise_exception(self):
async def run_exception():
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False))
await client.images.generate("Hello", "")
self.assertRaises(RuntimeError, asyncio.run, run_exception())
if __name__ == '__main__':
unittest.main()

View file

@ -1,4 +1,6 @@
from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider
from g4f.image import ImageResponse
from g4f.errors import MissingAuthError
class ProviderMock(AbstractProvider): class ProviderMock(AbstractProvider):
working = True working = True
@ -41,6 +43,25 @@ class YieldProviderMock(AsyncGeneratorProvider):
for message in messages: for message in messages:
yield message["content"] yield message["content"]
class YieldImageResponseProviderMock(AsyncGeneratorProvider):
working = True
@classmethod
async def create_async_generator(
cls, model, messages, stream, prompt: str, **kwargs
):
yield ImageResponse(prompt, "")
class MissingAuthProviderMock(AbstractProvider):
working = True
@classmethod
def create_completion(
cls, model, messages, stream, **kwargs
):
raise MissingAuthError(cls.__name__)
yield cls.__name__
class RaiseExceptionProviderMock(AbstractProvider): class RaiseExceptionProviderMock(AbstractProvider):
working = True working = True

View file

@ -66,7 +66,7 @@ MODELS = {
'flux-pro/v1.1-ultra-raw': {'persona_id': "flux-pro-v1.1-ultra-raw"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan 'flux-pro/v1.1-ultra-raw': {'persona_id': "flux-pro-v1.1-ultra-raw"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan
'flux/dev': {'persona_id': "flux-dev"}, 'flux/dev': {'persona_id': "flux-dev"},
'dalle-e-3': {'persona_id': "dalle-three"}, 'dall-e-3': {'persona_id': "dalle-three"},
'recraft-v3': {'persona_id': "recraft"} 'recraft-v3': {'persona_id': "recraft"}
} }
@ -130,7 +130,7 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
"flux-realism": "flux-realism", "flux-realism": "flux-realism",
"flux-dev": "flux/dev", "flux-dev": "flux/dev",
"dalle-3": "dalle-e-3", "dalle-3": "dall-e-3",
} }
@classmethod @classmethod

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os
import json import json
import asyncio import asyncio
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
@ -20,10 +19,10 @@ except ImportError:
from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation
from .helper import format_prompt from .helper import format_prompt
from ..typing import CreateResult, Messages, ImageType from ..typing import CreateResult, Messages, ImageType
from ..errors import MissingRequirementsError from ..errors import MissingRequirementsError, NoValidHarFileError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..providers.asyncio import get_running_loop from ..providers.asyncio import get_running_loop
from .openai.har_file import NoValidHarFileError, get_headers from .openai.har_file import get_headers, get_har_files
from ..requests import get_nodriver from ..requests import get_nodriver
from ..image import ImageResponse, to_bytes, is_accepted_format from ..image import ImageResponse, to_bytes, is_accepted_format
from .. import debug from .. import debug
@ -76,12 +75,12 @@ class Copilot(AbstractProvider, ProviderModelMixin):
if cls.needs_auth or image is not None: if cls.needs_auth or image is not None:
if conversation is None or conversation.access_token is None: if conversation is None or conversation.access_token is None:
try: try:
access_token, cookies = readHAR() access_token, cookies = readHAR(cls.url)
except NoValidHarFileError as h: except NoValidHarFileError as h:
debug.log(f"Copilot: {h}") debug.log(f"Copilot: {h}")
try: try:
get_running_loop(check_nested=True) get_running_loop(check_nested=True)
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy)) access_token, cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
except MissingRequirementsError: except MissingRequirementsError:
raise h raise h
else: else:
@ -162,35 +161,34 @@ class Copilot(AbstractProvider, ProviderModelMixin):
if not is_started: if not is_started:
raise RuntimeError(f"Invalid response: {last_msg}") raise RuntimeError(f"Invalid response: {last_msg}")
@classmethod async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
async def get_access_token_and_cookies(cls, proxy: str = None): browser = await get_nodriver(proxy=proxy)
browser = await get_nodriver(proxy=proxy) page = await browser.get(url)
page = await browser.get(cls.url) access_token = None
access_token = None while access_token is None:
while access_token is None: access_token = await page.evaluate("""
access_token = await page.evaluate(""" (() => {
(() => { for (var i = 0; i < localStorage.length; i++) {
for (var i = 0; i < localStorage.length; i++) { try {
try { item = JSON.parse(localStorage.getItem(localStorage.key(i)));
item = JSON.parse(localStorage.getItem(localStorage.key(i))); if (item.credentialType == "AccessToken"
if (item.credentialType == "AccessToken" && item.expiresOn > Math.floor(Date.now() / 1000)
&& item.expiresOn > Math.floor(Date.now() / 1000) && item.target.includes("target")) {
&& item.target.includes("ChatAI")) { return item.secret;
return item.secret; }
} } catch(e) {}
} catch(e) {} }
} })()
})() """.replace('"target"', json.dumps(target)))
""") if access_token is None:
if access_token is None: await asyncio.sleep(1)
await asyncio.sleep(1) cookies = {}
cookies = {} for c in await page.send(nodriver.cdp.network.get_cookies([url])):
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): cookies[c.name] = c.value
cookies[c.name] = c.value await page.close()
await page.close() return access_token, cookies
return access_token, cookies
def readHAR(): def readHAR(url: str):
api_key = None api_key = None
cookies = None cookies = None
for path in get_har_files(): for path in get_har_files():
@ -201,13 +199,10 @@ def readHAR():
# Error: not a HAR file! # Error: not a HAR file!
continue continue
for v in harFile['log']['entries']: for v in harFile['log']['entries']:
v_headers = get_headers(v) if v['request']['url'].startswith(url):
if v['request']['url'].startswith(Copilot.url): v_headers = get_headers(v)
try: if "authorization" in v_headers:
if "authorization" in v_headers: api_key = v_headers["authorization"].split(maxsplit=1).pop()
api_key = v_headers["authorization"].split(maxsplit=1).pop()
except Exception as e:
debug.log(f"Error on read headers: {e}")
if v['request']['cookies']: if v['request']['cookies']:
cookies = {c['name']: c['value'] for c in v['request']['cookies']} cookies = {c['name']: c['value'] for c in v['request']['cookies']}
if api_key is None: if api_key is None:

View file

@ -0,0 +1,167 @@
from __future__ import annotations
import uuid
import aiohttp
import random
import asyncio
import json
from ...image import ImageResponse
from ...errors import MissingRequirementsError, NoValidHarFileError
from ...typing import AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import get_connector
from ...requests import get_nodriver
from ..Copilot import get_headers, get_har_files
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_random_hex
from ... import debug
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
label = "Microsoft Designer"
url = "https://designer.microsoft.com"
working = True
needs_auth = True
default_image_model = "dall-e-3"
image_models = [default_image_model, "1024x1024", "1024x1792", "1792x1024"]
models = image_models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
image_size = "1024x1024"
if model != cls.default_image_model and model in cls.image_models:
image_size = model
yield await cls.generate(messages[-1]["content"] if prompt is None else prompt, image_size, proxy)
@classmethod
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
try:
access_token, user_agent = readHAR("https://designerapp.officeapps.live.com")
except NoValidHarFileError as h:
debug.log(f"{cls.__name__}: {h}")
try:
access_token, user_agent = await get_access_token_and_user_agent(cls.url, proxy)
except MissingRequirementsError:
raise h
images = await create_images(prompt, access_token, user_agent, image_size, proxy)
return ImageResponse(images, prompt)
async def create_images(prompt: str, access_token: str, user_agent: str, image_size: str, proxy: str = None, seed: int = None):
url = 'https://designerapp.officeapps.live.com/designerapp/DallE.ashx?action=GetDallEImagesCogSci'
if seed is None:
seed = random.randint(0, 10000)
headers = {
"User-Agent": user_agent,
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-US",
'Authorization': f'Bearer {access_token}',
"AudienceGroup": "Production",
"Caller": "DesignerApp",
"ClientId": "b5c2664a-7e9b-4a7a-8c9a-cd2c52dcf621",
"SessionId": str(uuid.uuid4()),
"UserId": get_random_hex(16),
"ContainerId": "1e2843a7-2a98-4a6c-93f2-42002de5c478",
"FileToken": "9f1a4cb7-37e7-4c90-b44d-cb61cfda4bb8",
"x-upload-to-storage-das": "1",
"traceparent": "",
"X-DC-Hint": "FranceCentral",
"Platform": "Web",
"HostApp": "DesignerApp",
"ReleaseChannel": "",
"IsSignedInUser": "true",
"Locale": "de-DE",
"UserType": "MSA",
"x-req-start": "2615401",
"ClientBuild": "1.0.20241120.9",
"ClientName": "DesignerApp",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "cross-site",
"Pragma": "no-cache",
"Cache-Control": "no-cache",
"Referer": "https://designer.microsoft.com/"
}
form_data = aiohttp.FormData()
form_data.add_field('dalle-caption', prompt)
form_data.add_field('dalle-scenario-name', 'TextToImage')
form_data.add_field('dalle-batch-size', '4')
form_data.add_field('dalle-image-response-format', 'UrlWithBase64Thumbnail')
form_data.add_field('dalle-seed', seed)
form_data.add_field('ClientFlights', 'EnableBICForDALLEFlight')
form_data.add_field('dalle-hear-back-in-ms', 1000)
form_data.add_field('dalle-include-b64-thumbnails', 'true')
form_data.add_field('dalle-aspect-ratio-scaling-factor-b64-thumbnails', 0.3)
form_data.add_field('dalle-image-size', image_size)
async with aiohttp.ClientSession(connector=get_connector(proxy=proxy)) as session:
async with session.post(url, headers=headers, data=form_data) as response:
await raise_for_status(response)
response_data = await response.json()
form_data.add_field('dalle-boost-count', response_data.get('dalle-boost-count', 0))
polling_meta_data = response_data.get('polling_response', {}).get('polling_meta_data', {})
form_data.add_field('dalle-poll-url', polling_meta_data.get('poll_url', ''))
while True:
await asyncio.sleep(polling_meta_data.get('poll_interval', 1000) / 1000)
async with session.post(url, headers=headers, data=form_data) as response:
await raise_for_status(response)
response_data = await response.json()
images = [image["ImageUrl"] for image in response_data.get('image_urls_thumbnail', [])]
if images:
return images
def readHAR(url: str) -> tuple[str, str]:
api_key = None
user_agent = None
for path in get_har_files():
with open(path, 'rb') as file:
try:
harFile = json.loads(file.read())
except json.JSONDecodeError:
# Error: not a HAR file!
continue
for v in harFile['log']['entries']:
if v['request']['url'].startswith(url):
v_headers = get_headers(v)
if "authorization" in v_headers:
api_key = v_headers["authorization"].split(maxsplit=1).pop()
if "user-agent" in v_headers:
user_agent = v_headers["user-agent"]
if api_key is None:
raise NoValidHarFileError("No access token found in .har files")
return api_key, user_agent
async def get_access_token_and_user_agent(url: str, proxy: str = None):
browser = await get_nodriver(proxy=proxy)
page = await browser.get(url)
user_agent = await page.evaluate("navigator.userAgent")
access_token = None
while access_token is None:
access_token = await page.evaluate("""
(() => {
for (var i = 0; i < localStorage.length; i++) {
try {
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
if (item.credentialType == "AccessToken"
&& item.expiresOn > Math.floor(Date.now() / 1000)
&& item.target.includes("designerappservice")) {
return item.secret;
}
} catch(e) {}
}
})()
""")
if access_token is None:
await asyncio.sleep(1)
await page.close()
return access_token, user_agent

View file

@ -22,10 +22,10 @@ from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession from ...requests import StreamSession
from ...requests import get_nodriver from ...requests import get_nodriver
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import BaseConversation, FinishReason, SynthesizeData from ...providers.response import BaseConversation, FinishReason, SynthesizeData
from ..helper import format_cookies from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError from ..openai.har_file import get_request_config
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
from ..openai.proofofwork import generate_proof_token from ..openai.proofofwork import generate_proof_token
from ..openai.new import get_requirements_token from ..openai.new import get_requirements_token

View file

@ -1,25 +1,27 @@
from .gigachat import * from .gigachat import *
from .BingCreateImages import BingCreateImages from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount from .CopilotAccount import CopilotAccount
from .DeepInfra import DeepInfra from .DeepInfra import DeepInfra
from .DeepInfraImage import DeepInfraImage from .DeepInfraImage import DeepInfraImage
from .Gemini import Gemini from .Gemini import Gemini
from .GeminiPro import GeminiPro from .GeminiPro import GeminiPro
from .GithubCopilot import GithubCopilot from .GithubCopilot import GithubCopilot
from .Groq import Groq from .Groq import Groq
from .HuggingFace import HuggingFace from .HuggingFace import HuggingFace
from .HuggingFace2 import HuggingFace2 from .HuggingFace2 import HuggingFace2
from .MetaAI import MetaAI from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount from .MetaAIAccount import MetaAIAccount
from .OpenaiAPI import OpenaiAPI from .MicrosoftDesigner import MicrosoftDesigner
from .OpenaiChat import OpenaiChat from .OpenaiAccount import OpenaiAccount
from .PerplexityApi import PerplexityApi from .OpenaiAPI import OpenaiAPI
from .Poe import Poe from .OpenaiChat import OpenaiChat
from .PollinationsAI import PollinationsAI from .PerplexityApi import PerplexityApi
from .Raycast import Raycast from .Poe import Poe
from .Replicate import Replicate from .PollinationsAI import PollinationsAI
from .Theb import Theb from .Raycast import Raycast
from .ThebApi import ThebApi from .Replicate import Replicate
from .WhiteRabbitNeo import WhiteRabbitNeo from .Theb import Theb
from .ThebApi import ThebApi
from .WhiteRabbitNeo import WhiteRabbitNeo

View file

@ -13,6 +13,7 @@ from copy import deepcopy
from .crypt import decrypt, encrypt from .crypt import decrypt, encrypt
from ...requests import StreamSession from ...requests import StreamSession
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...errors import NoValidHarFileError
from ... import debug from ... import debug
arkose_url = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147" arkose_url = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
@ -21,9 +22,6 @@ backend_anon_url = "https://chatgpt.com/backend-anon/conversation"
start_url = "https://chatgpt.com/" start_url = "https://chatgpt.com/"
conversation_url = "https://chatgpt.com/c/" conversation_url = "https://chatgpt.com/c/"
class NoValidHarFileError(Exception):
pass
class RequestConfig: class RequestConfig:
cookies: dict = None cookies: dict = None
headers: dict = None headers: dict = None

View file

@ -8,14 +8,11 @@ import logging
from ...requests import StreamSession, raise_for_status from ...requests import StreamSession, raise_for_status
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ...errors import MissingRequirementsError from ...errors import MissingRequirementsError, NoValidHarFileError
from ... import debug from ... import debug
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NoValidHarFileError(Exception):
...
class arkReq: class arkReq:
def __init__(self, arkURL, arkHeaders, arkBody, arkCookies, userAgent): def __init__(self, arkURL, arkHeaders, arkBody, arkCookies, userAgent):
self.arkURL = arkURL self.arkURL = arkURL

View file

@ -12,15 +12,17 @@ from ..image import ImageResponse, copy_images, images_dir
from ..typing import Messages, ImageType from ..typing import Messages, ImageType
from ..providers.types import ProviderType from ..providers.types import ProviderType
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
from ..errors import NoImageResponseError, ModelNotFoundError from ..errors import NoImageResponseError, MissingAuthError, NoValidHarFileError
from ..providers.retry_provider import IterListProvider from ..providers.retry_provider import IterListProvider
from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list from ..providers.asyncio import to_sync_generator, async_generator_to_list
from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
from ..image import to_bytes
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, get_last_provider, convert_to_provider from .service import get_model_and_provider, get_last_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -274,11 +276,6 @@ class Images:
provider_handler = provider provider_handler = provider
if provider_handler is None: if provider_handler is None:
return default return default
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
provider_handler = provider_handler.providers[0]
else:
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
return provider_handler return provider_handler
async def async_generate( async def async_generate(
@ -291,33 +288,23 @@ class Images:
**kwargs **kwargs
) -> ImagesResponse: ) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__ provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
if proxy is None: if proxy is None:
proxy = self.client.proxy proxy = self.client.proxy
response = None response = None
if hasattr(provider_handler, "create_async_generator"): if isinstance(provider_handler, IterListProvider):
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] for provider in provider_handler.providers:
async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs): try:
if isinstance(item, ImageResponse): response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
response = item if response is not None:
break provider_name = provider.__name__
elif hasattr(provider_handler, 'create'): break
if asyncio.iscoroutinefunction(provider_handler.create): except (MissingAuthError, NoValidHarFileError) as e:
response = await provider_handler.create(prompt) debug.log(f"Image provider {provider.__name__}: {e}")
else:
response = provider_handler.create(prompt)
if isinstance(response, str):
response = ImageResponse([response], prompt)
elif hasattr(provider_handler, "create_completion"):
get_running_loop(check_nested=True)
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
if isinstance(item, ImageResponse):
response = item
break
else: else:
raise ValueError(f"Provider {provider_name} does not support image generation") response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
if isinstance(response, ImageResponse): if isinstance(response, ImageResponse):
return await self._process_image_response( return await self._process_image_response(
response, response,
@ -330,6 +317,46 @@ class Images:
raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}") raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _generate_image_response(
self,
provider_handler,
provider_name,
model: str,
prompt: str,
prompt_prefix: str = "Generate a image: ",
image: ImageType = None,
**kwargs
) -> ImageResponse:
messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
response = None
if hasattr(provider_handler, "create_async_generator"):
async for item in provider_handler.create_async_generator(
model,
messages,
stream=True,
prompt=prompt,
image=image,
**kwargs
):
if isinstance(item, ImageResponse):
response = item
break
elif hasattr(provider_handler, "create_completion"):
for item in provider_handler.create_completion(
model,
messages,
True,
prompt=prompt,
image=image,
**kwargs
):
if isinstance(item, ImageResponse):
response = item
break
else:
raise ValueError(f"Provider {provider_name} does not support image generation")
return response
def create_variation( def create_variation(
self, self,
image: ImageType, image: ImageType,
@ -352,33 +379,28 @@ class Images:
**kwargs **kwargs
) -> ImagesResponse: ) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount) provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__ provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
if proxy is None: if proxy is None:
proxy = self.client.proxy proxy = self.client.proxy
prompt = "create a variation of this image"
if hasattr(provider_handler, "create_async_generator"): response = None
messages = [{"role": "user", "content": "create a variation of this image"}] if isinstance(provider_handler, IterListProvider):
generator = None # File pointer can be read only once, so we need to convert it to bytes
try: image = to_bytes(image)
generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) for provider in provider_handler.providers:
async for chunk in generator: try:
if isinstance(chunk, ImageResponse): response = await self._generate_image_response(provider, provider.__name__, model, prompt, image=image, **kwargs)
response = chunk if response is not None:
provider_name = provider.__name__
break break
finally: except (MissingAuthError, NoValidHarFileError) as e:
await safe_aclose(generator) debug.log(f"Image provider {provider.__name__}: {e}")
elif hasattr(provider_handler, 'create_variation'):
if asyncio.iscoroutinefunction(provider.provider_handler):
response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else: else:
raise NoImageResponseError(f"Provider {provider_name} does not support image variation") response = await self._generate_image_response(provider_handler, provider_name, model, prompt, image=image, **kwargs)
if isinstance(response, str):
response = ImageResponse([response])
if isinstance(response, ImageResponse): if isinstance(response, ImageResponse):
return self._process_image_response(response, response_format, proxy, model, provider_name) return await self._process_image_response(response, response_format, proxy, model, provider_name)
if response is None: if response is None:
raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}") raise NoImageResponseError(f"Unexpected response type: {type(response)}")

View file

@ -45,3 +45,6 @@ class ResponseError(Exception):
class ResponseStatusError(Exception): class ResponseStatusError(Exception):
... ...
class NoValidHarFileError(Exception):
...

View file

@ -7,10 +7,12 @@ from .Provider import (
AIChatFree, AIChatFree,
AmigoChat, AmigoChat,
Blackbox, Blackbox,
BingCreateImages,
ChatGpt, ChatGpt,
ChatGptEs, ChatGptEs,
Cloudflare, Cloudflare,
Copilot, Copilot,
CopilotAccount,
DarkAI, DarkAI,
DDG, DDG,
DeepInfraChat, DeepInfraChat,
@ -25,7 +27,9 @@ from .Provider import (
MagickPen, MagickPen,
Mhystical, Mhystical,
MetaAI, MetaAI,
MicrosoftDesigner,
OpenaiChat, OpenaiChat,
OpenaiAccount,
PerplexityLabs, PerplexityLabs,
Pi, Pi,
Pizzagpt, Pizzagpt,
@ -629,9 +633,9 @@ flux_4o = Model(
### OpenAI ### ### OpenAI ###
dalle_3 = Model( dalle_3 = Model(
name = 'dalle-3', name = 'dall-e-3',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = AmigoChat best_provider = IterListProvider([CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages])
) )
### Recraft ### ### Recraft ###
@ -828,6 +832,7 @@ class ModelUtils:
### OpenAI ### ### OpenAI ###
'dalle-3': dalle_3, 'dalle-3': dalle_3,
'dall-e-3': dalle_3,
### Recraft ### ### Recraft ###
'recraft-v3': recraft_v3, 'recraft-v3': recraft_v3,

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from asyncio import AbstractEventLoop, runners from asyncio import AbstractEventLoop, runners
from typing import Union, Callable, AsyncGenerator, Generator from typing import Optional, Callable, AsyncGenerator, Generator
from ..errors import NestAsyncioError from ..errors import NestAsyncioError
@ -17,7 +17,7 @@ try:
except ImportError: except ImportError:
has_uvloop = False has_uvloop = False
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: def get_running_loop(check_nested: bool) -> Optional[AbstractEventLoop]:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# Do not patch uvloop loop because its incompatible. # Do not patch uvloop loop because its incompatible.

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import random import random
from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult from ..typing import Type, List, CreateResult, Messages, AsyncResult
from .types import BaseProvider, BaseRetryProvider, ProviderType from .types import BaseProvider, BaseRetryProvider, ProviderType
from .. import debug from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError from ..errors import RetryProviderError, RetryNoProviderError

View file

@ -4,7 +4,8 @@ from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewT
try: try:
from PIL.Image import Image from PIL.Image import Image
except ImportError: except ImportError:
from typing import Type as Image class Image:
pass
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import TypedDict from typing import TypedDict