Skip to main content

synaptic_prompts/
example_selector.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5use synaptic_embeddings::Embeddings;
6use tokio::sync::RwLock;
7
8use crate::FewShotExample;
9
10/// Trait for selecting examples for few-shot prompting.
11#[async_trait]
12pub trait ExampleSelector: Send + Sync {
13    /// Select examples most relevant to the input.
14    async fn select_examples(&self, input: &str) -> Result<Vec<FewShotExample>, SynapticError>;
15
16    /// Add a new example to the selector's pool.
17    async fn add_example(&self, example: FewShotExample) -> Result<(), SynapticError>;
18}
19
20/// Selects examples based on semantic similarity using embeddings.
21pub struct SemanticSimilarityExampleSelector {
22    #[expect(clippy::type_complexity)]
23    examples: Arc<RwLock<Vec<(FewShotExample, Vec<f32>)>>>,
24    embeddings: Arc<dyn Embeddings>,
25    k: usize,
26}
27
28impl SemanticSimilarityExampleSelector {
29    /// Create a new selector that returns the top-k most similar examples.
30    pub fn new(embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
31        Self {
32            examples: Arc::new(RwLock::new(Vec::new())),
33            embeddings,
34            k,
35        }
36    }
37}
38
39#[async_trait]
40impl ExampleSelector for SemanticSimilarityExampleSelector {
41    async fn add_example(&self, example: FewShotExample) -> Result<(), SynapticError> {
42        let embedding = self.embeddings.embed_query(&example.input).await?;
43        let mut examples = self.examples.write().await;
44        examples.push((example, embedding));
45        Ok(())
46    }
47
48    async fn select_examples(&self, input: &str) -> Result<Vec<FewShotExample>, SynapticError> {
49        let query_embedding = self.embeddings.embed_query(input).await?;
50        let examples = self.examples.read().await;
51
52        if examples.is_empty() {
53            return Ok(Vec::new());
54        }
55
56        // Compute similarities and collect with indices
57        let mut scored: Vec<(usize, f32)> = examples
58            .iter()
59            .enumerate()
60            .map(|(i, (_, emb))| (i, cosine_similarity(&query_embedding, emb)))
61            .collect();
62
63        // Sort by similarity descending
64        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
65
66        // Take top-k
67        let result = scored
68            .iter()
69            .take(self.k)
70            .map(|(i, _)| examples[*i].0.clone())
71            .collect();
72
73        Ok(result)
74    }
75}
76
77/// Compute cosine similarity between two vectors.
78fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
80    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
81    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
82    if norm_a == 0.0 || norm_b == 0.0 {
83        return 0.0;
84    }
85    dot / (norm_a * norm_b)
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn cosine_similarity_identical_vectors() {
94        let a = vec![1.0, 2.0, 3.0];
95        let sim = cosine_similarity(&a, &a);
96        assert!((sim - 1.0).abs() < 1e-6);
97    }
98
99    #[test]
100    fn cosine_similarity_orthogonal_vectors() {
101        let a = vec![1.0, 0.0];
102        let b = vec![0.0, 1.0];
103        let sim = cosine_similarity(&a, &b);
104        assert!(sim.abs() < 1e-6);
105    }
106
107    #[test]
108    fn cosine_similarity_zero_vector() {
109        let a = vec![1.0, 2.0];
110        let b = vec![0.0, 0.0];
111        let sim = cosine_similarity(&a, &b);
112        assert_eq!(sim, 0.0);
113    }
114}