Improve update script

This commit is contained in:
hlohaus 2025-06-19 21:02:52 +02:00
parent c37f5c0912
commit ffb1914cbe
3 changed files with 28 additions and 14 deletions

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import time
import asyncio
import random
from aiohttp import ClientSession, ClientTimeout
from urllib.parse import quote, quote_plus
@ -35,7 +33,6 @@ class RequestConfig:
unique_list = list(set(cls.urls[prompt]))[:10]
return VideoResponse(unique_list, prompt, {
"headers": {"authorization": cls.headers.get("authorization")} if cls.headers.get("authorization") else {},
"preview": [url.replace("md.mp4", "thumb.webp") for url in unique_list]
})
async with ClientSession() as session:
found_urls = []
@ -135,7 +132,7 @@ class Video(AsyncGeneratorProvider, ProviderModelMixin):
page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request)
if model == "search":
for _ in range(5):
await page.scroll_down(50)
await page.scroll_down(5)
await asyncio.sleep(1)
response = await RequestConfig.get_response(prompt)
if response:

View file

@ -39,7 +39,7 @@ def get_media_extension(media: str) -> str:
if not extension or len(extension) > 4:
return ""
if extension[1:] not in EXTENSIONS_MAP:
raise ValueError(f"Unsupported media extension: {extension} in: {media}")
raise ""
return extension
def ensure_media_dir():
@ -55,11 +55,10 @@ def get_source_url(image: str, default: str = None) -> str:
return decoded_url
return default
def get_target_path(response, filename: str) -> str:
def update_filename(response, filename: str) -> str:
date = response.headers.get("last-modified", response.headers.get("date"))
timestamp = datetime.strptime(date, '%a, %d %b %Y %H:%M:%S %Z').timestamp()
filename = str(int(timestamp)) + "_" + filename.split("_", maxsplit=1)[-1]
return os.path.join(get_media_dir(), filename)
return str(int(timestamp)) + "_" + filename.split("_", maxsplit=1)[-1]
async def save_response_media(response, prompt: str, tags: list[str]) -> AsyncIterator:
"""Save media from response to local file and return URL"""
@ -71,7 +70,8 @@ async def save_response_media(response, prompt: str, tags: list[str]) -> AsyncIt
raise ValueError(f"Unsupported media type: {content_type}")
filename = get_filename(tags, prompt, f".{extension}", prompt)
target_path = get_target_path(response, filename)
filename = update_filename(response, filename)
target_path = os.path.join(get_media_dir(), filename)
ensure_media_dir()
with open(target_path, 'wb') as f:
if isinstance(response, bytes):
@ -117,6 +117,7 @@ async def copy_media(
tags: list[str] = None,
add_url: Union[bool, str] = True,
target: str = None,
thumbnail: bool = False,
ssl: bool = None,
timeout: Optional[int] = None
) -> list[str]:
@ -127,6 +128,11 @@ async def copy_media(
if add_url:
add_url = not cookies
ensure_media_dir()
media_dir = get_media_dir()
if thumbnail:
media_dir = os.path.join(media_dir, "thumbnails")
if not os.path.exists(media_dir):
os.makedirs(media_dir, exist_ok=True)
async with ClientSession(
connector=get_connector(proxy=proxy),
@ -149,7 +155,7 @@ async def copy_media(
filename = secure_filename(path[len("/media/"):])
else:
filename = get_filename(tags, alt, media_extension, image)
target_path = os.path.join(get_media_dir(), filename)
target_path = os.path.join(media_dir, filename)
try:
# Handle different image types
if image.startswith("data:"):
@ -167,7 +173,8 @@ async def copy_media(
async with session.get(image, ssl=request_ssl, headers=request_headers) as response:
response.raise_for_status()
if target is None:
target_path = get_target_path(response, filename)
filename = update_filename(response, filename)
target_path = os.path.join(media_dir, filename)
media_type = response.headers.get("content-type", "application/octet-stream")
if media_type not in ("application/octet-stream", "binary/octet-stream"):
if media_type not in MEDIA_TYPE_MAP:
@ -190,6 +197,8 @@ async def copy_media(
target_path = f"{target_path}{media_extension}"
except ValueError:
pass
if thumbnail:
return "/thumbnail/" + os.path.basename(target_path)
# Build URL relative to media directory
return f"/media/{os.path.basename(target_path)}" + ('?' + (add_url if isinstance(add_url, str) else '' + 'url=' + quote(image)) if add_url and not image.startswith('data:') else '')

View file

@ -339,9 +339,9 @@ class MediaResponse(ResponseType):
self.alt = alt
self.options = options
def get(self, key: str) -> any:
def get(self, key: str, default: any = None) -> any:
"""Get an option value by key."""
return self.options.get(key)
return self.options.get(key, default)
def get_list(self) -> List[str]:
"""Return images as a list."""
@ -355,7 +355,15 @@ class ImageResponse(MediaResponse):
class VideoResponse(MediaResponse):
def __str__(self) -> str:
"""Return videos as html elements."""
return "\n".join([f'<video controls src="{video}"></video>' for video in self.get_list()])
if self.get("preview"):
result = []
for idx, video in enumerate(self.get_list()):
image = self.get("preview")
if isinstance(image, list) and len(image) > idx:
image = image[idx]
result.append(f'<video controls src="{quote_url(video)}" poster="{quote_url(image)}"></video>')
return "\n".join(result)
return "\n".join([f'<video controls src="{quote_url(video)}"></video>' for video in self.get_list()])
class ImagePreview(ImageResponse):
def __str__(self) -> str: