diff --git a/docs/file.md b/docs/file.md index ce59bf9d..9e0f0273 100644 --- a/docs/file.md +++ b/docs/file.md @@ -180,23 +180,17 @@ fileInput.addEventListener('change', () => { **Integrating with `ChatCompletion`:** -To incorporate file uploads into your client applications, include the `tool_calls` parameter in your chat completion requests, using the `bucket_tool` function. The `bucket_id` is passed as a JSON object within your prompt. - +To incorporate file uploads into your client applications, include the `bucket` in your chat completion requests, using inline content parts. ```json { "messages": [ { "role": "user", - "content": "Answer this question using the files in the specified bucket: ...your question...\n{\"bucket_id\": \"your_actual_bucket_id\"}" - } - ], - "tool_calls": [ - { - "function": { - "name": "bucket_tool" - }, - "type": "function" + "content": [ + {"type": "text", "text": "Answer this question using the files in the specified bucket: ...your question..."}, + {"bucket_id": "your_actual_bucket_id"} + ] } ] } diff --git a/docs/media.md b/docs/media.md index dbfa3850..0083018a 100644 --- a/docs/media.md +++ b/docs/media.md @@ -30,6 +30,8 @@ asyncio.run(main()) #### **Transcribe an Audio File:** +Some providers in G4F support audio inputs in chat completions, allowing you to transcribe audio files by instructing the model accordingly. This example demonstrates how to use the `AsyncClient` to transcribe an audio file asynchronously: + ```python import asyncio from g4f.client import AsyncClient @@ -41,15 +43,32 @@ async def main(): with open("audio.wav", "rb") as audio_file: response = await client.chat.completions.create( messages="Transcribe this audio", - provider=g4f.Provider.Microsoft_Phi_4, media=[[audio_file, "audio.wav"]], modalities=["text"], ) - print(response.choices[0].message.content) -asyncio.run(main()) + print(response.choices[0].message.content) + +if __name__ == "__main__": + asyncio.run(main()) ``` +#### Explanation +- **Client Initialization**: An `AsyncClient` instance is created with a provider that supports audio inputs, such as `PollinationsAI` or `Microsoft_Phi_4`. +- **File Handling**: The audio file (`audio.wav`) is opened in binary read mode (`"rb"`) using a context manager (`with` statement) to ensure proper file closure after use. +- **API Call**: The `chat.completions.create` method is called with: + - `messages`: Containing a user message instructing the model to transcribe the audio. + - `media`: A list of lists, where each inner list contains the file object and its name (`[[audio_file, "audio.wav"]]`). + - `modalities=["text"]`: Specifies that the output should be text (the transcription). +- **Response**: The transcription is extracted from `response.choices[0].message.content` and printed. + +#### Notes +- **Provider Support**: Ensure the chosen provider (e.g., `PollinationsAI` or `Microsoft_Phi_4`) supports audio inputs in chat completions. Not all providers may offer this functionality. +- **File Path**: Replace `"audio.wav"` with the path to your own audio file. The file format (e.g., WAV) should be compatible with the provider. +- **Model Selection**: If `g4f.models.default` does not support audio transcription, you may need to specify a model that does (consult the provider's documentation for supported models). + +This example complements the guide by showcasing how to handle audio inputs asynchronously, expanding on the multimodal capabilities of the G4F AsyncClient API. + --- ### 2. **Image Generation** diff --git a/g4f/Provider/ARTA.py b/g4f/Provider/ARTA.py index 488540f6..1b6c6650 100644 --- a/g4f/Provider/ARTA.py +++ b/g4f/Provider/ARTA.py @@ -203,7 +203,7 @@ class ARTA(AsyncGeneratorProvider, ProviderModelMixin): else: raise ResponseError(f"Image generation failed with status: {status}") -async def raise_error(response: ClientResponse, message: str): +async def raise_error(message: str, response: ClientResponse): if response.ok: return error_text = await response.text() diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index bff00d1f..eb8b3fd8 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -20,7 +20,7 @@ from ..cookies import get_cookies_dir from .helper import format_image_prompt, render_messages from ..providers.response import JsonConversation, ImageResponse from ..tools.media import merge_media -from ..errors import RateLimitError +from ..errors import RateLimitError, NoValidHarFileError from .. import debug class Conversation(JsonConversation): @@ -470,6 +470,8 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): except Exception as e: debug.log(f"Blackbox: Error reading HAR file {file}: {e}") return None + except NoValidHarFileError: + pass except Exception as e: debug.log(f"Blackbox: Error searching HAR files: {e}") return None diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py index 3b35ccf3..ee7c8089 100644 --- a/g4f/Provider/Cloudflare.py +++ b/g4f/Provider/Cloudflare.py @@ -8,7 +8,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, AuthFileM from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi from ..providers.response import FinishReason, Usage -from ..errors import ResponseStatusError, ModelNotFoundError, MissingRequirementsError +from ..errors import ResponseStatusError, ModelNotFoundError from .. import debug from .helper import render_messages @@ -72,11 +72,11 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin): except ResponseStatusError as f: if has_nodriver: get_running_loop(check_nested=True) - args = get_args_from_nodriver(cls.url) try: + args = get_args_from_nodriver(cls.url) cls._args = asyncio.run(args) read_models() - except RuntimeError as e: + except (RuntimeError, FileNotFoundError) as e: cls.models = cls.fallback_models debug.log(f"Nodriver is not available: {type(e).__name__}: {e}") else: diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index db03abf0..1c5592ec 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -28,6 +28,10 @@ try: from .mini_max import HailuoAI, MiniMax except ImportError as e: debug.error("MiniMax providers not loaded:", e) +try: + from .audio import EdgeTTS +except ImportError as e: + debug.error("Audio providers not loaded:", e) try: from .AllenAI import AllenAI diff --git a/g4f/Provider/audio/EdgeTTS.py b/g4f/Provider/audio/EdgeTTS.py new file mode 100644 index 00000000..3c2023a4 --- /dev/null +++ b/g4f/Provider/audio/EdgeTTS.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import os +import random +import asyncio + +try: + import edge_tts + from edge_tts import VoicesManager + has_edge_tts = True +except ImportError: + has_edge_tts = False + +from ...typing import AsyncResult, Messages +from ...providers.response import AudioResponse +from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..helper import format_image_prompt + +class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin): + label = "Edge TTS" + working = has_edge_tts + default_model = "edge-tts" + default_locale = "en-US" + + @classmethod + def get_models(cls) -> list[str]: + if not cls.models: + voices = asyncio.run(VoicesManager.create()) + cls.default_model = voices.find(Locale=cls.default_locale)[0]["Name"] + cls.models = [voice["Name"] for voice in voices.voices] + return cls.models + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + proxy: str = None, + prompt: str = None, + language: str = None, + locale: str = None, + audio: dict = {"voice": None, "format": "mp3"}, + extra_parameters: list[str] = ["rate", "volume", "pitch"], + **kwargs + ) -> AsyncResult: + prompt = format_image_prompt(messages, prompt) + if not prompt: + raise ValueError("Prompt is empty.") + voice = audio.get("voice", model) + if not voice: + voices = await VoicesManager.create() + if locale is None: + if language is None: + voices = voices.find(Locale=cls.default_locale) + elif "-" in language: + voices = voices.find(Locale=language) + else: + voices = voices.find(Language=language) + else: + voices = voices.find(Locale=locale) + if not voices: + raise ValueError(f"No voices found for language '{language}' and locale '{locale}'.") + voice = random.choice(voices)["Name"] + + format = audio.get("format", "mp3") + filename = get_filename([cls.default_model], prompt, f".{format}", prompt) + target_path = os.path.join(get_media_dir(), filename) + ensure_media_dir() + + extra_parameters = {param: kwargs[param] for param in extra_parameters if param in kwargs} + communicate = edge_tts.Communicate(prompt, voice=voice, proxy=proxy, **extra_parameters) + + await communicate.save(target_path) + yield AudioResponse(f"/media/{filename}", voice=voice, prompt=prompt) \ No newline at end of file diff --git a/g4f/Provider/audio/__init__.py b/g4f/Provider/audio/__init__.py new file mode 100644 index 00000000..b040b1ac --- /dev/null +++ b/g4f/Provider/audio/__init__.py @@ -0,0 +1 @@ +from .EdgeTTS import EdgeTTS \ No newline at end of file diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 99f3f227..81d970c5 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -40,7 +40,7 @@ from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_p from g4f.providers.response import BaseConversation, JsonConversation from g4f.client.helper import filter_none from g4f.image import is_data_an_media, EXTENSIONS_MAP -from g4f.image.copy_images import images_dir, copy_media, get_source_url +from g4f.image.copy_images import get_media_dir, copy_media, get_source_url from g4f.errors import ProviderNotFoundError, ModelNotFoundError, MissingAuthError, NoValidHarFileError from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.providers.types import ProviderType @@ -130,7 +130,7 @@ class AppConfig: ignore_cookie_files: bool = False model: str = None provider: str = None - image_provider: str = None + media_provider: str = None proxy: str = None gui: bool = False demo: bool = False @@ -419,12 +419,13 @@ class Api: ): if config.provider is None: config.provider = provider + if config.provider is None: + config.provider = AppConfig.media_provider if credentials is not None and credentials.credentials != "secret": config.api_key = credentials.credentials try: response = await self.client.images.generate( **config.dict(exclude_none=True), - provider=AppConfig.image_provider if config.provider is None else config.provider ) for image in response.data: if hasattr(image, "url") and image.url.startswith("/"): @@ -562,9 +563,9 @@ class Api: HTTP_404_NOT_FOUND: {} }) async def get_media(filename, request: Request): - target = os.path.join(images_dir, os.path.basename(filename)) + target = os.path.join(get_media_dir(), os.path.basename(filename)) if not os.path.isfile(target): - other_name = os.path.join(images_dir, os.path.basename(quote_plus(filename))) + other_name = os.path.join(get_media_dir(), os.path.basename(quote_plus(filename))) if os.path.isfile(other_name): target = other_name ext = os.path.splitext(filename)[1][1:] @@ -627,7 +628,7 @@ class Api: def format_exception(e: Union[Exception, str], config: Union[ChatCompletionsConfig, ImageGenerationConfig] = None, image: bool = False) -> str: last_provider = {} if not image else g4f.get_last_provider(True) - provider = (AppConfig.image_provider if image else AppConfig.provider) + provider = (AppConfig.media_provider if image else AppConfig.provider) model = AppConfig.model if config is not None: if config.provider is not None: diff --git a/g4f/cli.py b/g4f/cli.py index abb63c70..069ebb42 100644 --- a/g4f/cli.py +++ b/g4f/cli.py @@ -16,7 +16,7 @@ def get_api_parser(): api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --reload and --workers)") api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working], default=None, help="Default provider for chat completion. (incompatible with --reload and --workers)") - api_parser.add_argument("--image-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and hasattr(provider, "image_models")], + api_parser.add_argument("--media-provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working and bool(getattr(provider, "image_models", False))], default=None, help="Default provider for image generation. (incompatible with --reload and --workers)"), api_parser.add_argument("--proxy", default=None, help="Default used proxy. (incompatible with --reload and --workers)") api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.") @@ -59,7 +59,7 @@ def run_api_args(args): ignored_providers=args.ignored_providers, g4f_api_key=args.g4f_api_key, provider=args.provider, - image_provider=args.image_provider, + media_provider=args.media_provider, proxy=args.proxy, model=args.model, gui=args.gui, diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index f350538e..802f1c5e 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import time import random import string @@ -8,7 +9,7 @@ import aiohttp import base64 from typing import Union, AsyncIterator, Iterator, Awaitable, Optional -from ..image.copy_images import copy_media +from ..image.copy_images import copy_media, get_media_dir from ..typing import Messages, ImageType from ..providers.types import ProviderType, BaseRetryProvider from ..providers.response import * @@ -16,11 +17,11 @@ from ..errors import NoMediaResponseError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import to_sync_generator from ..providers.any_provider import AnyProvider -from ..Provider.needs_auth import BingCreateImages, OpenaiAccount +from ..Provider import OpenaiAccount, PollinationsImage from ..tools.run_tools import async_iter_run_tools, iter_run_tools from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse, UsageModel, ToolCallModel from .models import ClientModels -from .types import IterResponse, ImageProvider, Client as BaseClient +from .types import IterResponse, Client as BaseClient from .service import convert_to_provider from .helper import find_stop, filter_json, filter_none, safe_aclose from .. import debug @@ -261,15 +262,15 @@ class Client(BaseClient): def __init__( self, provider: Optional[ProviderType] = None, - image_provider: Optional[ImageProvider] = None, + media_provider: Optional[ProviderType] = None, **kwargs ) -> None: super().__init__(**kwargs) self.chat: Chat = Chat(self, provider) - if image_provider is None: - image_provider = provider - self.models: ClientModels = ClientModels(self, provider, image_provider) - self.images: Images = Images(self, image_provider) + if media_provider is None: + media_provider = kwargs.get("image_provider", provider) + self.models: ClientModels = ClientModels(self, provider, media_provider) + self.images: Images = Images(self, media_provider) self.media: Images = self.images class Completions: @@ -364,7 +365,7 @@ class Images: """ return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs)) - async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider: + async def get_provider_handler(self, model: Optional[str], provider: Optional[ProviderType], default: ProviderType) -> ProviderType: if provider is None: provider_handler = self.provider if provider_handler is None: @@ -387,7 +388,7 @@ class Images: api_key: Optional[str] = None, **kwargs ) -> ImagesResponse: - provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) + provider_handler = await self.get_provider_handler(model, provider, PollinationsImage) provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__ if proxy is None: proxy = self.client.proxy @@ -407,20 +408,17 @@ class Images: debug.error(f"{provider.__name__} {type(e).__name__}: {e}") else: response = await self._generate_image_response(provider_handler, provider_name, model, prompt, proxy=proxy, api_key=api_key, **kwargs) - - if isinstance(response, MediaResponse): - return await self._process_image_response( - response, - model, - provider_name, - response_format, - proxy - ) if response is None: if error is not None: raise error - raise NoMediaResponseError(f"No image response from {provider_name}") - raise NoMediaResponseError(f"Unexpected response type: {type(response)}") + raise NoMediaResponseError(f"No media response from {provider_name}") + return await self._process_image_response( + response, + model, + provider_name, + response_format, + proxy + ) async def _generate_image_response( self, @@ -441,7 +439,7 @@ class Images: prompt=prompt, **kwargs ): - if isinstance(item, MediaResponse): + if isinstance(item, (MediaResponse, AudioResponse)): items.append(item) elif hasattr(provider_handler, "create_completion"): for item in provider_handler.create_completion( @@ -451,13 +449,15 @@ class Images: prompt=prompt, **kwargs ): - if isinstance(item, MediaResponse): + if isinstance(item, (MediaResponse, AudioResponse)): items.append(item) else: raise ValueError(f"Provider {provider_name} does not support image generation") urls = [] for item in items: - if isinstance(item.urls, str): + if isinstance(item, AudioResponse): + urls.append(item.to_uri()) + elif isinstance(item.urls, str): urls.append(item.urls) elif isinstance(item.urls, list): urls.extend(item.urls) @@ -508,14 +508,11 @@ class Images: debug.error(f"{provider.__name__} {type(e).__name__}: {e}") else: response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) - - if isinstance(response, MediaResponse): - return await self._process_image_response(response, model, provider_name, response_format, proxy) if response is None: if error is not None: raise error - raise NoMediaResponseError(f"No image response from {provider_name}") - raise NoMediaResponseError(f"Unexpected response type: {type(response)}") + raise NoMediaResponseError(f"No media response from {provider_name}") + return await self._process_image_response(response, model, provider_name, response_format, proxy) async def _process_image_response( self, @@ -531,12 +528,16 @@ class Images: elif response_format == "b64_json": # Convert URLs directly to base64 without saving async def get_b64_from_url(url: str) -> Image: + if url.startswith("/media/"): + with open(os.path.join(get_media_dir(), os.path.basename(url)), "wb") as f: + b64_data = base64.b64encode(f.read()).decode() + return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt) async with aiohttp.ClientSession(cookies=response.get("cookies")) as session: async with session.get(url, proxy=proxy) as resp: if resp.status == 200: - image_data = await resp.read() - b64_data = base64.b64encode(image_data).decode() + b64_data = base64.b64encode(await resp.read()).decode() return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt) + return Image.model_construct(url=url, revised_prompt=response.alt) images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()]) else: # Save locally for None (default) case @@ -554,15 +555,15 @@ class AsyncClient(BaseClient): def __init__( self, provider: Optional[ProviderType] = None, - image_provider: Optional[ImageProvider] = None, + media_provider: Optional[ProviderType] = None, **kwargs ) -> None: super().__init__(**kwargs) self.chat: AsyncChat = AsyncChat(self, provider) - if image_provider is None: - image_provider = provider - self.models: ClientModels = ClientModels(self, provider, image_provider) - self.images: AsyncImages = AsyncImages(self, image_provider) + if media_provider is None: + media_provider = kwargs.get("image_provider", provider) + self.models: ClientModels = ClientModels(self, provider, media_provider) + self.images: AsyncImages = AsyncImages(self, media_provider) self.media: AsyncImages = self.images class AsyncChat: diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py index 2ecd77ad..8e183566 100644 --- a/g4f/client/stubs.py +++ b/g4f/client/stubs.py @@ -5,7 +5,7 @@ from typing import Optional, List from time import time from ..image import extract_data_uri -from ..image.copy_images import images_dir +from ..image.copy_images import get_media_dir from ..client.helper import filter_markdown from .helper import filter_none @@ -123,7 +123,7 @@ class ChatCompletionMessage(BaseModel): def save(self, filepath: str, allowd_types = None): if hasattr(self.content, "data"): - os.rename(self.content.data.replace("/media", images_dir), filepath) + os.rename(self.content.data.replace("/media", get_media_dir()), filepath) return if self.content.startswith("data:"): with open(filepath, "wb") as f: diff --git a/g4f/client/types.py b/g4f/client/types.py index 5010e098..785f3d8b 100644 --- a/g4f/client/types.py +++ b/g4f/client/types.py @@ -6,7 +6,6 @@ from .stubs import ChatCompletion, ChatCompletionChunk from ..providers.types import BaseProvider from typing import Union, Iterator, AsyncIterator -ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]] AsyncIterResponse = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]] @@ -19,7 +18,7 @@ class Client(): **kwargs ) -> None: self.api_key: str = api_key - self.proxies= proxies + self.proxies = proxies self.proxy: str = self.get_proxy() def get_proxy(self) -> Union[str, None]: diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index eb165a42..28cf7baa 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -8,7 +8,7 @@ from flask import send_from_directory from inspect import signature from ...errors import VersionNotFoundError, MissingAuthError -from ...image.copy_images import copy_media, ensure_images_dir, images_dir +from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir from ...tools.run_tools import iter_run_tools from ... import Provider from ...providers.base_provider import ProviderModelMixin @@ -96,8 +96,8 @@ class Api: } def serve_images(self, name): - ensure_images_dir() - return send_from_directory(os.path.abspath(images_dir), name) + ensure_media_dir() + return send_from_directory(os.path.abspath(get_media_dir()), name) def _prepare_conversation_kwargs(self, json_data: dict): kwargs = {**json_data} diff --git a/g4f/gui/server/backend_api.py b/g4f/gui/server/backend_api.py index db942683..8adaf72f 100644 --- a/g4f/gui/server/backend_api.py +++ b/g4f/gui/server/backend_api.py @@ -25,7 +25,7 @@ from ...tools.run_tools import iter_run_tools from ...errors import ProviderNotFoundError from ...image import is_allowed_extension from ...cookies import get_cookies_dir -from ...image.copy_images import secure_filename, get_source_url, images_dir +from ...image.copy_images import secure_filename, get_source_url, get_media_dir from ... import ChatCompletion from ... import models from .api import Api @@ -346,11 +346,12 @@ class Backend_Api(Api): @app.route('/search/', methods=['GET']) def find_media(search: str): safe_search = [secure_filename(chunk.lower()) for chunk in search.split("+")] - if not os.access(images_dir, os.R_OK): + media_dir = get_media_dir() + if not os.access(media_dir, os.R_OK): return jsonify({"error": {"message": "Not found"}}), 404 if search not in self.match_files: self.match_files[search] = {} - for root, _, files in os.walk(images_dir): + for root, _, files in os.walk(media_dir): for file in files: mime_type = is_allowed_extension(file) if mime_type is not None: @@ -438,7 +439,7 @@ class Backend_Api(Api): def get_provider_models(self, provider: str): api_key = request.headers.get("x_api_key") api_base = request.headers.get("x_api_base") - ignored = request.headers.get("x_ignored").split() + ignored = request.headers.get("x_ignored", "").split() models = super().get_provider_models(provider, api_key, api_base, ignored) if models is None: return "Provider not found", 404 diff --git a/g4f/image.py b/g4f/image.py deleted file mode 100644 index e9341953..00000000 --- a/g4f/image.py +++ /dev/null @@ -1,253 +0,0 @@ -from __future__ import annotations - -import os -import re -import io -import base64 -from urllib.parse import quote_plus -from io import BytesIO -from pathlib import Path -try: - from PIL.Image import open as open_image, new as new_image - from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 - has_requirements = True -except ImportError: - has_requirements = False - -from .typing import ImageType, Union, Image, Optional, Cookies -from .errors import MissingRequirementsError -from .requests.aiohttp import get_connector - -ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} - -EXTENSIONS_MAP: dict[str, str] = { - "image/png": "png", - "image/jpeg": "jpg", - "image/gif": "gif", - "image/webp": "webp", -} - -# Define the directory for generated images -images_dir = "./generated_images" - -def to_image(image: ImageType, is_svg: bool = False) -> Image: - """ - Converts the input image to a PIL Image object. - - Args: - image (Union[str, bytes, Image]): The input image. - - Returns: - Image: The converted PIL Image object. - """ - if not has_requirements: - raise MissingRequirementsError('Install "pillow" package for images') - - if isinstance(image, str) and image.startswith("data:"): - is_data_uri_an_image(image) - image = extract_data_uri(image) - - if is_svg: - try: - import cairosvg - except ImportError: - raise MissingRequirementsError('Install "cairosvg" package for svg images') - if not isinstance(image, bytes): - image = image.read() - buffer = BytesIO() - cairosvg.svg2png(image, write_to=buffer) - return open_image(buffer) - - if isinstance(image, bytes): - is_accepted_format(image) - return open_image(BytesIO(image)) - elif not isinstance(image, Image): - image = open_image(image) - image.load() - return image - - return image - -def is_allowed_extension(filename: str) -> bool: - """ - Checks if the given filename has an allowed extension. - - Args: - filename (str): The filename to check. - - Returns: - bool: True if the extension is allowed, False otherwise. - """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS - -def is_data_uri_an_image(data_uri: str) -> bool: - """ - Checks if the given data URI represents an image. - - Args: - data_uri (str): The data URI to check. - - Raises: - ValueError: If the data URI is invalid or the image format is not allowed. - """ - # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif) - if not re.match(r'data:image/(\w+);base64,', data_uri): - raise ValueError("Invalid data URI image.") - # Extract the image format from the data URI - image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower() - # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) - if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml": - raise ValueError("Invalid image format (from mime file type).") - -def is_accepted_format(binary_data: bytes) -> str: - """ - Checks if the given binary data represents an image with an accepted format. - - Args: - binary_data (bytes): The binary data to check. - - Raises: - ValueError: If the image format is not allowed. - """ - if binary_data.startswith(b'\xFF\xD8\xFF'): - return "image/jpeg" - elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'): - return "image/png" - elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'): - return "image/gif" - elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'): - return "image/jpeg" - elif binary_data.startswith(b'\xFF\xD8'): - return "image/jpeg" - elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP': - return "image/webp" - else: - raise ValueError("Invalid image format (from magic code).") - -def extract_data_uri(data_uri: str) -> bytes: - """ - Extracts the binary data from the given data URI. - - Args: - data_uri (str): The data URI. - - Returns: - bytes: The extracted binary data. - """ - data = data_uri.split(",")[-1] - data = base64.b64decode(data) - return data - -def get_orientation(image: Image) -> int: - """ - Gets the orientation of the given image. - - Args: - image (Image): The image. - - Returns: - int: The orientation value. - """ - exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() - if exif_data is not None: - orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF - if orientation is not None: - return orientation - -def process_image(image: Image, new_width: int, new_height: int) -> Image: - """ - Processes the given image by adjusting its orientation and resizing it. - - Args: - image (Image): The image to process. - new_width (int): The new width of the image. - new_height (int): The new height of the image. - - Returns: - Image: The processed image. - """ - # Fix orientation - orientation = get_orientation(image) - if orientation: - if orientation > 4: - image = image.transpose(FLIP_LEFT_RIGHT) - if orientation in [3, 4]: - image = image.transpose(ROTATE_180) - if orientation in [5, 6]: - image = image.transpose(ROTATE_270) - if orientation in [7, 8]: - image = image.transpose(ROTATE_90) - # Resize image - image.thumbnail((new_width, new_height)) - # Remove transparency - if image.mode == "RGBA": - image.load() - white = new_image('RGB', image.size, (255, 255, 255)) - white.paste(image, mask=image.split()[-1]) - return white - # Convert to RGB for jpg format - elif image.mode != "RGB": - image = image.convert("RGB") - return image - - -def to_bytes(image: ImageType) -> bytes: - """ - Converts the given image to bytes. - - Args: - image (ImageType): The image to convert. - - Returns: - bytes: The image as bytes. - """ - if isinstance(image, bytes): - return image - elif isinstance(image, str) and image.startswith("data:"): - is_data_uri_an_image(image) - return extract_data_uri(image) - elif isinstance(image, Image): - bytes_io = BytesIO() - image.save(bytes_io, image.format) - image.seek(0) - return bytes_io.getvalue() - elif isinstance(image, (str, os.PathLike)): - return Path(image).read_bytes() - elif isinstance(image, Path): - return image.read_bytes() - else: - try: - image.seek(0) - except (AttributeError, io.UnsupportedOperation): - pass - return image.read() - -def to_data_uri(image: ImageType) -> str: - if not isinstance(image, str): - data = to_bytes(image) - data_base64 = base64.b64encode(data).decode() - return f"data:{is_accepted_format(data)};base64,{data_base64}" - return image - -class ImageDataResponse(): - def __init__( - self, - images: Union[str, list], - alt: str, - ): - self.images = images - self.alt = alt - - def get_list(self) -> list[str]: - return [self.images] if isinstance(self.images, str) else self.images - -class ImageRequest: - def __init__( - self, - options: dict = {} - ): - self.options = options - - def get(self, key: str): - return self.options.get(key) diff --git a/g4f/image/__init__.py b/g4f/image/__init__.py index cd0abfd3..97f1ad4d 100644 --- a/g4f/image/__init__.py +++ b/g4f/image/__init__.py @@ -17,13 +17,6 @@ except ImportError: from ..typing import ImageType, Union, Image from ..errors import MissingRequirementsError -MEDIA_TYPE_MAP: dict[str, str] = { - "image/png": "png", - "image/jpeg": "jpg", - "image/gif": "gif", - "image/webp": "webp", -} - EXTENSIONS_MAP: dict[str, str] = { # Image "png": "image/png", @@ -44,6 +37,8 @@ EXTENSIONS_MAP: dict[str, str] = { "mp4": "video/mp4", } +MEDIA_TYPE_MAP: dict[str, str] = {value: key for key, value in EXTENSIONS_MAP.items()} + def to_image(image: ImageType, is_svg: bool = False) -> Image: """ Converts the input image to a PIL Image object. @@ -82,6 +77,12 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image: return image +def get_extension(filename: str) -> Optional[str]: + if '.' in filename: + ext = os.path.splitext(filename)[1][1:].lower() + return ext if ext in EXTENSIONS_MAP else None + return None + def is_allowed_extension(filename: str) -> Optional[str]: """ Checks if the given filename has an allowed extension. @@ -92,8 +93,10 @@ def is_allowed_extension(filename: str) -> Optional[str]: Returns: bool: True if the extension is allowed, False otherwise. """ - ext = os.path.splitext(filename)[1][1:].lower() if '.' in filename else None - return EXTENSIONS_MAP[ext] if ext in EXTENSIONS_MAP else None + extension = get_extension(filename) + if extension is None: + return None + return EXTENSIONS_MAP[extension] def is_data_an_media(data, filename: str = None) -> str: content_type = is_data_an_audio(data, filename) @@ -105,12 +108,11 @@ def is_data_an_media(data, filename: str = None) -> str: def is_data_an_audio(data_uri: str = None, filename: str = None) -> str: if filename: - if filename.endswith(".wav"): - return "audio/wav" - elif filename.endswith(".mp3"): - return "audio/mpeg" - elif filename.endswith(".m4a"): - return "audio/m4a" + extension = get_extension(filename) + if extension is not None: + media_type = EXTENSIONS_MAP[extension] + if media_type.startswith("audio/"): + return media_type if isinstance(data_uri, str): audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri) if audio_format: @@ -266,10 +268,13 @@ def to_data_uri(image: ImageType, filename: str = None) -> str: def to_input_audio(audio: ImageType, filename: str = None) -> str: if not isinstance(audio, str): - if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")): + if filename is not None: + format = get_extension(filename) + if format is None: + raise ValueError("Invalid input audio") return { "data": base64.b64encode(to_bytes(audio)).decode(), - "format": "wav" if filename.endswith(".wav") else "mp3" + "format": format } raise ValueError("Invalid input audio") audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio) diff --git a/g4f/image/copy_images.py b/g4f/image/copy_images.py index 426a0c10..9405aaa0 100644 --- a/g4f/image/copy_images.py +++ b/g4f/image/copy_images.py @@ -21,6 +21,13 @@ from .. import debug # Directory for storing generated images images_dir = "./generated_images" +media_dir = "./generated_media" + +def get_media_dir() -> str:# + """Get the directory for storing generated media files""" + if os.access(images_dir, os.R_OK): + return images_dir + return media_dir def get_media_extension(media: str) -> str: """Extract media file extension from URL or filename""" @@ -34,9 +41,10 @@ def get_media_extension(media: str) -> str: raise ValueError(f"Unsupported media extension: {extension} in: {media}") return extension -def ensure_images_dir(): +def ensure_media_dir(): """Create images directory if it doesn't exist""" - os.makedirs(images_dir, exist_ok=True) + if not os.access(images_dir, os.R_OK): + os.makedirs(media_dir, exist_ok=True) def get_source_url(image: str, default: str = None) -> str: """Extract original URL from image parameter if present""" @@ -46,30 +54,27 @@ def get_source_url(image: str, default: str = None) -> str: return decoded_url return default -def is_valid_media_type(content_type: str) -> bool: - return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/") - async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator: """Save media from response to local file and return URL""" content_type = response.headers["content-type"] - if is_valid_media_type(content_type): - extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3") - if extension not in EXTENSIONS_MAP: - raise ValueError(f"Unsupported media type: {content_type}") - filename = get_filename(tags, prompt, f".{extension}", prompt) - target_path = os.path.join(images_dir, filename) - with open(target_path, 'wb') as f: - async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any(): - f.write(chunk) - media_url = f"/media/{filename}" - if response.method == "GET": - media_url = f"{media_url}?url={str(response.url)}" - if content_type.startswith("audio/"): - yield AudioResponse(media_url) - elif content_type.startswith("video/"): - yield VideoResponse(media_url, prompt) - else: - yield ImageResponse(media_url, prompt) + extension = MEDIA_TYPE_MAP.get(content_type) + if extension is None: + raise ValueError(f"Unsupported media type: {content_type}") + filename = get_filename(tags, prompt, f".{extension}", prompt) + target_path = os.path.join(get_media_dir(), filename) + ensure_media_dir() + with open(target_path, 'wb') as f: + async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any(): + f.write(chunk) + media_url = f"/media/{filename}" + if response.method == "GET": + media_url = f"{media_url}?url={str(response.url)}" + if content_type.startswith("audio/"): + yield AudioResponse(media_url) + elif content_type.startswith("video/"): + yield VideoResponse(media_url, prompt) + else: + yield ImageResponse(media_url, prompt) def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str: return "".join(( @@ -97,7 +102,7 @@ async def copy_media( """ if add_url: add_url = not cookies - ensure_images_dir() + ensure_media_dir() async with ClientSession( connector=get_connector(proxy=proxy), @@ -113,7 +118,7 @@ async def copy_media( if target_path is None: # Build safe filename with full Unicode support filename = get_filename(tags, alt, get_media_extension(image), image) - target_path = os.path.join(images_dir, filename) + target_path = os.path.join(get_media_dir(), filename) try: # Handle different image types if image.startswith("data:"): @@ -132,7 +137,7 @@ async def copy_media( response.raise_for_status() media_type = response.headers.get("content-type", "application/octet-stream") if media_type not in ("application/octet-stream", "binary/octet-stream"): - if not is_valid_media_type(media_type): + if media_type not in MEDIA_TYPE_MAP: raise ValueError(f"Unsupported media type: {media_type}") with open(target_path, "wb") as f: async for chunk in response.content.iter_any(): diff --git a/g4f/models.py b/g4f/models.py index c29f7a0e..97585d67 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -1006,6 +1006,6 @@ __models__ = { if model.best_provider is not None and model.best_provider.working else []) for model in ModelUtils.convert.values()] - if [p for p in providers if p.working] + if model.name and [True for provider in providers if provider.working] } _all_models = list(__models__.keys()) diff --git a/g4f/providers/any_provider.py b/g4f/providers/any_provider.py index a8f87fea..85217196 100644 --- a/g4f/providers/any_provider.py +++ b/g4f/providers/any_provider.py @@ -37,7 +37,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): cls.models_count = { model: len(providers) for model, providers in model_with_providers.items() if len(providers) > 1 } - all_models = ["default"] + list(model_with_providers.keys()) + all_models = [cls.default_model] + list(model_with_providers.keys()) for provider in [OpenaiChat, PollinationsAI, HuggingSpace, Cloudflare, PerplexityLabs, Gemini, Grok]: if not provider.working or getattr(provider, "parent", provider.__name__) in ignored: continue @@ -63,6 +63,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): ).replace("-03-2025", "" ).replace("-20250219", "" ).replace("-20241022", "" + ).replace("-20240904", "" ).replace("-2025-04-16", "" ).replace("-2025-04-14", "" ).replace("-0125", "" @@ -72,10 +73,13 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): ).replace("-2409", "" ).replace("-2410", "" ).replace("-2411", "" + ).replace("-1119", "" + ).replace("-0919", "" ).replace("-02-24", "" ).replace("-03-25", "" ).replace("-03-26", "" ).replace("-01-21", "" + ).replace("-002", "" ).replace(".1-", "-" ).replace("_", "." ).replace("c4ai-", "" @@ -98,8 +102,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): for provider in [Microsoft_Phi_4, PollinationsAI]: if provider.working and getattr(provider, "parent", provider.__name__) not in ignored: cls.audio_models.update(provider.audio_models) - cls.models_count.update({model: all_models.count(model) + cls.models_count.get(model, 0) for model in all_models}) - return list(dict.fromkeys([model if model else "default" for model in all_models])) + cls.models_count.update({model: all_models.count(model) for model in all_models if all_models.count(model) > cls.models_count.get(model, 0)}) + return list(dict.fromkeys([model if model else cls.default_model for model in all_models])) @classmethod async def create_async_generator( @@ -117,7 +121,8 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): providers = model.split(":") model = providers.pop() providers = [getattr(Provider, provider) for provider in providers] - elif not model or model == "default": + elif not model or model == cls.default_model: + model = "" has_image = False has_audio = "audio" in kwargs if not has_audio and media is not None: @@ -133,11 +138,11 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin): else: providers = models.default.best_provider.providers else: - for provider in [OpenaiChat, HuggingSpace, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox]: - if provider.working and (model if model else "auto") in provider.get_models(): - providers.append(provider) - for provider in [HuggingFace, HuggingFaceMedia, LambdaChat, LMArenaProvider, CopilotAccount, PollinationsAI, DeepInfraChat]: - if model in provider.model_aliases: + for provider in [ + OpenaiChat, Cloudflare, LMArenaProvider, PerplexityLabs, Gemini, Grok, DeepSeekAPI, FreeRouter, Blackbox, + HuggingFace, HuggingFaceMedia, HuggingSpace, LambdaChat, CopilotAccount, PollinationsAI, DeepInfraChat + ]: + if not model or model in provider.get_models() or model in provider.model_aliases: providers.append(provider) if model in models.__models__: for provider in models.__models__[model][1]: diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index c4b22bb0..13cb21c1 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -449,10 +449,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin): cache_file = cls.get_cache_file() try: if cache_file.exists(): - with cache_file.open("r") as f: - data = f.read() - if data: - auth_result = AuthResult(**json.loads(data)) + try: + with cache_file.open("r") as f: + auth_result = AuthResult(**json.load(f)) + except json.JSONDecodeError: + cache_file.unlink() + raise MissingAuthError(f"Invalid auth file: {cache_file}") else: raise MissingAuthError yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)) @@ -478,8 +480,12 @@ class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin): cache_file = cls.get_cache_file() try: if cache_file.exists(): - with cache_file.open("r") as f: - auth_result = AuthResult(**json.load(f)) + try: + with cache_file.open("r") as f: + auth_result = AuthResult(**json.load(f)) + except json.JSONDecodeError: + cache_file.unlink() + raise MissingAuthError(f"Invalid auth file: {cache_file}") else: raise MissingAuthError response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result)) diff --git a/g4f/providers/response.py b/g4f/providers/response.py index bc73ad43..82593a4b 100644 --- a/g4f/providers/response.py +++ b/g4f/providers/response.py @@ -264,9 +264,10 @@ class YouTube(HiddenResponse): ])) class AudioResponse(ResponseType): - def __init__(self, data: Union[bytes, str]) -> None: + def __init__(self, data: Union[bytes, str], **kwargs) -> None: """Initialize with audio data bytes.""" self.data = data + self.options = kwargs def to_uri(self) -> str: if isinstance(self.data, str): diff --git a/generated_images/.gitkeep b/generated_media/.gitkeep similarity index 100% rename from generated_images/.gitkeep rename to generated_media/.gitkeep