345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""
|
||
通过管理端 API 批量上传地点。
|
||
|
||
JSON 输入结构:
|
||
[
|
||
{
|
||
"title": "示例地点",
|
||
"city": "上海",
|
||
"longitude": 121.4737,
|
||
"latitude": 31.2304,
|
||
"description": "地点描述",
|
||
"transport": "地铁可达",
|
||
"best_time": "傍晚",
|
||
"difficulty": "低",
|
||
"is_free": true,
|
||
"audit_status": "approved",
|
||
"tag_ids": ["街拍", "城市"],
|
||
"images": ["./images/a.jpg", "https://example.com/b.jpg"]
|
||
}
|
||
]
|
||
|
||
CSV 使用相同字段名。列表字段可用分号、逗号或竖线分隔,
|
||
例如:images=./a.jpg;./b.jpg,tag_ids=街拍,城市。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import json
|
||
import mimetypes
|
||
import os
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from urllib.parse import urlparse
|
||
|
||
import httpx
|
||
|
||
|
||
DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1")
|
||
DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001")
|
||
DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456")
|
||
|
||
SPOT_FIELDS = {
|
||
"title",
|
||
"city",
|
||
"longitude",
|
||
"latitude",
|
||
"description",
|
||
"transport",
|
||
"best_time",
|
||
"difficulty",
|
||
"is_free",
|
||
"price_min",
|
||
"price_max",
|
||
"audit_status",
|
||
"reject_reason",
|
||
"creator_id",
|
||
"image_urls",
|
||
"tag_ids",
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class UploadResult:
|
||
index: int
|
||
title: str
|
||
ok: bool
|
||
spot_id: int | None = None
|
||
error: str | None = None
|
||
|
||
|
||
class TagResolver:
|
||
def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None:
|
||
self.client = client
|
||
self.create_missing = create_missing
|
||
self._cache: dict[str, int] | None = None
|
||
|
||
def resolve_many(self, names: list[str]) -> list[int]:
|
||
tag_ids: list[int] = []
|
||
for name in names:
|
||
tag_ids.append(self.resolve_one(name))
|
||
return list(dict.fromkeys(tag_ids))
|
||
|
||
def resolve_one(self, name: str) -> int:
|
||
normalized = normalize_tag_name(name)
|
||
cache = self._load_cache()
|
||
if normalized in cache:
|
||
return cache[normalized]
|
||
if not self.create_missing:
|
||
raise ValueError(f"tag not found: {name}")
|
||
tag_id = self._create_tag(name.strip())
|
||
cache[normalized] = tag_id
|
||
return tag_id
|
||
|
||
def _load_cache(self) -> dict[str, int]:
|
||
if self._cache is not None:
|
||
return self._cache
|
||
response = self.client.get("/tags", params={"sort": "name"})
|
||
response.raise_for_status()
|
||
self._cache = {
|
||
normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"])
|
||
for item in response.json()
|
||
if item.get("id") and (item.get("name") or item.get("title"))
|
||
}
|
||
return self._cache
|
||
|
||
def _create_tag(self, name: str) -> int:
|
||
response = self.client.post("/tags", json={"name": name})
|
||
if response.status_code == 409:
|
||
self._cache = None
|
||
cache = self._load_cache()
|
||
normalized = normalize_tag_name(name)
|
||
if normalized in cache:
|
||
return cache[normalized]
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
print(f"[TAG] created {name} -> tag_id={data['id']}")
|
||
return int(data["id"])
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="通过管理端 API 批量上传地点。", add_help=False)
|
||
parser._positionals.title = "位置参数"
|
||
parser._optionals.title = "可选参数"
|
||
parser.add_argument("-h", "--help", action="help", help="显示帮助信息并退出。")
|
||
parser.add_argument("input", type=Path, help="JSON 或 CSV 数据文件路径。")
|
||
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API 基础地址,默认:%(default)s")
|
||
parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="管理员手机号或邮箱。")
|
||
parser.add_argument("--password", default=DEFAULT_PASSWORD, help="管理员密码。")
|
||
parser.add_argument("--creator-id", type=int, help="数据中未填写 creator_id 时使用的默认创建者用户 ID。")
|
||
parser.add_argument(
|
||
"--audit-status",
|
||
default="approved",
|
||
choices=["pending", "approved", "rejected", "deleted"],
|
||
help="默认审核状态:pending=待审核,approved=已通过,rejected=已驳回,deleted=已删除。默认:%(default)s",
|
||
)
|
||
parser.add_argument("--timeout", type=float, default=30.0, help="单次请求超时时间,单位秒。默认:%(default)s")
|
||
parser.add_argument("--dry-run", action="store_true", help="只校验并打印提交内容,不登录、不上传图片、不创建地点。")
|
||
parser.add_argument("--stop-on-error", action="store_true", help="任意一条数据失败后立即停止。")
|
||
return parser.parse_args()
|
||
|
||
|
||
def load_items(path: Path) -> list[dict[str, Any]]:
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"input file not found: {path}")
|
||
suffix = path.suffix.lower()
|
||
if suffix == ".json":
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
if isinstance(data, dict):
|
||
data = data.get("items") or data.get("spots")
|
||
if not isinstance(data, list):
|
||
raise ValueError("JSON input must be a list or an object with items/spots list")
|
||
return [normalize_item(item, path.parent) for item in data]
|
||
if suffix == ".csv":
|
||
with path.open("r", encoding="utf-8-sig", newline="") as f:
|
||
return [normalize_item(row, path.parent) for row in csv.DictReader(f)]
|
||
raise ValueError("unsupported input format, use .json or .csv")
|
||
|
||
|
||
def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]:
|
||
item = {k: v for k, v in raw.items() if v not in ("", None)}
|
||
|
||
for key in ("longitude", "latitude", "price_min", "price_max"):
|
||
if key in item:
|
||
item[key] = float(item[key])
|
||
|
||
for key in ("creator_id",):
|
||
if key in item:
|
||
item[key] = int(item[key])
|
||
|
||
if "is_free" in item:
|
||
item["is_free"] = parse_bool(item["is_free"])
|
||
|
||
tag_values = item.get("tag_ids", item.get("tags", []))
|
||
item["tag_names"] = parse_string_list(tag_values)
|
||
item.pop("tags", None)
|
||
item.pop("tag_ids", None)
|
||
|
||
image_urls = parse_string_list(item.get("image_urls", []))
|
||
images = parse_string_list(item.get("images", []))
|
||
item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]]
|
||
|
||
return item
|
||
|
||
|
||
def parse_bool(value: Any) -> bool:
|
||
if isinstance(value, bool):
|
||
return value
|
||
return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "是", "免费"}
|
||
|
||
|
||
def parse_string_list(value: Any) -> list[str]:
|
||
if value is None or value == "":
|
||
return []
|
||
if isinstance(value, list):
|
||
return [str(v).strip() for v in value if str(v).strip()]
|
||
text = str(value).strip()
|
||
for sep in (";", "|", ","):
|
||
if sep in text:
|
||
return [part.strip() for part in text.split(sep) if part.strip()]
|
||
return [text] if text else []
|
||
|
||
|
||
def normalize_tag_name(value: str) -> str:
|
||
return value.strip().lower()
|
||
|
||
|
||
def resolve_local_path(source: str, input_dir: Path) -> str:
|
||
if is_url(source):
|
||
return source
|
||
path = Path(source)
|
||
if not path.is_absolute():
|
||
path = input_dir / path
|
||
return str(path)
|
||
|
||
|
||
def is_url(value: str) -> bool:
|
||
parsed = urlparse(value)
|
||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||
|
||
|
||
def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]:
|
||
response = client.post("/admin/auth/login", json={"account": account, "password": password})
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data["access_token"], int(data["user"]["id"])
|
||
|
||
|
||
def upload_image(client: httpx.Client, path: Path) -> str:
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"image file not found: {path}")
|
||
content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg"
|
||
with path.open("rb") as f:
|
||
response = client.post(
|
||
"/upload/image",
|
||
files={"file": (path.name, f, content_type)},
|
||
)
|
||
response.raise_for_status()
|
||
return str(response.json()["url"])
|
||
|
||
|
||
def build_payload(
|
||
client: httpx.Client,
|
||
item: dict[str, Any],
|
||
creator_id: int,
|
||
audit_status: str,
|
||
*,
|
||
upload_files: bool,
|
||
tag_resolver: TagResolver | None,
|
||
) -> dict[str, Any]:
|
||
payload = {key: item[key] for key in SPOT_FIELDS if key in item}
|
||
payload.setdefault("creator_id", creator_id)
|
||
payload.setdefault("audit_status", audit_status)
|
||
payload.setdefault("is_free", True)
|
||
if tag_resolver is None:
|
||
payload["tag_ids"] = []
|
||
if item.get("tag_names"):
|
||
payload["_tag_names"] = item["tag_names"]
|
||
else:
|
||
payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", []))
|
||
|
||
image_urls: list[str] = []
|
||
for source in item.get("image_sources", []):
|
||
if is_url(source):
|
||
image_urls.append(source)
|
||
elif not upload_files:
|
||
image_urls.append(source)
|
||
else:
|
||
image_urls.append(upload_image(client, Path(source)))
|
||
payload["image_urls"] = image_urls
|
||
|
||
missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload]
|
||
if missing:
|
||
raise ValueError(f"missing required fields: {', '.join(missing)}")
|
||
return payload
|
||
|
||
|
||
def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int:
|
||
response = client.post("/admin/spots", json=payload)
|
||
response.raise_for_status()
|
||
return int(response.json()["id"])
|
||
|
||
|
||
def run() -> int:
|
||
args = parse_args()
|
||
items = load_items(args.input)
|
||
results: list[UploadResult] = []
|
||
|
||
with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client:
|
||
token: str | None = None
|
||
current_admin_id = args.creator_id
|
||
if not args.dry_run:
|
||
token, admin_id = login(client, args.account, args.password)
|
||
current_admin_id = current_admin_id or admin_id
|
||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||
elif current_admin_id is None:
|
||
current_admin_id = 0
|
||
|
||
tag_resolver = None if args.dry_run else TagResolver(client)
|
||
for index, item in enumerate(items, start=1):
|
||
title = str(item.get("title", f"row-{index}"))
|
||
try:
|
||
payload = build_payload(
|
||
client,
|
||
item,
|
||
current_admin_id,
|
||
args.audit_status,
|
||
upload_files=not args.dry_run,
|
||
tag_resolver=tag_resolver,
|
||
)
|
||
if args.dry_run:
|
||
tag_names = payload.pop("_tag_names", None)
|
||
if tag_names:
|
||
payload["tag_ids"] = tag_names
|
||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||
results.append(UploadResult(index=index, title=title, ok=True))
|
||
continue
|
||
spot_id = create_spot(client, payload)
|
||
print(f"[OK] #{index} {title} -> spot_id={spot_id}")
|
||
results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id))
|
||
except Exception as exc:
|
||
message = format_error(exc)
|
||
print(f"[FAIL] #{index} {title}: {message}")
|
||
results.append(UploadResult(index=index, title=title, ok=False, error=message))
|
||
if args.stop_on_error:
|
||
break
|
||
|
||
ok_count = sum(1 for item in results if item.ok)
|
||
fail_count = len(results) - ok_count
|
||
print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}")
|
||
return 1 if fail_count else 0
|
||
|
||
|
||
def format_error(exc: Exception) -> str:
|
||
if isinstance(exc, httpx.HTTPStatusError):
|
||
body = exc.response.text
|
||
return f"HTTP {exc.response.status_code}: {body}"
|
||
return str(exc)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(run())
|