mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-05 18:20:35 -08:00
Address code review feedback: improve type hints, validation, and documentation
Co-authored-by: hlohaus <983577+hlohaus@users.noreply.github.com>
This commit is contained in:
parent
1e895fbb6a
commit
0a72ce961c
2 changed files with 71 additions and 37 deletions
|
|
@ -12,7 +12,7 @@ from __future__ import annotations
|
|||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||
|
|
@ -22,7 +22,7 @@ from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
|||
class MCPRequest:
|
||||
"""MCP request following JSON-RPC 2.0 format"""
|
||||
jsonrpc: str = "2.0"
|
||||
id: Optional[int | str] = None
|
||||
id: Optional[Union[int, str]] = None
|
||||
method: Optional[str] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class MCPRequest:
|
|||
class MCPResponse:
|
||||
"""MCP response following JSON-RPC 2.0 format"""
|
||||
jsonrpc: str = "2.0"
|
||||
id: Optional[int | str] = None
|
||||
id: Optional[Union[int, str]] = None
|
||||
result: Optional[Any] = None
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -29,8 +29,15 @@ class MCPTool(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, arguments: Dict[str, Any]) -> Any:
|
||||
"""Execute the tool with given arguments"""
|
||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute the tool with given arguments
|
||||
|
||||
Args:
|
||||
arguments: Tool input arguments matching the input_schema
|
||||
|
||||
Returns:
|
||||
Dict containing either results or an error key with error message
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -60,7 +67,11 @@ class WebSearchTool(MCPTool):
|
|||
}
|
||||
|
||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute web search"""
|
||||
"""Execute web search
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Search results or error message
|
||||
"""
|
||||
from ..tools.web_search import do_search
|
||||
|
||||
query = arguments.get("query", "")
|
||||
|
|
@ -72,7 +83,8 @@ class WebSearchTool(MCPTool):
|
|||
}
|
||||
|
||||
try:
|
||||
# Perform search
|
||||
# Perform search - query parameter is used for search execution
|
||||
# and prompt parameter holds the content to be searched
|
||||
result, sources = await do_search(
|
||||
prompt=query,
|
||||
query=query,
|
||||
|
|
@ -127,7 +139,11 @@ class WebScrapeTool(MCPTool):
|
|||
}
|
||||
|
||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute web scraping"""
|
||||
"""Execute web scraping
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Scraped content or error message
|
||||
"""
|
||||
from ..tools.fetch_and_scrape import fetch_and_scrape
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
|
@ -202,7 +218,11 @@ class ImageGenerationTool(MCPTool):
|
|||
}
|
||||
|
||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute image generation"""
|
||||
"""Execute image generation
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Generated image data or error message
|
||||
"""
|
||||
from ..client import AsyncClient
|
||||
from ..image import to_data_uri
|
||||
import base64
|
||||
|
|
@ -228,35 +248,49 @@ class ImageGenerationTool(MCPTool):
|
|||
height=height
|
||||
)
|
||||
|
||||
# Get the image data
|
||||
if response and hasattr(response, 'data') and response.data:
|
||||
image_data = response.data[0]
|
||||
# Get the image data with proper validation
|
||||
if not response:
|
||||
return {
|
||||
"error": "Image generation failed: No response from provider"
|
||||
}
|
||||
|
||||
# Convert to base64 if needed
|
||||
if hasattr(image_data, 'url'):
|
||||
image_url = image_data.url
|
||||
if not hasattr(response, 'data') or not response.data:
|
||||
return {
|
||||
"error": "Image generation failed: No image data in response"
|
||||
}
|
||||
|
||||
# Check if it's already a data URI
|
||||
if image_url.startswith('data:'):
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"image": image_url
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"image_url": image_url
|
||||
}
|
||||
if len(response.data) == 0:
|
||||
return {
|
||||
"error": "Image generation failed: Empty image data array"
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "Image generation failed: No image data in response"
|
||||
}
|
||||
image_data = response.data[0]
|
||||
|
||||
# Check if image_data has url attribute
|
||||
if not hasattr(image_data, 'url'):
|
||||
return {
|
||||
"error": "Image generation failed: No URL in image data"
|
||||
}
|
||||
|
||||
image_url = image_data.url
|
||||
|
||||
# Return result based on URL type
|
||||
if image_url.startswith('data:'):
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"image": image_url
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"image_url": image_url
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue