mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
62 lines
No EOL
2.3 KiB
Python
62 lines
No EOL
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from ..models import ModelUtils, ImageModel, VisionModel
|
|
from ..Provider import ProviderUtils
|
|
from ..providers.types import ProviderType
|
|
|
|
class ClientModels():
|
|
def __init__(self, client, provider: ProviderType = None, media_provider: ProviderType = None):
|
|
self.client = client
|
|
self.provider = provider
|
|
self.media_provider = media_provider
|
|
|
|
def get(self, name, default=None) -> ProviderType:
|
|
if name in ModelUtils.convert:
|
|
return ModelUtils.convert[name].best_provider
|
|
if name in ProviderUtils.convert:
|
|
return ProviderUtils.convert[name]
|
|
return default
|
|
|
|
def get_all(self, api_key: str = None, **kwargs) -> list[str]:
|
|
if self.provider is None:
|
|
return []
|
|
if api_key is None:
|
|
api_key = self.client.api_key
|
|
return self.provider.get_models(
|
|
**kwargs,
|
|
**{} if api_key is None else {"api_key": api_key}
|
|
)
|
|
|
|
def get_vision(self, **kwargs) -> list[str]:
|
|
if self.provider is None:
|
|
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, VisionModel)]
|
|
self.get_all(**kwargs)
|
|
if hasattr(self.provider, "vision_models"):
|
|
return self.provider.vision_models
|
|
return []
|
|
|
|
def get_media(self, api_key: str = None, **kwargs) -> list[str]:
|
|
if self.media_provider is None:
|
|
return []
|
|
if api_key is None:
|
|
api_key = self.client.api_key
|
|
return self.media_provider.get_models(
|
|
**kwargs,
|
|
**{} if api_key is None else {"api_key": api_key}
|
|
)
|
|
|
|
def get_image(self, **kwargs) -> list[str]:
|
|
if self.media_provider is None:
|
|
return [model_id for model_id, model in ModelUtils.convert.items() if isinstance(model, ImageModel)]
|
|
self.get_media(**kwargs)
|
|
if hasattr(self.media_provider, "image_models"):
|
|
return self.media_provider.image_models
|
|
return []
|
|
|
|
def get_video(self, **kwargs) -> list[str]:
|
|
if self.media_provider is None:
|
|
return []
|
|
self.get_media(**kwargs)
|
|
if hasattr(self.media_provider, "video_models"):
|
|
return self.media_provider.video_models
|
|
return [] |