""" 通过管理端 API 批量上传地点。 JSON 输入结构: [ { "title": "示例地点", "city": "上海", "longitude": 121.4737, "latitude": 31.2304, "description": "地点描述", "transport": "地铁可达", "best_time": "傍晚", "difficulty": "低", "is_free": true, "audit_status": "approved", "tag_ids": ["街拍", "城市"], "images": ["./images/a.jpg", "https://example.com/b.jpg"] } ] CSV 使用相同字段名。列表字段可用分号、逗号或竖线分隔, 例如:images=./a.jpg;./b.jpg,tag_ids=街拍,城市。 """ from __future__ import annotations import argparse import csv import json import mimetypes import os from dataclasses import dataclass from pathlib import Path from typing import Any from urllib.parse import urlparse import httpx DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1") DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001") DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456") SPOT_FIELDS = { "title", "city", "longitude", "latitude", "description", "transport", "best_time", "difficulty", "is_free", "price_min", "price_max", "audit_status", "reject_reason", "creator_id", "image_urls", "tag_ids", } @dataclass class UploadResult: index: int title: str ok: bool spot_id: int | None = None error: str | None = None class TagResolver: def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None: self.client = client self.create_missing = create_missing self._cache: dict[str, int] | None = None def resolve_many(self, names: list[str]) -> list[int]: tag_ids: list[int] = [] for name in names: tag_ids.append(self.resolve_one(name)) return list(dict.fromkeys(tag_ids)) def resolve_one(self, name: str) -> int: normalized = normalize_tag_name(name) cache = self._load_cache() if normalized in cache: return cache[normalized] if not self.create_missing: raise ValueError(f"tag not found: {name}") tag_id = self._create_tag(name.strip()) cache[normalized] = tag_id return tag_id def _load_cache(self) -> dict[str, int]: if self._cache is not None: return self._cache response = self.client.get("/tags", params={"sort": "name"}) response.raise_for_status() self._cache = { normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"]) for item in response.json() if item.get("id") and (item.get("name") or item.get("title")) } return self._cache def _create_tag(self, name: str) -> int: response = self.client.post("/tags", json={"name": name}) if response.status_code == 409: self._cache = None cache = self._load_cache() normalized = normalize_tag_name(name) if normalized in cache: return cache[normalized] response.raise_for_status() data = response.json() print(f"[TAG] created {name} -> tag_id={data['id']}") return int(data["id"]) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="通过管理端 API 批量上传地点。", add_help=False) parser._positionals.title = "位置参数" parser._optionals.title = "可选参数" parser.add_argument("-h", "--help", action="help", help="显示帮助信息并退出。") parser.add_argument("input", type=Path, help="JSON 或 CSV 数据文件路径。") parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API 基础地址,默认:%(default)s") parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="管理员手机号或邮箱。") parser.add_argument("--password", default=DEFAULT_PASSWORD, help="管理员密码。") parser.add_argument("--creator-id", type=int, help="数据中未填写 creator_id 时使用的默认创建者用户 ID。") parser.add_argument( "--audit-status", default="approved", choices=["pending", "approved", "rejected", "deleted"], help="默认审核状态:pending=待审核,approved=已通过,rejected=已驳回,deleted=已删除。默认:%(default)s", ) parser.add_argument("--timeout", type=float, default=30.0, help="单次请求超时时间,单位秒。默认:%(default)s") parser.add_argument("--dry-run", action="store_true", help="只校验并打印提交内容,不登录、不上传图片、不创建地点。") parser.add_argument("--stop-on-error", action="store_true", help="任意一条数据失败后立即停止。") return parser.parse_args() def load_items(path: Path) -> list[dict[str, Any]]: if not path.exists(): raise FileNotFoundError(f"input file not found: {path}") suffix = path.suffix.lower() if suffix == ".json": data = json.loads(path.read_text(encoding="utf-8")) if isinstance(data, dict): data = data.get("items") or data.get("spots") if not isinstance(data, list): raise ValueError("JSON input must be a list or an object with items/spots list") return [normalize_item(item, path.parent) for item in data] if suffix == ".csv": with path.open("r", encoding="utf-8-sig", newline="") as f: return [normalize_item(row, path.parent) for row in csv.DictReader(f)] raise ValueError("unsupported input format, use .json or .csv") def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]: item = {k: v for k, v in raw.items() if v not in ("", None)} for key in ("longitude", "latitude", "price_min", "price_max"): if key in item: item[key] = float(item[key]) for key in ("creator_id",): if key in item: item[key] = int(item[key]) if "is_free" in item: item["is_free"] = parse_bool(item["is_free"]) tag_values = item.get("tag_ids", item.get("tags", [])) item["tag_names"] = parse_string_list(tag_values) item.pop("tags", None) item.pop("tag_ids", None) image_urls = parse_string_list(item.get("image_urls", [])) images = parse_string_list(item.get("images", [])) item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]] return item def parse_bool(value: Any) -> bool: if isinstance(value, bool): return value return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "是", "免费"} def parse_string_list(value: Any) -> list[str]: if value is None or value == "": return [] if isinstance(value, list): return [str(v).strip() for v in value if str(v).strip()] text = str(value).strip() for sep in (";", "|", ","): if sep in text: return [part.strip() for part in text.split(sep) if part.strip()] return [text] if text else [] def normalize_tag_name(value: str) -> str: return value.strip().lower() def resolve_local_path(source: str, input_dir: Path) -> str: if is_url(source): return source path = Path(source) if not path.is_absolute(): path = input_dir / path return str(path) def is_url(value: str) -> bool: parsed = urlparse(value) return parsed.scheme in {"http", "https"} and bool(parsed.netloc) def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]: response = client.post("/admin/auth/login", json={"account": account, "password": password}) response.raise_for_status() data = response.json() return data["access_token"], int(data["user"]["id"]) def upload_image(client: httpx.Client, path: Path) -> str: if not path.exists(): raise FileNotFoundError(f"image file not found: {path}") content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg" with path.open("rb") as f: response = client.post( "/upload/image", files={"file": (path.name, f, content_type)}, ) response.raise_for_status() return str(response.json()["url"]) def build_payload( client: httpx.Client, item: dict[str, Any], creator_id: int, audit_status: str, *, upload_files: bool, tag_resolver: TagResolver | None, ) -> dict[str, Any]: payload = {key: item[key] for key in SPOT_FIELDS if key in item} payload.setdefault("creator_id", creator_id) payload.setdefault("audit_status", audit_status) payload.setdefault("is_free", True) if tag_resolver is None: payload["tag_ids"] = [] if item.get("tag_names"): payload["_tag_names"] = item["tag_names"] else: payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", [])) image_urls: list[str] = [] for source in item.get("image_sources", []): if is_url(source): image_urls.append(source) elif not upload_files: image_urls.append(source) else: image_urls.append(upload_image(client, Path(source))) payload["image_urls"] = image_urls missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload] if missing: raise ValueError(f"missing required fields: {', '.join(missing)}") return payload def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int: response = client.post("/admin/spots", json=payload) response.raise_for_status() return int(response.json()["id"]) def run() -> int: args = parse_args() items = load_items(args.input) results: list[UploadResult] = [] with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client: token: str | None = None current_admin_id = args.creator_id if not args.dry_run: token, admin_id = login(client, args.account, args.password) current_admin_id = current_admin_id or admin_id client.headers.update({"Authorization": f"Bearer {token}"}) elif current_admin_id is None: current_admin_id = 0 tag_resolver = None if args.dry_run else TagResolver(client) for index, item in enumerate(items, start=1): title = str(item.get("title", f"row-{index}")) try: payload = build_payload( client, item, current_admin_id, args.audit_status, upload_files=not args.dry_run, tag_resolver=tag_resolver, ) if args.dry_run: tag_names = payload.pop("_tag_names", None) if tag_names: payload["tag_ids"] = tag_names print(json.dumps(payload, ensure_ascii=False, indent=2)) results.append(UploadResult(index=index, title=title, ok=True)) continue spot_id = create_spot(client, payload) print(f"[OK] #{index} {title} -> spot_id={spot_id}") results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id)) except Exception as exc: message = format_error(exc) print(f"[FAIL] #{index} {title}: {message}") results.append(UploadResult(index=index, title=title, ok=False, error=message)) if args.stop_on_error: break ok_count = sum(1 for item in results if item.ok) fail_count = len(results) - ok_count print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}") return 1 if fail_count else 0 def format_error(exc: Exception) -> str: if isinstance(exc, httpx.HTTPStatusError): body = exc.response.text return f"HTTP {exc.response.status_code}: {body}" return str(exc) if __name__ == "__main__": raise SystemExit(run())