Add CORS support to MCP server and tools; enhance image URL handling in tools

This commit is contained in:
hlohaus 2025-11-02 07:31:28 +01:00
parent 5d53e58d2c
commit a492352901
4 changed files with 29 additions and 12 deletions

View file

@ -84,11 +84,12 @@ def get_mcp_parser():
mcp_parser.add_argument("--http", action="store_true", help="Use HTTP transport instead of stdio.")
mcp_parser.add_argument("--host", default="0.0.0.0", help="Host to bind HTTP server to (default: 0.0.0.0)")
mcp_parser.add_argument("--port", type=int, default=8765, help="Port to bind HTTP server to (default: 8765)")
mcp_parser.add_argument("--origin", type=str, default=None, help="Origin URL for CORS (default: None)")
return mcp_parser
def run_mcp_args(args):
from ..mcp.server import main as mcp_main
mcp_main(http=args.http, host=args.host, port=args.port)
mcp_main(http=args.http, host=args.host, port=args.port, origin=args.origin)
def main():
parser = argparse.ArgumentParser(description="Run gpt4free", exit_on_error=False)

View file

@ -356,11 +356,13 @@ class Backend_Api(Api):
if response.startswith("/media/"):
media_dir = get_media_dir()
filename = os.path.basename(response.split("?")[0])
if not cache_id:
try:
return send_from_directory(os.path.abspath(media_dir), filename)
finally:
if not cache_id:
os.remove(os.path.join(media_dir, filename))
else:
return redirect(response)
elif response.startswith("https://") or response.startswith("http://"):
return redirect(response)
if do_filter:

View file

@ -31,6 +31,7 @@ class MCPRequest:
id: Optional[Union[int, str]] = None
method: Optional[str] = None
params: Optional[Dict[str, Any]] = None
origin: Optional[str] = None
@dataclass
@ -101,6 +102,7 @@ class MCPServer:
elif method == "tools/call":
tool_name = params.get("name")
tool_arguments = params.get("arguments", {})
tool_arguments.setdefault("origin", request.origin)
if tool_name not in self.tools:
return MCPResponse(
@ -173,7 +175,7 @@ class MCPServer:
jsonrpc=request_data.get("jsonrpc", "2.0"),
id=request_data.get("id"),
method=request_data.get("method"),
params=request_data.get("params")
params=request_data.get("params"),
)
# Handle request
@ -199,7 +201,7 @@ class MCPServer:
sys.stderr.write(f"Error: {e}\n")
sys.stderr.flush()
async def run_http(self, host: str = "0.0.0.0", port: int = 8765):
async def run_http(self, host: str = "0.0.0.0", port: int = 8765, origin: Optional[str] = None):
"""Run the MCP server with HTTP transport
Args:
@ -218,12 +220,15 @@ class MCPServer:
try:
# Parse JSON-RPC request from POST body
request_data = await request.json()
if origin is None:
request_origin = request.headers.get("origin")
mcp_request = MCPRequest(
jsonrpc=request_data.get("jsonrpc", "2.0"),
id=request_data.get("id"),
method=request_data.get("method"),
params=request_data.get("params")
params=request_data.get("params"),
origin=request_origin
)
# Handle request
@ -295,7 +300,7 @@ class MCPServer:
await runner.cleanup()
def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765):
def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765, origin: Optional[str] = None):
"""Main entry point for MCP server
Args:
@ -305,7 +310,7 @@ def main(http: bool = False, host: str = "0.0.0.0", port: int = 8765):
"""
server = MCPServer()
if http:
asyncio.run(server.run_http(host, port))
asyncio.run(server.run_http(host, port, origin))
else:
asyncio.run(server.run())

View file

@ -11,6 +11,8 @@ from __future__ import annotations
from typing import Any, Dict
from abc import ABC, abstractmethod
from aiohttp import ClientSession
class MCPTool(ABC):
"""Base class for MCP tools"""
@ -278,6 +280,8 @@ class ImageGenerationTool(MCPTool):
"image": image_url
}
else:
if arguments.get("origin") and image_url.startswith("/media/"):
image_url = f"{arguments.get('origin')}{image_url}"
return {
"prompt": prompt,
"model": model,
@ -437,8 +441,13 @@ class TextToAudioTool(MCPTool):
encoded_prompt = prompt.replace(" ", "%20") # Basic space encoding
# Construct the Pollinations AI text-to-speech URL
base_url = "https://text.pollinations.ai"
audio_url = f"{base_url}/{encoded_prompt}?voice={voice}"
audio_url = f"/backend-api/v2/create?provider=Gemini&model=gemini-audio&cache=true&prompt={encoded_prompt}"
if arguments.get("origin"):
audio_url = f"{arguments.get('origin')}{audio_url}"
async with ClientSession() as session:
async with session.get(audio_url, max_redirects=0) as resp:
audio_url = str(resp.url)
return {
"prompt": prompt,