mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
Add CORS support to MCP server and tools; enhance image URL handling in tools
This commit is contained in:
parent
5d53e58d2c
commit
a492352901
4 changed files with 29 additions and 12 deletions
|
|
@ -84,11 +84,12 @@ def get_mcp_parser():
|
|||
mcp_parser.add_argument("--http", action="store_true", help="Use HTTP transport instead of stdio.")
|
||||
mcp_parser.add_argument("--host", default="0.0.0.0", help="Host to bind HTTP server to (default: 0.0.0.0)")
|
||||
mcp_parser.add_argument("--port", type=int, default=8765, help="Port to bind HTTP server to (default: 8765)")
|
||||
mcp_parser.add_argument("--origin", type=str, default=None, help="Origin URL for CORS (default: None)")
|
||||
return mcp_parser
|
||||
|
||||
def run_mcp_args(args):
|
||||
from ..mcp.server import main as mcp_main
|
||||
mcp_main(http=args.http, host=args.host, port=args.port)
|
||||
mcp_main(http=args.http, host=args.host, port=args.port, origin=args.origin)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run gpt4free", exit_on_error=False)
|
||||
|
|
|
|||
|
|
@ -356,11 +356,13 @@ class Backend_Api(Api):
|
|||
if response.startswith("/media/"):
|
||||
media_dir = get_media_dir()
|
||||
filename = os.path.basename(response.split("?")[0])
|
||||
if not cache_id:
|
||||
try:
|
||||
return send_from_directory(os.path.abspath(media_dir), filename)
|
||||
finally:
|
||||
if not cache_id:
|
||||
os.remove(os.path.join(media_dir, filename))
|
||||
else:
|
||||
return redirect(response)
|
||||
elif response.startswith("https://") or response.startswith("http://"):
|
||||
return redirect(response)
|
||||
if do_filter:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class MCPRequest:
|
|||
id: Optional[Union[int, str]] = None
|
||||
method: Optional[str] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
origin: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -101,6 +102,7 @@ class MCPServer:
|
|||
elif method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
tool_arguments = params.get("arguments", {})
|
||||
tool_arguments.setdefault("origin", request.origin)
|
||||
|
||||
if tool_name not in self.tools:
|
||||
return MCPResponse(
|
||||
|
|
@ -173,7 +175,7 @@ class MCPServer:
|
|||
jsonrpc=request_data.get("jsonrpc", "2.0"),
|
||||
id=request_data.get("id"),
|
||||
method=request_data.get("method"),
|
||||
params=request_data.get("params")
|
||||
params=request_data.get("params"),
|
||||
)
|
||||
|
||||
# Handle request
|
||||
|
|
@ -199,7 +201,7 @@ class MCPServer:
|
|||
sys.stderr.write(f"Error: {e}\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
async def run_http(self, host: str = "0.0.0.0", port: int = 8765):
|
||||
async def run_http(self, host: str = "0.0.0.0", port: int = 8765, origin: Optional[str] = None):
|
||||
"""Run the MCP server with HTTP transport
|
||||
|
||||
Args:
|
||||
|
|
@ -218,12 +220,15 @@ class MCPServer:
|
|||
try:
|
||||
# Parse JSON-RPC request from POST body
|
||||
request_data = await request.json()
|
||||
if origin is None:
|
||||
request_origin = request.headers.get("origin")
|
||||
|
||||
mcp_request = MCPRequest(
|
||||
jsonrpc=request_data.get("jsonrpc", "2.0"),
|
||||
id=request_data.get("id"),
|
||||
method=request_data.get("method"),
|
||||
params=request_data.get("params")
|
||||
params=request_data.get("params"),
|
||||
origin=request_origin
|
||||
)
|
||||
|
||||
# Handle request
|
||||
|
|
@ -295,7 +300,7 @@ class MCPServer:
|
|||
await runner.cleanup()
|
||||
|
||||
|
||||
def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765):
|
||||
def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765, origin: Optional[str] = None):
|
||||
"""Main entry point for MCP server
|
||||
|
||||
Args:
|
||||
|
|
@ -305,7 +310,7 @@ def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765):
|
|||
"""
|
||||
server = MCPServer()
|
||||
if http:
|
||||
asyncio.run(server.run_http(host, port))
|
||||
asyncio.run(server.run_http(host, port, origin))
|
||||
else:
|
||||
asyncio.run(server.run())
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from __future__ import annotations
|
|||
from typing import Any, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
||||
class MCPTool(ABC):
|
||||
"""Base class for MCP tools"""
|
||||
|
|
@ -278,6 +280,8 @@ class ImageGenerationTool(MCPTool):
|
|||
"image": image_url
|
||||
}
|
||||
else:
|
||||
if arguments.get("origin") and image_url.startswith("/media/"):
|
||||
image_url = f"{arguments.get('origin')}{image_url}"
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
|
|
@ -437,8 +441,13 @@ class TextToAudioTool(MCPTool):
|
|||
encoded_prompt = prompt.replace(" ", "%20") # Basic space encoding
|
||||
|
||||
# Construct the Pollinations AI text-to-speech URL
|
||||
base_url = "https://text.pollinations.ai"
|
||||
audio_url = f"{base_url}/{encoded_prompt}?voice={voice}"
|
||||
audio_url = f"/backend-api/v2/create?provider=Gemini&model=gemini-audio&cache=true&prompt={encoded_prompt}"
|
||||
|
||||
if arguments.get("origin"):
|
||||
audio_url = f"{arguments.get('origin')}{audio_url}"
|
||||
async with ClientSession() as session:
|
||||
async with session.get(audio_url, max_redirects=0) as resp:
|
||||
audio_url = str(resp.url)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue