Add more flux dev image providers

This commit is contained in:
Heiner Lohaus 2024-12-08 04:13:09 +01:00
parent 54a6d91cfc
commit 1bfb36176c
9 changed files with 132 additions and 39 deletions

View file

@ -12,6 +12,7 @@ from ...typing import CreateResult, Messages, Cookies
from ...errors import MissingRequirementsError
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies
from ...image import ImageResponse
from ..base_provider import ProviderModelMixin, AbstractProvider, BaseConversation
from ..helper import format_prompt
from ... import debug
@ -26,10 +27,12 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
working = True
supports_stream = True
needs_auth = True
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
default_model = "Qwen/Qwen2.5-72B-Instruct"
image_models = [
"black-forest-labs/FLUX.1-dev"
]
models = [
'Qwen/Qwen2.5-72B-Instruct',
default_model,
'meta-llama/Meta-Llama-3.1-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'Qwen/QwQ-32B-Preview',
@ -39,8 +42,8 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'NousResearch/Hermes-3-Llama-3.1-8B',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
*image_models
]
model_aliases = {
"qwen-2.5-72b": "Qwen/Qwen2.5-72B-Instruct",
"llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
@ -52,6 +55,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"hermes-3": "NousResearch/Hermes-3-Llama-3.1-8B",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
"flux-dev": "black-forest-labs/FLUX.1-dev",
}
@classmethod
@ -109,7 +113,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"is_retry": False,
"is_continue": False,
"web_search": web_search,
"tools": []
"tools": ["000000000000000000000001"] if model in cls.image_models else [],
}
headers = {
@ -162,14 +166,18 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
elif line["type"] == "finalAnswer":
break
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
if not stream:
yield full_response
@classmethod
def create_conversation(cls, session: Session, model: str):
if model in cls.image_models:
model = cls.default_model
json_data = {
'model': model,
}