367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""
|
|
AI Quality Evaluator for Norda Biznes Hub
|
|
|
|
Evaluates the quality of AI chat responses by:
|
|
1. Loading test cases from JSON
|
|
2. Sending queries to the chat engine
|
|
3. Extracting company names from responses
|
|
4. Comparing with expected companies
|
|
5. Calculating pass rate and metrics
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import sys
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional, Set, Tuple
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from database import SessionLocal, init_db, Company
|
|
from nordabiz_chat import NordaBizChatEngine
|
|
import gemini_service
|
|
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
"""Result of a single test case evaluation"""
|
|
test_id: str
|
|
query: str
|
|
expected_companies: List[str]
|
|
found_companies: List[str]
|
|
matched_companies: List[str]
|
|
response_text: str
|
|
score: float # 0.0 to 1.0
|
|
passed: bool
|
|
execution_time_ms: int
|
|
error: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class EvaluationReport:
|
|
"""Full evaluation report"""
|
|
timestamp: datetime
|
|
total_tests: int
|
|
passed_tests: int
|
|
failed_tests: int
|
|
pass_rate: float
|
|
average_score: float
|
|
results: List[TestResult] = field(default_factory=list)
|
|
summary_by_category: Dict[str, Dict] = field(default_factory=dict)
|
|
|
|
|
|
class AIQualityEvaluator:
|
|
"""Evaluates AI chat quality using test cases"""
|
|
|
|
def __init__(self, test_cases_path: str = None):
|
|
"""
|
|
Initialize evaluator
|
|
|
|
Args:
|
|
test_cases_path: Path to test cases JSON file
|
|
"""
|
|
if test_cases_path is None:
|
|
test_cases_path = Path(__file__).parent / "ai_quality_test_cases.json"
|
|
|
|
self.test_cases_path = Path(test_cases_path)
|
|
self.test_cases = self._load_test_cases()
|
|
self.config = self.test_cases.get("evaluation_config", {})
|
|
|
|
# Initialize database
|
|
init_db()
|
|
self.db = SessionLocal()
|
|
|
|
# Initialize Gemini service
|
|
gemini_service.init_gemini_service(model='flash-2.0')
|
|
|
|
# Load all company names for matching
|
|
self.all_companies = self._load_company_names()
|
|
|
|
def _load_test_cases(self) -> Dict:
|
|
"""Load test cases from JSON file"""
|
|
if not self.test_cases_path.exists():
|
|
raise FileNotFoundError(f"Test cases file not found: {self.test_cases_path}")
|
|
|
|
with open(self.test_cases_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def _load_company_names(self) -> Set[str]:
|
|
"""Load all company names from database"""
|
|
companies = self.db.query(Company).filter_by(status='active').all()
|
|
return {c.name for c in companies}
|
|
|
|
def _extract_companies_from_response(self, response: str) -> List[str]:
|
|
"""
|
|
Extract company names mentioned in the AI response
|
|
|
|
Uses multiple strategies:
|
|
1. Direct name matching
|
|
2. Partial matching (company name appears in response)
|
|
"""
|
|
found = []
|
|
|
|
# Normalize response for matching
|
|
response_lower = response.lower()
|
|
|
|
for company_name in self.all_companies:
|
|
# Check if company name appears in response
|
|
if company_name.lower() in response_lower:
|
|
found.append(company_name)
|
|
continue
|
|
|
|
# Check for partial match (first word of company name)
|
|
first_word = company_name.split()[0].lower()
|
|
if len(first_word) > 3 and first_word in response_lower:
|
|
# Verify it's not a false positive
|
|
if any(word.lower() == first_word for word in response.split()):
|
|
found.append(company_name)
|
|
|
|
return found
|
|
|
|
def _calculate_score(
|
|
self,
|
|
expected: List[str],
|
|
found: List[str]
|
|
) -> Tuple[float, List[str]]:
|
|
"""
|
|
Calculate match score
|
|
|
|
Returns:
|
|
Tuple of (score, list of matched companies)
|
|
"""
|
|
if not expected:
|
|
return 1.0, []
|
|
|
|
expected_lower = {e.lower() for e in expected}
|
|
found_lower = {f.lower() for f in found}
|
|
|
|
# Find matches (case-insensitive, partial matching)
|
|
matched = []
|
|
for exp in expected:
|
|
exp_lower = exp.lower()
|
|
# Check exact match
|
|
if exp_lower in found_lower:
|
|
matched.append(exp)
|
|
continue
|
|
# Check partial match (company name contains expected)
|
|
for f in found:
|
|
if exp_lower in f.lower() or f.lower() in exp_lower:
|
|
matched.append(exp)
|
|
break
|
|
|
|
# Score: proportion of expected companies found
|
|
score = len(matched) / len(expected)
|
|
|
|
return score, matched
|
|
|
|
def evaluate_single(self, test_case: Dict, verbose: bool = False) -> TestResult:
|
|
"""
|
|
Evaluate a single test case
|
|
|
|
Args:
|
|
test_case: Test case dictionary
|
|
verbose: Print details during evaluation
|
|
"""
|
|
import time
|
|
|
|
test_id = test_case.get("id", "unknown")
|
|
query = test_case["query"]
|
|
expected = test_case["expected_companies"]
|
|
|
|
if verbose:
|
|
print(f"\n [{test_id}] Query: {query}")
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Create chat engine and start a test conversation
|
|
engine = NordaBizChatEngine(self.db)
|
|
|
|
# Use a test user ID (1 = system user for testing)
|
|
test_user_id = 1
|
|
conversation = engine.start_conversation(
|
|
user_id=test_user_id,
|
|
title=f"Test: {test_id}",
|
|
conversation_type='test'
|
|
)
|
|
|
|
# Send message to the conversation
|
|
response = engine.send_message(conversation.id, query)
|
|
response_text = response.get("response", "")
|
|
|
|
execution_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# Extract companies from response
|
|
found = self._extract_companies_from_response(response_text)
|
|
|
|
# Calculate score
|
|
score, matched = self._calculate_score(expected, found)
|
|
|
|
# Determine pass/fail
|
|
threshold = self.config.get("pass_threshold", 0.7)
|
|
passed = score >= threshold
|
|
|
|
if verbose:
|
|
status = "PASS" if passed else "FAIL"
|
|
print(f" Status: {status} (score: {score:.2f})")
|
|
print(f" Expected: {expected}")
|
|
print(f" Found: {found}")
|
|
if matched:
|
|
print(f" Matched: {matched}")
|
|
|
|
return TestResult(
|
|
test_id=test_id,
|
|
query=query,
|
|
expected_companies=expected,
|
|
found_companies=found,
|
|
matched_companies=matched,
|
|
response_text=response_text[:500], # Truncate for storage
|
|
score=score,
|
|
passed=passed,
|
|
execution_time_ms=execution_time_ms
|
|
)
|
|
|
|
except Exception as e:
|
|
execution_time_ms = int((time.time() - start_time) * 1000)
|
|
if verbose:
|
|
print(f" ERROR: {str(e)}")
|
|
|
|
return TestResult(
|
|
test_id=test_id,
|
|
query=query,
|
|
expected_companies=expected,
|
|
found_companies=[],
|
|
matched_companies=[],
|
|
response_text="",
|
|
score=0.0,
|
|
passed=False,
|
|
execution_time_ms=execution_time_ms,
|
|
error=str(e)
|
|
)
|
|
|
|
def evaluate_all(self, verbose: bool = False) -> EvaluationReport:
|
|
"""
|
|
Run all test cases and generate report
|
|
|
|
Args:
|
|
verbose: Print details during evaluation
|
|
"""
|
|
test_cases = self.test_cases.get("test_cases", [])
|
|
|
|
if verbose:
|
|
print(f"\n{'='*60}")
|
|
print(f"AI Quality Evaluation - {len(test_cases)} test cases")
|
|
print(f"{'='*60}")
|
|
|
|
results = []
|
|
category_stats = {}
|
|
|
|
for tc in test_cases:
|
|
result = self.evaluate_single(tc, verbose)
|
|
results.append(result)
|
|
|
|
# Update category stats
|
|
category = tc.get("category", "Other")
|
|
if category not in category_stats:
|
|
category_stats[category] = {"total": 0, "passed": 0, "scores": []}
|
|
|
|
category_stats[category]["total"] += 1
|
|
if result.passed:
|
|
category_stats[category]["passed"] += 1
|
|
category_stats[category]["scores"].append(result.score)
|
|
|
|
# Calculate overall metrics
|
|
passed = sum(1 for r in results if r.passed)
|
|
total = len(results)
|
|
pass_rate = passed / total if total > 0 else 0.0
|
|
avg_score = sum(r.score for r in results) / total if total > 0 else 0.0
|
|
|
|
# Finalize category stats
|
|
summary_by_category = {}
|
|
for cat, stats in category_stats.items():
|
|
summary_by_category[cat] = {
|
|
"total": stats["total"],
|
|
"passed": stats["passed"],
|
|
"pass_rate": stats["passed"] / stats["total"] if stats["total"] > 0 else 0.0,
|
|
"avg_score": sum(stats["scores"]) / len(stats["scores"]) if stats["scores"] else 0.0
|
|
}
|
|
|
|
report = EvaluationReport(
|
|
timestamp=datetime.now(),
|
|
total_tests=total,
|
|
passed_tests=passed,
|
|
failed_tests=total - passed,
|
|
pass_rate=pass_rate,
|
|
average_score=avg_score,
|
|
results=results,
|
|
summary_by_category=summary_by_category
|
|
)
|
|
|
|
if verbose:
|
|
self._print_summary(report)
|
|
|
|
return report
|
|
|
|
def _print_summary(self, report: EvaluationReport):
|
|
"""Print evaluation summary"""
|
|
print(f"\n{'='*60}")
|
|
print("EVALUATION SUMMARY")
|
|
print(f"{'='*60}")
|
|
print(f"Total tests: {report.total_tests}")
|
|
print(f"Passed: {report.passed_tests}")
|
|
print(f"Failed: {report.failed_tests}")
|
|
print(f"Pass rate: {report.pass_rate:.1%}")
|
|
print(f"Average score: {report.average_score:.2f}")
|
|
|
|
print(f"\nBy Category:")
|
|
for cat, stats in sorted(report.summary_by_category.items()):
|
|
print(f" {cat}: {stats['passed']}/{stats['total']} ({stats['pass_rate']:.0%})")
|
|
|
|
# Show failed tests
|
|
failed = [r for r in report.results if not r.passed]
|
|
if failed:
|
|
print(f"\nFailed Tests:")
|
|
for r in failed:
|
|
print(f" [{r.test_id}] {r.query}")
|
|
print(f" Expected: {r.expected_companies}")
|
|
print(f" Found: {r.found_companies}")
|
|
|
|
def close(self):
|
|
"""Close database connection"""
|
|
self.db.close()
|
|
|
|
|
|
def main():
|
|
"""Run evaluation from command line"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="AI Quality Evaluator for Norda Biznes")
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
|
parser.add_argument("--test-cases", "-t", type=str, help="Path to test cases JSON")
|
|
args = parser.parse_args()
|
|
|
|
evaluator = AIQualityEvaluator(args.test_cases)
|
|
|
|
try:
|
|
report = evaluator.evaluate_all(verbose=args.verbose)
|
|
|
|
# Return exit code based on pass rate
|
|
threshold = 0.7
|
|
if report.pass_rate >= threshold:
|
|
print(f"\nEvaluation PASSED (pass rate: {report.pass_rate:.1%} >= {threshold:.0%})")
|
|
sys.exit(0)
|
|
else:
|
|
print(f"\nEvaluation FAILED (pass rate: {report.pass_rate:.1%} < {threshold:.0%})")
|
|
sys.exit(1)
|
|
|
|
finally:
|
|
evaluator.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|