250 lines
11 KiB
Python
250 lines
11 KiB
Python
import pytest
|
|
from fastapi import HTTPException
|
|
from jira_webhook_llm import app
|
|
from llm.models import JiraWebhookPayload
|
|
from database.crud import create_analysis_record, get_analysis_by_id
|
|
from database.models import JiraAnalysis
|
|
from database.database import get_db
|
|
from database.models import JiraAnalysis
|
|
from database.crud import create_analysis_record, update_record_status, get_analysis_by_id
|
|
from unittest.mock import MagicMock # Import MagicMock
|
|
from datetime import datetime, timezone
|
|
|
|
def test_error_handling_middleware(test_client, mock_jira_payload):
|
|
# Test 404 error handling
|
|
response = test_client.post("/nonexistent-endpoint", json={})
|
|
assert response.status_code == 404
|
|
assert "detail" in response.json() # FastAPI's default 404 response uses "detail"
|
|
|
|
# Test validation error handling
|
|
invalid_payload = mock_jira_payload.copy()
|
|
invalid_payload.pop("issueKey")
|
|
response = test_client.post("/api/jira-webhook", json=invalid_payload)
|
|
assert response.status_code == 422
|
|
assert "detail" in response.json() # FastAPI's default 422 response uses "detail"
|
|
|
|
def test_webhook_handler(setup_db, test_client, mock_full_jira_payload, monkeypatch):
|
|
# Mock the LLM analysis chain to avoid external calls
|
|
mock_chain = MagicMock()
|
|
mock_chain.ainvoke.return_value = { # Use ainvoke as per webhooks/handlers.py
|
|
"hasMultipleEscalations": False,
|
|
"customerSentiment": "neutral",
|
|
"analysisSummary": "Mock analysis summary.",
|
|
"actionableItems": ["Mock action item 1", "Mock action item 2"],
|
|
"analysisFlags": ["mock_flag"]
|
|
}
|
|
|
|
monkeypatch.setattr("llm.chains.analysis_chain", mock_chain)
|
|
|
|
# Test successful webhook handling with full payload
|
|
response = test_client.post("/api/jira-webhook", json=mock_full_jira_payload)
|
|
assert response.status_code == 202
|
|
response_data = response.json()
|
|
assert "status" in response_data
|
|
assert response_data["status"] in ["success", "skipped", "queued"]
|
|
if response_data["status"] == "success":
|
|
assert "analysis_flags" in response_data
|
|
|
|
# Validate database storage
|
|
from database.models import JiraAnalysis
|
|
from database.database import get_db
|
|
with get_db() as db:
|
|
record = db.query(JiraAnalysis).filter_by(issue_key=mock_full_jira_payload["issueKey"]).first()
|
|
assert record is not None
|
|
assert record.issue_summary == mock_full_jira_payload["summary"]
|
|
assert record.request_payload == mock_full_jira_payload
|
|
|
|
def test_llm_test_endpoint(test_client):
|
|
# Test LLM test endpoint
|
|
response = test_client.post("/api/test-llm")
|
|
assert response.status_code == 200
|
|
assert "response" in response.json()
|
|
|
|
def test_create_analysis_record_endpoint(setup_db, test_client, mock_full_jira_payload):
|
|
# Test successful creation of a new analysis record via API
|
|
response = test_client.post("/api/request", json=mock_full_jira_payload)
|
|
assert response.status_code == 201
|
|
response_data = response.json()
|
|
assert "message" in response_data
|
|
assert response_data["message"] == "Record created successfully"
|
|
assert "record_id" in response_data
|
|
|
|
# Verify the record exists in the database
|
|
with get_db() as db:
|
|
record = get_analysis_by_id(db, response_data["record_id"])
|
|
assert record is not None
|
|
assert record.issue_key == mock_full_jira_payload["issueKey"]
|
|
assert record.issue_summary == mock_full_jira_payload["summary"]
|
|
assert record.request_payload == mock_full_jira_payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_decorator():
|
|
# Test retry decorator functionality
|
|
from jira_webhook_llm import retry # Import decorator from main module
|
|
@retry(max_retries=3) # Use imported decorator
|
|
async def failing_function():
|
|
raise Exception("Test error")
|
|
|
|
with pytest.raises(Exception):
|
|
await failing_function()
|
|
|
|
def test_get_pending_queue_records_endpoint(setup_db, test_client, mock_full_jira_payload):
|
|
# Create a pending record
|
|
with get_db() as db:
|
|
payload_model = JiraWebhookPayload(**mock_full_jira_payload)
|
|
pending_record = create_analysis_record(db, payload_model)
|
|
db.commit()
|
|
db.refresh(pending_record)
|
|
|
|
response = test_client.get("/api/queue/pending")
|
|
assert response.status_code == 200, f"Expected 200 but got {response.status_code}. Response: {response.text}"
|
|
data = response.json()["data"]
|
|
assert len(data) == 1
|
|
assert data[0]["issue_key"] == mock_full_jira_payload["issueKey"]
|
|
assert data[0]["status"] == "pending"
|
|
|
|
def test_get_pending_queue_records_endpoint_empty(setup_db, test_client):
|
|
# Ensure no records exist
|
|
with get_db() as db:
|
|
db.query(JiraAnalysis).delete()
|
|
db.commit()
|
|
|
|
response = test_client.get("/api/queue/pending")
|
|
assert response.status_code == 200
|
|
data = response.json()["data"]
|
|
assert len(data) == 0
|
|
|
|
def test_get_pending_queue_records_endpoint_error(test_client, monkeypatch):
|
|
def mock_get_pending_analysis_records(db):
|
|
raise Exception("Database error")
|
|
|
|
monkeypatch.setattr("api.handlers.get_pending_analysis_records", mock_get_pending_analysis_records)
|
|
|
|
response = test_client.get("/api/queue/pending")
|
|
assert response.status_code == 500, f"Expected 500 but got {response.status_code}. Response: {response.text}"
|
|
assert "detail" in response.json() # FastAPI's HTTPException uses "detail"
|
|
assert response.json()["detail"] == "Database error: Database error"
|
|
|
|
def test_retry_analysis_record_endpoint_success(setup_db, test_client, mock_full_jira_payload):
|
|
# Create a failed record
|
|
with get_db() as db:
|
|
payload_model = JiraWebhookPayload(**mock_full_jira_payload)
|
|
failed_record = create_analysis_record(db, payload_model)
|
|
update_record_status(db, failed_record.id, "failed", error_message="LLM failed")
|
|
db.commit()
|
|
db.refresh(failed_record)
|
|
|
|
response = test_client.post(f"/api/queue/{failed_record.id}/retry")
|
|
assert response.status_code == 200
|
|
assert response.json()["message"] == f"Record {failed_record.id} marked for retry."
|
|
|
|
with get_db() as db:
|
|
updated_record = get_analysis_by_id(db, failed_record.id)
|
|
assert updated_record.status == "pending"
|
|
assert updated_record.error_message is None
|
|
assert updated_record.analysis_result is None
|
|
assert updated_record.raw_response is None
|
|
assert updated_record.next_retry_at is None
|
|
|
|
def test_retry_analysis_record_endpoint_not_found(test_client):
|
|
response = test_client.post("/api/queue/99999/retry")
|
|
assert response.status_code == 404
|
|
# Handle both possible error response formats
|
|
assert "detail" in response.json() # FastAPI's HTTPException uses "detail"
|
|
assert response.json()["detail"] == "Analysis record not found"
|
|
|
|
def test_retry_analysis_record_endpoint_invalid_status(setup_db, test_client, mock_full_jira_payload):
|
|
# Create a successful record
|
|
with get_db() as db:
|
|
payload_model = JiraWebhookPayload(**mock_full_jira_payload)
|
|
successful_record = create_analysis_record(db, payload_model)
|
|
update_record_status(db, successful_record.id, "success")
|
|
db.commit()
|
|
db.refresh(successful_record)
|
|
|
|
response = test_client.post(f"/api/queue/{successful_record.id}/retry")
|
|
assert response.status_code == 400
|
|
assert response.json()["detail"] == f"Record status is 'success'. Only 'failed' or 'validation_failed' records can be retried."
|
|
|
|
def test_retry_analysis_record_endpoint_db_update_failure(setup_db, test_client, mock_full_jira_payload, monkeypatch):
|
|
# Create a failed record
|
|
with get_db() as db:
|
|
payload_model = JiraWebhookPayload(**mock_full_jira_payload)
|
|
failed_record = create_analysis_record(db, payload_model)
|
|
update_record_status(db, failed_record.id, "failed", error_message="LLM failed")
|
|
db.commit()
|
|
db.refresh(failed_record)
|
|
|
|
def mock_update_record_status(*args, **kwargs):
|
|
return None # Simulate update failure
|
|
|
|
monkeypatch.setattr("api.handlers.update_record_status", mock_update_record_status)
|
|
|
|
response = test_client.post(f"/api/queue/{failed_record.id}/retry")
|
|
assert response.status_code == 500, f"Expected 500 but got {response.status_code}. Response: {response.text}"
|
|
assert response.json()["detail"] == "Failed to update record for retry."
|
|
|
|
def test_retry_analysis_record_endpoint_retry_count_and_next_retry_at(setup_db, test_client, mock_full_jira_payload):
|
|
# Create a failed record with an initial retry count and next_retry_at
|
|
with get_db() as db:
|
|
payload_model = JiraWebhookPayload(**mock_full_jira_payload)
|
|
failed_record = create_analysis_record(db, payload_model)
|
|
update_record_status(
|
|
db,
|
|
failed_record.id,
|
|
"failed",
|
|
error_message="LLM failed",
|
|
retry_count_increment=1,
|
|
next_retry_at=datetime.now(timezone.utc)
|
|
)
|
|
db.commit()
|
|
db.refresh(failed_record)
|
|
initial_retry_count = failed_record.retry_count
|
|
|
|
response = test_client.post(f"/api/queue/{failed_record.id}/retry")
|
|
assert response.status_code == 200
|
|
|
|
with get_db() as db:
|
|
updated_record = get_analysis_by_id(db, failed_record.id)
|
|
assert updated_record.status == "pending"
|
|
assert updated_record.error_message is None
|
|
assert updated_record.next_retry_at is None # Should be reset to None
|
|
# The retry endpoint itself doesn't increment retry_count,
|
|
# it just resets the status. The increment happens during processing.
|
|
# So, we assert it remains the same as before the retry request.
|
|
assert updated_record.retry_count == initial_retry_count
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_retry_operations(setup_db, test_client, mock_full_jira_payload):
|
|
# Create multiple failed records
|
|
record_ids = []
|
|
with get_db() as db:
|
|
for i in range(5):
|
|
payload = mock_full_jira_payload.copy()
|
|
payload["issueKey"] = f"TEST-{i}"
|
|
payload_model = JiraWebhookPayload(**payload)
|
|
failed_record = create_analysis_record(db, payload_model)
|
|
update_record_status(db, failed_record.id, "failed", error_message=f"LLM failed {i}")
|
|
db.commit()
|
|
db.refresh(failed_record)
|
|
record_ids.append(failed_record.id)
|
|
|
|
# Simulate concurrent retry requests
|
|
import asyncio
|
|
async def send_retry_request(record_id):
|
|
return test_client.post(f"/api/queue/{record_id}/retry")
|
|
|
|
tasks = [send_retry_request(rid) for rid in record_ids]
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
for response in responses:
|
|
assert response.status_code == 200
|
|
assert "message" in response.json()
|
|
|
|
# Verify all records are marked as pending
|
|
with get_db() as db:
|
|
for record_id in record_ids:
|
|
updated_record = get_analysis_by_id(db, record_id)
|
|
assert updated_record.status == "pending"
|
|
assert updated_record.error_message is None
|
|
assert updated_record.next_retry_at is None |