jira-webhook-llm/llm/chains.py

191 lines
6.7 KiB
Python

from typing import Union
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from loguru import logger
import sys
from config import settings
from .models import AnalysisFlags
class LLMInitializationError(Exception):
"""Custom exception for LLM initialization errors"""
def __init__(self, message, details=None):
super().__init__(message)
self.details = details
# Initialize LLM
llm = None
if settings.llm.mode == 'openai':
logger.info(f"Initializing ChatOpenAI with model: {settings.openai_model}")
llm = ChatOpenAI(
model=settings.openai_model,
temperature=0.7,
max_tokens=2000,
api_key=settings.openai_api_key,
base_url=settings.openai_api_base_url
)
elif settings.llm.mode == 'ollama':
logger.info(f"Initializing OllamaLLM with model: {settings.llm.ollama_model} at {settings.llm.ollama_base_url}")
try:
# Verify connection parameters
if not settings.llm.ollama_base_url:
raise ValueError("Ollama base URL is not configured")
if not settings.llm.ollama_model:
raise ValueError("Ollama model is not specified")
logger.debug(f"Attempting to connect to Ollama at {settings.llm.ollama_base_url}")
# Append /api/chat to base URL for OpenWebUI compatibility
base_url = f"{settings.llm.ollama_base_url.rstrip('/')}"
llm = OllamaLLM(
model=settings.llm.ollama_model,
base_url=base_url,
streaming=False,
timeout=30, # 30 second timeout
max_retries=3
# , # Retry up to 3 times
# temperature=0.1,
# top_p=0.2
)
# Test connection
logger.debug("Testing Ollama connection...")
# llm.invoke("test") # Simple test request
logger.info("Ollama connection established successfully")
except Exception as e:
error_msg = f"Failed to initialize Ollama: {str(e)}"
details = {
'model': settings.llm.ollama_model,
'url': settings.llm.ollama_base_url,
'error_type': type(e).__name__
}
logger.error(error_msg)
logger.debug(f"Connection details: {details}")
raise LLMInitializationError(
"Failed to connect to Ollama service. Please check:"
"\n1. Ollama is installed and running"
"\n2. The base URL is correct"
"\n3. The model is available",
details=details
) from e
if llm is None:
logger.error("LLM could not be initialized. Exiting.")
print("\nERROR: Unable to initialize LLM. Check logs for details.", file=sys.stderr)
sys.exit(1)
# Set up Output Parser for structured JSON
parser = JsonOutputParser(pydantic_object=AnalysisFlags)
# Load prompt template from file
def load_prompt_template(version="v1.1.0"):
try:
with open(f"llm/prompts/jira_analysis_{version}.txt", "r") as f:
template_content = f.read()
# Split system and user parts
system_template, user_template = template_content.split("\n\nUSER:\n")
system_template = system_template.replace("SYSTEM:\n", "").strip()
return ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(user_template)
])
except Exception as e:
logger.error(f"Failed to load prompt template: {str(e)}")
raise
# Fallback prompt template
FALLBACK_PROMPT = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(
"Analyze Jira tickets and output JSON with hasMultipleEscalations, customerSentiment"
),
HumanMessagePromptTemplate.from_template(
"Issue Key: {issueKey}\nSummary: {summary}"
)
])
# Create chain with fallback mechanism
def create_analysis_chain():
try:
prompt_template = load_prompt_template()
chain = (
{
"issueKey": lambda x: x["issueKey"],
"summary": lambda x: x["summary"],
"description": lambda x: x["description"],
"status": lambda x: x["status"],
"labels": lambda x: x["labels"],
"assignee": lambda x: x["assignee"],
"updated": lambda x: x["updated"],
"comment": lambda x: x["comment"],
"format_instructions": lambda _: parser.get_format_instructions()
}
| prompt_template
| llm
| parser
)
# Add langfuse handler if enabled
if settings.langfuse.enabled:
chain = chain.with_config(
callbacks=[settings.langfuse_handler]
)
return chain
except Exception as e:
logger.warning(f"Using fallback prompt due to error: {str(e)}")
chain = FALLBACK_PROMPT | llm | parser
if settings.langfuse.enabled:
chain = chain.with_config(
callbacks=[settings.langfuse_handler]
)
return chain
# Initialize analysis chain
analysis_chain = create_analysis_chain()
# Enhanced response validation function
def validate_response(response: Union[dict, str]) -> bool:
"""Validate the JSON response structure and content"""
try:
# If response is a string, attempt to parse it as JSON
if isinstance(response, str):
try:
response = json.loads(response)
except json.JSONDecodeError:
logger.error(f"Invalid JSON response: {response}")
raise ValueError("Invalid JSON response format")
# Ensure response is a dictionary
if not isinstance(response, dict):
return False
# Check required fields
required_fields = ["hasMultipleEscalations", "customerSentiment"]
if not all(field in response for field in required_fields):
return False
# Validate field types
if not isinstance(response["hasMultipleEscalations"], bool):
return False
if response["customerSentiment"] is not None:
if not isinstance(response["customerSentiment"], str):
return False
# Validate against schema using AnalysisFlags model
try:
AnalysisFlags.model_validate(response)
return True
except Exception:
return False
except Exception:
return False