semantic_commands/
semantic_commands.rs

1use anyhow::{Result, bail};
2use futures::future::join_all;
3use log::info;
4use std::{any::Any, sync::Arc};
5
6use crate::{Command, cache::Cache, embedder::Embedder, input::Input};
7
8/// Calculate cosine similarity between two vectors
9///
10/// Returns a value between 0.0 (completely different) and 1.0 (identical)
11pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
12	if a.len() != b.len() {
13		return 0.0;
14	}
15
16	let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
17	let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
18	let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
19
20	if magnitude_a == 0.0 || magnitude_b == 0.0 {
21		return 0.0;
22	}
23
24	(dot_product / (magnitude_a * magnitude_b)).clamp(0.0, 1.0)
25}
26pub struct SemanticCommands<E: Embedder, Ch: Cache, C> {
27	embedder: Arc<E>,
28	cache: Arc<Ch>,
29	context: Arc<C>,
30	threshold: f32,
31	entries: Vec<(Vec<Input>, Command<C>)>,
32}
33impl<E: Embedder, Ch: Cache, C> SemanticCommands<E, Ch, C> {
34	pub async fn get_embedding(&self, input: &str) -> Result<Vec<f32>> {
35		match self.cache.get(input).await? {
36			Some(embedding) => Ok(embedding),
37			None => {
38				info!("embedding not found in cache, generating new one");
39				let embedding = self.embedder.as_ref().embed(input).await?;
40				self.cache.put(input, embedding.clone()).await?;
41				Ok(embedding)
42			}
43		}
44	}
45
46	pub fn new(embedder: E, cache: Ch, context: C) -> Self {
47		Self {
48			embedder: Arc::new(embedder),
49			cache: Arc::new(cache),
50			context: Arc::new(context),
51			threshold: 0.8,
52			entries: vec![],
53		}
54	}
55
56	/// Set the similarity threshold (default is 0.2)
57	pub fn threshold(mut self, threshold: f32) -> Self {
58		self.threshold = threshold;
59		self
60	}
61
62	async fn find_similar(&mut self, embedding: Vec<f32>, threshold: f32) -> Result<Option<(&Input, &Command<C>)>> {
63		// Pre-calculate all missing embeddings in one batch
64		let missing_embeddings: Vec<_> = self
65			.entries
66			.iter()
67			.flat_map(|(inputs, _)| inputs)
68			.filter(|input| input.embedding.is_none())
69			.map(|input| input.text.clone())
70			.collect();
71
72		// Get all embeddings in parallel using existing get_embedding method
73		let embeddings: Vec<_> = join_all(missing_embeddings.iter().map(|text| async { self.get_embedding(text).await }))
74			.await
75			.into_iter()
76			.filter_map(Result::ok)
77			.collect();
78
79		// Update inputs with new embeddings
80		let mut emb_iter = embeddings.into_iter();
81		for (inputs, _) in &mut self.entries {
82			for input in inputs {
83				if input.embedding.is_none() {
84					input.embedding = emb_iter.next();
85				}
86			}
87		}
88
89		let res = self
90			.entries
91			.iter()
92			.flat_map(|(inputs, command)| {
93				let emb = embedding.clone();
94				inputs.iter().filter_map(move |input| {
95					let similarity = cosine_similarity(&emb, input.embedding.as_ref()?);
96					(similarity >= threshold).then_some((similarity, input, command))
97				})
98			})
99			.collect::<Vec<_>>();
100
101		Ok(res
102			.into_iter()
103			.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
104			.map(|(_similarity, input, command)| (input, command)))
105	}
106
107	pub async fn execute(&mut self, input: &str) -> Result<Box<dyn Any + Send>> {
108		let input_embedding = self.get_embedding(input).await?;
109		let context = self.context.clone();
110		let similar = self.find_similar(input_embedding, self.threshold).await?;
111		match similar {
112			Some((_input, command)) => {
113				info!("command recognized as: {:?}", command.name);
114				let result = (command.executor)(context).await;
115				Ok(result)
116			}
117			None => {
118				bail!("no similar command found");
119			}
120		}
121	}
122
123	pub fn add_command(&mut self, command: Command<C>, inputs: Vec<Input>) -> &mut Self {
124		self.entries.push((inputs, command));
125		self
126	}
127
128	pub fn add_commands(&mut self, commands: Vec<(Command<C>, Vec<Input>)>) -> &mut Self {
129		commands.into_iter().for_each(|(command, inputs)| {
130			self.entries.push((inputs, command));
131		});
132		self
133	}
134
135	pub async fn init(&mut self) -> Result<&mut Self> {
136		self.cache.init().await?;
137		Ok(self)
138	}
139}
140
141#[cfg(test)]
142mod tests {
143	use super::*;
144	#[test]
145	fn test_cosine_similarity() {
146		let a = vec![1.0, 0.0, 0.0];
147		let b = vec![0.0, 1.0, 0.0];
148		let c = vec![1.0, 0.0, 0.0];
149		assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
150		assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6);
151	}
152}