semantic_commands/
semantic_commands.rs1use anyhow::{Result, bail};
2use futures::future::join_all;
3use log::info;
4use std::{any::Any, sync::Arc};
5
6use crate::{Command, cache::Cache, embedder::Embedder, input::Input};
7
8pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
12 if a.len() != b.len() {
13 return 0.0;
14 }
15
16 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
17 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
18 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
19
20 if magnitude_a == 0.0 || magnitude_b == 0.0 {
21 return 0.0;
22 }
23
24 (dot_product / (magnitude_a * magnitude_b)).clamp(0.0, 1.0)
25}
26pub struct SemanticCommands<E: Embedder, Ch: Cache, C> {
27 embedder: Arc<E>,
28 cache: Arc<Ch>,
29 context: Arc<C>,
30 threshold: f32,
31 entries: Vec<(Vec<Input>, Command<C>)>,
32}
33impl<E: Embedder, Ch: Cache, C> SemanticCommands<E, Ch, C> {
34 pub async fn get_embedding(&self, input: &str) -> Result<Vec<f32>> {
35 match self.cache.get(input).await? {
36 Some(embedding) => Ok(embedding),
37 None => {
38 info!("embedding not found in cache, generating new one");
39 let embedding = self.embedder.as_ref().embed(input).await?;
40 self.cache.put(input, embedding.clone()).await?;
41 Ok(embedding)
42 }
43 }
44 }
45
46 pub fn new(embedder: E, cache: Ch, context: C) -> Self {
47 Self {
48 embedder: Arc::new(embedder),
49 cache: Arc::new(cache),
50 context: Arc::new(context),
51 threshold: 0.8,
52 entries: vec![],
53 }
54 }
55
56 pub fn threshold(mut self, threshold: f32) -> Self {
58 self.threshold = threshold;
59 self
60 }
61
62 async fn find_similar(&mut self, embedding: Vec<f32>, threshold: f32) -> Result<Option<(&Input, &Command<C>)>> {
63 let missing_embeddings: Vec<_> = self
65 .entries
66 .iter()
67 .flat_map(|(inputs, _)| inputs)
68 .filter(|input| input.embedding.is_none())
69 .map(|input| input.text.clone())
70 .collect();
71
72 let embeddings: Vec<_> = join_all(missing_embeddings.iter().map(|text| async { self.get_embedding(text).await }))
74 .await
75 .into_iter()
76 .filter_map(Result::ok)
77 .collect();
78
79 let mut emb_iter = embeddings.into_iter();
81 for (inputs, _) in &mut self.entries {
82 for input in inputs {
83 if input.embedding.is_none() {
84 input.embedding = emb_iter.next();
85 }
86 }
87 }
88
89 let res = self
90 .entries
91 .iter()
92 .flat_map(|(inputs, command)| {
93 let emb = embedding.clone();
94 inputs.iter().filter_map(move |input| {
95 let similarity = cosine_similarity(&emb, input.embedding.as_ref()?);
96 (similarity >= threshold).then_some((similarity, input, command))
97 })
98 })
99 .collect::<Vec<_>>();
100
101 Ok(res
102 .into_iter()
103 .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
104 .map(|(_similarity, input, command)| (input, command)))
105 }
106
107 pub async fn execute(&mut self, input: &str) -> Result<Box<dyn Any + Send>> {
108 let input_embedding = self.get_embedding(input).await?;
109 let context = self.context.clone();
110 let similar = self.find_similar(input_embedding, self.threshold).await?;
111 match similar {
112 Some((_input, command)) => {
113 info!("command recognized as: {:?}", command.name);
114 let result = (command.executor)(context).await;
115 Ok(result)
116 }
117 None => {
118 bail!("no similar command found");
119 }
120 }
121 }
122
123 pub fn add_command(&mut self, command: Command<C>, inputs: Vec<Input>) -> &mut Self {
124 self.entries.push((inputs, command));
125 self
126 }
127
128 pub fn add_commands(&mut self, commands: Vec<(Command<C>, Vec<Input>)>) -> &mut Self {
129 commands.into_iter().for_each(|(command, inputs)| {
130 self.entries.push((inputs, command));
131 });
132 self
133 }
134
135 pub async fn init(&mut self) -> Result<&mut Self> {
136 self.cache.init().await?;
137 Ok(self)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 #[test]
145 fn test_cosine_similarity() {
146 let a = vec![1.0, 0.0, 0.0];
147 let b = vec![0.0, 1.0, 0.0];
148 let c = vec![1.0, 0.0, 0.0];
149 assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
150 assert!((cosine_similarity(&a, &c) - 1.0).abs() < 1e-6);
151 }
152}