1.0 init
This commit is contained in:
133
gemini-api/SKILL.md
Normal file
133
gemini-api/SKILL.md
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
---
|
||||||
|
name: gemini-api
|
||||||
|
description: Use when needing to call Gemini LLM for text generation or Gemini LMM for image generation via the local network Master API at 192.168.1.5:10900. Supports chat completion and image generation with concurrency control.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Gemini API (Local Network)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Wraps the local Gemini Master API (192.168.1.5:10900) for LLM text and LMM image generation. Max concurrency = 2. Uses `gemini_client.py` helper script.
|
||||||
|
|
||||||
|
## When to Use
|
||||||
|
|
||||||
|
- Need LLM text generation (chat/completion) via Gemini
|
||||||
|
- Need image generation via Gemini LMM
|
||||||
|
- Any task requiring Gemini model capabilities on the local network
|
||||||
|
|
||||||
|
## Quick Reference
|
||||||
|
|
||||||
|
| Capability | Endpoint | Model Options | Timeout |
|
||||||
|
|-----------|----------|---------------|---------|
|
||||||
|
| LLM Chat (sync) | `/api/v1/llm/chat/sync` | `pro` (default), `flash` | ~200s |
|
||||||
|
| LMM Image (sync) | `/api/v1/lmm/image/sync` | `nano-bananapro` (default) | ~350s |
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Helper Script
|
||||||
|
|
||||||
|
All calls go through `~/.claude/skills/gemini-api/gemini_client.py`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LLM: text generation
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py chat "你的问题"
|
||||||
|
|
||||||
|
# LLM: with model selection
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py chat "你的问题" --model flash
|
||||||
|
|
||||||
|
# LMM: image generation (saves to file)
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py image --prompt "图片描述" --output result.png
|
||||||
|
|
||||||
|
# LMM: image + auto-remove Gemini watermark
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py image --prompt "图片描述" --output result.png --remove-watermark
|
||||||
|
|
||||||
|
# LMM: image from event data (JSON file)
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py image --event-data event.json --output result.png
|
||||||
|
|
||||||
|
# Batch: multiple requests with concurrency=2
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py batch requests.jsonl --output-dir ./results/
|
||||||
|
python ~/.claude/skills/gemini-api/gemini_client.py batch requests.jsonl --output-dir ./results/ --remove-watermark
|
||||||
|
```
|
||||||
|
|
||||||
|
### Python API (recommended for agents)
|
||||||
|
|
||||||
|
When writing inline Python scripts (not via CLI), import the module directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, r'C:\Users\ZhuanZ(无密码)\.claude\skills\gemini-api')
|
||||||
|
# Or on Linux/Mac: sys.path.insert(0, os.path.expanduser('~/.claude/skills/gemini-api'))
|
||||||
|
from gemini_client import image_sync, save_image_from_response, remove_watermark
|
||||||
|
|
||||||
|
# Generate image (always use base64 format)
|
||||||
|
resp = image_sync(prompt="your prompt here")
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
save_image_from_response(resp, "output.png")
|
||||||
|
|
||||||
|
# Remove watermark (optional)
|
||||||
|
remove_watermark("output.png")
|
||||||
|
|
||||||
|
# Or combine: generate + save + remove watermark in one flow
|
||||||
|
resp = image_sync(prompt="your prompt here")
|
||||||
|
save_image_from_response(resp, "output.png")
|
||||||
|
remove_watermark("output.png") # overwrites in-place
|
||||||
|
```
|
||||||
|
|
||||||
|
## Image Generation: Critical Notes
|
||||||
|
|
||||||
|
### MUST use base64 format (NOT url)
|
||||||
|
|
||||||
|
The `output_format="url"` mode is **broken** — returned URLs consistently 404. The client defaults to `base64` which works reliably.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CORRECT — base64 (default, works)
|
||||||
|
resp = image_sync(prompt="...", output_format="base64")
|
||||||
|
save_image_from_response(resp, "out.png")
|
||||||
|
|
||||||
|
# WRONG — url download will fail with 404
|
||||||
|
resp = image_sync(prompt="...", output_format="url")
|
||||||
|
download_image_by_url(resp["result"]["image_url"], "out.png") # 404!
|
||||||
|
```
|
||||||
|
|
||||||
|
### Watermark Removal
|
||||||
|
|
||||||
|
Gemini `nano-bananapro` adds a small **star/sparkle watermark** in the bottom-right corner of every generated image. Use `--remove-watermark` (CLI) or `remove_watermark()` (Python) to clean it.
|
||||||
|
|
||||||
|
**Requires**: `pip install opencv-python numpy` (one-time setup).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from gemini_client import remove_watermark
|
||||||
|
|
||||||
|
# Remove watermark in-place
|
||||||
|
remove_watermark("image.png")
|
||||||
|
|
||||||
|
# Or save to a different file
|
||||||
|
remove_watermark("input.png", "clean.png")
|
||||||
|
|
||||||
|
# Custom region size (for non-standard watermark placement)
|
||||||
|
remove_watermark("input.png", region_w=300, region_h=250)
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it works**: OpenCV `cv2.inpaint` with TELEA algorithm. Detects the watermark region by corner position, creates a mask, and fills in using surrounding pixels. Works well on both solid and complex backgrounds.
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
|
||||||
|
- **Concurrency**: Max 2 simultaneous requests (enforced by helper script)
|
||||||
|
- **API Key**: Uses one key per request, rotated from pool of 3
|
||||||
|
- **Rate Limit**: 10 req/min, 400 req/hour (server-side)
|
||||||
|
- **Proxy**: Must bypass system proxy for 192.168.1.5
|
||||||
|
- **Image format**: Always use `output_format="base64"`, not `"url"`
|
||||||
|
- **Watermark deps**: `remove_watermark()` needs `opencv-python` and `numpy`
|
||||||
|
|
||||||
|
## Common Mistakes
|
||||||
|
|
||||||
|
| Mistake | Consequence | Fix |
|
||||||
|
|---------|-------------|-----|
|
||||||
|
| Using `output_format="url"` | URL downloads return 404 | Use `"base64"` (default) |
|
||||||
|
| Forgetting `--noproxy '*'` with curl | Request hangs (proxy intercepts LAN) | Always add `--noproxy '*'` |
|
||||||
|
| Using `127.0.0.1` instead of `192.168.1.5` | Wrong host | Use `192.168.1.5` |
|
||||||
|
| Calling `download_image_by_url()` | 404 error | Use `save_image_from_response()` |
|
||||||
|
| Exceeding concurrency=2 | Queuing delays, timeouts | Use batch mode |
|
||||||
|
| Not checking `status` field | Missing errors silently | Check `resp.get("status") == "completed"` |
|
||||||
|
| Forgetting watermark removal | Star logo in bottom-right | Add `--remove-watermark` or call `remove_watermark()` |
|
||||||
BIN
gemini-api/__pycache__/gemini_client.cpython-310.pyc
Normal file
BIN
gemini-api/__pycache__/gemini_client.cpython-310.pyc
Normal file
Binary file not shown.
474
gemini-api/gemini_client.py
Normal file
474
gemini-api/gemini_client.py
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Gemini API Client - Local Network Master API wrapper.
|
||||||
|
Supports LLM chat and LMM image generation with concurrency control.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python gemini_client.py chat "your question" [--model pro|flash] [--temperature 0.7]
|
||||||
|
python gemini_client.py image --prompt "description" [--output result.png]
|
||||||
|
python gemini_client.py image --prompt "description" --output icon.png --remove-watermark
|
||||||
|
python gemini_client.py image --event-data event.json --output result.png
|
||||||
|
python gemini_client.py batch requests.jsonl [--output-dir ./results/]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from urllib.request import Request, urlopen, ProxyHandler, build_opener
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
# -- Configuration ----------------------------------------------------------
|
||||||
|
|
||||||
|
MASTER_URL = "http://192.168.1.5:10900"
|
||||||
|
API_KEYS = [
|
||||||
|
"sk-gemini-api-key-002", # primary — routes to asus2023_chat (verified working)
|
||||||
|
"sk-gemini-api-key-001",
|
||||||
|
"test-api-key-001",
|
||||||
|
]
|
||||||
|
MAX_CONCURRENCY = 2
|
||||||
|
DEFAULT_LLM_MODEL = "pro"
|
||||||
|
DEFAULT_LMM_MODEL = "nano-bananapro"
|
||||||
|
CHAT_TIMEOUT = 200 # seconds
|
||||||
|
IMAGE_TIMEOUT = 350 # seconds
|
||||||
|
|
||||||
|
# -- Proxy bypass for LAN ---------------------------------------------------
|
||||||
|
|
||||||
|
_opener = build_opener(ProxyHandler({})) # bypass all proxies
|
||||||
|
|
||||||
|
# -- API Key rotation (thread-safe) -----------------------------------------
|
||||||
|
|
||||||
|
_key_index = 0
|
||||||
|
_key_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _next_api_key():
|
||||||
|
global _key_index
|
||||||
|
with _key_lock:
|
||||||
|
key = API_KEYS[_key_index % len(API_KEYS)]
|
||||||
|
_key_index += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
# -- Low-level HTTP ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _post_json_once(endpoint: str, payload: dict, api_key: str, timeout: int = CHAT_TIMEOUT) -> dict:
|
||||||
|
"""Single POST attempt with a specific API key."""
|
||||||
|
url = f"{MASTER_URL}{endpoint}"
|
||||||
|
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
req = Request(url, data=data, method="POST")
|
||||||
|
req.add_header("Content-Type", "application/json")
|
||||||
|
req.add_header("X-API-Key", api_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _opener.open(req, timeout=timeout) as resp:
|
||||||
|
body = resp.read().decode("utf-8")
|
||||||
|
return json.loads(body)
|
||||||
|
except HTTPError as e:
|
||||||
|
body = e.read().decode("utf-8", errors="replace")
|
||||||
|
return {"error": f"HTTP {e.code}: {body}", "status": "failed"}
|
||||||
|
except URLError as e:
|
||||||
|
return {"error": f"Connection failed: {e.reason}", "status": "failed"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e), "status": "failed"}
|
||||||
|
|
||||||
|
|
||||||
|
def _post_json(endpoint: str, payload: dict, timeout: int = CHAT_TIMEOUT) -> dict:
|
||||||
|
"""POST JSON with auto-retry across all API keys on failure."""
|
||||||
|
last_resp = None
|
||||||
|
for _ in range(len(API_KEYS)):
|
||||||
|
api_key = _next_api_key()
|
||||||
|
resp = _post_json_once(endpoint, payload, api_key, timeout)
|
||||||
|
if resp.get("status") == "completed":
|
||||||
|
return resp
|
||||||
|
# Retryable server-side errors (not logged in, tool not activated, etc.)
|
||||||
|
err = resp.get("error", "")
|
||||||
|
if "未登录" in err or "All retries failed" in err:
|
||||||
|
print(f"[retry] key {api_key[:12]}... failed: {err[:60]}, trying next key", file=sys.stderr)
|
||||||
|
last_resp = resp
|
||||||
|
continue
|
||||||
|
# Non-retryable or unknown — return as-is
|
||||||
|
return resp
|
||||||
|
return last_resp or {"error": "All API keys exhausted", "status": "failed"}
|
||||||
|
|
||||||
|
def _get(endpoint: str, timeout: int = 30) -> bytes:
|
||||||
|
"""GET raw bytes from Master API."""
|
||||||
|
url = f"{MASTER_URL}{endpoint}"
|
||||||
|
api_key = _next_api_key()
|
||||||
|
req = Request(url, method="GET")
|
||||||
|
req.add_header("X-API-Key", api_key)
|
||||||
|
with _opener.open(req, timeout=timeout) as resp:
|
||||||
|
return resp.read()
|
||||||
|
|
||||||
|
# -- High-level API ----------------------------------------------------------
|
||||||
|
|
||||||
|
def chat_sync(messages: list, model: str = DEFAULT_LLM_MODEL,
|
||||||
|
temperature: float = 0.7, max_tokens: int = 4096,
|
||||||
|
new_conversation: bool = True) -> dict:
|
||||||
|
"""Synchronous LLM chat. Returns full response dict."""
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"model": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"new_conversation": new_conversation,
|
||||||
|
}
|
||||||
|
return _post_json("/api/v1/llm/chat/sync", payload, timeout=CHAT_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
def image_sync(prompt: str = None, event_data: dict = None,
|
||||||
|
attachments: dict = None,
|
||||||
|
model: str = DEFAULT_LMM_MODEL,
|
||||||
|
output_format: str = "base64") -> dict:
|
||||||
|
"""Synchronous image generation. Returns full response dict.
|
||||||
|
|
||||||
|
IMPORTANT: Use output_format="base64" (default). The "url" format returns
|
||||||
|
URLs that are NOT downloadable (404). base64 is the only reliable method.
|
||||||
|
"""
|
||||||
|
payload = {"model": model, "output_format": output_format}
|
||||||
|
if prompt:
|
||||||
|
payload["prompt"] = prompt
|
||||||
|
if event_data:
|
||||||
|
payload["event_data"] = event_data
|
||||||
|
if attachments:
|
||||||
|
payload["attachments"] = attachments
|
||||||
|
return _post_json("/api/v1/lmm/image/sync", payload, timeout=IMAGE_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
def save_image_from_response(resp: dict, output_path: str) -> str:
|
||||||
|
"""Extract image from API response and save to file.
|
||||||
|
|
||||||
|
Handles both base64 and URL formats (base64 preferred).
|
||||||
|
Returns the saved file path, or raises on failure.
|
||||||
|
"""
|
||||||
|
if resp.get("status") != "completed":
|
||||||
|
raise RuntimeError(f"Image generation failed: {resp.get('error', resp.get('status'))}")
|
||||||
|
|
||||||
|
result = resp.get("result", {})
|
||||||
|
|
||||||
|
# Prefer base64 (reliable)
|
||||||
|
b64_data = result.get("image_base64") or result.get("base64") or result.get("image_data")
|
||||||
|
if b64_data:
|
||||||
|
img_bytes = base64.b64decode(b64_data)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(img_bytes)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# Fallback: try URL download (usually fails with 404)
|
||||||
|
image_url = result.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
try:
|
||||||
|
return download_image_by_url(image_url, output_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"base64 not in response and URL download failed: {e}. "
|
||||||
|
f"URL was: {MASTER_URL}{image_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(f"No image data in response. Keys: {list(result.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_image(task_id: str, output_path: str) -> str:
|
||||||
|
"""Download generated image by task_id. Returns saved path."""
|
||||||
|
data = _get(f"/api/v1/lmm/image/{task_id}", timeout=60)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_image_by_url(image_url: str, output_path: str) -> str:
|
||||||
|
"""Download image from the result's image_url field.
|
||||||
|
|
||||||
|
WARNING: URL downloads frequently return 404. Prefer base64 format.
|
||||||
|
"""
|
||||||
|
if image_url.startswith("/"):
|
||||||
|
url = f"{MASTER_URL}{image_url}"
|
||||||
|
else:
|
||||||
|
url = image_url
|
||||||
|
req = Request(url, method="GET")
|
||||||
|
req.add_header("X-API-Key", _next_api_key())
|
||||||
|
with _opener.open(req, timeout=60) as resp:
|
||||||
|
data = resp.read()
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
# -- Watermark removal -------------------------------------------------------
|
||||||
|
|
||||||
|
def remove_watermark(input_path: str, output_path: str = None,
|
||||||
|
corner: str = "bottom_right",
|
||||||
|
region_w: int = 260, region_h: int = 200,
|
||||||
|
inpaint_radius: int = 12) -> str:
|
||||||
|
"""Remove Gemini watermark (star logo) from generated images.
|
||||||
|
|
||||||
|
Uses OpenCV inpainting to cleanly erase the bottom-right watermark
|
||||||
|
that nano-bananapro adds to all generated images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to the watermarked image.
|
||||||
|
output_path: Where to save. Defaults to overwriting input_path.
|
||||||
|
corner: Which corner has the watermark (default "bottom_right").
|
||||||
|
region_w: Width of the watermark region in pixels (for 2048px image).
|
||||||
|
region_h: Height of the watermark region in pixels.
|
||||||
|
inpaint_radius: Radius for cv2.inpaint (larger = smoother but slower).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output file path.
|
||||||
|
|
||||||
|
Requires: pip install opencv-python numpy
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Watermark removal requires opencv-python and numpy. "
|
||||||
|
"Install with: pip install opencv-python numpy"
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_path is None:
|
||||||
|
output_path = input_path
|
||||||
|
|
||||||
|
img = cv2.imread(input_path)
|
||||||
|
if img is None:
|
||||||
|
raise FileNotFoundError(f"Cannot read image: {input_path}")
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
# Scale region size proportionally to image dimensions (calibrated for 2048px)
|
||||||
|
scale = max(w, h) / 2048.0
|
||||||
|
rw = int(region_w * scale)
|
||||||
|
rh = int(region_h * scale)
|
||||||
|
|
||||||
|
# Build inpainting mask for the specified corner
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
if corner == "bottom_right":
|
||||||
|
cv2.rectangle(mask, (w - rw, h - rh), (w, h), 255, -1)
|
||||||
|
elif corner == "bottom_left":
|
||||||
|
cv2.rectangle(mask, (0, h - rh), (rw, h), 255, -1)
|
||||||
|
elif corner == "top_right":
|
||||||
|
cv2.rectangle(mask, (w - rw, 0), (w, rh), 255, -1)
|
||||||
|
elif corner == "top_left":
|
||||||
|
cv2.rectangle(mask, (0, 0), (rw, rh), 255, -1)
|
||||||
|
|
||||||
|
result = cv2.inpaint(img, mask, inpaint_radius, cv2.INPAINT_TELEA)
|
||||||
|
|
||||||
|
# Determine output format from extension
|
||||||
|
ext = os.path.splitext(output_path)[1].lower()
|
||||||
|
if ext in (".jpg", ".jpeg"):
|
||||||
|
cv2.imwrite(output_path, result, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||||
|
else:
|
||||||
|
cv2.imwrite(output_path, result)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
# -- Batch execution with concurrency control --------------------------------
|
||||||
|
|
||||||
|
_semaphore = threading.Semaphore(MAX_CONCURRENCY)
|
||||||
|
|
||||||
|
def _run_with_semaphore(fn, *args, **kwargs):
|
||||||
|
with _semaphore:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def batch_execute(requests_list: list, output_dir: str = ".",
|
||||||
|
remove_wm: bool = False) -> list:
|
||||||
|
"""Execute a list of requests with max concurrency = 2.
|
||||||
|
Each item: {"type": "chat"|"image", ...params}
|
||||||
|
Returns list of results.
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def _execute_one(idx, item):
|
||||||
|
req_type = item.get("type", "chat")
|
||||||
|
if req_type == "chat":
|
||||||
|
messages = item.get("messages", [{"role": "user", "content": item.get("content", "")}])
|
||||||
|
resp = chat_sync(
|
||||||
|
messages=messages,
|
||||||
|
model=item.get("model", DEFAULT_LLM_MODEL),
|
||||||
|
temperature=item.get("temperature", 0.7),
|
||||||
|
)
|
||||||
|
return {"index": idx, "type": "chat", "response": resp}
|
||||||
|
elif req_type == "image":
|
||||||
|
resp = image_sync(
|
||||||
|
prompt=item.get("prompt"),
|
||||||
|
event_data=item.get("event_data"),
|
||||||
|
model=item.get("model", DEFAULT_LMM_MODEL),
|
||||||
|
output_format="base64",
|
||||||
|
)
|
||||||
|
if resp.get("status") == "completed":
|
||||||
|
out_path = os.path.join(output_dir, f"image_{idx}.png")
|
||||||
|
try:
|
||||||
|
save_image_from_response(resp, out_path)
|
||||||
|
if remove_wm:
|
||||||
|
remove_watermark(out_path)
|
||||||
|
resp["local_path"] = out_path
|
||||||
|
except Exception as e:
|
||||||
|
resp["save_error"] = str(e)
|
||||||
|
return {"index": idx, "type": "image", "response": resp}
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=MAX_CONCURRENCY) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(_run_with_semaphore, _execute_one, i, item): i
|
||||||
|
for i, item in enumerate(requests_list)
|
||||||
|
}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
results.append(future.result())
|
||||||
|
|
||||||
|
results.sort(key=lambda r: r["index"])
|
||||||
|
return results
|
||||||
|
|
||||||
|
# -- CLI ---------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Ensure UTF-8 output on Windows
|
||||||
|
if sys.platform == "win32":
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8")
|
||||||
|
sys.stderr.reconfigure(encoding="utf-8")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Gemini API Client")
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# chat
|
||||||
|
p_chat = sub.add_parser("chat", help="LLM text generation")
|
||||||
|
p_chat.add_argument("content", help="User message content")
|
||||||
|
p_chat.add_argument("--model", default=DEFAULT_LLM_MODEL, choices=["pro", "flash"])
|
||||||
|
p_chat.add_argument("--temperature", type=float, default=0.7)
|
||||||
|
p_chat.add_argument("--max-tokens", type=int, default=4096)
|
||||||
|
p_chat.add_argument("--system", default=None, help="System prompt")
|
||||||
|
|
||||||
|
# image
|
||||||
|
p_img = sub.add_parser("image", help="LMM image generation")
|
||||||
|
p_img.add_argument("--prompt", default=None, help="Direct image prompt")
|
||||||
|
p_img.add_argument("--event-data", default=None, help="Path to event_data JSON file")
|
||||||
|
p_img.add_argument("--output", "-o", default=None, help="Output file path")
|
||||||
|
p_img.add_argument("--model", default=DEFAULT_LMM_MODEL)
|
||||||
|
p_img.add_argument("--remove-watermark", "--rw", action="store_true",
|
||||||
|
help="Remove Gemini watermark (bottom-right star) after generation")
|
||||||
|
|
||||||
|
# batch
|
||||||
|
p_batch = sub.add_parser("batch", help="Batch requests (max concurrency=2)")
|
||||||
|
p_batch.add_argument("input_file", help="JSONL file with requests")
|
||||||
|
p_batch.add_argument("--output-dir", default="./batch_results")
|
||||||
|
p_batch.add_argument("--remove-watermark", "--rw", action="store_true",
|
||||||
|
help="Remove watermark from all generated images")
|
||||||
|
|
||||||
|
# health
|
||||||
|
sub.add_parser("health", help="Check API health")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "chat":
|
||||||
|
messages = []
|
||||||
|
if args.system:
|
||||||
|
messages.append({"role": "user", "content": args.system})
|
||||||
|
messages.append({"role": "user", "content": args.content})
|
||||||
|
resp = chat_sync(messages, model=args.model,
|
||||||
|
temperature=args.temperature,
|
||||||
|
max_tokens=args.max_tokens)
|
||||||
|
if resp.get("error"):
|
||||||
|
print(f"ERROR: {resp['error']}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
# Extract content
|
||||||
|
result = resp.get("result", {})
|
||||||
|
content = result.get("content", "")
|
||||||
|
print(content)
|
||||||
|
# Print metadata to stderr
|
||||||
|
model_used = result.get("model", "unknown")
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
fallback = result.get("fallback_reason")
|
||||||
|
meta = f"[model={model_used}, tokens={usage.get('prompt_tokens',0)}+{usage.get('completion_tokens',0)}"
|
||||||
|
if fallback:
|
||||||
|
meta += f", fallback={fallback}"
|
||||||
|
meta += "]"
|
||||||
|
print(meta, file=sys.stderr)
|
||||||
|
|
||||||
|
elif args.command == "image":
|
||||||
|
event_data = None
|
||||||
|
if args.event_data:
|
||||||
|
with open(args.event_data, "r", encoding="utf-8") as f:
|
||||||
|
event_data = json.load(f)
|
||||||
|
|
||||||
|
if not args.prompt and not event_data:
|
||||||
|
print("ERROR: --prompt or --event-data required", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Always use base64 — URL download is broken (returns 404)
|
||||||
|
resp = image_sync(prompt=args.prompt, event_data=event_data,
|
||||||
|
model=args.model, output_format="base64")
|
||||||
|
|
||||||
|
if resp.get("error"):
|
||||||
|
print(f"ERROR: {resp['error']}", file=sys.stderr)
|
||||||
|
print(json.dumps(resp, ensure_ascii=False, indent=2), file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if resp.get("status") != "completed":
|
||||||
|
print(f"Status: {resp.get('status')}", file=sys.stderr)
|
||||||
|
print(json.dumps(resp, ensure_ascii=False, indent=2))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
result = resp.get("result", {})
|
||||||
|
prompt_used = result.get("prompt_used", "")
|
||||||
|
gen_time = result.get("generation_time_seconds", 0)
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt_used[:150]}...", file=sys.stderr)
|
||||||
|
print(f"Generation time: {gen_time:.1f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
try:
|
||||||
|
save_image_from_response(resp, args.output)
|
||||||
|
print(f"Saved to: {args.output}")
|
||||||
|
|
||||||
|
if args.remove_watermark:
|
||||||
|
print("Removing watermark...", file=sys.stderr)
|
||||||
|
remove_watermark(args.output)
|
||||||
|
print(f"Watermark removed: {args.output}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Save failed: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# No output path — print base64 data info or URL
|
||||||
|
b64 = result.get("image_base64")
|
||||||
|
if b64:
|
||||||
|
print(f"[base64 image data: {len(b64)} chars, decode with base64.b64decode()]")
|
||||||
|
else:
|
||||||
|
image_url = result.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
print(f"Image URL (may 404): {MASTER_URL}{image_url}")
|
||||||
|
else:
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
elif args.command == "batch":
|
||||||
|
requests_list = []
|
||||||
|
with open(args.input_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
requests_list.append(json.loads(line))
|
||||||
|
|
||||||
|
print(f"Processing {len(requests_list)} requests (max concurrency={MAX_CONCURRENCY})...",
|
||||||
|
file=sys.stderr)
|
||||||
|
results = batch_execute(requests_list, args.output_dir,
|
||||||
|
remove_wm=args.remove_watermark)
|
||||||
|
print(json.dumps(results, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
elif args.command == "health":
|
||||||
|
try:
|
||||||
|
url = f"{MASTER_URL}/health"
|
||||||
|
req = Request(url, method="GET")
|
||||||
|
with _opener.open(req, timeout=10) as resp:
|
||||||
|
body = resp.read().decode("utf-8")
|
||||||
|
print(body)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Health check failed: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user