synaptic_prompts/
example_selector.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5use synaptic_embeddings::Embeddings;
6use tokio::sync::RwLock;
7
8use crate::FewShotExample;
9
10#[async_trait]
12pub trait ExampleSelector: Send + Sync {
13 async fn select_examples(&self, input: &str) -> Result<Vec<FewShotExample>, SynapticError>;
15
16 async fn add_example(&self, example: FewShotExample) -> Result<(), SynapticError>;
18}
19
20pub struct SemanticSimilarityExampleSelector {
22 #[expect(clippy::type_complexity)]
23 examples: Arc<RwLock<Vec<(FewShotExample, Vec<f32>)>>>,
24 embeddings: Arc<dyn Embeddings>,
25 k: usize,
26}
27
28impl SemanticSimilarityExampleSelector {
29 pub fn new(embeddings: Arc<dyn Embeddings>, k: usize) -> Self {
31 Self {
32 examples: Arc::new(RwLock::new(Vec::new())),
33 embeddings,
34 k,
35 }
36 }
37}
38
39#[async_trait]
40impl ExampleSelector for SemanticSimilarityExampleSelector {
41 async fn add_example(&self, example: FewShotExample) -> Result<(), SynapticError> {
42 let embedding = self.embeddings.embed_query(&example.input).await?;
43 let mut examples = self.examples.write().await;
44 examples.push((example, embedding));
45 Ok(())
46 }
47
48 async fn select_examples(&self, input: &str) -> Result<Vec<FewShotExample>, SynapticError> {
49 let query_embedding = self.embeddings.embed_query(input).await?;
50 let examples = self.examples.read().await;
51
52 if examples.is_empty() {
53 return Ok(Vec::new());
54 }
55
56 let mut scored: Vec<(usize, f32)> = examples
58 .iter()
59 .enumerate()
60 .map(|(i, (_, emb))| (i, cosine_similarity(&query_embedding, emb)))
61 .collect();
62
63 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
65
66 let result = scored
68 .iter()
69 .take(self.k)
70 .map(|(i, _)| examples[*i].0.clone())
71 .collect();
72
73 Ok(result)
74 }
75}
76
77fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
79 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
80 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
81 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
82 if norm_a == 0.0 || norm_b == 0.0 {
83 return 0.0;
84 }
85 dot / (norm_a * norm_b)
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn cosine_similarity_identical_vectors() {
94 let a = vec![1.0, 2.0, 3.0];
95 let sim = cosine_similarity(&a, &a);
96 assert!((sim - 1.0).abs() < 1e-6);
97 }
98
99 #[test]
100 fn cosine_similarity_orthogonal_vectors() {
101 let a = vec![1.0, 0.0];
102 let b = vec![0.0, 1.0];
103 let sim = cosine_similarity(&a, &b);
104 assert!(sim.abs() < 1e-6);
105 }
106
107 #[test]
108 fn cosine_similarity_zero_vector() {
109 let a = vec![1.0, 2.0];
110 let b = vec![0.0, 0.0];
111 let sim = cosine_similarity(&a, &b);
112 assert_eq!(sim, 0.0);
113 }
114}