#!/usr/bin/env python3 """ Gemini API Client - Local Network Master API wrapper. Supports LLM chat and LMM image generation with concurrency control. Usage: python gemini_client.py chat "your question" [--model pro|flash] [--temperature 0.7] python gemini_client.py image --prompt "description" [--output result.png] python gemini_client.py image --prompt "description" --output icon.png --remove-watermark python gemini_client.py image --event-data event.json --output result.png python gemini_client.py batch requests.jsonl [--output-dir ./results/] """ import sys import os import json import argparse import base64 import time import threading from urllib.request import Request, urlopen, ProxyHandler, build_opener from urllib.error import HTTPError, URLError from concurrent.futures import ThreadPoolExecutor, as_completed # -- Configuration ---------------------------------------------------------- MASTER_URL = "http://192.168.1.5:10900" API_KEYS = [ "sk-gemini-api-key-002", # primary — routes to asus2023_chat (verified working) "sk-gemini-api-key-001", "test-api-key-001", ] MAX_CONCURRENCY = 2 DEFAULT_LLM_MODEL = "pro" DEFAULT_LMM_MODEL = "nano-bananapro" CHAT_TIMEOUT = 200 # seconds IMAGE_TIMEOUT = 350 # seconds # -- Proxy bypass for LAN --------------------------------------------------- _opener = build_opener(ProxyHandler({})) # bypass all proxies # -- API Key rotation (thread-safe) ----------------------------------------- _key_index = 0 _key_lock = threading.Lock() def _next_api_key(): global _key_index with _key_lock: key = API_KEYS[_key_index % len(API_KEYS)] _key_index += 1 return key # -- Low-level HTTP ---------------------------------------------------------- def _post_json_once(endpoint: str, payload: dict, api_key: str, timeout: int = CHAT_TIMEOUT) -> dict: """Single POST attempt with a specific API key.""" url = f"{MASTER_URL}{endpoint}" data = json.dumps(payload, ensure_ascii=False).encode("utf-8") req = Request(url, data=data, method="POST") req.add_header("Content-Type", "application/json") req.add_header("X-API-Key", api_key) try: with _opener.open(req, timeout=timeout) as resp: body = resp.read().decode("utf-8") return json.loads(body) except HTTPError as e: body = e.read().decode("utf-8", errors="replace") return {"error": f"HTTP {e.code}: {body}", "status": "failed"} except URLError as e: return {"error": f"Connection failed: {e.reason}", "status": "failed"} except Exception as e: return {"error": str(e), "status": "failed"} def _post_json(endpoint: str, payload: dict, timeout: int = CHAT_TIMEOUT) -> dict: """POST JSON with auto-retry across all API keys on failure.""" last_resp = None for _ in range(len(API_KEYS)): api_key = _next_api_key() resp = _post_json_once(endpoint, payload, api_key, timeout) if resp.get("status") == "completed": return resp # Retryable server-side errors (not logged in, tool not activated, etc.) err = resp.get("error", "") if "未登录" in err or "All retries failed" in err: print(f"[retry] key {api_key[:12]}... failed: {err[:60]}, trying next key", file=sys.stderr) last_resp = resp continue # Non-retryable or unknown — return as-is return resp return last_resp or {"error": "All API keys exhausted", "status": "failed"} def _get(endpoint: str, timeout: int = 30) -> bytes: """GET raw bytes from Master API.""" url = f"{MASTER_URL}{endpoint}" api_key = _next_api_key() req = Request(url, method="GET") req.add_header("X-API-Key", api_key) with _opener.open(req, timeout=timeout) as resp: return resp.read() # -- High-level API ---------------------------------------------------------- def chat_sync(messages: list, model: str = DEFAULT_LLM_MODEL, temperature: float = 0.7, max_tokens: int = 4096, new_conversation: bool = True) -> dict: """Synchronous LLM chat. Returns full response dict.""" payload = { "messages": messages, "model": model, "temperature": temperature, "max_tokens": max_tokens, "new_conversation": new_conversation, } return _post_json("/api/v1/llm/chat/sync", payload, timeout=CHAT_TIMEOUT) def image_sync(prompt: str = None, event_data: dict = None, attachments: dict = None, model: str = DEFAULT_LMM_MODEL, output_format: str = "base64") -> dict: """Synchronous image generation. Returns full response dict. IMPORTANT: Use output_format="base64" (default). The "url" format returns URLs that are NOT downloadable (404). base64 is the only reliable method. """ payload = {"model": model, "output_format": output_format} if prompt: payload["prompt"] = prompt if event_data: payload["event_data"] = event_data if attachments: payload["attachments"] = attachments return _post_json("/api/v1/lmm/image/sync", payload, timeout=IMAGE_TIMEOUT) def save_image_from_response(resp: dict, output_path: str) -> str: """Extract image from API response and save to file. Handles both base64 and URL formats (base64 preferred). Returns the saved file path, or raises on failure. """ if resp.get("status") != "completed": raise RuntimeError(f"Image generation failed: {resp.get('error', resp.get('status'))}") result = resp.get("result", {}) # Prefer base64 (reliable) b64_data = result.get("image_base64") or result.get("base64") or result.get("image_data") if b64_data: img_bytes = base64.b64decode(b64_data) with open(output_path, "wb") as f: f.write(img_bytes) return output_path # Fallback: try URL download (usually fails with 404) image_url = result.get("image_url", "") if image_url: try: return download_image_by_url(image_url, output_path) except Exception as e: raise RuntimeError( f"base64 not in response and URL download failed: {e}. " f"URL was: {MASTER_URL}{image_url}" ) raise RuntimeError(f"No image data in response. Keys: {list(result.keys())}") def download_image(task_id: str, output_path: str) -> str: """Download generated image by task_id. Returns saved path.""" data = _get(f"/api/v1/lmm/image/{task_id}", timeout=60) with open(output_path, "wb") as f: f.write(data) return output_path def download_image_by_url(image_url: str, output_path: str) -> str: """Download image from the result's image_url field. WARNING: URL downloads frequently return 404. Prefer base64 format. """ if image_url.startswith("/"): url = f"{MASTER_URL}{image_url}" else: url = image_url req = Request(url, method="GET") req.add_header("X-API-Key", _next_api_key()) with _opener.open(req, timeout=60) as resp: data = resp.read() with open(output_path, "wb") as f: f.write(data) return output_path # -- Watermark removal ------------------------------------------------------- def remove_watermark(input_path: str, output_path: str = None, corner: str = "bottom_right", region_w: int = 260, region_h: int = 200, inpaint_radius: int = 12) -> str: """Remove Gemini watermark (star logo) from generated images. Uses OpenCV inpainting to cleanly erase the bottom-right watermark that nano-bananapro adds to all generated images. Args: input_path: Path to the watermarked image. output_path: Where to save. Defaults to overwriting input_path. corner: Which corner has the watermark (default "bottom_right"). region_w: Width of the watermark region in pixels (for 2048px image). region_h: Height of the watermark region in pixels. inpaint_radius: Radius for cv2.inpaint (larger = smoother but slower). Returns: The output file path. Requires: pip install opencv-python numpy """ try: import cv2 import numpy as np except ImportError: raise RuntimeError( "Watermark removal requires opencv-python and numpy. " "Install with: pip install opencv-python numpy" ) if output_path is None: output_path = input_path img = cv2.imread(input_path) if img is None: raise FileNotFoundError(f"Cannot read image: {input_path}") h, w = img.shape[:2] # Scale region size proportionally to image dimensions (calibrated for 2048px) scale = max(w, h) / 2048.0 rw = int(region_w * scale) rh = int(region_h * scale) # Build inpainting mask for the specified corner mask = np.zeros((h, w), dtype=np.uint8) if corner == "bottom_right": cv2.rectangle(mask, (w - rw, h - rh), (w, h), 255, -1) elif corner == "bottom_left": cv2.rectangle(mask, (0, h - rh), (rw, h), 255, -1) elif corner == "top_right": cv2.rectangle(mask, (w - rw, 0), (w, rh), 255, -1) elif corner == "top_left": cv2.rectangle(mask, (0, 0), (rw, rh), 255, -1) result = cv2.inpaint(img, mask, inpaint_radius, cv2.INPAINT_TELEA) # Determine output format from extension ext = os.path.splitext(output_path)[1].lower() if ext in (".jpg", ".jpeg"): cv2.imwrite(output_path, result, [cv2.IMWRITE_JPEG_QUALITY, 95]) else: cv2.imwrite(output_path, result) return output_path # -- Batch execution with concurrency control -------------------------------- _semaphore = threading.Semaphore(MAX_CONCURRENCY) def _run_with_semaphore(fn, *args, **kwargs): with _semaphore: return fn(*args, **kwargs) def batch_execute(requests_list: list, output_dir: str = ".", remove_wm: bool = False) -> list: """Execute a list of requests with max concurrency = 2. Each item: {"type": "chat"|"image", ...params} Returns list of results. """ os.makedirs(output_dir, exist_ok=True) results = [] def _execute_one(idx, item): req_type = item.get("type", "chat") if req_type == "chat": messages = item.get("messages", [{"role": "user", "content": item.get("content", "")}]) resp = chat_sync( messages=messages, model=item.get("model", DEFAULT_LLM_MODEL), temperature=item.get("temperature", 0.7), ) return {"index": idx, "type": "chat", "response": resp} elif req_type == "image": resp = image_sync( prompt=item.get("prompt"), event_data=item.get("event_data"), model=item.get("model", DEFAULT_LMM_MODEL), output_format="base64", ) if resp.get("status") == "completed": out_path = os.path.join(output_dir, f"image_{idx}.png") try: save_image_from_response(resp, out_path) if remove_wm: remove_watermark(out_path) resp["local_path"] = out_path except Exception as e: resp["save_error"] = str(e) return {"index": idx, "type": "image", "response": resp} with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as pool: futures = { pool.submit(_run_with_semaphore, _execute_one, i, item): i for i, item in enumerate(requests_list) } for future in as_completed(futures): results.append(future.result()) results.sort(key=lambda r: r["index"]) return results # -- CLI --------------------------------------------------------------------- def main(): # Ensure UTF-8 output on Windows if sys.platform == "win32": sys.stdout.reconfigure(encoding="utf-8") sys.stderr.reconfigure(encoding="utf-8") parser = argparse.ArgumentParser(description="Gemini API Client") sub = parser.add_subparsers(dest="command", required=True) # chat p_chat = sub.add_parser("chat", help="LLM text generation") p_chat.add_argument("content", help="User message content") p_chat.add_argument("--model", default=DEFAULT_LLM_MODEL, choices=["pro", "flash"]) p_chat.add_argument("--temperature", type=float, default=0.7) p_chat.add_argument("--max-tokens", type=int, default=4096) p_chat.add_argument("--system", default=None, help="System prompt") # image p_img = sub.add_parser("image", help="LMM image generation") p_img.add_argument("--prompt", default=None, help="Direct image prompt") p_img.add_argument("--event-data", default=None, help="Path to event_data JSON file") p_img.add_argument("--output", "-o", default=None, help="Output file path") p_img.add_argument("--model", default=DEFAULT_LMM_MODEL) p_img.add_argument("--remove-watermark", "--rw", action="store_true", help="Remove Gemini watermark (bottom-right star) after generation") # batch p_batch = sub.add_parser("batch", help="Batch requests (max concurrency=2)") p_batch.add_argument("input_file", help="JSONL file with requests") p_batch.add_argument("--output-dir", default="./batch_results") p_batch.add_argument("--remove-watermark", "--rw", action="store_true", help="Remove watermark from all generated images") # health sub.add_parser("health", help="Check API health") args = parser.parse_args() if args.command == "chat": messages = [] if args.system: messages.append({"role": "user", "content": args.system}) messages.append({"role": "user", "content": args.content}) resp = chat_sync(messages, model=args.model, temperature=args.temperature, max_tokens=args.max_tokens) if resp.get("error"): print(f"ERROR: {resp['error']}", file=sys.stderr) sys.exit(1) # Extract content result = resp.get("result", {}) content = result.get("content", "") print(content) # Print metadata to stderr model_used = result.get("model", "unknown") usage = result.get("usage", {}) fallback = result.get("fallback_reason") meta = f"[model={model_used}, tokens={usage.get('prompt_tokens',0)}+{usage.get('completion_tokens',0)}" if fallback: meta += f", fallback={fallback}" meta += "]" print(meta, file=sys.stderr) elif args.command == "image": event_data = None if args.event_data: with open(args.event_data, "r", encoding="utf-8") as f: event_data = json.load(f) if not args.prompt and not event_data: print("ERROR: --prompt or --event-data required", file=sys.stderr) sys.exit(1) # Always use base64 — URL download is broken (returns 404) resp = image_sync(prompt=args.prompt, event_data=event_data, model=args.model, output_format="base64") if resp.get("error"): print(f"ERROR: {resp['error']}", file=sys.stderr) print(json.dumps(resp, ensure_ascii=False, indent=2), file=sys.stderr) sys.exit(1) if resp.get("status") != "completed": print(f"Status: {resp.get('status')}", file=sys.stderr) print(json.dumps(resp, ensure_ascii=False, indent=2)) sys.exit(1) result = resp.get("result", {}) prompt_used = result.get("prompt_used", "") gen_time = result.get("generation_time_seconds", 0) print(f"Prompt: {prompt_used[:150]}...", file=sys.stderr) print(f"Generation time: {gen_time:.1f}s", file=sys.stderr) if args.output: try: save_image_from_response(resp, args.output) print(f"Saved to: {args.output}") if args.remove_watermark: print("Removing watermark...", file=sys.stderr) remove_watermark(args.output) print(f"Watermark removed: {args.output}") except Exception as e: print(f"Save failed: {e}", file=sys.stderr) sys.exit(1) else: # No output path — print base64 data info or URL b64 = result.get("image_base64") if b64: print(f"[base64 image data: {len(b64)} chars, decode with base64.b64decode()]") else: image_url = result.get("image_url", "") if image_url: print(f"Image URL (may 404): {MASTER_URL}{image_url}") else: print(json.dumps(result, ensure_ascii=False, indent=2)) elif args.command == "batch": requests_list = [] with open(args.input_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: requests_list.append(json.loads(line)) print(f"Processing {len(requests_list)} requests (max concurrency={MAX_CONCURRENCY})...", file=sys.stderr) results = batch_execute(requests_list, args.output_dir, remove_wm=args.remove_watermark) print(json.dumps(results, ensure_ascii=False, indent=2)) elif args.command == "health": try: url = f"{MASTER_URL}/health" req = Request(url, method="GET") with _opener.open(req, timeout=10) as resp: body = resp.read().decode("utf-8") print(body) except Exception as e: print(f"Health check failed: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()