# services/law_db.py
import asyncio
import os
from typing import Any

import httpx

# Support absolute import when run as module, and fallback when run as script
try:
    from common.supabase_client import (
        close_supabase_client,
        init_supabase_client,
        supabase_rpc_call,
    )

    from common.openai_client import get_openai_client
except ImportError:
    import os as _os
    import sys

    _ROOT = _os.path.dirname(_os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))))
    if _ROOT not in sys.path:
        sys.path.append(_ROOT)
    from common.supabase_client import (
        init_supabase_client,
        supabase_rpc_call,
    )

    from common.openai_client import get_openai_client


# ====================== CONFIG ======================
def get_supabase_config():
    """Lazy load config để tránh lỗi khi chưa có env vars"""
    return {
        "url": f"{os.environ['SUPABASE_URL']}/rest/v1/rpc/hoi_phap_luat_all_in_one",
        "headers": {
            "apikey": os.environ["SUPABASE_ANON_KEY"],
            "Authorization": f"Bearer {os.environ['SUPABASE_ANON_KEY']}",
            "Content-Type": "application/json; charset=utf-8",
            "Accept": "application/json",
        },
    }


# ====================== HTTP HELPERS ======================
async def _post_with_retry(client: httpx.AsyncClient, url: str, **kw) -> httpx.Response:
    """POST với retry/backoff."""
    last_exc: Exception | None = None
    for i in range(3):
        try:
            r = await client.post(url, **kw)
            r.raise_for_status()
            return r
        except Exception as e:
            last_exc = e
            if i == 2:
                raise
            await asyncio.sleep(0.4 * (2**i))
    raise last_exc  # về lý thuyết không tới đây


# ====================== LABEL HELPER ======================
def _label_for_call(call: dict[str, Any], index: int) -> str:
    """Tạo nhãn hiển thị cho mỗi lệnh gọi dựa trên tham số (để phân biệt kết quả)."""
    params = call.get("params") or {}
    if params.get("p_so_hieu"):
        return str(params["p_so_hieu"])
    if params.get("p_vb_pattern"):
        return str(params["p_vb_pattern"])
    return f"Truy vấn {index}"


# ====================== RAW FETCHERS ======================
async def _get_embedding(text: str) -> list[float]:
    """Gọi OpenAI API để lấy embedding vector của đoạn văn bản."""
    if not text:
        return []

    # already imported get_openai_client above with fallback
    OAI = get_openai_client()

    normalized_text = (text or "").strip().lower()
    input_text = normalized_text[:8000]  # Giới hạn độ dài

    try:
        resp = await OAI.embeddings.create(
            model="text-embedding-3-small",
            input=input_text,
        )
        return resp.data[0].embedding
    except Exception as e:
        print(f"❌ Lỗi OpenAI embedding: {e}")
        return []


async def law_db_fetch_one(params: dict[str, Any]) -> list[dict[str, Any]]:
    """
    Gọi trực tiếp bằng httpx thay vì RPC
    """

    # Xử lý p_vector_text thành embedding
    processed_params = dict(params)
    if processed_params.get("p_vector_text"):
        vector_text = processed_params.pop("p_vector_text")
        embedding = await _get_embedding(vector_text)
        if embedding and len(embedding) > 0:
            processed_params["p_vector"] = embedding
            # Set mode to semantic nếu chưa có
            if "p_mode" not in processed_params and "_mode" not in processed_params:
                processed_params["p_mode"] = "semantic"

    # Map parameters to correct function signature
    mapped_params = {}
    for key, value in processed_params.items():
        # Skip None values
        if value is None:
            continue

        # Validate p_vector format
        if key == "p_vector":
            if isinstance(value, list) and len(value) > 0:
                # Check for NaN or invalid values
                import math

                valid_vector = [v for v in value if isinstance(v, (int, float)) and not math.isnan(v)]
                if len(valid_vector) == len(value):
                    mapped_params["p_vector"] = value
                else:
                    print("⚠️ Warning: p_vector contains invalid values (NaN/inf), skipping")
            elif isinstance(value, list) and len(value) == 0:
                print("⚠️ Warning: p_vector is empty, skipping")
            else:
                mapped_params["p_vector"] = value
        elif key in {"_mode", "mode"}:
            mapped_params["p_mode"] = value
        elif key == "p_vb_pattern":
            mapped_params["p_vb_pattern"] = value
        elif key == "p_so_hieu":
            mapped_params["p_so_hieu"] = value
        elif key == "p_trang_thai":
            mapped_params["p_trang_thai"] = value
        elif key == "p_co_quan":
            mapped_params["p_co_quan"] = value
        elif key == "p_loai_vb":
            mapped_params["p_loai_vb"] = value
        elif key == "p_nam_from":
            # Validate integer
            if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
                mapped_params["p_nam_from"] = int(value)
            elif value is not None:
                print(f"⚠️ Warning: p_nam_from has invalid type/value: {value}, skipping")
        elif key == "p_nam_to":
            # Validate integer
            if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
                mapped_params["p_nam_to"] = int(value)
            elif value is not None:
                print(f"⚠️ Warning: p_nam_to has invalid type/value: {value}, skipping")
        elif key == "p_only_source":
            mapped_params["p_only_source"] = value
        elif key == "p_chapter":
            mapped_params["p_chapter"] = value
        elif key == "p_article":
            mapped_params["p_article"] = value
        elif key == "p_phu_luc":
            mapped_params["p_phu_luc"] = value
        elif key == "p_limit":
            # Validate integer
            if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
                mapped_params["p_limit"] = int(value)
            elif value is not None:
                print(f"⚠️ Warning: p_limit has invalid type/value: {value}, skipping")
        elif key == "p_ef_search":
            # Validate integer
            if isinstance(value, (int, float)) and not (isinstance(value, float) and value != int(value)):
                mapped_params["p_ef_search"] = int(value)
            elif value is not None:
                print(f"⚠️ Warning: p_ef_search has invalid type/value: {value}, skipping")

    # Gọi qua Supabase shared client (đảm bảo đã init)
    try:
        await init_supabase_client()

        # DEBUG: Print ra JSON sẽ gửi
        print("📤 GỬI JSON PAYLOAD:")
        import json

        debug_payload = json.dumps(mapped_params, ensure_ascii=False, indent=2)
        print(debug_payload[:500])  # Print 500 chars đầu để check

        rows = await supabase_rpc_call("hoi_phap_luat_all_in_one", mapped_params)

        print(f"✅ NHẬN RESULT: {len(rows)} rows")
        if rows:
            print(f"First row keys: {list(rows[0].keys())}")
            _nd = rows[0].get("NoiDung") or rows[0].get("NoiDungDieu") or rows[0].get("NoiDungPhuLuc") or ""
            try:
                print(f"NoiDung length: {len(_nd)}")
            except Exception:
                print("NoiDung length: (unavailable)")
        return rows or []
    except Exception as e:
        print(f"❌ HTTPX call failed: {e}")
        import traceback

        traceback.print_exc()
        return []


async def law_db_fetch_plan(calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
    """
    Nhận danh sách calls (mỗi call có .params) -> chạy song song -> trả:
    [
      {"label": "...", "rows": [ ... ]},
      ...
    ]
    """
    if not calls:
        return []

    async def run_one(call: dict[str, Any], idx: int) -> dict[str, Any]:
        label = _label_for_call(call, idx)
        params = call.get("params", {})
        # Truyền mode xuống để map thành p_mode trong law_db_fetch_one
        mode_value = call.get("mode")
        print(f"DEBUG: call.mode = {mode_value}")
        if mode_value is not None and "_mode" not in params and "mode" not in params:
            # Ưu tiên dùng _mode để tránh đè lên tên trường khác
            params = dict(params)
            params["_mode"] = mode_value
            print(f"DEBUG: added _mode = {mode_value} to params")
        print(f"DEBUG: final params keys = {list(params.keys())}")
        try:
            rows = await law_db_fetch_one(params)
            return {"label": label, "rows": rows}
        except Exception:
            return {"label": label, "rows": []}

    tasks = [run_one(c, i) for i, c in enumerate(calls, start=1)]
    return await asyncio.gather(*tasks)


# ====================== PREVIEW BUILDERS ======================
def build_db_preview(rows: list[dict[str, Any]]) -> str:
    """
    Xây dựng chuỗi văn bản nội dung từ danh sách hàng kết quả (có nội dung văn bản).
    Nhóm theo văn bản pháp luật và các đơn vị (điều, phụ lục) bên trong.
    """
    if not rows:
        return ""
    docs: dict[str, dict[str, Any]] = {}
    for r in rows:
        so_hieu = (r.get("SoHieu") or "").strip() or "(Không rõ số hiệu)"
        title = (r.get("TieuDe") or r.get("TieuDeVanBan") or r.get("TieuDeDieu") or so_hieu).strip()
        docs.setdefault(so_hieu, {"title": title, "groups": {}})
        content = (r.get("NoiDung") or r.get("NoiDungDieu") or r.get("NoiDungPhuLuc") or "").strip()
        if not content:
            continue
        chunk_idx = int(r.get("chunk_idx") or r.get("ChunkIdx") or r.get("ChunkIndex") or 1)
        phu_luc = r.get("PhuLuc") or r.get("Phu_Luc") or r.get("Phu_luc")
        dieu = r.get("Dieu")
        chuong = r.get("Chuong")
        tieu_de_dieu = r.get("TieuDeDieu") or ""
        tieu_de_pl = r.get("TieuDePhuLuc") or r.get("TenPhuLuc") or ""
        if phu_luc is not None:
            group_key = ("PL", str(phu_luc))
            group_title = f"Phụ lục {phu_luc}"
            subtitle = tieu_de_pl
        elif dieu is not None:
            group_key = ("DIEU", str(dieu))
            group_title = f"Điều {dieu}"
            subtitle = tieu_de_dieu
        else:
            group_key = ("KHAC", f"Chương {chuong}" if chuong is not None else "Khác")
            group_title = group_key[1]
            subtitle = ""
        groups = docs[so_hieu]["groups"]
        groups.setdefault(group_key, {"title": group_title, "subtitle": subtitle, "segs": []})
        groups[group_key]["segs"].append((chunk_idx, content))

    parts: list[str] = []
    for so_hieu, doc in docs.items():
        header = f"=== {doc['title']} ({so_hieu}) ==="
        parts.append(header)
        for group_key, group in sorted(
            doc["groups"].items(),
            key=lambda x: (
                {"PL": 0, "DIEU": 1, "KHAC": 2}.get(x[0][0], 3),
                int(x[0][1]) if x[0][1].isdigit() else float("inf"),
            ),
        ):
            title_line = group["title"] + (f" — {group['subtitle']}" if group["subtitle"] else "")
            parts.append(title_line)
            for _, text in sorted(group["segs"], key=lambda seg: seg[0]):
                parts.append(text)
            parts.append("")  # ngắt giữa các nhóm
        parts.append("")  # ngắt giữa các văn bản
    return "\n".join(parts).strip()


def build_meta_preview(rows: list[dict[str, Any]]) -> str:
    """
    Xây dựng chuỗi văn bản liệt kê các văn bản (chỉ meta, không có nội dung chi tiết).
    """
    if not rows:
        return ""
    lines = []
    for i, r in enumerate(rows[:50], start=1):
        title = r.get("TieuDe") or r.get("TieuDeVanBan") or "(Không rõ tiêu đề)"
        so_hieu = r.get("SoHieu") or "—"
        loai = r.get("LoaiVanBan") or ""
        cq = r.get("CoQuanBanHanh") or ""
        nam = r.get("NamBanHanh") or ""
        trang_thai = r.get("TrangThaiVB") or r.get("TrangThai") or ""
        meta = " • ".join(filter(None, [loai, cq, f"Năm {nam}" if nam else "", trang_thai]))
        lines.append(f"{i}. {title} — {so_hieu}{(' (' + meta + ')') if meta else ''}")
    return "\n".join(lines)


def _build_multi_preview(labeled_results: list[dict[str, Any]]) -> str:
    """
    Ghép nội dung từ nhiều kết quả truy vấn thành một chuỗi,
    có tiêu đề cho từng nhóm kết quả tương ứng với từng truy vấn.
    """
    sections: list[str] = []
    for item in labeled_results:
        label = item.get("label", "Truy vấn")
        rows = item.get("rows") or []
        if any(r.get("NoiDung") or r.get("NoiDungDieu") or r.get("NoiDungPhuLuc") for r in rows):
            preview_text = build_db_preview(rows)
        else:
            preview_text = build_meta_preview(rows)
        if not preview_text:
            preview_text = "(Không có dữ liệu)"
        sections.append(f"### {label}\n{preview_text}")
    return "\n\n".join(sections).strip()


# ====================== PUBLIC APIS ======================
async def fetch_data_db_law(calls: list[dict[str, Any]]) -> str:
    """
    Hàm chính: nhận calls -> fetch raw data -> build preview theo mode -> return preview
    """
    # Bước 1: Fetch raw data song song
    labeled = await law_db_fetch_plan(calls or [])

    # Bước 2: Build preview theo mode của từng call
    sections: list[str] = []
    for item in labeled:
        label = item.get("label", "Truy vấn")
        rows = item.get("rows") or []

        # Tìm mode từ calls tương ứng
        call_mode = "content"  # default
        for call in calls:
            if (call.get("params", {}).get("p_so_hieu") and str(call.get("params", {}).get("p_so_hieu")) in label) or (
                call.get("params", {}).get("p_vb_pattern") and str(call.get("params", {}).get("p_vb_pattern")) in label
            ):
                call_mode = call.get("mode", "content")
                break

        # Build preview theo mode
        if call_mode == "content":
            preview_text = build_db_preview(rows)
        elif call_mode == "meta":
            preview_text = build_meta_preview(rows)
        elif call_mode == "semantic":
            preview_text = build_db_preview(rows)
        else:
            preview_text = build_db_preview(rows)

        if not preview_text:
            preview_text = "(Không có dữ liệu)"

        sections.append(f"### {label}\n{preview_text}")

    return "\n\n".join(sections).strip()


__all__ = [
    "build_db_preview",
    "build_meta_preview",
    "fetch_data_db_law",
    "law_db_fetch_one",
    "law_db_fetch_plan",
]

# if __name__ == "__main__":
#     import argparse
#     import json as _json

#     async def main():
#         parser = argparse.ArgumentParser(description="Test fetch_data_db_law nhanh qua CLI")
#         parser.add_argument("--mode", "-m", type=str, default="content", help="content|semantic|meta (có thể phân tách bằng dấu phẩy)")
#         parser.add_argument("--so_hieu", type=str, default=None, help="Giá trị cho p_so_hieu")
#         parser.add_argument("--vb_pattern", type=str, default=None, help="Regex/từ khóa cho p_vb_pattern")
#         parser.add_argument("--vector_text", type=str, default=None, help="Văn bản để embedding semantic (p_vector_text)")
#         parser.add_argument("--loai_vb", type=str, default=None)
#         parser.add_argument("--co_quan", type=str, default=None)
#         parser.add_argument("--nam_from", type=int, default=None)
#         parser.add_argument("--nam_to", type=int, default=None)
#         parser.add_argument("--only_source", type=str, default=None, help="chinh_thong|dia_phuong")
#         parser.add_argument("--limit", type=int, default=50)
#         parser.add_argument("--article", type=int, default=None)
#         parser.add_argument("--chapter", type=int, default=None)
#         parser.add_argument("--phu_luc", type=str, default=None)
#         parser.add_argument("--multi", action="store_true", help="Nếu bật, tạo nhiều calls mẫu để so sánh")
#         args = parser.parse_args()

#         modes = [m.strip() for m in (args.mode or "content").split(",") if m.strip()]

#         def _build_params():
#             return {
#                 "p_so_hieu": args.so_hieu,
#                 "p_vb_pattern": args.vb_pattern,
#                 "p_co_quan": args.co_quan,
#                 "p_loai_vb": args.loai_vb,
#                 "p_nam_from": args.nam_from,
#                 "p_nam_to": args.nam_to,
#                 "p_only_source": args.only_source,
#                 "p_article": args.article,
#                 "p_chapter": args.chapter,
#                 "p_phu_luc": args.phu_luc,
#                 "p_limit": args.limit,
#                 "p_vector_text": args.vector_text,
#             }

#         calls = []
#         if args.multi:
#             # Tạo một số tổ hợp mẫu để tiện so sánh
#             sample_targets = []
#             if args.so_hieu:
#                 sample_targets.append({"p_so_hieu": args.so_hieu})
#             if args.vb_pattern:
#                 sample_targets.append({"p_vb_pattern": args.vb_pattern})
#             if args.vector_text:
#                 sample_targets.append({"p_vector_text": args.vector_text})
#             if not sample_targets:
#                 # nếu không có gì, dùng một pattern mặc định
#                 sample_targets = [
#                     {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường"},
#                     {"p_vector_text": "nội dung nghị định chưa được xác định theo yêu cầu của người hỏi"},
#                 ]
#             base_params = _build_params()
#             for mode in modes:
#                 for target in sample_targets:
#                     params = dict(base_params)
#                     params.update({k: v for k, v in target.items() if v is not None})
#                     calls.append({"mode": mode, "params": {k: v for k, v in params.items() if v is not None}})
#         else:
#             params = {k: v for k, v in _build_params().items() if v is not None}
#             if not params:
#                 # nếu không truyền gì, tạo ví dụ tối thiểu
#                 params = {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường", "p_limit": 30}
#             for mode in modes:
#                 calls.append({"mode": mode, "params": params})

#         print("CLI calls:")
#         print(_json.dumps(calls, ensure_ascii=False, indent=2))

#         # Init Supabase client trước khi gọi
#         await init_supabase_client()
#         try:
#             preview = await fetch_data_db_law(calls)
#             print("\n===== PREVIEW =====\n")
#             print(preview)
#         finally:
#             await close_supabase_client()

#     asyncio.run(main())
