Add example for video generation

Add support for images in messages
This commit is contained in:
hlohaus 2025-03-27 09:38:31 +01:00
parent db1cfc48bc
commit 46d0b87008
26 changed files with 410 additions and 230 deletions

View file

@ -6,6 +6,15 @@ import string
from ..typing import Messages, Cookies, AsyncIterator, Iterator
from .. import debug
def to_string(value) -> str:
if isinstance(value, str):
return value
elif isinstance(value, dict):
return value.get("text")
elif isinstance(value, list):
return "".join([to_string(v) for v in value if v.get("type") == "text"])
return str(value)
def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
"""
Format a series of messages into a single string, optionally adding special tokens.
@ -18,11 +27,16 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
str: A formatted string containing all messages.
"""
if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"]
formatted = "\n".join([
f'{message["role"].capitalize()}: {message["content"]}'
return to_string(messages[0]["content"])
messages = [
(message["role"], to_string(message["content"]))
for message in messages
if include_system or message["role"] != "system"
if include_system or message.get("role") != "system"
]
formatted = "\n".join([
f'{role.capitalize()}: {content}'
for role, content in messages
if content.strip()
])
if do_continue:
return formatted
@ -34,11 +48,13 @@ def get_system_prompt(messages: Messages) -> str:
def get_last_user_message(messages: Messages) -> str:
user_messages = []
last_message = None if len(messages) == 0 else messages[-1]
messages = messages.copy()
while last_message is not None and messages:
last_message = messages.pop()
if last_message["role"] == "user":
if isinstance(last_message["content"], str):
user_messages.append(last_message["content"].strip())
content = to_string(last_message["content"]).strip()
if content:
user_messages.append(content)
else:
return "\n".join(user_messages[::-1])
return "\n".join(user_messages[::-1])