jira-webhook-llm/jira_webhook_llm.py

377 lines
15 KiB
Python

import os
import json
import time
from dotenv import load_dotenv
load_dotenv()
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from http import HTTPStatus
from loguru import logger
import uuid
import sys
from typing import Optional
from datetime import datetime, timedelta, timezone
import asyncio
from functools import partial, wraps
from sqlalchemy.orm import Session
from database.database import init_db, get_db, SessionLocal
from database.crud import get_analysis_record, update_record_status, create_analysis_record
from database.models import JiraAnalysis, AnalysisFlags
from llm.models import JiraWebhookPayload
from llm.chains import analysis_chain, validate_response
from api.handlers import router # Correct variable name
from webhooks.handlers import webhook_router
from database.crud import get_all_analysis_records, delete_all_analysis_records, get_analysis_by_id, get_analysis_record
from logging_config import configure_logging
# Initialize logging as early as possible
from config import settings
import signal
from contextlib import asynccontextmanager
cleanup_tasks = [] # Initialize cleanup_tasks globally
def calculate_next_retry_time(retry_count: int) -> datetime:
"""Calculates the next retry time using exponential backoff."""
delay = settings.processor.initial_retry_delay_seconds * (2 ** retry_count)
return datetime.now(timezone.utc) + timedelta(seconds=delay)
def retry(max_retries: int = 3, initial_delay: float = 1.0):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
for i in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
if i == max_retries:
logger.error(f"Function {func.__name__} failed after {max_retries} retries: {e}")
raise
delay = initial_delay * (2 ** i)
logger.warning(f"Function {func.__name__} failed, retrying in {delay:.2f} seconds (attempt {i+1}/{max_retries})...")
await asyncio.sleep(delay)
return wrapper
return decorator
async def process_single_jira_request(db: Session, record: JiraAnalysis):
"""Processes a single Jira webhook request using the LLM."""
issue_key = record.issue_key
record_id = record.id
payload = JiraWebhookPayload.model_validate(record.request_payload)
logger.bind(
issue_key=issue_key,
record_id=record_id,
timestamp=datetime.now(timezone.utc).isoformat()
).info(f"[{issue_key}] Processing webhook request.")
# Create Langfuse trace if enabled
trace = None
if settings.langfuse.enabled:
trace = settings.langfuse_client.start_span(
name="Jira Webhook Processing",
input=payload.model_dump(),
metadata={
"trace_id": f"processor-{issue_key}-{record_id}",
"issue_key": issue_key,
"record_id": record_id,
"timestamp": datetime.now(timezone.utc).isoformat()
}
)
llm_input = {
"issueKey": payload.issueKey,
"summary": payload.summary,
"description": payload.description if payload.description else "No description provided.",
"status": payload.status if payload.status else "Unknown",
"labels": ", ".join(payload.labels) if payload.labels else "None",
"assignee": payload.assignee if payload.assignee else "Unassigned",
"updated": payload.updated if payload.updated else "Unknown",
"comment": payload.comment if payload.comment else "No new comment provided."
}
llm_span = None
if settings.langfuse.enabled and trace:
llm_span = trace.start_span(
name="LLM Processing",
input=llm_input,
metadata={
"model": settings.llm.model if settings.llm.mode == 'openai' else settings.llm.ollama_model
}
)
try:
raw_llm_response = await analysis_chain.ainvoke(llm_input)
if settings.langfuse.enabled and llm_span:
llm_span.update(output=raw_llm_response)
llm_span.end()
# Validate response structure before processing
if not validate_response(raw_llm_response):
error_msg = f"Invalid LLM response structure: {raw_llm_response}"
logger.error(f"[{issue_key}] {error_msg}")
update_record_status(
db=db,
record_id=record_id,
analysis_result={"hasMultipleEscalations": False, "customerSentiment": "neutral"},
raw_response=json.dumps(raw_llm_response),
status="validation_failed",
error_message=error_msg,
last_processed_at=datetime.now(timezone.utc),
retry_count_increment=1,
next_retry_at=calculate_next_retry_time(record.retry_count + 1) if record.retry_count < settings.processor.max_retries else None
)
if settings.langfuse.enabled and trace:
trace.end()
raise ValueError(error_msg)
try:
AnalysisFlags(
hasMultipleEscalations=raw_llm_response.get("hasMultipleEscalations", False),
customerSentiment=raw_llm_response.get("customerSentiment", "neutral")
)
except Exception as e:
logger.error(f"[{issue_key}] Invalid LLM response structure: {e}", exc_info=True)
update_record_status(
db=db,
record_id=record_id,
analysis_result={"hasMultipleEscalations": False, "customerSentiment": "neutral"},
raw_response=json.dumps(raw_llm_response),
status="validation_failed",
error_message=f"LLM response validation failed: {e}",
last_processed_at=datetime.now(timezone.utc),
retry_count_increment=1,
next_retry_at=calculate_next_retry_time(record.retry_count + 1) if record.retry_count < settings.processor.max_retries else None
)
if settings.langfuse.enabled and trace:
trace.end()
raise ValueError(f"Invalid LLM response format: {e}") from e
logger.debug(f"[{issue_key}] LLM Analysis Result: {json.dumps(raw_llm_response, indent=2)}")
update_record_status(
db=db,
record_id=record_id,
analysis_result=raw_llm_response,
raw_response=json.dumps(raw_llm_response),
status="completed",
last_processed_at=datetime.now(timezone.utc),
next_retry_at=None # No retry needed on success
)
if settings.langfuse.enabled and trace:
trace.end()
logger.info(f"[{issue_key}] Successfully processed and updated record {record_id}.")
except Exception as e:
logger.error(f"[{issue_key}] LLM processing failed for record {record_id}: {str(e)}")
if settings.langfuse.enabled and llm_span:
llm_span.end()
new_retry_count = record.retry_count + 1
new_status = "failed"
next_retry = None
if new_retry_count <= settings.processor.max_retries:
next_retry = calculate_next_retry_time(new_retry_count)
new_status = "retrying" # Indicate that it will be retried
update_record_status(
db=db,
record_id=record_id,
status=new_status,
error_message=f"LLM processing failed: {str(e)}",
last_processed_at=datetime.now(timezone.utc),
retry_count_increment=1,
next_retry_at=next_retry
)
if settings.langfuse.enabled and trace:
trace.end()
logger.error(f"[{issue_key}] Record {record_id} status updated to '{new_status}'. Retry count: {new_retry_count}")
async def main_processor_loop():
"""Main loop for the Jira webhook processor."""
logger.info("Starting Jira webhook processor.")
while True: # This loop will run indefinitely until the app shuts down
db: Session = SessionLocal() # Get a new session for each loop iteration
try:
# Fetch records that are 'pending' or 'retrying' and past their next_retry_at
# Order by created_at to process older requests first
pending_records = db.query(JiraAnalysis).filter(
(JiraAnalysis.status == "pending") |
((JiraAnalysis.status == "retrying") & (JiraAnalysis.next_retry_at <= datetime.now(timezone.utc)))
).order_by(JiraAnalysis.created_at.asc()).all()
if not pending_records:
logger.debug(f"No pending or retrying records found. Sleeping for {settings.processor.poll_interval_seconds} seconds.")
for record in pending_records:
# Update status to 'processing' immediately to prevent other workers from picking it up
update_record_status(db, record.id, "processing", last_processed_at=datetime.now(timezone.utc))
db.refresh(record) # Refresh to get the latest state
await process_single_jira_request(db, record)
except Exception as e:
logger.error(f"Error in main processor loop: {str(e)}", exc_info=True)
finally:
db.close() # Ensure the session is closed
await asyncio.sleep(settings.processor.poll_interval_seconds) # Use asyncio.sleep for non-blocking sleep
# Setup async-compatible signal handling
def handle_shutdown_signal(signum, loop):
"""Graceful shutdown signal handler"""
logger.info(f"Received signal {signum}, initiating shutdown...")
# Set shutdown flag and remove signal handlers to prevent reentrancy
if not hasattr(loop, '_shutdown'):
loop._shutdown = True
# Prevent further signal handling
for sig in (signal.SIGTERM, signal.SIGINT):
loop.remove_signal_handler(sig)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Context manager for managing the lifespan of the FastAPI application.
Initializes the database, sets up signal handlers, and handles cleanup.
"""
# Flag to track if initialization succeeded
init_success = False
try:
logger.info("Initializing application...")
init_db() # Initialize the database
# Setup signal handlers
# Only set up signal handlers if not in a test environment
if os.getenv("IS_TEST_ENV") != "true":
loop = asyncio.get_running_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, partial(handle_shutdown_signal, sig, loop))
logger.info("Signal handlers configured successfully")
else:
logger.info("Skipping signal handler configuration in test environment.")
# Start the background processor task only if not in a test environment
processor_task = None
if os.getenv("IS_TEST_ENV") != "true":
processor_task = asyncio.create_task(main_processor_loop())
cleanup_tasks.append(processor_task)
logger.info("Background Jira processor started.")
else:
logger.info("Skipping background Jira processor in test environment.")
# Verify critical components
if not hasattr(settings, 'langfuse_handler'):
logger.error("Langfuse handler not found in settings")
raise RuntimeError("Langfuse handler not initialized")
logger.info("Application initialized successfully")
init_success = True
except Exception as e:
logger.critical(f"Application initialization failed: {str(e)}. Exiting.")
# Don't re-raise to allow finally block to execute cleanup
try:
# Yield control to the application
yield
finally:
# Cleanup logic runs after application finishes
if init_success:
# Check shutdown flag before cleanup
loop = asyncio.get_running_loop()
if hasattr(loop, '_shutdown'):
logger.info("Shutdown initiated, starting cleanup...")
# Cancel the processor task
if processor_task:
logger.info("Cancelling background Jira processor task...")
processor_task.cancel()
try:
await processor_task
except asyncio.CancelledError:
logger.info("Background Jira processor task cancelled.")
except Exception as e:
logger.error(f"Error cancelling processor task: {str(e)}")
# Close langfuse with retry
if hasattr(settings, 'langfuse_handler') and hasattr(settings.langfuse_handler, 'close'):
try:
await asyncio.wait_for(settings.langfuse_handler.close(), timeout=5.0)
logger.info("Langfuse client closed successfully")
except asyncio.TimeoutError:
logger.warning("Timeout while closing Langfuse client")
except Exception as e:
logger.error(f"Error closing Langfuse client: {str(e)}")
# Execute any other cleanup tasks
if cleanup_tasks:
try:
# Filter out the processor_task if it's already handled
remaining_cleanup_tasks = [task for task in cleanup_tasks if task != processor_task]
if remaining_cleanup_tasks:
await asyncio.gather(*remaining_cleanup_tasks)
except Exception as e:
logger.error(f"Error during additional cleanup tasks: {str(e)}")
def create_app():
"""Factory function to create FastAPI app instance"""
configure_logging(log_level="DEBUG")
_app = FastAPI(lifespan=lifespan)
# Include routers without prefixes to match test expectations
_app.include_router(webhook_router)
_app.include_router(router)
# Add health check endpoint
@_app.get("/health")
async def health_check():
return {"status": "healthy"}
# Add error handling middleware
@_app.middleware("http")
async def error_handling_middleware(request: Request, call_next):
request_id = str(uuid.uuid4())
logger.bind(request_id=request_id).info(f"Request started: {request.method} {request.url}")
try:
response = await call_next(request)
return response
except HTTPException as e:
logger.error(f"HTTP Error: {e.status_code} - {e.detail}")
error_response = ErrorResponse(
error_id=request_id,
timestamp=datetime.now(timezone.utc).isoformat(),
status_code=e.status_code,
message=e.detail,
details=str(e)
)
return JSONResponse(status_code=e.status_code, content=error_response.model_dump())
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
error_response = ErrorResponse(
error_id=request_id,
timestamp=datetime.now(timezone.utc).isoformat(),
status_code=500,
message="Internal Server Error",
details=str(e)
)
return JSONResponse(status_code=500, content=error_response.model_dump())
return _app
from api.handlers import test_llm_endpoint
app = create_app()
class ErrorResponse(BaseModel):
error_id: str
timestamp: str
status_code: int
message: str
details: Optional[str] = None