1use crate::error::EmbeddingError;
2use rusqlite::{params, Connection};
3use serde::Deserialize;
4use skill_core::Skill;
5use std::path::PathBuf;
6use tracing::{debug, info};
7
8pub mod error;
9
10#[derive(Clone)]
11pub struct EmbeddingService {
12 provider: Provider,
13 db_path: PathBuf,
14}
15
16#[derive(Clone)]
17enum Provider {
18 Ollama { url: String, model: String },
19}
20
21impl EmbeddingService {
22 pub fn new_ollama(url: String, model: String, db_path: PathBuf) -> Self {
23 Self {
24 provider: Provider::Ollama { url, model },
25 db_path,
26 }
27 }
28
29 pub async fn init_db(&self) -> Result<(), EmbeddingError> {
30 let conn = Connection::open(&self.db_path)
31 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
32
33 conn.execute(
34 "CREATE TABLE IF NOT EXISTS embeddings (
35 skill_id TEXT PRIMARY KEY,
36 embedding BLOB NOT NULL,
37 updated_at INTEGER NOT NULL
38 )",
39 [],
40 )
41 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
42
43 info!("Initialized embeddings database at {:?}", self.db_path);
44 Ok(())
45 }
46
47 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
48 match &self.provider {
49 Provider::Ollama { url, model } => self.embed_ollama(url, model, text).await,
50 }
51 }
52
53 async fn embed_ollama(
54 &self,
55 url: &str,
56 model: &str,
57 text: &str,
58 ) -> Result<Vec<f32>, EmbeddingError> {
59 let client = reqwest::Client::new();
60
61 let request = serde_json::json!({
62 "model": model,
63 "input": text,
64 });
65
66 let response = client
67 .post(format!("{}/api/embeddings", url))
68 .json(&request)
69 .send()
70 .await
71 .map_err(|e| EmbeddingError::RequestError(e.to_string()))?;
72
73 if !response.status().is_success() {
74 let status = response.status();
75 let body = response.text().await.unwrap_or_default();
76 return Err(EmbeddingError::ApiError(format!(
77 "Status: {}, Body: {}",
78 status, body
79 )));
80 }
81
82 let result: EmbedResponse = response
83 .json()
84 .await
85 .map_err(|e| EmbeddingError::ParseError(e.to_string()))?;
86
87 Ok(result.embedding)
88 }
89
90 pub async fn embed_skill(&self, skill: &Skill) -> Result<Vec<f32>, EmbeddingError> {
91 let text = skill.search_text();
92 debug!("Embedding skill: {} ({} chars)", skill.name, text.len());
93 self.embed_text(&text).await
94 }
95
96 pub async fn store_embedding(
97 &self,
98 skill_id: &str,
99 embedding: &[f32],
100 ) -> Result<(), EmbeddingError> {
101 let conn = Connection::open(&self.db_path)
102 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
103
104 let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
105
106 let now = std::time::SystemTime::now()
107 .duration_since(std::time::UNIX_EPOCH)
108 .unwrap()
109 .as_secs() as i64;
110
111 conn.execute(
112 "INSERT OR REPLACE INTO embeddings (skill_id, embedding, updated_at) VALUES (?1, ?2, ?3)",
113 params![skill_id, embedding_bytes, now],
114 )
115 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
116
117 debug!("Stored embedding for skill: {}", skill_id);
118 Ok(())
119 }
120
121 pub async fn get_embedding(&self, skill_id: &str) -> Result<Option<Vec<f32>>, EmbeddingError> {
122 let conn = Connection::open(&self.db_path)
123 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
124
125 let mut stmt = conn
126 .prepare("SELECT embedding FROM embeddings WHERE skill_id = ?1")
127 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
128
129 let mut rows = stmt
130 .query(params![skill_id])
131 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
132
133 if let Some(row) = rows
134 .next()
135 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?
136 {
137 let embedding_bytes: Vec<u8> = row
138 .get(0)
139 .map_err(|e| EmbeddingError::DatabaseError(e.to_string()))?;
140 let embedding: Vec<f32> = embedding_bytes
141 .chunks_exact(4)
142 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
143 .collect();
144 Ok(Some(embedding))
145 } else {
146 Ok(None)
147 }
148 }
149
150 pub async fn index_skill(&self, skill: &Skill) -> Result<(), EmbeddingError> {
151 let embedding = self.embed_skill(skill).await?;
152 self.store_embedding(&skill.id, &embedding).await?;
153 info!("Indexed skill: {}", skill.name);
154 Ok(())
155 }
156
157 pub async fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f64 {
158 if a.len() != b.len() {
159 return 0.0;
160 }
161
162 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
163 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
164 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
165
166 if mag_a == 0.0 || mag_b == 0.0 {
167 return 0.0;
168 }
169
170 (dot_product / (mag_a * mag_b)) as f64
171 }
172
173 pub async fn search_similar(
174 &self,
175 query_embedding: &[f32],
176 skill_ids: &[String],
177 limit: usize,
178 ) -> Result<Vec<(String, f64)>, EmbeddingError> {
179 let mut scores = Vec::new();
180
181 for skill_id in skill_ids {
182 if let Some(embedding) = self.get_embedding(skill_id).await? {
183 let similarity = self.cosine_similarity(query_embedding, &embedding).await;
184 scores.push((skill_id.clone(), similarity));
185 }
186 }
187
188 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189 scores.truncate(limit);
190
191 Ok(scores)
192 }
193}
194
195#[derive(Deserialize)]
196struct EmbedResponse {
197 embedding: Vec<f32>,
198}