Files
CosScene/server/scripts/bulk_upload_spots.py
T

345 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
通过管理端 API 批量上传地点。
JSON 输入结构:
[
{
"title": "示例地点",
"city": "上海",
"longitude": 121.4737,
"latitude": 31.2304,
"description": "地点描述",
"transport": "地铁可达",
"best_time": "傍晚",
"difficulty": "",
"is_free": true,
"audit_status": "approved",
"tag_ids": ["街拍", "城市"],
"images": ["./images/a.jpg", "https://example.com/b.jpg"]
}
]
CSV 使用相同字段名。列表字段可用分号、逗号或竖线分隔,
例如:images=./a.jpg;./b.jpgtag_ids=街拍,城市。
"""
from __future__ import annotations
import argparse
import csv
import json
import mimetypes
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import httpx
DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1")
DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001")
DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456")
SPOT_FIELDS = {
"title",
"city",
"longitude",
"latitude",
"description",
"transport",
"best_time",
"difficulty",
"is_free",
"price_min",
"price_max",
"audit_status",
"reject_reason",
"creator_id",
"image_urls",
"tag_ids",
}
@dataclass
class UploadResult:
index: int
title: str
ok: bool
spot_id: int | None = None
error: str | None = None
class TagResolver:
def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None:
self.client = client
self.create_missing = create_missing
self._cache: dict[str, int] | None = None
def resolve_many(self, names: list[str]) -> list[int]:
tag_ids: list[int] = []
for name in names:
tag_ids.append(self.resolve_one(name))
return list(dict.fromkeys(tag_ids))
def resolve_one(self, name: str) -> int:
normalized = normalize_tag_name(name)
cache = self._load_cache()
if normalized in cache:
return cache[normalized]
if not self.create_missing:
raise ValueError(f"tag not found: {name}")
tag_id = self._create_tag(name.strip())
cache[normalized] = tag_id
return tag_id
def _load_cache(self) -> dict[str, int]:
if self._cache is not None:
return self._cache
response = self.client.get("/tags", params={"sort": "name"})
response.raise_for_status()
self._cache = {
normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"])
for item in response.json()
if item.get("id") and (item.get("name") or item.get("title"))
}
return self._cache
def _create_tag(self, name: str) -> int:
response = self.client.post("/tags", json={"name": name})
if response.status_code == 409:
self._cache = None
cache = self._load_cache()
normalized = normalize_tag_name(name)
if normalized in cache:
return cache[normalized]
response.raise_for_status()
data = response.json()
print(f"[TAG] created {name} -> tag_id={data['id']}")
return int(data["id"])
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="通过管理端 API 批量上传地点。", add_help=False)
parser._positionals.title = "位置参数"
parser._optionals.title = "可选参数"
parser.add_argument("-h", "--help", action="help", help="显示帮助信息并退出。")
parser.add_argument("input", type=Path, help="JSON 或 CSV 数据文件路径。")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API 基础地址,默认:%(default)s")
parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="管理员手机号或邮箱。")
parser.add_argument("--password", default=DEFAULT_PASSWORD, help="管理员密码。")
parser.add_argument("--creator-id", type=int, help="数据中未填写 creator_id 时使用的默认创建者用户 ID。")
parser.add_argument(
"--audit-status",
default="approved",
choices=["pending", "approved", "rejected", "deleted"],
help="默认审核状态:pending=待审核,approved=已通过,rejected=已驳回,deleted=已删除。默认:%(default)s",
)
parser.add_argument("--timeout", type=float, default=30.0, help="单次请求超时时间,单位秒。默认:%(default)s")
parser.add_argument("--dry-run", action="store_true", help="只校验并打印提交内容,不登录、不上传图片、不创建地点。")
parser.add_argument("--stop-on-error", action="store_true", help="任意一条数据失败后立即停止。")
return parser.parse_args()
def load_items(path: Path) -> list[dict[str, Any]]:
if not path.exists():
raise FileNotFoundError(f"input file not found: {path}")
suffix = path.suffix.lower()
if suffix == ".json":
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, dict):
data = data.get("items") or data.get("spots")
if not isinstance(data, list):
raise ValueError("JSON input must be a list or an object with items/spots list")
return [normalize_item(item, path.parent) for item in data]
if suffix == ".csv":
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [normalize_item(row, path.parent) for row in csv.DictReader(f)]
raise ValueError("unsupported input format, use .json or .csv")
def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]:
item = {k: v for k, v in raw.items() if v not in ("", None)}
for key in ("longitude", "latitude", "price_min", "price_max"):
if key in item:
item[key] = float(item[key])
for key in ("creator_id",):
if key in item:
item[key] = int(item[key])
if "is_free" in item:
item["is_free"] = parse_bool(item["is_free"])
tag_values = item.get("tag_ids", item.get("tags", []))
item["tag_names"] = parse_string_list(tag_values)
item.pop("tags", None)
item.pop("tag_ids", None)
image_urls = parse_string_list(item.get("image_urls", []))
images = parse_string_list(item.get("images", []))
item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]]
return item
def parse_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "", "免费"}
def parse_string_list(value: Any) -> list[str]:
if value is None or value == "":
return []
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
text = str(value).strip()
for sep in (";", "|", ","):
if sep in text:
return [part.strip() for part in text.split(sep) if part.strip()]
return [text] if text else []
def normalize_tag_name(value: str) -> str:
return value.strip().lower()
def resolve_local_path(source: str, input_dir: Path) -> str:
if is_url(source):
return source
path = Path(source)
if not path.is_absolute():
path = input_dir / path
return str(path)
def is_url(value: str) -> bool:
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]:
response = client.post("/admin/auth/login", json={"account": account, "password": password})
response.raise_for_status()
data = response.json()
return data["access_token"], int(data["user"]["id"])
def upload_image(client: httpx.Client, path: Path) -> str:
if not path.exists():
raise FileNotFoundError(f"image file not found: {path}")
content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg"
with path.open("rb") as f:
response = client.post(
"/upload/image",
files={"file": (path.name, f, content_type)},
)
response.raise_for_status()
return str(response.json()["url"])
def build_payload(
client: httpx.Client,
item: dict[str, Any],
creator_id: int,
audit_status: str,
*,
upload_files: bool,
tag_resolver: TagResolver | None,
) -> dict[str, Any]:
payload = {key: item[key] for key in SPOT_FIELDS if key in item}
payload.setdefault("creator_id", creator_id)
payload.setdefault("audit_status", audit_status)
payload.setdefault("is_free", True)
if tag_resolver is None:
payload["tag_ids"] = []
if item.get("tag_names"):
payload["_tag_names"] = item["tag_names"]
else:
payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", []))
image_urls: list[str] = []
for source in item.get("image_sources", []):
if is_url(source):
image_urls.append(source)
elif not upload_files:
image_urls.append(source)
else:
image_urls.append(upload_image(client, Path(source)))
payload["image_urls"] = image_urls
missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload]
if missing:
raise ValueError(f"missing required fields: {', '.join(missing)}")
return payload
def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int:
response = client.post("/admin/spots", json=payload)
response.raise_for_status()
return int(response.json()["id"])
def run() -> int:
args = parse_args()
items = load_items(args.input)
results: list[UploadResult] = []
with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client:
token: str | None = None
current_admin_id = args.creator_id
if not args.dry_run:
token, admin_id = login(client, args.account, args.password)
current_admin_id = current_admin_id or admin_id
client.headers.update({"Authorization": f"Bearer {token}"})
elif current_admin_id is None:
current_admin_id = 0
tag_resolver = None if args.dry_run else TagResolver(client)
for index, item in enumerate(items, start=1):
title = str(item.get("title", f"row-{index}"))
try:
payload = build_payload(
client,
item,
current_admin_id,
args.audit_status,
upload_files=not args.dry_run,
tag_resolver=tag_resolver,
)
if args.dry_run:
tag_names = payload.pop("_tag_names", None)
if tag_names:
payload["tag_ids"] = tag_names
print(json.dumps(payload, ensure_ascii=False, indent=2))
results.append(UploadResult(index=index, title=title, ok=True))
continue
spot_id = create_spot(client, payload)
print(f"[OK] #{index} {title} -> spot_id={spot_id}")
results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id))
except Exception as exc:
message = format_error(exc)
print(f"[FAIL] #{index} {title}: {message}")
results.append(UploadResult(index=index, title=title, ok=False, error=message))
if args.stop_on_error:
break
ok_count = sum(1 for item in results if item.ok)
fail_count = len(results) - ok_count
print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}")
return 1 if fail_count else 0
def format_error(exc: Exception) -> str:
if isinstance(exc, httpx.HTTPStatusError):
body = exc.response.text
return f"HTTP {exc.response.status_code}: {body}"
return str(exc)
if __name__ == "__main__":
raise SystemExit(run())