Skip to main content

skill_embeddings/
lib.rs

1use crate::error::EmbeddingError;
2use rusqlite::{params, Connection};
3use serde::Deserialize;
4use skill_core::Skill;
5use std::path::PathBuf;
6use tracing::{debug, info};
7
8pub mod error;
9
10#[derive(Clone)]
11pub struct EmbeddingService {
12    provider: Provider,
13    db_path: PathBuf,
14}
15
16#[derive(Clone)]
17enum Provider {
18    Ollama { url: String, model: String },
19}
20
21impl EmbeddingService {
22    pub fn new_ollama(url: String, model: String, db_path: PathBuf) -> Self {
23        Self {
24            provider: Provider::Ollama { url, model },
25            db_path,
26        }
27    }
28
29    pub async fn init_db(&self) -> Result<(), EmbeddingError> {
30        let conn = Connection::open(&self.db_path)
31            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
32
33        conn.execute(
34            "CREATE TABLE IF NOT EXISTS embeddings (
35                skill_id TEXT PRIMARY KEY,
36                embedding BLOB NOT NULL,
37                updated_at INTEGER NOT NULL
38            )",
39            [],
40        )
41        .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
42
43        info!("Initialized embeddings database at {:?}", self.db_path);
44        Ok(())
45    }
46
47    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
48        match &self.provider {
49            Provider::Ollama { url, model } => self.embed_ollama(url, model, text).await,
50        }
51    }
52
53    async fn embed_ollama(
54        &self,
55        url: &str,
56        model: &str,
57        text: &str,
58    ) -> Result<Vec<f32>, EmbeddingError> {
59        let client = reqwest::Client::new();
60
61        let request = serde_json::json!({
62            "model": model,
63            "input": text,
64        });
65
66        let response = client
67            .post(format!("{}/api/embeddings", url))
68            .json(&request)
69            .send()
70            .await
71            .map_err(|e| EmbeddingError::RequestError(e.to_string()))?;
72
73        if !response.status().is_success() {
74            let status = response.status();
75            let body = response.text().await.unwrap_or_default();
76            return Err(EmbeddingError::ApiError(format!(
77                "Status: {}, Body: {}",
78                status, body
79            )));
80        }
81
82        let result: EmbedResponse = response
83            .json()
84            .await
85            .map_err(|e| EmbeddingError::ParseError(e.to_string()))?;
86
87        Ok(result.embedding)
88    }
89
90    pub async fn embed_skill(&self, skill: &Skill) -> Result<Vec<f32>, EmbeddingError> {
91        let text = skill.search_text();
92        debug!("Embedding skill: {} ({} chars)", skill.name, text.len());
93        self.embed_text(&text).await
94    }
95
96    pub async fn store_embedding(
97        &self,
98        skill_id: &str,
99        embedding: &[f32],
100    ) -> Result<(), EmbeddingError> {
101        let conn = Connection::open(&self.db_path)
102            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
103
104        let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
105
106        let now = std::time::SystemTime::now()
107            .duration_since(std::time::UNIX_EPOCH)
108            .unwrap()
109            .as_secs() as i64;
110
111        conn.execute(
112            "INSERT OR REPLACE INTO embeddings (skill_id, embedding, updated_at) VALUES (?1, ?2, ?3)",
113            params![skill_id, embedding_bytes, now],
114        )
115        .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
116
117        debug!("Stored embedding for skill: {}", skill_id);
118        Ok(())
119    }
120
121    pub async fn get_embedding(&self, skill_id: &str) -> Result<Option<Vec<f32>>, EmbeddingError> {
122        let conn = Connection::open(&self.db_path)
123            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
124
125        let mut stmt = conn
126            .prepare("SELECT embedding FROM embeddings WHERE skill_id = ?1")
127            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
128
129        let mut rows = stmt
130            .query(params![skill_id])
131            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
132
133        if let Some(row) = rows
134            .next()
135            .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?
136        {
137            let embedding_bytes: Vec<u8> = row
138                .get(0)
139                .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
140            let embedding: Vec<f32> = embedding_bytes
141                .chunks_exact(4)
142                .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
143                .collect();
144            Ok(Some(embedding))
145        } else {
146            Ok(None)
147        }
148    }
149
150    pub async fn index_skill(&self, skill: &Skill) -> Result<(), EmbeddingError> {
151        let embedding = self.embed_skill(skill).await?;
152        self.store_embedding(&skill.id, &embedding).await?;
153        info!("Indexed skill: {}", skill.name);
154        Ok(())
155    }
156
157    pub async fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f64 {
158        if a.len() != b.len() {
159            return 0.0;
160        }
161
162        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165
166        if mag_a == 0.0 || mag_b == 0.0 {
167            return 0.0;
168        }
169
170        (dot_product / (mag_a * mag_b)) as f64
171    }
172
173    pub async fn search_similar(
174        &self,
175        query_embedding: &[f32],
176        skill_ids: &[String],
177        limit: usize,
178    ) -> Result<Vec<(String, f64)>, EmbeddingError> {
179        let mut scores = Vec::new();
180
181        for skill_id in skill_ids {
182            if let Some(embedding) = self.get_embedding(skill_id).await? {
183                let similarity = self.cosine_similarity(query_embedding, &embedding).await;
184                scores.push((skill_id.clone(), similarity));
185            }
186        }
187
188        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189        scores.truncate(limit);
190
191        Ok(scores)
192    }
193}
194
195#[derive(Deserialize)]
196struct EmbedResponse {
197    embedding: Vec<f32>,
198}