docs: update media examples and add parameter details for TTS providers

- Updated EdgeTTS example to mention additional audio parameters (`rate`, `volume`, `pitch`)
- Updated gTTS example to include support for `tld` and `slow` audio parameters
- Modified the EdgeTTS provider to use `get_last_message` instead of `format_image_prompt` for prompt handling
- Modified the gTTS provider to use `get_last_message` instead of `format_image_prompt` for prompt handling
- Refactored audio file generation logic in the gTTS provider to handle `language` parameter and update the voice model accordingly
- Refactored backend API code to introduce `cast_str` function for processing responses, including cache management and response formatting
- Fixed a bug in `AnyProvider` where the model string check would fail if the model was `None`
- Added check in `to_string` helper function to handle `None` values correctly
```
This commit is contained in:
hlohaus 2025-04-20 13:54:46 +02:00
parent 099d7283ed
commit 2e928c3b94
6 changed files with 44 additions and 25 deletions

View file

@ -51,9 +51,13 @@ client = Client(provider=EdgeTTS)
response = client.media.generate("Hello", audio={"language": "en"})
response.data[0].save("edge-tts.mp3")
# The EdgeTTS provider also support the audio parameters `rate`, `volume` and `pitch`
client = Client(provider=gTTS)
response = client.media.generate("Hello", audio={"language": "en"})
response = client.media.generate("Hello", audio={"language": "en-US"})
response.data[0].save("google-tts.mp3")
# The gTTS provider also support the audio parameters `tld` and `slow`
```
#### **Transcribe an Audio File:**

View file

@ -15,7 +15,7 @@ from ...typing import AsyncResult, Messages
from ...providers.response import AudioResponse
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
from ..helper import get_last_message
class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
label = "Edge TTS"
@ -43,7 +43,7 @@ class EdgeTTS(AsyncGeneratorProvider, ProviderModelMixin):
audio: dict = {},
**kwargs
) -> AsyncResult:
prompt = format_image_prompt(messages, prompt)
prompt = get_last_message(messages, prompt)
if not prompt:
raise ValueError("Prompt is empty.")
voice = audio.get("voice", model if model and model != cls.model_id else None)

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import os
import random
import asyncio
try:
from gtts import gTTS as gTTS_Service
@ -14,7 +12,7 @@ from ...typing import AsyncResult, Messages
from ...providers.response import AudioResponse
from ...image.copy_images import get_filename, get_media_dir, ensure_media_dir
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
from ..helper import get_last_message
locals = {
"en-AU": ["English (Australia)", "en", "com.au"],
@ -56,7 +54,7 @@ class gTTS(AsyncGeneratorProvider, ProviderModelMixin):
audio: dict = {},
**kwargs
) -> AsyncResult:
prompt = format_image_prompt(messages, prompt)
prompt = get_last_message(messages, prompt)
if not prompt:
raise ValueError("Prompt is empty.")
format = audio.get("format", cls.default_format)
@ -64,6 +62,9 @@ class gTTS(AsyncGeneratorProvider, ProviderModelMixin):
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
if "language" in audio:
model = locals[audio["language"]][0] if audio["language"] in locals else model
gTTS_Service(
prompt,
**{

View file

@ -63,7 +63,7 @@ class Backend_Api(Api):
if app.demo:
@app.route('/', methods=['GET'])
def home():
client_id = os.environ.get("OAUTH_CLIENT_ID", "ed074164-4f8d-4fb2-8bec-44952707965e")
client_id = os.environ.get("OAUTH_CLIENT_ID", "")
backend_url = os.environ.get("G4F_BACKEND_URL", "")
return render_template('demo.html', backend_url=backend_url, client_id=client_id)
else:
@ -248,6 +248,14 @@ class Backend_Api(Api):
"ignore_stream": not request.args.get("stream"),
"tool_calls": tool_calls,
}
def cast_str(response):
for chunk in response:
if isinstance(chunk, FinishReason):
yield f"[{chunk.reason}]" if chunk.reason != "stop" else ""
elif not isinstance(chunk, Exception):
chunk = str(chunk)
if chunk:
yield chunk
if cache_id:
cache_id = sha256(cache_id.encode() + json.dumps(parameters, sort_keys=True).encode()).hexdigest()
cache_dir = Path(get_cookies_dir()) / ".scrape_cache" / "create"
@ -255,26 +263,22 @@ class Backend_Api(Api):
if cache_file.exists():
with cache_file.open("r") as f:
response = f.read()
else:
if not response:
response = iter_run_tools(ChatCompletion.create, **parameters)
cache_dir.mkdir(parents=True, exist_ok=True)
copy_response = [chunk for chunk in response]
with cache_file.open("w") as f:
for chunk in copy_response:
f.write(str(chunk))
copy_response = cast_str(response)
if copy_response:
with cache_file.open("w") as f:
for chunk in copy_response:
f.write(chunk)
response = copy_response
else:
response = iter_run_tools(ChatCompletion.create, **parameters)
response = cast_str(iter_run_tools(ChatCompletion.create, **parameters))
if do_filter_markdown:
return Response(filter_markdown("".join([str(chunk) for chunk in response]), do_filter_markdown), mimetype='text/plain')
def cast_str():
for chunk in response:
if isinstance(chunk, FinishReason):
yield f"[{chunk.reason}]" if chunk.reason != "stop" else ""
elif not isinstance(chunk, Exception):
yield str(chunk)
return Response(cast_str(), mimetype='text/plain')
is_true_filter_markdown = do_filter_markdown.lower() in ["true", "1"]
response = "".join(response)
return Response(filter_markdown(response, do_filter_markdown, response if is_true_filter_markdown else ""), mimetype='text/plain')
return Response(response, mimetype='text/plain')
except Exception as e:
logger.exception(e)
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500

View file

@ -117,7 +117,7 @@ class AnyProvider(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
providers = []
if ":" in model:
if model and ":" in model:
providers = model.split(":")
model = providers.pop()
providers = [getattr(Provider, provider) for provider in providers]

View file

@ -22,6 +22,8 @@ def to_string(value) -> str:
return ""
elif isinstance(value, list):
return "".join([to_string(v) for v in value if v.get("type", "text") == "text"])
elif value is None:
return ""
return str(value)
def render_messages(messages: Messages) -> Iterator:
@ -71,13 +73,21 @@ def get_last_user_message(messages: Messages) -> str:
while last_message is not None and messages:
last_message = messages.pop()
if last_message["role"] == "user":
content = to_string(last_message["content"]).strip()
content = to_string(last_message.get("content")).strip()
if content:
user_messages.append(content)
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])
def get_last_message(messages: Messages, prompt: str = None) -> str:
if prompt is None:
for message in messages[::-1]:
content = to_string(message.get("content")).strip()
if content:
prompt = content
return prompt
def format_image_prompt(messages, prompt: str = None) -> str:
if prompt is None:
return get_last_user_message(messages)