Commit a4083b51 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

chore: update batch testing tool — excel output, concurrency, UUID user_id,...

chore: update batch testing tool — excel output, concurrency, UUID user_id, frontend default API URL
parent c2eaaeb7
...@@ -8,9 +8,8 @@ from typing import Any ...@@ -8,9 +8,8 @@ from typing import Any
from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from services.batch_processor import BatchProcessor from services.batch_processor import BatchProcessor
from utils.excel_handler import read_excel, create_results_excel from utils.excel_handler import create_results_excel, read_excel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
...@@ -24,6 +23,7 @@ async def upload_and_test( ...@@ -24,6 +23,7 @@ async def upload_and_test(
file: UploadFile = File(...), file: UploadFile = File(...),
num_tests: int = Form(1), num_tests: int = Form(1),
question_column: str = Form("Câu hỏi"), question_column: str = Form("Câu hỏi"),
api_url: str = Form("http://localhost:8000"),
): ):
""" """
Upload Excel file và bắt đầu batch testing Upload Excel file và bắt đầu batch testing
...@@ -39,14 +39,18 @@ async def upload_and_test( ...@@ -39,14 +39,18 @@ async def upload_and_test(
try: try:
# Validate file # Validate file
if not file.filename or not file.filename.endswith((".xlsx", ".xls")): if not file.filename or not file.filename.endswith((".xlsx", ".xls")):
raise HTTPException(status_code=400, detail="File phải là Excel (.xlsx hoặc .xls)") raise HTTPException(
status_code=400, detail="File phải là Excel (.xlsx hoặc .xls)"
)
# Đọc file # Đọc file
file_content = await file.read() file_content = await file.read()
questions = read_excel(file_content, question_column=question_column) questions = read_excel(file_content, question_column=question_column)
if not questions: if not questions:
raise HTTPException(status_code=400, detail="Không tìm thấy câu hỏi nào trong file") raise HTTPException(
status_code=400, detail="Không tìm thấy câu hỏi nào trong file"
)
# Tạo task ID # Tạo task ID
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
...@@ -63,10 +67,11 @@ async def upload_and_test( ...@@ -63,10 +67,11 @@ async def upload_and_test(
} }
# Start batch processing (async) # Start batch processing (async)
processor = BatchProcessor() processor = BatchProcessor(api_url=api_url)
async def process_task(): async def process_task():
try: try:
def progress_callback(progress: dict[str, Any]): def progress_callback(progress: dict[str, Any]):
progress_store[task_id].update(progress) progress_store[task_id].update(progress)
...@@ -192,4 +197,3 @@ async def download_results(task_id: str, background_tasks: BackgroundTasks): ...@@ -192,4 +197,3 @@ async def download_results(task_id: str, background_tasks: BackgroundTasks):
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
return JSONResponse({"status": "ok", "message": "Batch testing tool is running"}) return JSONResponse({"status": "ok", "message": "Batch testing tool is running"})
...@@ -80,13 +80,17 @@ JWT_SECRET: str | None = os.getenv("JWT_SECRET") ...@@ -80,13 +80,17 @@ JWT_SECRET: str | None = os.getenv("JWT_SECRET")
JWT_ALGORITHM: str | None = os.getenv("JWT_ALGORITHM") JWT_ALGORITHM: str | None = os.getenv("JWT_ALGORITHM")
# ====================== SERVER CONFIG ====================== # ====================== SERVER CONFIG ======================
PORT: int = int(os.getenv("PORT", "5000")) # Lấy PORT từ environment variable, mặc định 5002
# Có thể thay đổi bằng cách set PORT=xxxx trong .env hoặc system env
PORT: int = int(os.getenv("PORT", "5002"))
FIRECRAWL_API_KEY: str | None = os.getenv("FIRECRAWL_API_KEY") FIRECRAWL_API_KEY: str | None = os.getenv("FIRECRAWL_API_KEY")
# ====================== LANGFUSE CONFIGURATION (DEPRECATED) ====================== # ====================== LANGFUSE CONFIGURATION (DEPRECATED) ======================
LANGFUSE_SECRET_KEY: str | None = os.getenv("LANGFUSE_SECRET_KEY") LANGFUSE_SECRET_KEY: str | None = os.getenv("LANGFUSE_SECRET_KEY")
LANGFUSE_PUBLIC_KEY: str | None = os.getenv("LANGFUSE_PUBLIC_KEY") LANGFUSE_PUBLIC_KEY: str | None = os.getenv("LANGFUSE_PUBLIC_KEY")
LANGFUSE_BASE_URL: str | None = os.getenv("LANGFUSE_BASE_URL", "https://cloud.langfuse.com") LANGFUSE_BASE_URL: str | None = os.getenv(
"LANGFUSE_BASE_URL", "https://cloud.langfuse.com"
)
# ====================== LANGSMITH CONFIGURATION (TẮT VÌ RATE LIMIT) ====================== # ====================== LANGSMITH CONFIGURATION (TẮT VÌ RATE LIMIT) ======================
# LANGSMITH_TRACING = os.getenv("LANGSMITH_TRACING", "false") # LANGSMITH_TRACING = os.getenv("LANGSMITH_TRACING", "false")
...@@ -107,7 +111,9 @@ CONV_DATABASE_URL: str | None = os.getenv("CONV_DATABASE_URL") ...@@ -107,7 +111,9 @@ CONV_DATABASE_URL: str | None = os.getenv("CONV_DATABASE_URL")
# ====================== MONGO CONFIGURATION ====================== # ====================== MONGO CONFIGURATION ======================
MONGODB_URI: str | None = os.getenv("MONGODB_URI", "mongodb://localhost:27017") MONGODB_URI: str | None = os.getenv("MONGODB_URI", "mongodb://localhost:27017")
MONGODB_DB_NAME: str | None = os.getenv("MONGODB_DB_NAME", "ai_law") MONGODB_DB_NAME: str | None = os.getenv("MONGODB_DB_NAME", "ai_law")
USE_MONGO_CONVERSATION: bool = os.getenv("USE_MONGO_CONVERSATION", "true").lower() == "true" USE_MONGO_CONVERSATION: bool = (
os.getenv("USE_MONGO_CONVERSATION", "true").lower() == "true"
)
# ====================== CANIFA INTERNAL POSTGRES ====================== # ====================== CANIFA INTERNAL POSTGRES ======================
CHECKPOINT_POSTGRES_URL: str | None = os.getenv("CHECKPOINT_POSTGRES_URL") CHECKPOINT_POSTGRES_URL: str | None = os.getenv("CHECKPOINT_POSTGRES_URL")
...@@ -127,11 +133,13 @@ OTEL_EXPORTER_JAEGER_AGENT_HOST = os.getenv("OTEL_EXPORTER_JAEGER_AGENT_HOST") ...@@ -127,11 +133,13 @@ OTEL_EXPORTER_JAEGER_AGENT_HOST = os.getenv("OTEL_EXPORTER_JAEGER_AGENT_HOST")
OTEL_EXPORTER_JAEGER_AGENT_PORT = os.getenv("OTEL_EXPORTER_JAEGER_AGENT_PORT") OTEL_EXPORTER_JAEGER_AGENT_PORT = os.getenv("OTEL_EXPORTER_JAEGER_AGENT_PORT")
OTEL_SERVICE_NAME = os.getenv("OTEL_SERVICE_NAME") OTEL_SERVICE_NAME = os.getenv("OTEL_SERVICE_NAME")
OTEL_TRACES_EXPORTER = os.getenv("OTEL_TRACES_EXPORTER") OTEL_TRACES_EXPORTER = os.getenv("OTEL_TRACES_EXPORTER")
OTEL_EXPORTER_JAEGER_AGENT_SPLIT_OVERSIZED_BATCHES = os.getenv("OTEL_EXPORTER_JAEGER_AGENT_SPLIT_OVERSIZED_BATCHES") OTEL_EXPORTER_JAEGER_AGENT_SPLIT_OVERSIZED_BATCHES = os.getenv(
"OTEL_EXPORTER_JAEGER_AGENT_SPLIT_OVERSIZED_BATCHES"
)
# ====================== BATCH TESTING TOOL CONFIGURATION ====================== # ====================== BATCH TESTING TOOL CONFIGURATION ======================
CHATBOT_API_URL: str = os.getenv("CHATBOT_API_URL", "http://localhost:8000") CHATBOT_API_URL: str = os.getenv("CHATBOT_API_URL", "http://localhost:8000")
CHATBOT_API_ENDPOINT: str = os.getenv("CHATBOT_API_ENDPOINT", "/api/agent/chat") CHATBOT_API_ENDPOINT: str = os.getenv("CHATBOT_API_ENDPOINT", "/api/agent/chat")
TOOL_PORT: int = int(os.getenv("TOOL_PORT", "5001")) TOOL_PORT: int = int(os.getenv("TOOL_PORT", "5002"))
MAX_CONCURRENT_REQUESTS: int = int(os.getenv("MAX_CONCURRENT_REQUESTS", "5")) MAX_CONCURRENT_REQUESTS: int = int(os.getenv("MAX_CONCURRENT_REQUESTS", "5"))
REQUEST_TIMEOUT: int = int(os.getenv("REQUEST_TIMEOUT", "60")) REQUEST_TIMEOUT: int = int(os.getenv("REQUEST_TIMEOUT", "60"))
No preview for this file type
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
httpx==0.25.2
pandas==2.1.3
openpyxl==3.1.2
python-dotenv==1.0.0
.\.venv\Scripts\activate
\ No newline at end of file
...@@ -8,7 +8,6 @@ import time ...@@ -8,7 +8,6 @@ import time
from typing import Any from typing import Any
import httpx import httpx
from config import CHATBOT_API_ENDPOINT, CHATBOT_API_URL, REQUEST_TIMEOUT from config import CHATBOT_API_ENDPOINT, CHATBOT_API_URL, REQUEST_TIMEOUT
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -17,8 +16,8 @@ logger = logging.getLogger(__name__) ...@@ -17,8 +16,8 @@ logger = logging.getLogger(__name__)
class ChatbotAPIClient: class ChatbotAPIClient:
"""Client để gọi chatbot API""" """Client để gọi chatbot API"""
def __init__(self): def __init__(self, api_url: str | None = None):
self.base_url = CHATBOT_API_URL self.base_url = api_url or CHATBOT_API_URL
self.endpoint = CHATBOT_API_ENDPOINT self.endpoint = CHATBOT_API_ENDPOINT
self.timeout = REQUEST_TIMEOUT self.timeout = REQUEST_TIMEOUT
...@@ -50,9 +49,17 @@ class ChatbotAPIClient: ...@@ -50,9 +49,17 @@ class ChatbotAPIClient:
} }
""" """
if not user_id: if not user_id:
user_id = f"batch_test_{test_id}_{test_attempt}" if test_id else "batch_test_user" user_id = (
f"batch_test_{test_id}_{test_attempt}" if test_id else "batch_test_user"
)
# Nếu api_url đã chứa endpoint (kết thúc bằng /chat), dùng trực tiếp
# Nếu không, thêm endpoint vào
if self.base_url.endswith("/chat") or self.base_url.endswith("/api/agent/chat"):
url = self.base_url
else:
url = f"{self.base_url}{self.endpoint}"
url = f"{self.base_url}{self.endpoint}"
payload = { payload = {
"user_query": query, "user_query": query,
"user_id": user_id, "user_id": user_id,
...@@ -75,7 +82,9 @@ class ChatbotAPIClient: ...@@ -75,7 +82,9 @@ class ChatbotAPIClient:
ai_response = data.get("ai_response", "") ai_response = data.get("ai_response", "")
product_ids = data.get("product_ids", []) product_ids = data.get("product_ids", [])
else: else:
error_message = f"API returned status {response.status_code}: {response.text}" error_message = (
f"API returned status {response.status_code}: {response.text}"
)
logger.error(error_message) logger.error(error_message)
except httpx.TimeoutException: except httpx.TimeoutException:
...@@ -123,4 +132,3 @@ class ChatbotAPIClient: ...@@ -123,4 +132,3 @@ class ChatbotAPIClient:
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
return results return results
...@@ -5,6 +5,7 @@ Batch Processor để xử lý batch testing ...@@ -5,6 +5,7 @@ Batch Processor để xử lý batch testing
import asyncio import asyncio
import logging import logging
import statistics import statistics
import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Callable from typing import Any, Callable
...@@ -17,8 +18,8 @@ logger = logging.getLogger(__name__) ...@@ -17,8 +18,8 @@ logger = logging.getLogger(__name__)
class BatchProcessor: class BatchProcessor:
"""Processor để xử lý batch testing""" """Processor để xử lý batch testing"""
def __init__(self, max_concurrent: int = 5): def __init__(self, max_concurrent: int = 5, api_url: str | None = None):
self.api_client = ChatbotAPIClient() self.api_client = ChatbotAPIClient(api_url=api_url)
self.langfuse_client = LangfuseClient() self.langfuse_client = LangfuseClient()
self.max_concurrent = max_concurrent self.max_concurrent = max_concurrent
...@@ -45,35 +46,33 @@ class BatchProcessor: ...@@ -45,35 +46,33 @@ class BatchProcessor:
total_questions = len(questions) total_questions = len(questions)
total_tests = total_questions * num_tests_per_question total_tests = total_questions * num_tests_per_question
detailed_results = [] detailed_results: list[dict[str, Any]] = []
aggregated_results = [] aggregated_results: list[dict[str, Any]] = []
start_time = datetime.now() start_time = datetime.now()
processed = 0 processed = 0
successful = 0 successful = 0
failed = 0 failed = 0
total_cost = 0.0
# Process từng câu hỏi semaphore = asyncio.Semaphore(self.max_concurrent)
for question_data in questions: lock = asyncio.Lock()
async def run_test(question_data: dict[str, Any], test_attempt: int):
nonlocal processed, successful, failed
question_id = question_data["id"] question_id = question_data["id"]
question = question_data["question"] question = question_data["question"]
row_data = question_data.get("row_data", {}) row_data = question_data.get("row_data", {})
test_id = f"q{question_id}_t{test_attempt}"
try:
async with semaphore:
result = await self.api_client.chat(
query=question,
user_id=str(uuid.uuid4()),
test_id=test_id,
test_attempt=test_attempt,
)
# Test câu hỏi này N lần
test_results = []
for test_attempt in range(1, num_tests_per_question + 1):
test_id = f"q{question_id}_t{test_attempt}"
# Gọi API
result = await self.api_client.chat(
query=question,
user_id=f"batch_test_{question_id}",
test_id=str(question_id),
test_attempt=test_attempt,
)
# Thêm thông tin vào result
result["question_id"] = question_id result["question_id"] = question_id
result["question"] = question result["question"] = question
result["test_attempt"] = test_attempt result["test_attempt"] = test_attempt
...@@ -81,43 +80,67 @@ class BatchProcessor: ...@@ -81,43 +80,67 @@ class BatchProcessor:
# Lấy thêm metrics từ Langfuse (optional) # Lấy thêm metrics từ Langfuse (optional)
if self.langfuse_client.enabled: if self.langfuse_client.enabled:
# Tìm trace từ Langfuse (có thể cần delay để Langfuse sync) await asyncio.sleep(1)
await asyncio.sleep(1) # Đợi Langfuse sync
langfuse_metrics = await self.langfuse_client.get_trace_metrics( langfuse_metrics = await self.langfuse_client.get_trace_metrics(
trace_id=test_id trace_id=test_id
) )
if langfuse_metrics: if langfuse_metrics:
result.update(langfuse_metrics) result.update(langfuse_metrics)
test_results.append(result) except Exception as e:
result = {
"status": "error",
"ai_response": "",
"product_ids": [],
"latency_ms": 0,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"error_message": str(e),
"question_id": question_id,
"question": question,
"test_attempt": test_attempt,
"product_count": 0,
}
async with lock:
detailed_results.append(result) detailed_results.append(result)
# Update counters
processed += 1 processed += 1
if result["status"] == "success": if result.get("status") == "success":
successful += 1 successful += 1
else: else:
failed += 1 failed += 1
# Update progress
if progress_callback: if progress_callback:
progress_callback( progress_callback(
{ {
"processed": processed, "processed": processed,
"total": total_tests, "total": total_tests,
"current_question": question_id, "current_question": result.get("question_id"),
"current_attempt": test_attempt, "current_attempt": result.get("test_attempt"),
"successful": successful, "successful": successful,
"failed": failed, "failed": failed,
} }
) )
# Tính toán aggregated metrics cho câu hỏi này # Tạo và chạy tasks cho tất cả tests (các question được xử lý song song)
aggregated = self._aggregate_test_results(question_id, question, test_results, row_data) tasks = []
for q in questions:
for attempt in range(1, num_tests_per_question + 1):
tasks.append(asyncio.create_task(run_test(q, attempt)))
if tasks:
await asyncio.gather(*tasks)
# Tính aggregated results per question
for q in questions:
qid = q["id"]
q_results = [r for r in detailed_results if r.get("question_id") == qid]
aggregated = self._aggregate_test_results(
qid, q.get("question", ""), q_results, q.get("row_data", {})
)
aggregated_results.append(aggregated) aggregated_results.append(aggregated)
# Update total cost # Tính tổng chi phí từ detailed results
total_cost += aggregated.get("avg_cost", 0.0) * num_tests_per_question total_cost = round(sum(r.get("cost", 0) for r in detailed_results), 4)
# Tính tổng kết # Tính tổng kết
end_time = datetime.now() end_time = datetime.now()
...@@ -129,9 +152,13 @@ class BatchProcessor: ...@@ -129,9 +152,13 @@ class BatchProcessor:
"total_tests": total_tests, "total_tests": total_tests,
"successful": successful, "successful": successful,
"failed": failed, "failed": failed,
"success_rate": round((successful / total_tests * 100) if total_tests > 0 else 0, 2), "success_rate": round(
(successful / total_tests * 100) if total_tests > 0 else 0, 2
),
"total_cost_usd": round(total_cost, 4), "total_cost_usd": round(total_cost, 4),
"avg_cost_per_test": round(total_cost / total_tests if total_tests > 0 else 0, 4), "avg_cost_per_test": round(
total_cost / total_tests if total_tests > 0 else 0, 4
),
"duration_seconds": round(duration_seconds, 2), "duration_seconds": round(duration_seconds, 2),
"start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
"end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"), "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
...@@ -181,15 +208,21 @@ class BatchProcessor: ...@@ -181,15 +208,21 @@ class BatchProcessor:
"total_tests": total_count, "total_tests": total_count,
"successful_tests": success_count, "successful_tests": success_count,
"failed_tests": total_count - success_count, "failed_tests": total_count - success_count,
"success_rate": round((success_count / total_count * 100) if total_count > 0 else 0, 2), "success_rate": round(
(success_count / total_count * 100) if total_count > 0 else 0, 2
),
"avg_latency_ms": round(statistics.mean(latencies), 2) if latencies else 0, "avg_latency_ms": round(statistics.mean(latencies), 2) if latencies else 0,
"min_latency_ms": round(min(latencies), 2) if latencies else 0, "min_latency_ms": round(min(latencies), 2) if latencies else 0,
"max_latency_ms": round(max(latencies), 2) if latencies else 0, "max_latency_ms": round(max(latencies), 2) if latencies else 0,
"avg_cost_usd": round(statistics.mean(costs), 4) if costs else 0, "avg_cost_usd": round(statistics.mean(costs), 4) if costs else 0,
"total_cost_usd": round(sum(costs), 4), "total_cost_usd": round(sum(costs), 4),
"avg_product_count": round(statistics.mean(product_counts), 2) if product_counts else 0, "avg_product_count": round(statistics.mean(product_counts), 2)
if product_counts
else 0,
"response_consistency": consistency, "response_consistency": consistency,
"sample_response": successful_results[0].get("ai_response", "") if successful_results else "", "sample_response": successful_results[0].get("ai_response", "")
if successful_results
else "",
} }
# Thêm các cột gốc từ Excel # Thêm các cột gốc từ Excel
...@@ -224,7 +257,9 @@ class BatchProcessor: ...@@ -224,7 +257,9 @@ class BatchProcessor:
similarities = [] similarities = []
for i in range(len(responses)): for i in range(len(responses)):
for j in range(i + 1, len(responses)): for j in range(i + 1, len(responses)):
similarities.append(self._simple_similarity(responses[i], responses[j])) similarities.append(
self._simple_similarity(responses[i], responses[j])
)
similarity = statistics.mean(similarities) if similarities else 0 similarity = statistics.mean(similarities) if similarities else 0
# Đánh giá # Đánh giá
...@@ -251,4 +286,3 @@ class BatchProcessor: ...@@ -251,4 +286,3 @@ class BatchProcessor:
union = len(words1 | words2) union = len(words1 | words2)
return intersection / union if union > 0 else 0.0 return intersection / union if union > 0 else 0.0
...@@ -48,6 +48,15 @@ ...@@ -48,6 +48,15 @@
padding: 40px; padding: 40px;
} }
.section-title {
font-size: 1.3em;
font-weight: 700;
margin-bottom: 20px;
color: #333;
border-bottom: 3px solid #667eea;
padding-bottom: 10px;
}
.form-section { .form-section {
background: #f8f9fa; background: #f8f9fa;
padding: 30px; padding: 30px;
...@@ -55,8 +64,21 @@ ...@@ -55,8 +64,21 @@
margin-bottom: 30px; margin-bottom: 30px;
} }
.form-row {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-bottom: 20px;
}
@media (max-width: 768px) {
.form-row {
grid-template-columns: 1fr;
}
}
.form-group { .form-group {
margin-bottom: 25px; margin-bottom: 20px;
} }
.form-group label { .form-group label {
...@@ -64,6 +86,15 @@ ...@@ -64,6 +86,15 @@
margin-bottom: 8px; margin-bottom: 8px;
font-weight: 600; font-weight: 600;
color: #333; color: #333;
font-size: 0.95em;
}
.form-group .hint {
display: block;
font-size: 0.85em;
color: #666;
margin-top: 5px;
font-weight: normal;
} }
.form-group input, .form-group input,
...@@ -248,28 +279,41 @@ ...@@ -248,28 +279,41 @@
<div class="content"> <div class="content">
<form id="uploadForm" class="form-section"> <form id="uploadForm" class="form-section">
<!-- Section 1: API Configuration -->
<div class="section-title">⚙️ API Configuration</div>
<div class="form-group">
<label for="apiUrl">🔗 Chatbot API URL</label>
<input type="text" id="apiUrl" value="http://localhost:5000/api/agent/chat" placeholder="http://localhost:5000/api/agent/chat" required>
<span class="hint">URL của API chatbot để test (ví dụ: http://localhost:8000/api/agent/chat hoặc http://localhost:5000)</span>
</div>
<!-- Section 2: File Upload -->
<div class="section-title">📁 Upload File</div>
<div class="form-group"> <div class="form-group">
<label for="excelFile">📁 Chọn file Excel:</label> <label for="excelFile">Chọn file Excel</label>
<div class="file-upload"> <div class="file-upload">
<input type="file" id="excelFile" accept=".xlsx,.xls" required> <input type="file" id="excelFile" accept=".xlsx,.xls" required>
<label for="excelFile" class="file-upload-label"> <label for="excelFile" class="file-upload-label">
<span>📎 Click để chọn file hoặc kéo thả file vào đây</span> <span>📎 Click để chọn hoặc kéo thả file Excel vào đây</span>
</label> </label>
</div> </div>
<div id="fileName" class="file-name"></div> <div id="fileName" class="file-name"></div>
<span class="hint">File Excel phải chứa ít nhất 1 cột câu hỏi</span>
</div> </div>
<div class="form-group"> <!-- Section 3: Test Configuration -->
<label for="questionColumn">🔤 Tên cột chứa câu hỏi:</label> <div class="section-title">⚙️ Test Configuration</div>
<input type="text" id="questionColumn" value="Câu hỏi" placeholder="Câu hỏi, Question, Query..."> <div class="form-row">
</div> <div class="form-group">
<label for="questionColumn">🔤 Tên cột chứa câu hỏi</label>
<div class="form-group"> <input type="text" id="questionColumn" value="Câu hỏi" placeholder="Câu hỏi, Question, Query..." required>
<label for="numTests">🔄 Số lần test mỗi câu hỏi:</label> <span class="hint">Tên cột trong Excel chứa câu hỏi</span>
<input type="number" id="numTests" value="1" min="1" max="10" required> </div>
<small style="color: #666; margin-top: 5px; display: block;"> <div class="form-group">
Mỗi câu hỏi sẽ được test số lần này để đánh giá consistency <label for="numTests">🔄 Số lần test mỗi câu hỏi</label>
</small> <input type="number" id="numTests" value="1" min="1" max="10" required>
<span class="hint">Mỗi câu hỏi sẽ được test N lần để đánh giá consistency</span>
</div>
</div> </div>
<button type="submit" class="btn" id="submitBtn"> <button type="submit" class="btn" id="submitBtn">
...@@ -322,6 +366,7 @@ ...@@ -322,6 +366,7 @@
const errorMessage = document.getElementById('errorMessage'); const errorMessage = document.getElementById('errorMessage');
const infoMessage = document.getElementById('infoMessage'); const infoMessage = document.getElementById('infoMessage');
let progressInterval = null; let progressInterval = null;
let lastProcessed = -1;
let currentTaskId = null; let currentTaskId = null;
// File upload handler // File upload handler
...@@ -335,6 +380,7 @@ ...@@ -335,6 +380,7 @@
e.preventDefault(); e.preventDefault();
const fileInput = document.getElementById('excelFile'); const fileInput = document.getElementById('excelFile');
const apiUrl = document.getElementById('apiUrl').value;
const questionColumn = document.getElementById('questionColumn').value; const questionColumn = document.getElementById('questionColumn').value;
const numTests = parseInt(document.getElementById('numTests').value); const numTests = parseInt(document.getElementById('numTests').value);
...@@ -343,6 +389,11 @@ ...@@ -343,6 +389,11 @@
return; return;
} }
if (!apiUrl) {
showError('Vui lòng nhập API URL');
return;
}
// Hide previous messages // Hide previous messages
hideMessages(); hideMessages();
resultSection.classList.remove('active'); resultSection.classList.remove('active');
...@@ -356,6 +407,7 @@ ...@@ -356,6 +407,7 @@
formData.append('file', fileInput.files[0]); formData.append('file', fileInput.files[0]);
formData.append('num_tests', numTests); formData.append('num_tests', numTests);
formData.append('question_column', questionColumn); formData.append('question_column', questionColumn);
formData.append('api_url', apiUrl);
try { try {
const response = await fetch('/api/batch-test/upload', { const response = await fetch('/api/batch-test/upload', {
...@@ -387,12 +439,20 @@ ...@@ -387,12 +439,20 @@
clearInterval(progressInterval); clearInterval(progressInterval);
} }
// reset last processed counter for this task
lastProcessed = -1;
progressInterval = setInterval(async () => { progressInterval = setInterval(async () => {
try { try {
const response = await fetch(`/api/batch-test/progress/${taskId}`); const response = await fetch(`/api/batch-test/progress/${taskId}`);
const progress = await response.json(); const progress = await response.json();
updateProgress(progress); // Only update UI when processed count changes or status not 'processing'
const processedNow = progress.processed || 0;
if (processedNow !== lastProcessed || progress.status !== 'processing') {
updateProgress(progress);
lastProcessed = processedNow;
}
if (progress.status === 'completed') { if (progress.status === 'completed') {
clearInterval(progressInterval); clearInterval(progressInterval);
......
...@@ -17,11 +17,14 @@ def _get_pandas(): ...@@ -17,11 +17,14 @@ def _get_pandas():
global _pandas global _pandas
if _pandas is None: if _pandas is None:
import pandas as pd import pandas as pd
_pandas = pd _pandas = pd
return _pandas return _pandas
def read_excel(file_content: bytes, question_column: str = "Câu hỏi") -> list[dict[str, Any]]: def read_excel(
file_content: bytes, question_column: str = "Câu hỏi"
) -> list[dict[str, Any]]:
""" """
Đọc Excel file và extract câu hỏi Đọc Excel file và extract câu hỏi
...@@ -39,12 +42,22 @@ def read_excel(file_content: bytes, question_column: str = "Câu hỏi") -> list ...@@ -39,12 +42,22 @@ def read_excel(file_content: bytes, question_column: str = "Câu hỏi") -> list
# Tìm cột câu hỏi (case-insensitive) # Tìm cột câu hỏi (case-insensitive)
question_col = None question_col = None
for col in df.columns: for col in df.columns:
if question_column.lower() in col.lower() or "question" in col.lower() or "query" in col.lower(): if (
question_column.lower() in col.lower()
or "question" in col.lower()
or "query" in col.lower()
):
question_col = col question_col = col
break break
# Nếu chỉ có 1 cột trong file, mặc định dùng cột đó làm question
if question_col is None: if question_col is None:
raise ValueError(f"Không tìm thấy cột '{question_column}' trong file Excel") if df.shape[1] == 1:
question_col = df.columns[0]
else:
raise ValueError(
f"Không tìm thấy cột '{question_column}' trong file Excel"
)
results = [] results = []
for idx, row in df.iterrows(): for idx, row in df.iterrows():
...@@ -78,7 +91,9 @@ def create_results_excel( ...@@ -78,7 +91,9 @@ def create_results_excel(
aggregated_results: list[dict[str, Any]], aggregated_results: list[dict[str, Any]],
) -> bytes: ) -> bytes:
""" """
Tạo Excel file với 3 sheets: Summary, Results, Aggregated Tạo Excel file với 1 sheet duy nhất:
- Cột Question (câu hỏi)
- Cột Answer1, Answer2, ... (theo số lần test)
Args: Args:
summary_data: Dict tổng kết summary_data: Dict tổng kết
...@@ -92,21 +107,51 @@ def create_results_excel( ...@@ -92,21 +107,51 @@ def create_results_excel(
pd = _get_pandas() pd = _get_pandas()
output = BytesIO() output = BytesIO()
# Type ignore for BytesIO - pandas accepts it at runtime # Lấy số lần test từ summary
with pd.ExcelWriter(output, engine="openpyxl") as writer: # type: ignore num_tests = int(summary_data.get("num_tests_per_question", 1))
# Sheet 1: Summary
summary_df = pd.DataFrame([summary_data]) # Xây dựng dữ liệu cho sheet: mỗi row là 1 câu hỏi + các answers
summary_df.to_excel(writer, sheet_name="Summary", index=False) sheet_data: list[dict[str, Any]] = []
for agg in aggregated_results:
row: dict[str, Any] = {"Question": agg.get("question", "")}
question_id = agg.get("question_id")
for test_num in range(1, num_tests + 1):
# Tìm kết quả của lần test này trong detailed_results
result = next(
(
r
for r in detailed_results
if r.get("question_id") == question_id
and r.get("test_attempt") == test_num
),
None,
)
# Sheet 2: Results (chi tiết từng lần test) answer = result.get("ai_response", "") if result else ""
if detailed_results: row[f"Answer {test_num}"] = answer
results_df = pd.DataFrame(detailed_results)
results_df.to_excel(writer, sheet_name="Results", index=False)
# Sheet 3: Aggregated (tổng hợp theo câu hỏi) sheet_data.append(row)
if aggregated_results:
aggregated_df = pd.DataFrame(aggregated_results) # Tạo DataFrame và ghi ra Excel
aggregated_df.to_excel(writer, sheet_name="Aggregated", index=False) results_df = pd.DataFrame(sheet_data)
with pd.ExcelWriter(output, engine="openpyxl") as writer: # type: ignore
results_df.to_excel(writer, sheet_name="Results", index=False)
# Format column width
worksheet = writer.sheets["Results"]
for column_cells in worksheet.columns:
max_length = 0
column_letter = column_cells[0].column_letter
for cell in column_cells:
try:
if cell.value:
max_length = max(max_length, len(str(cell.value)))
except Exception:
pass
adjusted_width = min(max_length + 2, 50)
worksheet.column_dimensions[column_letter].width = adjusted_width
output.seek(0) output.seek(0)
return output.getvalue() return output.getvalue()
...@@ -114,4 +159,3 @@ def create_results_excel( ...@@ -114,4 +159,3 @@ def create_results_excel(
except Exception as e: except Exception as e:
logger.error(f"Error creating Excel: {e}", exc_info=True) logger.error(f"Error creating Excel: {e}", exc_info=True)
raise raise
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment