Skip to main content

synaptic_eval/
embedding_distance.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5use synaptic_embeddings::Embeddings;
6
7use crate::evaluator::{EvalResult, Evaluator};
8
9/// Evaluator that computes cosine similarity between embeddings of prediction and reference.
10pub struct EmbeddingDistanceEvaluator {
11    embeddings: Arc<dyn Embeddings>,
12    threshold: f64,
13}
14
15impl EmbeddingDistanceEvaluator {
16    /// Create a new embedding distance evaluator.
17    ///
18    /// - `embeddings`: The embeddings model to use.
19    /// - `threshold`: Minimum cosine similarity to pass (default suggestion: 0.8).
20    pub fn new(embeddings: Arc<dyn Embeddings>, threshold: f64) -> Self {
21        Self {
22            embeddings,
23            threshold,
24        }
25    }
26}
27
28/// Compute cosine similarity between two vectors.
29fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
30    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
31    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
32    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
33    if mag_a == 0.0 || mag_b == 0.0 {
34        return 0.0;
35    }
36    (dot / (mag_a * mag_b)) as f64
37}
38
39#[async_trait]
40impl Evaluator for EmbeddingDistanceEvaluator {
41    async fn evaluate(
42        &self,
43        prediction: &str,
44        reference: &str,
45        _input: &str,
46    ) -> Result<EvalResult, SynapticError> {
47        let pred_embedding = self.embeddings.embed_query(prediction).await?;
48        let ref_embedding = self.embeddings.embed_query(reference).await?;
49
50        let similarity = cosine_similarity(&pred_embedding, &ref_embedding);
51
52        let passed = similarity >= self.threshold;
53        let result = EvalResult {
54            score: similarity,
55            passed,
56            reasoning: Some(format!(
57                "Cosine similarity: {:.4}, threshold: {:.4}",
58                similarity, self.threshold
59            )),
60        };
61
62        Ok(result)
63    }
64}