mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
fix: replace format_image_prompt with format_media_prompt across multiple files
- Updated imports to use format_media_prompt in g4f/Provider/ARTA.py, PollinationsAI.py, PollinationsImage.py, Websim.py, audio/OpenAIFM.py, hf_space/BlackForestLabs_Flux1Dev.py, hf_space/DeepseekAI_JanusPro7b.py, hf_space/G4F.py, hf_space/Microsoft_Phi_4_Multimodal.py, hf_space/StabilityAI_SD35Large.py, needs_auth/BingCreateImages.py, needs_auth/BlackboxPro.py, needs_auth/DeepInfra.py, needs_auth/Gemini.py, needs_auth/MicrosoftDesigner.py, needs_auth/OpenaiChat.py, needs_auth/hf/HuggingChat.py, needs_auth/hf/HuggingFaceInference.py, needs_auth/hf/HuggingFaceMedia.py, not_working/AllenAI.py, template/OpenaiTemplate.py, api.py, and gui/server/api.py - Replaced calls to format_image_prompt with format_media_prompt in relevant locations - Changed media prompt handling in various providers to ensure consistent usage of format_media_prompt - Modified the __aenter__ and __aexit__ methods of requests/aiohttp.py to properly manage ClientSession lifecycle
This commit is contained in:
parent
bf4ed09ab9
commit
f96ea67f50
23 changed files with 66 additions and 59 deletions
|
|
@ -13,7 +13,7 @@ from ..providers.response import ImageResponse, Reasoning
|
|||
from ..errors import ResponseError, ModelNotFoundError
|
||||
from ..cookies import get_cookies_dir
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_image_prompt
|
||||
from .helper import format_media_prompt
|
||||
from .. import debug
|
||||
|
||||
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
|
@ -177,7 +177,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
|
||||
# Generate a random seed if not provided
|
||||
if seed is None:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from urllib.parse import quote_plus
|
|||
from typing import Optional
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from .helper import filter_none, format_image_prompt
|
||||
from .helper import filter_none, format_media_prompt
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..typing import AsyncResult, Messages, MediaListType
|
||||
from ..image import is_data_an_audio
|
||||
|
|
@ -132,6 +132,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
### Image Models ###
|
||||
"sdxl-turbo": "turbo",
|
||||
"gpt-image": "gptimage",
|
||||
"dall-e-3": "gptimage",
|
||||
"flux-pro": "flux",
|
||||
"flux-dev": "flux",
|
||||
"flux-schnell": "flux"
|
||||
|
|
@ -292,7 +293,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if model in cls.image_models:
|
||||
async for chunk in cls._generate_image(
|
||||
model=model,
|
||||
prompt=format_image_prompt(messages, prompt),
|
||||
prompt=format_media_prompt(messages, prompt),
|
||||
proxy=proxy,
|
||||
aspect_ratio=aspect_ratio,
|
||||
width=width,
|
||||
|
|
@ -524,7 +525,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
debug.error("Error generating title and followups")
|
||||
debug.error(e)
|
||||
elif response.headers["content-type"].startswith("application/json"):
|
||||
prompt = format_image_prompt(messages)
|
||||
prompt = format_media_prompt(messages)
|
||||
result = await response.json()
|
||||
if result.get("model"):
|
||||
yield ProviderInfo(**cls.get_dict(), model=result.get("model"))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from .helper import format_image_prompt
|
||||
from .helper import format_media_prompt
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..constants import STATIC_URL
|
||||
from .PollinationsAI import PollinationsAI
|
||||
|
|
@ -54,7 +54,7 @@ class PollinationsImage(PollinationsAI):
|
|||
cls.get_models()
|
||||
async for chunk in cls._generate_image(
|
||||
model=model,
|
||||
prompt=format_image_prompt(messages, prompt),
|
||||
prompt=format_media_prompt(messages, prompt),
|
||||
proxy=proxy,
|
||||
aspect_ratio=aspect_ratio,
|
||||
width=width,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..errors import ResponseStatusError
|
||||
from ..providers.response import ImageResponse
|
||||
from .helper import format_prompt, format_image_prompt
|
||||
from .helper import format_prompt, format_media_prompt
|
||||
|
||||
|
||||
class Websim(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
|
@ -110,7 +110,7 @@ class Websim(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
proxy: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
used_prompt = format_image_prompt(messages, prompt)
|
||||
used_prompt = format_media_prompt(messages, prompt)
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
data = {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from aiohttp import ClientSession
|
|||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_last_user_message, get_system_prompt
|
||||
from ..helper import format_media_prompt, get_system_prompt
|
||||
from ...image.copy_images import save_response_media
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import get_connector
|
||||
|
|
@ -16,11 +16,10 @@ class OpenAIFM(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
api_endpoint = "https://www.openai.fm/api/generate"
|
||||
working = True
|
||||
|
||||
default_model = 'gpt-4o-mini-tts'
|
||||
default_audio_model = default_model
|
||||
default_voice = 'coral'
|
||||
voices = ['alloy', 'ash', 'ballad', default_voice, 'echo', 'fable', 'onyx', 'nova', 'sage', 'shimmer', 'verse']
|
||||
audio_models = {default_audio_model: voices}
|
||||
default_model = 'coral'
|
||||
voices = ['alloy', 'ash', 'ballad', default_model, 'echo', 'fable', 'onyx', 'nova', 'sage', 'shimmer', 'verse']
|
||||
audio_models = {"gpt-4o-mini-tts": voices}
|
||||
model_aliases = {"gpt-4o-mini-tts": default_model}
|
||||
models = voices
|
||||
|
||||
friendly = """Affect/personality: A cheerful guide
|
||||
|
|
@ -99,16 +98,17 @@ Emotion: Restrained enthusiasm for discoveries and findings, conveying intellect
|
|||
audio: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
# Retrieve parameters from the audio dictionary
|
||||
voice = audio.get("voice", kwargs.get("voice", cls.default_voice))
|
||||
instructions = audio.get("instructions", kwargs.get("instructions", get_system_prompt(messages) or cls.friendly))
|
||||
model = cls.get_model(model)
|
||||
voice = audio.get("voice", kwargs.get("voice", model))
|
||||
default_instructions = get_system_prompt(messages) or cls.friendly
|
||||
instructions = audio.get("instructions", kwargs.get("instructions", default_instructions))
|
||||
headers = {
|
||||
**DEFAULT_HEADERS,
|
||||
"referer": f"{cls.url}/"
|
||||
}
|
||||
text = get_last_user_message(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
params = {
|
||||
"input": text,
|
||||
"input": prompt,
|
||||
"prompt": instructions,
|
||||
"voice": voice
|
||||
}
|
||||
|
|
@ -118,5 +118,5 @@ Emotion: Restrained enthusiasm for discoveries and findings, conveying intellect
|
|||
params=params
|
||||
) as response:
|
||||
await raise_for_status(response)
|
||||
async for chunk in save_response_media(response, text, [model, voice]):
|
||||
async for chunk in save_response_media(response, prompt, [model, voice]):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from ...requests import StreamSession
|
|||
from ...image import use_aspect_ratio
|
||||
from ...errors import ResponseError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_image_prompt
|
||||
from ..helper import format_media_prompt
|
||||
from .DeepseekAI_JanusPro7b import get_zerogpu_token
|
||||
from .raise_for_status import raise_for_status
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ class BlackForestLabs_Flux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
**kwargs
|
||||
) -> AsyncResult:
|
||||
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio)
|
||||
data = [prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps]
|
||||
conversation = JsonConversation(zerogpu_token=api_key, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import urllib.parse
|
|||
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, format_image_prompt
|
||||
from ..helper import format_prompt, format_media_prompt
|
||||
from ...providers.response import JsonConversation, ImageResponse, Reasoning
|
||||
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
|
|
@ -85,7 +85,7 @@ class DeepseekAI_JanusPro7b(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if model == cls.default_image_model or prompt is not None:
|
||||
method = "image"
|
||||
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
if seed is None:
|
||||
seed = random.randint(1000, 999999)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import asyncio
|
|||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...providers.response import ImageResponse, Reasoning, JsonConversation
|
||||
from ..helper import format_image_prompt, get_random_string
|
||||
from ..helper import format_media_prompt, get_random_string
|
||||
from .DeepseekAI_JanusPro7b import DeepseekAI_JanusPro7b, get_zerogpu_token
|
||||
from .BlackForestLabs_Flux1Dev import BlackForestLabs_Flux1Dev
|
||||
from .raise_for_status import raise_for_status
|
||||
|
|
@ -80,7 +80,7 @@ class G4F(DeepseekAI_JanusPro7b):
|
|||
width = max(32, width - (width % 8))
|
||||
height = max(32, height - (height % 8))
|
||||
if prompt is None:
|
||||
prompt = format_image_prompt(messages)
|
||||
prompt = format_media_prompt(messages)
|
||||
if seed is None:
|
||||
seed = random.randint(9999, 2**32 - 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
|
||||
from ...typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, format_image_prompt
|
||||
from ..helper import format_prompt, format_media_prompt
|
||||
from ...providers.response import JsonConversation
|
||||
from ...requests.aiohttp import StreamSession, StreamResponse, FormData
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
|
|
@ -104,7 +104,7 @@ class Microsoft_Phi_4_Multimodal(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
**kwargs
|
||||
) -> AsyncResult:
|
||||
prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
|
||||
session_hash = uuid.uuid4().hex if conversation is None else getattr(conversation, "session_hash", uuid.uuid4().hex)
|
||||
async with StreamSession(proxy=proxy, impersonate="chrome") as session:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from ...providers.response import ImageResponse, ImagePreview
|
|||
from ...image import use_aspect_ratio
|
||||
from ...errors import ResponseError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_image_prompt
|
||||
from ..helper import format_media_prompt
|
||||
|
||||
class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "StabilityAI SD-3.5-Large"
|
||||
|
|
@ -46,7 +46,7 @@ class StabilityAI_SD35Large(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with ClientSession(headers=headers) as session:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
data = use_aspect_ratio({"width": width, "height": height}, aspect_ratio)
|
||||
data = {
|
||||
"data": [prompt, negative_prompt, seed, randomize_seed, data.get("width"), data.get("height"), guidance_scale, num_inference_steps]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from ...errors import MissingAuthError
|
|||
from ...typing import AsyncResult, Messages, Cookies
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .bing.create_images import create_images, create_session
|
||||
from ..helper import format_image_prompt
|
||||
from ..helper import format_media_prompt
|
||||
|
||||
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Microsoft Designer in Bing"
|
||||
|
|
@ -36,7 +36,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
**kwargs
|
||||
) -> AsyncResult:
|
||||
session = BingCreateImages(cookies, proxy, api_key)
|
||||
yield await session.generate(format_image_prompt(messages, prompt))
|
||||
yield await session.generate(format_media_prompt(messages, prompt))
|
||||
|
||||
async def generate(self, prompt: str) -> ImageResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
|||
from ..openai.har_file import get_har_files
|
||||
from ...image import to_data_uri
|
||||
from ...cookies import get_cookies_dir
|
||||
from ..helper import format_image_prompt, render_messages
|
||||
from ..helper import format_media_prompt, render_messages
|
||||
from ...providers.response import JsonConversation, ImageResponse
|
||||
from ...tools.media import merge_media
|
||||
from ...errors import RateLimitError, NoValidHarFileError
|
||||
|
|
@ -1343,7 +1343,7 @@ class BlackboxPro(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
image_url_match = re.search(r'!\[.*?\]\((.*?)\)', full_response_text)
|
||||
if image_url_match:
|
||||
image_url = image_url_match.group(1)
|
||||
yield ImageResponse(urls=[image_url], alt=format_image_prompt(messages, prompt))
|
||||
yield ImageResponse(urls=[image_url], alt=format_media_prompt(messages, prompt))
|
||||
return
|
||||
|
||||
# Handle conversation history once, in one place
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from ...typing import AsyncResult, Messages
|
|||
from ...requests import StreamSession, raise_for_status
|
||||
from ...providers.response import ImageResponse
|
||||
from ..template import OpenaiTemplate
|
||||
from ..helper import format_image_prompt
|
||||
from ..helper import format_media_prompt
|
||||
|
||||
class DeepInfra(OpenaiTemplate):
|
||||
url = "https://deepinfra.com"
|
||||
|
|
@ -56,7 +56,7 @@ class DeepInfra(OpenaiTemplate):
|
|||
) -> AsyncResult:
|
||||
if model in cls.get_image_models():
|
||||
yield cls.create_async_image(
|
||||
format_image_prompt(messages, prompt),
|
||||
format_media_prompt(messages, prompt),
|
||||
model,
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from ...image import to_bytes
|
|||
from ...cookies import get_cookies_dir
|
||||
from ...tools.media import merge_media
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import format_prompt, get_cookies, get_last_user_message, format_image_prompt
|
||||
from ..helper import format_prompt, get_cookies, get_last_user_message, format_media_prompt
|
||||
from ... import debug
|
||||
|
||||
REQUEST_HEADERS = {
|
||||
|
|
@ -187,7 +187,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
if audio is not None or model == "gemini-audio":
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
filename = get_filename(["gemini"], prompt, ".ogx", prompt)
|
||||
ensure_media_dir()
|
||||
path = os.path.join(get_media_dir(), filename)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from ...requests.aiohttp import get_connector
|
|||
from ...requests import get_nodriver
|
||||
from ..Copilot import get_headers, get_har_files
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_random_hex, format_image_prompt
|
||||
from ..helper import get_random_hex, format_media_prompt
|
||||
from ... import debug
|
||||
|
||||
class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
|
@ -39,7 +39,7 @@ class MicrosoftDesigner(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
image_size = "1024x1024"
|
||||
if model != cls.default_image_model and model in cls.image_models:
|
||||
image_size = model
|
||||
yield await cls.generate(format_image_prompt(messages, prompt), image_size, proxy)
|
||||
yield await cls.generate(format_media_prompt(messages, prompt), image_size, proxy)
|
||||
|
||||
@classmethod
|
||||
async def generate(cls, prompt: str, image_size: str, proxy: str = None) -> ImageResponse:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from ...errors import MissingAuthError, NoValidHarFileError, ModelNotFoundError
|
|||
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult, ImageResponse, ImagePreview
|
||||
from ...providers.response import Sources, TitleGeneration, RequestLogin, Reasoning
|
||||
from ...tools.media import merge_media
|
||||
from ..helper import format_cookies, format_image_prompt, to_string
|
||||
from ..helper import format_cookies, format_media_prompt, to_string
|
||||
from ..openai.models import default_model, default_image_model, models, image_models, text_models
|
||||
from ..openai.har_file import get_request_config
|
||||
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
|
||||
|
|
@ -426,7 +426,7 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
if conversation.conversation_id is not None:
|
||||
data["conversation_id"] = conversation.conversation_id
|
||||
debug.log(f"OpenaiChat: Use conversation: {conversation.conversation_id}")
|
||||
prompt = conversation.prompt = format_image_prompt(messages, prompt)
|
||||
prompt = conversation.prompt = format_media_prompt(messages, prompt)
|
||||
if action != "continue":
|
||||
data["parent_message_id"] = getattr(conversation, "parent_message_id", conversation.message_id)
|
||||
conversation.parent_message_id = None
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ except ImportError:
|
|||
has_curl_cffi = False
|
||||
|
||||
from ...base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
|
||||
from ...helper import format_prompt, format_image_prompt, get_last_user_message
|
||||
from ...helper import format_prompt, format_media_prompt, get_last_user_message
|
||||
from ....typing import AsyncResult, Messages, Cookies, MediaListType
|
||||
from ....errors import MissingRequirementsError, MissingAuthError, ResponseError
|
||||
from ....image import to_bytes
|
||||
|
|
@ -184,7 +184,7 @@ class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
|
|||
break
|
||||
elif line["type"] == "file":
|
||||
url = f"{cls.url}/conversation/{conversationId}/output/{line['sha']}"
|
||||
yield ImageResponse(url, format_image_prompt(messages, prompt), options={"cookies": auth_result.cookies})
|
||||
yield ImageResponse(url, format_media_prompt(messages, prompt), options={"cookies": auth_result.cookies})
|
||||
elif line["type"] == "webSearch" and "sources" in line:
|
||||
sources = Sources(line["sources"])
|
||||
elif line["type"] == "title":
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ....requests import StreamSession, raise_for_status
|
|||
from ....providers.response import FinishReason, ImageResponse
|
||||
from ....image.copy_images import save_response_media
|
||||
from ....image import use_aspect_ratio
|
||||
from ...helper import format_image_prompt, get_last_user_message
|
||||
from ...helper import format_media_prompt, get_last_user_message
|
||||
from .models import default_model, default_image_model, model_aliases, text_models, image_models, vision_models
|
||||
from .... import debug
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
if model in provider_together_urls:
|
||||
data = {
|
||||
"response_format": "url",
|
||||
"prompt": format_image_prompt(messages, prompt),
|
||||
"prompt": format_media_prompt(messages, prompt),
|
||||
"model": model,
|
||||
**image_extra_body
|
||||
}
|
||||
|
|
@ -136,7 +136,7 @@ class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
pipeline_tag = model_data.get("pipeline_tag")
|
||||
if pipeline_tag == "text-to-image":
|
||||
stream = False
|
||||
inputs = format_image_prompt(messages, prompt)
|
||||
inputs = format_media_prompt(messages, prompt)
|
||||
payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **image_extra_body}}
|
||||
elif pipeline_tag in ("text-generation", "image-text-to-text"):
|
||||
model_type = None
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import requests
|
|||
from ....providers.types import Messages
|
||||
from ....requests import StreamSession, raise_for_status
|
||||
from ....errors import ModelNotFoundError
|
||||
from ....providers.helper import format_image_prompt
|
||||
from ....providers.helper import format_media_prompt
|
||||
from ....providers.base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ....providers.response import ProviderInfo, ImageResponse, VideoResponse, Reasoning
|
||||
from ....image.copy_images import save_response_media
|
||||
|
|
@ -119,7 +119,7 @@ class HuggingFaceMedia(AsyncGeneratorProvider, ProviderModelMixin):
|
|||
model, selected_provider = model.split(":", 1)
|
||||
elif not model:
|
||||
model = cls.get_models()[0]
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
provider_mapping = await cls.get_mapping(model, api_key)
|
||||
headers = {
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from ...image import to_bytes, is_accepted_format, to_data_uri
|
|||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...providers.response import FinishReason, JsonConversation
|
||||
from ..helper import format_prompt, get_last_user_message, format_image_prompt
|
||||
from ..helper import format_prompt, get_last_user_message, format_media_prompt
|
||||
from ...tools.media import merge_media
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import requests
|
||||
|
||||
from ..helper import filter_none, format_image_prompt
|
||||
from ..helper import filter_none, format_media_prompt
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
|
|
@ -85,7 +85,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
|
||||
# Proxy for image generation feature
|
||||
if model and model in cls.image_models:
|
||||
prompt = format_image_prompt(messages, prompt)
|
||||
prompt = format_media_prompt(messages, prompt)
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ...tools.run_tools import iter_run_tools
|
|||
from ... import Provider
|
||||
from ...providers.base_provider import ProviderModelMixin
|
||||
from ...providers.retry_provider import BaseRetryProvider
|
||||
from ...providers.helper import format_image_prompt
|
||||
from ...providers.helper import format_media_prompt
|
||||
from ...providers.response import *
|
||||
from ... import version, models
|
||||
from ... import ChatCompletion, get_model_and_provider
|
||||
|
|
@ -194,7 +194,7 @@ class Api:
|
|||
elif isinstance(chunk, MediaResponse):
|
||||
media = chunk
|
||||
if download_media or chunk.get("cookies"):
|
||||
chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
|
||||
chunk.alt = format_media_prompt(kwargs.get("messages"), chunk.alt)
|
||||
tags = [model, kwargs.get("aspect_ratio"), kwargs.get("resolution"), kwargs.get("width"), kwargs.get("height")]
|
||||
media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
|
||||
media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class StreamResponse(ClientResponse):
|
|||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
class StreamSession(ClientSession):
|
||||
class StreamSession():
|
||||
def __init__(
|
||||
self,
|
||||
headers: dict = {},
|
||||
|
|
@ -54,7 +54,7 @@ class StreamSession(ClientSession):
|
|||
timeout = ClientTimeout(timeout, connect)
|
||||
if proxy is None:
|
||||
proxy = proxies.get("all", proxies.get("https"))
|
||||
super().__init__(
|
||||
self.inner = ClientSession(
|
||||
**kwargs,
|
||||
timeout=timeout,
|
||||
response_class=StreamResponse,
|
||||
|
|
@ -62,6 +62,12 @@ class StreamSession(ClientSession):
|
|||
headers=headers
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "ClientSession":
|
||||
return self.inner
|
||||
|
||||
async def __aexit__(self, **kwargs) -> None:
|
||||
await self.inner.close()
|
||||
|
||||
def get_connector(connector: BaseConnector = None, proxy: str = None, rdns: bool = False) -> Optional[BaseConnector]:
|
||||
if proxy and not connector:
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue