Add bulk spot upload script
This commit is contained in:
@@ -7,7 +7,7 @@ from app.core.deps import get_current_active_user, get_db
|
||||
from app.models.spot import Spot
|
||||
from app.models.tag import SpotTag, Tag
|
||||
from app.models.user import User
|
||||
from app.schemas.tag import TagOut
|
||||
from app.schemas.tag import TagCreate, TagOut
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -16,6 +16,14 @@ class TagAttach(BaseModel):
|
||||
tag_id: int
|
||||
|
||||
|
||||
def _assert_admin_role(user: User) -> None:
|
||||
if user.role not in ("admin", "moderator"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permission required",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tags", response_model=list[TagOut])
|
||||
async def list_tags(
|
||||
sort: str = Query(default="hot", regex="^(hot|name)$"),
|
||||
@@ -31,6 +39,28 @@ async def list_tags(
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/tags", response_model=TagOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_tag(
|
||||
payload: TagCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
_assert_admin_role(current_user)
|
||||
name = payload.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Tag name is required")
|
||||
|
||||
existing = await db.execute(select(Tag).where(Tag.name == name))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Tag already exists")
|
||||
|
||||
tag = Tag(name=name, category=payload.category, is_active=True)
|
||||
db.add(tag)
|
||||
await db.commit()
|
||||
await db.refresh(tag)
|
||||
return tag
|
||||
|
||||
|
||||
@router.post("/spots/{spot_id}/tags", status_code=status.HTTP_201_CREATED)
|
||||
async def add_tag_to_spot(
|
||||
spot_id: int,
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Bulk upload spots through the admin API.
|
||||
|
||||
Input JSON shape:
|
||||
[
|
||||
{
|
||||
"title": "示例地点",
|
||||
"city": "上海",
|
||||
"longitude": 121.4737,
|
||||
"latitude": 31.2304,
|
||||
"description": "地点描述",
|
||||
"transport": "地铁可达",
|
||||
"best_time": "傍晚",
|
||||
"difficulty": "低",
|
||||
"is_free": true,
|
||||
"audit_status": "approved",
|
||||
"tag_ids": ["街拍", "城市"],
|
||||
"images": ["./images/a.jpg", "https://example.com/b.jpg"]
|
||||
}
|
||||
]
|
||||
|
||||
CSV uses the same field names. List fields can be separated by semicolon, comma,
|
||||
or pipe, for example: images=./a.jpg;./b.jpg and tag_ids=街拍,城市.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
DEFAULT_BASE_URL = os.getenv("CIYUAN_API_BASE_URL", "http://localhost:8000/api/v1")
|
||||
DEFAULT_ACCOUNT = os.getenv("CIYUAN_ADMIN_ACCOUNT", "13900000001")
|
||||
DEFAULT_PASSWORD = os.getenv("CIYUAN_ADMIN_PASSWORD", "admin123456")
|
||||
|
||||
SPOT_FIELDS = {
|
||||
"title",
|
||||
"city",
|
||||
"longitude",
|
||||
"latitude",
|
||||
"description",
|
||||
"transport",
|
||||
"best_time",
|
||||
"difficulty",
|
||||
"is_free",
|
||||
"price_min",
|
||||
"price_max",
|
||||
"audit_status",
|
||||
"reject_reason",
|
||||
"creator_id",
|
||||
"image_urls",
|
||||
"tag_ids",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadResult:
|
||||
index: int
|
||||
title: str
|
||||
ok: bool
|
||||
spot_id: int | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class TagResolver:
|
||||
def __init__(self, client: httpx.Client, *, create_missing: bool = True) -> None:
|
||||
self.client = client
|
||||
self.create_missing = create_missing
|
||||
self._cache: dict[str, int] | None = None
|
||||
|
||||
def resolve_many(self, names: list[str]) -> list[int]:
|
||||
tag_ids: list[int] = []
|
||||
for name in names:
|
||||
tag_ids.append(self.resolve_one(name))
|
||||
return list(dict.fromkeys(tag_ids))
|
||||
|
||||
def resolve_one(self, name: str) -> int:
|
||||
normalized = normalize_tag_name(name)
|
||||
cache = self._load_cache()
|
||||
if normalized in cache:
|
||||
return cache[normalized]
|
||||
if not self.create_missing:
|
||||
raise ValueError(f"tag not found: {name}")
|
||||
tag_id = self._create_tag(name.strip())
|
||||
cache[normalized] = tag_id
|
||||
return tag_id
|
||||
|
||||
def _load_cache(self) -> dict[str, int]:
|
||||
if self._cache is not None:
|
||||
return self._cache
|
||||
response = self.client.get("/tags", params={"sort": "name"})
|
||||
response.raise_for_status()
|
||||
self._cache = {
|
||||
normalize_tag_name(str(item.get("name") or item.get("title") or "")): int(item["id"])
|
||||
for item in response.json()
|
||||
if item.get("id") and (item.get("name") or item.get("title"))
|
||||
}
|
||||
return self._cache
|
||||
|
||||
def _create_tag(self, name: str) -> int:
|
||||
response = self.client.post("/tags", json={"name": name})
|
||||
if response.status_code == 409:
|
||||
self._cache = None
|
||||
cache = self._load_cache()
|
||||
normalized = normalize_tag_name(name)
|
||||
if normalized in cache:
|
||||
return cache[normalized]
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"[TAG] created {name} -> tag_id={data['id']}")
|
||||
return int(data["id"])
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Bulk upload spots via admin API.")
|
||||
parser.add_argument("input", type=Path, help="JSON or CSV data file.")
|
||||
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="API base URL, default: %(default)s")
|
||||
parser.add_argument("--account", default=DEFAULT_ACCOUNT, help="Admin phone/email.")
|
||||
parser.add_argument("--password", default=DEFAULT_PASSWORD, help="Admin password.")
|
||||
parser.add_argument("--creator-id", type=int, help="Default creator_id when an item omits it.")
|
||||
parser.add_argument("--audit-status", default="approved", choices=["pending", "approved", "rejected", "deleted"])
|
||||
parser.add_argument("--timeout", type=float, default=30.0)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Validate and print payloads without sending requests.")
|
||||
parser.add_argument("--stop-on-error", action="store_true", help="Abort on the first failed item.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_items(path: Path) -> list[dict[str, Any]]:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"input file not found: {path}")
|
||||
suffix = path.suffix.lower()
|
||||
if suffix == ".json":
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, dict):
|
||||
data = data.get("items") or data.get("spots")
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("JSON input must be a list or an object with items/spots list")
|
||||
return [normalize_item(item, path.parent) for item in data]
|
||||
if suffix == ".csv":
|
||||
with path.open("r", encoding="utf-8-sig", newline="") as f:
|
||||
return [normalize_item(row, path.parent) for row in csv.DictReader(f)]
|
||||
raise ValueError("unsupported input format, use .json or .csv")
|
||||
|
||||
|
||||
def normalize_item(raw: dict[str, Any], input_dir: Path) -> dict[str, Any]:
|
||||
item = {k: v for k, v in raw.items() if v not in ("", None)}
|
||||
|
||||
for key in ("longitude", "latitude", "price_min", "price_max"):
|
||||
if key in item:
|
||||
item[key] = float(item[key])
|
||||
|
||||
for key in ("creator_id",):
|
||||
if key in item:
|
||||
item[key] = int(item[key])
|
||||
|
||||
if "is_free" in item:
|
||||
item["is_free"] = parse_bool(item["is_free"])
|
||||
|
||||
tag_values = item.get("tag_ids", item.get("tags", []))
|
||||
item["tag_names"] = parse_string_list(tag_values)
|
||||
item.pop("tags", None)
|
||||
item.pop("tag_ids", None)
|
||||
|
||||
image_urls = parse_string_list(item.get("image_urls", []))
|
||||
images = parse_string_list(item.get("images", []))
|
||||
item["image_sources"] = [resolve_local_path(src, input_dir) for src in [*image_urls, *images]]
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def parse_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "y", "on", "是", "免费"}
|
||||
|
||||
|
||||
def parse_string_list(value: Any) -> list[str]:
|
||||
if value is None or value == "":
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(v).strip() for v in value if str(v).strip()]
|
||||
text = str(value).strip()
|
||||
for sep in (";", "|", ","):
|
||||
if sep in text:
|
||||
return [part.strip() for part in text.split(sep) if part.strip()]
|
||||
return [text] if text else []
|
||||
|
||||
|
||||
def normalize_tag_name(value: str) -> str:
|
||||
return value.strip().lower()
|
||||
|
||||
|
||||
def resolve_local_path(source: str, input_dir: Path) -> str:
|
||||
if is_url(source):
|
||||
return source
|
||||
path = Path(source)
|
||||
if not path.is_absolute():
|
||||
path = input_dir / path
|
||||
return str(path)
|
||||
|
||||
|
||||
def is_url(value: str) -> bool:
|
||||
parsed = urlparse(value)
|
||||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||||
|
||||
|
||||
def login(client: httpx.Client, account: str, password: str) -> tuple[str, int]:
|
||||
response = client.post("/admin/auth/login", json={"account": account, "password": password})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["access_token"], int(data["user"]["id"])
|
||||
|
||||
|
||||
def upload_image(client: httpx.Client, path: Path) -> str:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"image file not found: {path}")
|
||||
content_type = mimetypes.guess_type(path.name)[0] or "image/jpeg"
|
||||
with path.open("rb") as f:
|
||||
response = client.post(
|
||||
"/upload/image",
|
||||
files={"file": (path.name, f, content_type)},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return str(response.json()["url"])
|
||||
|
||||
|
||||
def build_payload(
|
||||
client: httpx.Client,
|
||||
item: dict[str, Any],
|
||||
creator_id: int,
|
||||
audit_status: str,
|
||||
*,
|
||||
upload_files: bool,
|
||||
tag_resolver: TagResolver | None,
|
||||
) -> dict[str, Any]:
|
||||
payload = {key: item[key] for key in SPOT_FIELDS if key in item}
|
||||
payload.setdefault("creator_id", creator_id)
|
||||
payload.setdefault("audit_status", audit_status)
|
||||
payload.setdefault("is_free", True)
|
||||
if tag_resolver is None:
|
||||
payload["tag_ids"] = []
|
||||
if item.get("tag_names"):
|
||||
payload["_tag_names"] = item["tag_names"]
|
||||
else:
|
||||
payload["tag_ids"] = tag_resolver.resolve_many(item.get("tag_names", []))
|
||||
|
||||
image_urls: list[str] = []
|
||||
for source in item.get("image_sources", []):
|
||||
if is_url(source):
|
||||
image_urls.append(source)
|
||||
elif not upload_files:
|
||||
image_urls.append(source)
|
||||
else:
|
||||
image_urls.append(upload_image(client, Path(source)))
|
||||
payload["image_urls"] = image_urls
|
||||
|
||||
missing = [key for key in ("title", "city", "longitude", "latitude", "creator_id") if key not in payload]
|
||||
if missing:
|
||||
raise ValueError(f"missing required fields: {', '.join(missing)}")
|
||||
return payload
|
||||
|
||||
|
||||
def create_spot(client: httpx.Client, payload: dict[str, Any]) -> int:
|
||||
response = client.post("/admin/spots", json=payload)
|
||||
response.raise_for_status()
|
||||
return int(response.json()["id"])
|
||||
|
||||
|
||||
def run() -> int:
|
||||
args = parse_args()
|
||||
items = load_items(args.input)
|
||||
results: list[UploadResult] = []
|
||||
|
||||
with httpx.Client(base_url=args.base_url.rstrip("/"), timeout=args.timeout) as client:
|
||||
token: str | None = None
|
||||
current_admin_id = args.creator_id
|
||||
if not args.dry_run:
|
||||
token, admin_id = login(client, args.account, args.password)
|
||||
current_admin_id = current_admin_id or admin_id
|
||||
client.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif current_admin_id is None:
|
||||
current_admin_id = 0
|
||||
|
||||
tag_resolver = None if args.dry_run else TagResolver(client)
|
||||
for index, item in enumerate(items, start=1):
|
||||
title = str(item.get("title", f"row-{index}"))
|
||||
try:
|
||||
payload = build_payload(
|
||||
client,
|
||||
item,
|
||||
current_admin_id,
|
||||
args.audit_status,
|
||||
upload_files=not args.dry_run,
|
||||
tag_resolver=tag_resolver,
|
||||
)
|
||||
if args.dry_run:
|
||||
tag_names = payload.pop("_tag_names", None)
|
||||
if tag_names:
|
||||
payload["tag_ids"] = tag_names
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
results.append(UploadResult(index=index, title=title, ok=True))
|
||||
continue
|
||||
spot_id = create_spot(client, payload)
|
||||
print(f"[OK] #{index} {title} -> spot_id={spot_id}")
|
||||
results.append(UploadResult(index=index, title=title, ok=True, spot_id=spot_id))
|
||||
except Exception as exc:
|
||||
message = format_error(exc)
|
||||
print(f"[FAIL] #{index} {title}: {message}")
|
||||
results.append(UploadResult(index=index, title=title, ok=False, error=message))
|
||||
if args.stop_on_error:
|
||||
break
|
||||
|
||||
ok_count = sum(1 for item in results if item.ok)
|
||||
fail_count = len(results) - ok_count
|
||||
print(f"Done. total={len(results)} ok={ok_count} failed={fail_count}")
|
||||
return 1 if fail_count else 0
|
||||
|
||||
|
||||
def format_error(exc: Exception) -> str:
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
body = exc.response.text
|
||||
return f"HTTP {exc.response.status_code}: {body}"
|
||||
return str(exc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(run())
|
||||
@@ -0,0 +1,36 @@
|
||||
[
|
||||
{
|
||||
"title": "示例取景地",
|
||||
"city": "上海",
|
||||
"longitude": 121.4737,
|
||||
"latitude": 31.2304,
|
||||
"description": "适合街拍与城市感构图。",
|
||||
"transport": "地铁到站后步行约 10 分钟。",
|
||||
"best_time": "傍晚蓝调时刻",
|
||||
"difficulty": "低",
|
||||
"is_free": true,
|
||||
"audit_status": "approved",
|
||||
"tag_ids": ["街拍", "城市夜景"],
|
||||
"images": [
|
||||
"./images/example.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "收费影棚示例",
|
||||
"city": "上海",
|
||||
"longitude": 121.4801,
|
||||
"latitude": 31.2352,
|
||||
"description": "适合室内棚拍和商业主题拍摄。",
|
||||
"transport": "地铁到站后打车约 8 分钟。",
|
||||
"best_time": "预约时段",
|
||||
"difficulty": "中",
|
||||
"is_free": false,
|
||||
"price_min": 199,
|
||||
"price_max": 499,
|
||||
"audit_status": "approved",
|
||||
"tag_ids": ["影棚", "室内"],
|
||||
"images": [
|
||||
"./images/studio.jpg"
|
||||
]
|
||||
}
|
||||
]
|
||||
Reference in New Issue
Block a user