377 lines
15 KiB
Python
377 lines
15 KiB
Python
import os
|
|
import json
|
|
import time
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from pydantic import BaseModel
|
|
from fastapi.responses import JSONResponse
|
|
from http import HTTPStatus
|
|
from loguru import logger
|
|
import uuid
|
|
import sys
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta, timezone
|
|
import asyncio
|
|
from functools import partial, wraps
|
|
from sqlalchemy.orm import Session
|
|
|
|
from database.database import init_db, get_db, SessionLocal
|
|
from database.crud import get_analysis_record, update_record_status, create_analysis_record
|
|
from database.models import JiraAnalysis, AnalysisFlags
|
|
from llm.models import JiraWebhookPayload
|
|
from llm.chains import analysis_chain, validate_response
|
|
from api.handlers import router # Correct variable name
|
|
from webhooks.handlers import webhook_router
|
|
from database.crud import get_all_analysis_records, delete_all_analysis_records, get_analysis_by_id, get_analysis_record
|
|
from logging_config import configure_logging
|
|
|
|
# Initialize logging as early as possible
|
|
from config import settings
|
|
|
|
import signal
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
cleanup_tasks = [] # Initialize cleanup_tasks globally
|
|
|
|
def calculate_next_retry_time(retry_count: int) -> datetime:
|
|
"""Calculates the next retry time using exponential backoff."""
|
|
delay = settings.processor.initial_retry_delay_seconds * (2 ** retry_count)
|
|
return datetime.now(timezone.utc) + timedelta(seconds=delay)
|
|
|
|
def retry(max_retries: int = 3, initial_delay: float = 1.0):
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
for i in range(max_retries + 1):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
if i == max_retries:
|
|
logger.error(f"Function {func.__name__} failed after {max_retries} retries: {e}")
|
|
raise
|
|
delay = initial_delay * (2 ** i)
|
|
logger.warning(f"Function {func.__name__} failed, retrying in {delay:.2f} seconds (attempt {i+1}/{max_retries})...")
|
|
await asyncio.sleep(delay)
|
|
return wrapper
|
|
return decorator
|
|
|
|
async def process_single_jira_request(db: Session, record: JiraAnalysis):
|
|
"""Processes a single Jira webhook request using the LLM."""
|
|
issue_key = record.issue_key
|
|
record_id = record.id
|
|
payload = JiraWebhookPayload.model_validate(record.request_payload)
|
|
|
|
logger.bind(
|
|
issue_key=issue_key,
|
|
record_id=record_id,
|
|
timestamp=datetime.now(timezone.utc).isoformat()
|
|
).info(f"[{issue_key}] Processing webhook request.")
|
|
|
|
# Create Langfuse trace if enabled
|
|
trace = None
|
|
if settings.langfuse.enabled:
|
|
trace = settings.langfuse_client.start_span(
|
|
name="Jira Webhook Processing",
|
|
input=payload.model_dump(),
|
|
metadata={
|
|
"trace_id": f"processor-{issue_key}-{record_id}",
|
|
"issue_key": issue_key,
|
|
"record_id": record_id,
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
)
|
|
|
|
llm_input = {
|
|
"issueKey": payload.issueKey,
|
|
"summary": payload.summary,
|
|
"description": payload.description if payload.description else "No description provided.",
|
|
"status": payload.status if payload.status else "Unknown",
|
|
"labels": ", ".join(payload.labels) if payload.labels else "None",
|
|
"assignee": payload.assignee if payload.assignee else "Unassigned",
|
|
"updated": payload.updated if payload.updated else "Unknown",
|
|
"comment": payload.comment if payload.comment else "No new comment provided."
|
|
}
|
|
|
|
llm_span = None
|
|
if settings.langfuse.enabled and trace:
|
|
llm_span = trace.start_span(
|
|
name="LLM Processing",
|
|
input=llm_input,
|
|
metadata={
|
|
"model": settings.llm.model if settings.llm.mode == 'openai' else settings.llm.ollama_model
|
|
}
|
|
)
|
|
|
|
try:
|
|
raw_llm_response = await analysis_chain.ainvoke(llm_input)
|
|
|
|
if settings.langfuse.enabled and llm_span:
|
|
llm_span.update(output=raw_llm_response)
|
|
llm_span.end()
|
|
|
|
# Validate response structure before processing
|
|
if not validate_response(raw_llm_response):
|
|
error_msg = f"Invalid LLM response structure: {raw_llm_response}"
|
|
logger.error(f"[{issue_key}] {error_msg}")
|
|
update_record_status(
|
|
db=db,
|
|
record_id=record_id,
|
|
analysis_result={"hasMultipleEscalations": False, "customerSentiment": "neutral"},
|
|
raw_response=json.dumps(raw_llm_response),
|
|
status="validation_failed",
|
|
error_message=error_msg,
|
|
last_processed_at=datetime.now(timezone.utc),
|
|
retry_count_increment=1,
|
|
next_retry_at=calculate_next_retry_time(record.retry_count + 1) if record.retry_count < settings.processor.max_retries else None
|
|
)
|
|
if settings.langfuse.enabled and trace:
|
|
trace.end()
|
|
raise ValueError(error_msg)
|
|
|
|
try:
|
|
AnalysisFlags(
|
|
hasMultipleEscalations=raw_llm_response.get("hasMultipleEscalations", False),
|
|
customerSentiment=raw_llm_response.get("customerSentiment", "neutral")
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"[{issue_key}] Invalid LLM response structure: {e}", exc_info=True)
|
|
update_record_status(
|
|
db=db,
|
|
record_id=record_id,
|
|
analysis_result={"hasMultipleEscalations": False, "customerSentiment": "neutral"},
|
|
raw_response=json.dumps(raw_llm_response),
|
|
status="validation_failed",
|
|
error_message=f"LLM response validation failed: {e}",
|
|
last_processed_at=datetime.now(timezone.utc),
|
|
retry_count_increment=1,
|
|
next_retry_at=calculate_next_retry_time(record.retry_count + 1) if record.retry_count < settings.processor.max_retries else None
|
|
)
|
|
if settings.langfuse.enabled and trace:
|
|
trace.end()
|
|
raise ValueError(f"Invalid LLM response format: {e}") from e
|
|
|
|
logger.debug(f"[{issue_key}] LLM Analysis Result: {json.dumps(raw_llm_response, indent=2)}")
|
|
update_record_status(
|
|
db=db,
|
|
record_id=record_id,
|
|
analysis_result=raw_llm_response,
|
|
raw_response=json.dumps(raw_llm_response),
|
|
status="completed",
|
|
last_processed_at=datetime.now(timezone.utc),
|
|
next_retry_at=None # No retry needed on success
|
|
)
|
|
if settings.langfuse.enabled and trace:
|
|
trace.end()
|
|
logger.info(f"[{issue_key}] Successfully processed and updated record {record_id}.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{issue_key}] LLM processing failed for record {record_id}: {str(e)}")
|
|
if settings.langfuse.enabled and llm_span:
|
|
llm_span.end()
|
|
|
|
new_retry_count = record.retry_count + 1
|
|
new_status = "failed"
|
|
next_retry = None
|
|
if new_retry_count <= settings.processor.max_retries:
|
|
next_retry = calculate_next_retry_time(new_retry_count)
|
|
new_status = "retrying" # Indicate that it will be retried
|
|
|
|
update_record_status(
|
|
db=db,
|
|
record_id=record_id,
|
|
status=new_status,
|
|
error_message=f"LLM processing failed: {str(e)}",
|
|
last_processed_at=datetime.now(timezone.utc),
|
|
retry_count_increment=1,
|
|
next_retry_at=next_retry
|
|
)
|
|
if settings.langfuse.enabled and trace:
|
|
trace.end()
|
|
logger.error(f"[{issue_key}] Record {record_id} status updated to '{new_status}'. Retry count: {new_retry_count}")
|
|
|
|
|
|
async def main_processor_loop():
|
|
"""Main loop for the Jira webhook processor."""
|
|
logger.info("Starting Jira webhook processor.")
|
|
while True: # This loop will run indefinitely until the app shuts down
|
|
db: Session = SessionLocal() # Get a new session for each loop iteration
|
|
try:
|
|
# Fetch records that are 'pending' or 'retrying' and past their next_retry_at
|
|
# Order by created_at to process older requests first
|
|
pending_records = db.query(JiraAnalysis).filter(
|
|
(JiraAnalysis.status == "pending") |
|
|
((JiraAnalysis.status == "retrying") & (JiraAnalysis.next_retry_at <= datetime.now(timezone.utc)))
|
|
).order_by(JiraAnalysis.created_at.asc()).all()
|
|
|
|
if not pending_records:
|
|
logger.debug(f"No pending or retrying records found. Sleeping for {settings.processor.poll_interval_seconds} seconds.")
|
|
|
|
for record in pending_records:
|
|
# Update status to 'processing' immediately to prevent other workers from picking it up
|
|
update_record_status(db, record.id, "processing", last_processed_at=datetime.now(timezone.utc))
|
|
db.refresh(record) # Refresh to get the latest state
|
|
await process_single_jira_request(db, record)
|
|
except Exception as e:
|
|
logger.error(f"Error in main processor loop: {str(e)}", exc_info=True)
|
|
finally:
|
|
db.close() # Ensure the session is closed
|
|
|
|
await asyncio.sleep(settings.processor.poll_interval_seconds) # Use asyncio.sleep for non-blocking sleep
|
|
|
|
# Setup async-compatible signal handling
|
|
def handle_shutdown_signal(signum, loop):
|
|
"""Graceful shutdown signal handler"""
|
|
logger.info(f"Received signal {signum}, initiating shutdown...")
|
|
# Set shutdown flag and remove signal handlers to prevent reentrancy
|
|
if not hasattr(loop, '_shutdown'):
|
|
loop._shutdown = True
|
|
|
|
# Prevent further signal handling
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.remove_signal_handler(sig)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
Context manager for managing the lifespan of the FastAPI application.
|
|
Initializes the database, sets up signal handlers, and handles cleanup.
|
|
"""
|
|
# Flag to track if initialization succeeded
|
|
init_success = False
|
|
|
|
try:
|
|
logger.info("Initializing application...")
|
|
init_db() # Initialize the database
|
|
|
|
# Setup signal handlers
|
|
# Only set up signal handlers if not in a test environment
|
|
if os.getenv("IS_TEST_ENV") != "true":
|
|
loop = asyncio.get_running_loop()
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, partial(handle_shutdown_signal, sig, loop))
|
|
logger.info("Signal handlers configured successfully")
|
|
else:
|
|
logger.info("Skipping signal handler configuration in test environment.")
|
|
|
|
# Start the background processor task only if not in a test environment
|
|
processor_task = None
|
|
if os.getenv("IS_TEST_ENV") != "true":
|
|
processor_task = asyncio.create_task(main_processor_loop())
|
|
cleanup_tasks.append(processor_task)
|
|
logger.info("Background Jira processor started.")
|
|
else:
|
|
logger.info("Skipping background Jira processor in test environment.")
|
|
|
|
# Verify critical components
|
|
if not hasattr(settings, 'langfuse_handler'):
|
|
logger.error("Langfuse handler not found in settings")
|
|
raise RuntimeError("Langfuse handler not initialized")
|
|
|
|
logger.info("Application initialized successfully")
|
|
init_success = True
|
|
except Exception as e:
|
|
logger.critical(f"Application initialization failed: {str(e)}. Exiting.")
|
|
# Don't re-raise to allow finally block to execute cleanup
|
|
|
|
try:
|
|
# Yield control to the application
|
|
yield
|
|
finally:
|
|
# Cleanup logic runs after application finishes
|
|
if init_success:
|
|
# Check shutdown flag before cleanup
|
|
loop = asyncio.get_running_loop()
|
|
if hasattr(loop, '_shutdown'):
|
|
logger.info("Shutdown initiated, starting cleanup...")
|
|
|
|
# Cancel the processor task
|
|
if processor_task:
|
|
logger.info("Cancelling background Jira processor task...")
|
|
processor_task.cancel()
|
|
try:
|
|
await processor_task
|
|
except asyncio.CancelledError:
|
|
logger.info("Background Jira processor task cancelled.")
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling processor task: {str(e)}")
|
|
|
|
# Close langfuse with retry
|
|
if hasattr(settings, 'langfuse_handler') and hasattr(settings.langfuse_handler, 'close'):
|
|
try:
|
|
await asyncio.wait_for(settings.langfuse_handler.close(), timeout=5.0)
|
|
logger.info("Langfuse client closed successfully")
|
|
except asyncio.TimeoutError:
|
|
logger.warning("Timeout while closing Langfuse client")
|
|
except Exception as e:
|
|
logger.error(f"Error closing Langfuse client: {str(e)}")
|
|
|
|
# Execute any other cleanup tasks
|
|
if cleanup_tasks:
|
|
try:
|
|
# Filter out the processor_task if it's already handled
|
|
remaining_cleanup_tasks = [task for task in cleanup_tasks if task != processor_task]
|
|
if remaining_cleanup_tasks:
|
|
await asyncio.gather(*remaining_cleanup_tasks)
|
|
except Exception as e:
|
|
logger.error(f"Error during additional cleanup tasks: {str(e)}")
|
|
def create_app():
|
|
"""Factory function to create FastAPI app instance"""
|
|
configure_logging(log_level="DEBUG")
|
|
_app = FastAPI(lifespan=lifespan)
|
|
|
|
# Include routers without prefixes to match test expectations
|
|
_app.include_router(webhook_router)
|
|
_app.include_router(router)
|
|
|
|
# Add health check endpoint
|
|
@_app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
# Add error handling middleware
|
|
@_app.middleware("http")
|
|
async def error_handling_middleware(request: Request, call_next):
|
|
request_id = str(uuid.uuid4())
|
|
logger.bind(request_id=request_id).info(f"Request started: {request.method} {request.url}")
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
return response
|
|
except HTTPException as e:
|
|
logger.error(f"HTTP Error: {e.status_code} - {e.detail}")
|
|
error_response = ErrorResponse(
|
|
error_id=request_id,
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
status_code=e.status_code,
|
|
message=e.detail,
|
|
details=str(e)
|
|
)
|
|
return JSONResponse(status_code=e.status_code, content=error_response.model_dump())
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {str(e)}")
|
|
error_response = ErrorResponse(
|
|
error_id=request_id,
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
status_code=500,
|
|
message="Internal Server Error",
|
|
details=str(e)
|
|
)
|
|
return JSONResponse(status_code=500, content=error_response.model_dump())
|
|
|
|
return _app
|
|
|
|
from api.handlers import test_llm_endpoint
|
|
|
|
app = create_app()
|
|
|
|
class ErrorResponse(BaseModel):
|
|
error_id: str
|
|
timestamp: str
|
|
status_code: int
|
|
message: str
|
|
details: Optional[str] = None
|
|
|