Add bulk spot upload script

This commit is contained in:
2026-05-23 17:21:00 +08:00
parent f1ce992d69
commit be09bf6f0a
3 changed files with 403 additions and 1 deletions
+31 -1
View File
@@ -7,7 +7,7 @@ from app.core.deps import get_current_active_user, get_db
from app.models.spot import Spot
from app.models.tag import SpotTag, Tag
from app.models.user import User
from app.schemas.tag import TagOut
from app.schemas.tag import TagCreate, TagOut
router = APIRouter()
@@ -16,6 +16,14 @@ class TagAttach(BaseModel):
tag_id: int
def _assert_admin_role(user: User) -> None:
if user.role not in ("admin", "moderator"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permission required",
)
@router.get("/tags", response_model=list[TagOut])
async def list_tags(
sort: str = Query(default="hot", regex="^(hot|name)$"),
@@ -31,6 +39,28 @@ async def list_tags(
return result.scalars().all()
@router.post("/tags", response_model=TagOut, status_code=status.HTTP_201_CREATED)
async def create_tag(
payload: TagCreate,
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
_assert_admin_role(current_user)
name = payload.name.strip()
if not name:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Tag name is required")
existing = await db.execute(select(Tag).where(Tag.name == name))
if existing.scalar_one_or_none():
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Tag already exists")
tag = Tag(name=name, category=payload.category, is_active=True)
db.add(tag)
await db.commit()
await db.refresh(tag)
return tag
@router.post("/spots/{spot_id}/tags", status_code=status.HTTP_201_CREATED)
async def add_tag_to_spot(
spot_id: int,
+336
View File
@@ -0,0 +1,336 @@
"""
Bulk upload spots through the admin API.
Input JSON shape:
[
{
"title": "示例地点",
"city": "上海",
"longitude": 121.4737,
"latitude": 31.2304,
"description": "地点描述",
"transport": "地铁可达",
"best_time": "傍晚",
"difficulty": "",
"is_free": true,
"audit_status": "approved",
"tag_ids": ["街拍", "城市"],
"images": ["./images/a.jpg", "https://example.com/b.jpg"]
}
]
CSV uses the same field names. List fields can be separated by semicolon, comma,
or pipe, for example: images=./a.jpg;./b.jpg and tag_ids=街拍,城市.
"""
from __future__ import annotations
import argparse
import csv
import json
import mimetypes
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import httpx
DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1")
DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001")
DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456")
SPOT_FIELDS = {
"title",
"city",
"longitude",
"latitude",
"description",
"transport",
"best_time",
"difficulty",
"is_free",
"price_min",
"price_max",
"audit_status",
"reject_reason",
"creator_id",
"image_urls",
"tag_ids",
}
@dataclass
class UploadResult:
index: int
title: str
ok: bool
spot_id: int | None = None
error: str | None = None
class TagResolver:
def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None:
self.client = client
self.create_missing = create_missing
self._cache: dict[str, int] | None = None
def resolve_many(self, names: list[str]) -> list[int]:
tag_ids: list[int] = []
for name in names:
tag_ids.append(self.resolve_one(name))
return list(dict.fromkeys(tag_ids))
def resolve_one(self, name: str) -> int:
normalized = normalize_tag_name(name)
cache = self._load_cache()
if normalized in cache:
return cache[normalized]
if not self.create_missing:
raise ValueError(f"tag not found: {name}")
tag_id = self._create_tag(name.strip())
cache[normalized] = tag_id
return tag_id
def _load_cache(self) -> dict[str, int]:
if self._cache is not None:
return self._cache
response = self.client.get("/tags", params={"sort": "name"})
response.raise_for_status()
self._cache = {
normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"])
for item in response.json()
if item.get("id") and (item.get("name") or item.get("title"))
}
return self._cache
def _create_tag(self, name: str) -> int:
response = self.client.post("/tags", json={"name": name})
if response.status_code == 409:
self._cache = None
cache = self._load_cache()
normalized = normalize_tag_name(name)
if normalized in cache:
return cache[normalized]
response.raise_for_status()
data = response.json()
print(f"[TAG] created {name} -> tag_id={data['id']}")
return int(data["id"])
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Bulk upload spots via admin API.")
parser.add_argument("input", type=Path, help="JSON or CSV data file.")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API base URL, default: %(default)s")
parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="Admin phone/email.")
parser.add_argument("--password", default=DEFAULT_PASSWORD, help="Admin password.")
parser.add_argument("--creator-id", type=int, help="Default creator_id when an item omits it.")
parser.add_argument("--audit-status", default="approved", choices=["pending", "approved", "rejected", "deleted"])
parser.add_argument("--timeout", type=float, default=30.0)
parser.add_argument("--dry-run", action="store_true", help="Validate and print payloads without sending requests.")
parser.add_argument("--stop-on-error", action="store_true", help="Abort on the first failed item.")
return parser.parse_args()
def load_items(path: Path) -> list[dict[str, Any]]:
if not path.exists():
raise FileNotFoundError(f"input file not found: {path}")
suffix = path.suffix.lower()
if suffix == ".json":
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, dict):
data = data.get("items") or data.get("spots")
if not isinstance(data, list):
raise ValueError("JSON input must be a list or an object with items/spots list")
return [normalize_item(item, path.parent) for item in data]
if suffix == ".csv":
with path.open("r", encoding="utf-8-sig", newline="") as f:
return [normalize_item(row, path.parent) for row in csv.DictReader(f)]
raise ValueError("unsupported input format, use .json or .csv")
def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]:
item = {k: v for k, v in raw.items() if v not in ("", None)}
for key in ("longitude", "latitude", "price_min", "price_max"):
if key in item:
item[key] = float(item[key])
for key in ("creator_id",):
if key in item:
item[key] = int(item[key])
if "is_free" in item:
item["is_free"] = parse_bool(item["is_free"])
tag_values = item.get("tag_ids", item.get("tags", []))
item["tag_names"] = parse_string_list(tag_values)
item.pop("tags", None)
item.pop("tag_ids", None)
image_urls = parse_string_list(item.get("image_urls", []))
images = parse_string_list(item.get("images", []))
item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]]
return item
def parse_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "", "免费"}
def parse_string_list(value: Any) -> list[str]:
if value is None or value == "":
return []
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
text = str(value).strip()
for sep in (";", "|", ","):
if sep in text:
return [part.strip() for part in text.split(sep) if part.strip()]
return [text] if text else []
def normalize_tag_name(value: str) -> str:
return value.strip().lower()
def resolve_local_path(source: str, input_dir: Path) -> str:
if is_url(source):
return source
path = Path(source)
if not path.is_absolute():
path = input_dir / path
return str(path)
def is_url(value: str) -> bool:
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]:
response = client.post("/admin/auth/login", json={"account": account, "password": password})
response.raise_for_status()
data = response.json()
return data["access_token"], int(data["user"]["id"])
def upload_image(client: httpx.Client, path: Path) -> str:
if not path.exists():
raise FileNotFoundError(f"image file not found: {path}")
content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg"
with path.open("rb") as f:
response = client.post(
"/upload/image",
files={"file": (path.name, f, content_type)},
)
response.raise_for_status()
return str(response.json()["url"])
def build_payload(
client: httpx.Client,
item: dict[str, Any],
creator_id: int,
audit_status: str,
*,
upload_files: bool,
tag_resolver: TagResolver | None,
) -> dict[str, Any]:
payload = {key: item[key] for key in SPOT_FIELDS if key in item}
payload.setdefault("creator_id", creator_id)
payload.setdefault("audit_status", audit_status)
payload.setdefault("is_free", True)
if tag_resolver is None:
payload["tag_ids"] = []
if item.get("tag_names"):
payload["_tag_names"] = item["tag_names"]
else:
payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", []))
image_urls: list[str] = []
for source in item.get("image_sources", []):
if is_url(source):
image_urls.append(source)
elif not upload_files:
image_urls.append(source)
else:
image_urls.append(upload_image(client, Path(source)))
payload["image_urls"] = image_urls
missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload]
if missing:
raise ValueError(f"missing required fields: {', '.join(missing)}")
return payload
def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int:
response = client.post("/admin/spots", json=payload)
response.raise_for_status()
return int(response.json()["id"])
def run() -> int:
args = parse_args()
items = load_items(args.input)
results: list[UploadResult] = []
with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client:
token: str | None = None
current_admin_id = args.creator_id
if not args.dry_run:
token, admin_id = login(client, args.account, args.password)
current_admin_id = current_admin_id or admin_id
client.headers.update({"Authorization": f"Bearer {token}"})
elif current_admin_id is None:
current_admin_id = 0
tag_resolver = None if args.dry_run else TagResolver(client)
for index, item in enumerate(items, start=1):
title = str(item.get("title", f"row-{index}"))
try:
payload = build_payload(
client,
item,
current_admin_id,
args.audit_status,
upload_files=not args.dry_run,
tag_resolver=tag_resolver,
)
if args.dry_run:
tag_names = payload.pop("_tag_names", None)
if tag_names:
payload["tag_ids"] = tag_names
print(json.dumps(payload, ensure_ascii=False, indent=2))
results.append(UploadResult(index=index, title=title, ok=True))
continue
spot_id = create_spot(client, payload)
print(f"[OK] #{index} {title} -> spot_id={spot_id}")
results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id))
except Exception as exc:
message = format_error(exc)
print(f"[FAIL] #{index} {title}: {message}")
results.append(UploadResult(index=index, title=title, ok=False, error=message))
if args.stop_on_error:
break
ok_count = sum(1 for item in results if item.ok)
fail_count = len(results) - ok_count
print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}")
return 1 if fail_count else 0
def format_error(exc: Exception) -> str:
if isinstance(exc, httpx.HTTPStatusError):
body = exc.response.text
return f"HTTP {exc.response.status_code}: {body}"
return str(exc)
if __name__ == "__main__":
raise SystemExit(run())
+36
View File
@@ -0,0 +1,36 @@
[
{
"title": "示例取景地",
"city": "上海",
"longitude": 121.4737,
"latitude": 31.2304,
"description": "适合街拍与城市感构图。",
"transport": "地铁到站后步行约 10 分钟。",
"best_time": "傍晚蓝调时刻",
"difficulty": "低",
"is_free": true,
"audit_status": "approved",
"tag_ids": ["街拍", "城市夜景"],
"images": [
"./images/example.jpg"
]
},
{
"title": "收费影棚示例",
"city": "上海",
"longitude": 121.4801,
"latitude": 31.2352,
"description": "适合室内棚拍和商业主题拍摄。",
"transport": "地铁到站后打车约 8 分钟。",
"best_time": "预约时段",
"difficulty": "中",
"is_free": false,
"price_min": 199,
"price_max": 499,
"audit_status": "approved",
"tag_ids": ["影棚", "室内"],
"images": [
"./images/studio.jpg"
]
}
]