Skip to main content

synaptic_eval/
llm_judge.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, Message, SynapticError};
5
6use crate::evaluator::{EvalResult, Evaluator};
7
8const DEFAULT_PROMPT_TEMPLATE: &str = r#"You are an impartial judge evaluating the quality of an AI response.
9
10Input: {input}
11Expected answer: {reference}
12AI response: {prediction}
13
14Rate the AI response on a scale of 0 to 10, where 0 means completely wrong and 10 means perfect.
15Respond with ONLY a single integer between 0 and 10."#;
16
17/// Evaluator that uses an LLM to judge prediction quality.
18pub struct LLMJudgeEvaluator {
19    model: Arc<dyn ChatModel>,
20    prompt_template: String,
21}
22
23impl LLMJudgeEvaluator {
24    /// Create a new LLM judge evaluator with the default prompt template.
25    pub fn new(model: Arc<dyn ChatModel>) -> Self {
26        Self {
27            model,
28            prompt_template: DEFAULT_PROMPT_TEMPLATE.to_string(),
29        }
30    }
31
32    /// Create a new LLM judge evaluator with a custom prompt template.
33    ///
34    /// The template should contain `{input}`, `{prediction}`, and `{reference}` placeholders.
35    pub fn with_prompt(model: Arc<dyn ChatModel>, template: impl Into<String>) -> Self {
36        Self {
37            model,
38            prompt_template: template.into(),
39        }
40    }
41}
42
43/// Parse a score (0-10) from the model's response text.
44fn parse_score(text: &str) -> Option<f64> {
45    // Look for a number in the response
46    for word in text.split_whitespace() {
47        let cleaned = word.trim_matches(|c: char| !c.is_ascii_digit() && c != '.');
48        if let Ok(num) = cleaned.parse::<f64>() {
49            if (0.0..=10.0).contains(&num) {
50                return Some(num / 10.0);
51            }
52        }
53    }
54    None
55}
56
57#[async_trait]
58impl Evaluator for LLMJudgeEvaluator {
59    async fn evaluate(
60        &self,
61        prediction: &str,
62        reference: &str,
63        input: &str,
64    ) -> Result<EvalResult, SynapticError> {
65        let prompt = self
66            .prompt_template
67            .replace("{input}", input)
68            .replace("{prediction}", prediction)
69            .replace("{reference}", reference);
70
71        let request = ChatRequest::new(vec![Message::human(prompt)]);
72        let response = self.model.chat(request).await?;
73        let response_text = response.message.content();
74
75        match parse_score(response_text) {
76            Some(score) => Ok(EvalResult::with_score(score)
77                .with_reasoning(format!("LLM judge score: {:.1}/10", score * 10.0))),
78            None => Err(SynapticError::Parsing(format!(
79                "Could not parse score from LLM response: {:?}",
80                response_text
81            ))),
82        }
83    }
84}