diff --git a/server/app/api/v1/endpoints/tags.py b/server/app/api/v1/endpoints/tags.py index 022d285..438d896 100644 --- a/server/app/api/v1/endpoints/tags.py +++ b/server/app/api/v1/endpoints/tags.py @@ -7,7 +7,7 @@ from app.core.deps import get_current_active_user, get_db from app.models.spot import Spot from app.models.tag import SpotTag, Tag from app.models.user import User -from app.schemas.tag import TagOut +from app.schemas.tag import TagCreate, TagOut router = APIRouter() @@ -16,6 +16,14 @@ class TagAttach(BaseModel): tag_id: int +def _assert_admin_role(user: User) -> None: + if user.role not in ("admin", "moderator"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin permission required", + ) + + @router.get("/tags", response_model=list[TagOut]) async def list_tags( sort: str = Query(default="hot", regex="^(hot|name)$"), @@ -31,6 +39,28 @@ async def list_tags( return result.scalars().all() +@router.post("/tags", response_model=TagOut, status_code=status.HTTP_201_CREATED) +async def create_tag( + payload: TagCreate, + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_db), +): + _assert_admin_role(current_user) + name = payload.name.strip() + if not name: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Tag name is required") + + existing = await db.execute(select(Tag).where(Tag.name == name)) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Tag already exists") + + tag = Tag(name=name, category=payload.category, is_active=True) + db.add(tag) + await db.commit() + await db.refresh(tag) + return tag + + @router.post("/spots/{spot_id}/tags", status_code=status.HTTP_201_CREATED) async def add_tag_to_spot( spot_id: int, diff --git a/server/scripts/bulk_upload_spots.py b/server/scripts/bulk_upload_spots.py new file mode 100644 index 0000000..27f0141 --- /dev/null +++ b/server/scripts/bulk_upload_spots.py @@ -0,0 +1,336 @@ +""" +Bulk upload spots through the admin API. + +Input JSON shape: +[ + { + "title": "示例地点", + "city": "上海", + "longitude": 121.4737, + "latitude": 31.2304, + "description": "地点描述", + "transport": "地铁可达", + "best_time": "傍晚", + "difficulty": "低", + "is_free": true, + "audit_status": "approved", + "tag_ids": ["街拍", "城市"], + "images": ["./images/a.jpg", "https://example.com/b.jpg"] + } +] + +CSV uses the same field names. List fields can be separated by semicolon, comma, +or pipe, for example: images=./a.jpg;./b.jpg and tag_ids=街拍,城市. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import mimetypes +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import httpx + + +DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1") +DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001") +DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456") + +SPOT_FIELDS = { + "title", + "city", + "longitude", + "latitude", + "description", + "transport", + "best_time", + "difficulty", + "is_free", + "price_min", + "price_max", + "audit_status", + "reject_reason", + "creator_id", + "image_urls", + "tag_ids", +} + + +@dataclass +class UploadResult: + index: int + title: str + ok: bool + spot_id: int | None = None + error: str | None = None + + +class TagResolver: + def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None: + self.client = client + self.create_missing = create_missing + self._cache: dict[str, int] | None = None + + def resolve_many(self, names: list[str]) -> list[int]: + tag_ids: list[int] = [] + for name in names: + tag_ids.append(self.resolve_one(name)) + return list(dict.fromkeys(tag_ids)) + + def resolve_one(self, name: str) -> int: + normalized = normalize_tag_name(name) + cache = self._load_cache() + if normalized in cache: + return cache[normalized] + if not self.create_missing: + raise ValueError(f"tag not found: {name}") + tag_id = self._create_tag(name.strip()) + cache[normalized] = tag_id + return tag_id + + def _load_cache(self) -> dict[str, int]: + if self._cache is not None: + return self._cache + response = self.client.get("/tags", params={"sort": "name"}) + response.raise_for_status() + self._cache = { + normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"]) + for item in response.json() + if item.get("id") and (item.get("name") or item.get("title")) + } + return self._cache + + def _create_tag(self, name: str) -> int: + response = self.client.post("/tags", json={"name": name}) + if response.status_code == 409: + self._cache = None + cache = self._load_cache() + normalized = normalize_tag_name(name) + if normalized in cache: + return cache[normalized] + response.raise_for_status() + data = response.json() + print(f"[TAG] created {name} -> tag_id={data['id']}") + return int(data["id"]) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Bulk upload spots via admin API.") + parser.add_argument("input", type=Path, help="JSON or CSV data file.") + parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API base URL, default: %(default)s") + parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="Admin phone/email.") + parser.add_argument("--password", default=DEFAULT_PASSWORD, help="Admin password.") + parser.add_argument("--creator-id", type=int, help="Default creator_id when an item omits it.") + parser.add_argument("--audit-status", default="approved", choices=["pending", "approved", "rejected", "deleted"]) + parser.add_argument("--timeout", type=float, default=30.0) + parser.add_argument("--dry-run", action="store_true", help="Validate and print payloads without sending requests.") + parser.add_argument("--stop-on-error", action="store_true", help="Abort on the first failed item.") + return parser.parse_args() + + +def load_items(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + raise FileNotFoundError(f"input file not found: {path}") + suffix = path.suffix.lower() + if suffix == ".json": + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, dict): + data = data.get("items") or data.get("spots") + if not isinstance(data, list): + raise ValueError("JSON input must be a list or an object with items/spots list") + return [normalize_item(item, path.parent) for item in data] + if suffix == ".csv": + with path.open("r", encoding="utf-8-sig", newline="") as f: + return [normalize_item(row, path.parent) for row in csv.DictReader(f)] + raise ValueError("unsupported input format, use .json or .csv") + + +def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]: + item = {k: v for k, v in raw.items() if v not in ("", None)} + + for key in ("longitude", "latitude", "price_min", "price_max"): + if key in item: + item[key] = float(item[key]) + + for key in ("creator_id",): + if key in item: + item[key] = int(item[key]) + + if "is_free" in item: + item["is_free"] = parse_bool(item["is_free"]) + + tag_values = item.get("tag_ids", item.get("tags", [])) + item["tag_names"] = parse_string_list(tag_values) + item.pop("tags", None) + item.pop("tag_ids", None) + + image_urls = parse_string_list(item.get("image_urls", [])) + images = parse_string_list(item.get("images", [])) + item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]] + + return item + + +def parse_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "是", "免费"} + + +def parse_string_list(value: Any) -> list[str]: + if value is None or value == "": + return [] + if isinstance(value, list): + return [str(v).strip() for v in value if str(v).strip()] + text = str(value).strip() + for sep in (";", "|", ","): + if sep in text: + return [part.strip() for part in text.split(sep) if part.strip()] + return [text] if text else [] + + +def normalize_tag_name(value: str) -> str: + return value.strip().lower() + + +def resolve_local_path(source: str, input_dir: Path) -> str: + if is_url(source): + return source + path = Path(source) + if not path.is_absolute(): + path = input_dir / path + return str(path) + + +def is_url(value: str) -> bool: + parsed = urlparse(value) + return parsed.scheme in {"http", "https"} and bool(parsed.netloc) + + +def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]: + response = client.post("/admin/auth/login", json={"account": account, "password": password}) + response.raise_for_status() + data = response.json() + return data["access_token"], int(data["user"]["id"]) + + +def upload_image(client: httpx.Client, path: Path) -> str: + if not path.exists(): + raise FileNotFoundError(f"image file not found: {path}") + content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg" + with path.open("rb") as f: + response = client.post( + "/upload/image", + files={"file": (path.name, f, content_type)}, + ) + response.raise_for_status() + return str(response.json()["url"]) + + +def build_payload( + client: httpx.Client, + item: dict[str, Any], + creator_id: int, + audit_status: str, + *, + upload_files: bool, + tag_resolver: TagResolver | None, +) -> dict[str, Any]: + payload = {key: item[key] for key in SPOT_FIELDS if key in item} + payload.setdefault("creator_id", creator_id) + payload.setdefault("audit_status", audit_status) + payload.setdefault("is_free", True) + if tag_resolver is None: + payload["tag_ids"] = [] + if item.get("tag_names"): + payload["_tag_names"] = item["tag_names"] + else: + payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", [])) + + image_urls: list[str] = [] + for source in item.get("image_sources", []): + if is_url(source): + image_urls.append(source) + elif not upload_files: + image_urls.append(source) + else: + image_urls.append(upload_image(client, Path(source))) + payload["image_urls"] = image_urls + + missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload] + if missing: + raise ValueError(f"missing required fields: {', '.join(missing)}") + return payload + + +def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int: + response = client.post("/admin/spots", json=payload) + response.raise_for_status() + return int(response.json()["id"]) + + +def run() -> int: + args = parse_args() + items = load_items(args.input) + results: list[UploadResult] = [] + + with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client: + token: str | None = None + current_admin_id = args.creator_id + if not args.dry_run: + token, admin_id = login(client, args.account, args.password) + current_admin_id = current_admin_id or admin_id + client.headers.update({"Authorization": f"Bearer {token}"}) + elif current_admin_id is None: + current_admin_id = 0 + + tag_resolver = None if args.dry_run else TagResolver(client) + for index, item in enumerate(items, start=1): + title = str(item.get("title", f"row-{index}")) + try: + payload = build_payload( + client, + item, + current_admin_id, + args.audit_status, + upload_files=not args.dry_run, + tag_resolver=tag_resolver, + ) + if args.dry_run: + tag_names = payload.pop("_tag_names", None) + if tag_names: + payload["tag_ids"] = tag_names + print(json.dumps(payload, ensure_ascii=False, indent=2)) + results.append(UploadResult(index=index, title=title, ok=True)) + continue + spot_id = create_spot(client, payload) + print(f"[OK] #{index} {title} -> spot_id={spot_id}") + results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id)) + except Exception as exc: + message = format_error(exc) + print(f"[FAIL] #{index} {title}: {message}") + results.append(UploadResult(index=index, title=title, ok=False, error=message)) + if args.stop_on_error: + break + + ok_count = sum(1 for item in results if item.ok) + fail_count = len(results) - ok_count + print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}") + return 1 if fail_count else 0 + + +def format_error(exc: Exception) -> str: + if isinstance(exc, httpx.HTTPStatusError): + body = exc.response.text + return f"HTTP {exc.response.status_code}: {body}" + return str(exc) + + +if __name__ == "__main__": + raise SystemExit(run()) diff --git a/server/scripts/sample_spots.json b/server/scripts/sample_spots.json new file mode 100644 index 0000000..187e20b --- /dev/null +++ b/server/scripts/sample_spots.json @@ -0,0 +1,36 @@ +[ + { + "title": "示例取景地", + "city": "上海", + "longitude": 121.4737, + "latitude": 31.2304, + "description": "适合街拍与城市感构图。", + "transport": "地铁到站后步行约 10 分钟。", + "best_time": "傍晚蓝调时刻", + "difficulty": "低", + "is_free": true, + "audit_status": "approved", + "tag_ids": ["街拍", "城市夜景"], + "images": [ + "./images/example.jpg" + ] + }, + { + "title": "收费影棚示例", + "city": "上海", + "longitude": 121.4801, + "latitude": 31.2352, + "description": "适合室内棚拍和商业主题拍摄。", + "transport": "地铁到站后打车约 8 分钟。", + "best_time": "预约时段", + "difficulty": "中", + "is_free": false, + "price_min": 199, + "price_max": 499, + "audit_status": "approved", + "tag_ids": ["影棚", "室内"], + "images": [ + "./images/studio.jpg" + ] + } +]