the_code_graph_domain/use_cases/
embed.rs1use std::collections::{HashMap, HashSet};
2
3use sha2::{Digest, Sha256};
4
5use crate::analysis::search::symbol_to_text;
6use crate::error::Result;
7use crate::model::{Edge, EmbedStats, EmbeddingConfig, EmbeddingEntry};
8use crate::ports::{EmbeddingProvider, GraphStore, VectorStore};
9
10pub struct EmbedUseCase<S: GraphStore, E: EmbeddingProvider, V: VectorStore> {
15 store: S,
16 provider: E,
17 vector_store: V,
18}
19
20impl<S: GraphStore, E: EmbeddingProvider, V: VectorStore> EmbedUseCase<S, E, V> {
21 pub fn new(store: S, provider: E, vector_store: V) -> Self {
22 Self {
23 store,
24 provider,
25 vector_store,
26 }
27 }
28
29 pub fn embed_all(
32 &self,
33 config: &EmbeddingConfig,
34 on_batch: impl Fn(usize, usize),
35 ) -> Result<EmbedStats> {
36 let symbols = self.store.all_symbols()?;
37 let edges = self.store.all_edges()?;
38 let edge_map = build_edge_map(&edges);
39
40 let stored: HashMap<String, String> =
42 self.vector_store.get_stored_hashes()?.into_iter().collect();
43
44 let mut to_embed: Vec<(String, String, String)> = Vec::new(); let mut skipped = 0usize;
46
47 for sym in &symbols {
48 let sym_edges = edge_map
49 .get(&sym.qualified_name)
50 .cloned()
51 .unwrap_or_default();
52 let text = symbol_to_text(sym, &sym_edges);
53 let hash = sha256_hex(&text);
54
55 if stored
56 .get(&sym.qualified_name)
57 .map(|h| h == &hash)
58 .unwrap_or(false)
59 {
60 skipped += 1;
61 continue;
62 }
63 to_embed.push((sym.qualified_name.clone(), text, hash));
64 }
65
66 let total_to_embed = to_embed.len();
67 let mut embedded = 0usize;
68 for chunk in to_embed.chunks(config.batch_size) {
69 let texts: Vec<String> = chunk.iter().map(|(_, t, _)| t.clone()).collect();
70 let vectors = self.provider.embed_batch(&texts)?;
71 let entries: Vec<EmbeddingEntry> = chunk
72 .iter()
73 .zip(vectors)
74 .map(|((qn, _, hash), vec)| EmbeddingEntry {
75 qualified_name: qn.clone(),
76 vector: vec,
77 text_hash: hash.clone(),
78 })
79 .collect();
80 self.vector_store.store_embeddings(&entries)?;
81 embedded += entries.len();
82 on_batch(embedded, total_to_embed);
83 }
84
85 Ok(EmbedStats {
86 total_symbols: symbols.len(),
87 embedded,
88 skipped,
89 removed: 0,
90 })
91 }
92
93 pub fn cleanup_orphans(&self) -> Result<usize> {
95 let symbols: HashSet<String> = self
96 .store
97 .all_symbols()?
98 .iter()
99 .map(|s| s.qualified_name.clone())
100 .collect();
101
102 let stored = self.vector_store.get_stored_hashes()?;
103 let orphans: Vec<&str> = stored
104 .iter()
105 .filter(|(qn, _)| !symbols.contains(qn.as_str()))
106 .map(|(qn, _)| qn.as_str())
107 .collect();
108
109 let count = orphans.len();
110 if !orphans.is_empty() {
111 self.vector_store.remove_embeddings(&orphans)?;
112 }
113 Ok(count)
114 }
115}
116
117fn build_edge_map(edges: &[Edge]) -> HashMap<String, Vec<Edge>> {
122 let mut map: HashMap<String, Vec<Edge>> = HashMap::new();
123 for edge in edges {
124 map.entry(edge.source.clone())
125 .or_default()
126 .push(edge.clone());
127 map.entry(edge.target.clone())
128 .or_default()
129 .push(edge.clone());
130 }
131 map
132}
133
134fn sha256_hex(text: &str) -> String {
135 let mut hasher = Sha256::new();
136 hasher.update(text.as_bytes());
137 format!("{:x}", hasher.finalize())
138}
139
140#[cfg(test)]
145mod tests {
146 use std::sync::Arc;
147
148 use super::*;
149 use crate::model::*;
150 use crate::test_support::*;
151
152 fn setup() -> (
153 InMemoryGraphStore,
154 InMemoryEmbeddingProvider,
155 Arc<InMemoryVectorStore>,
156 ) {
157 let store = InMemoryGraphStore::new();
158 let provider = InMemoryEmbeddingProvider::new(4);
159 let vs = Arc::new(InMemoryVectorStore::new());
160 (store, provider, vs)
161 }
162
163 fn make_symbol(name: &str) -> SymbolNode {
164 SymbolNode {
165 name: name.to_string(),
166 qualified_name: format!("test.rs::{name}"),
167 kind: SymbolKind::Function,
168 location: Location {
169 file: "test.rs".into(),
170 line_start: 1,
171 line_end: 5,
172 col_start: 0,
173 col_end: 0,
174 },
175 visibility: Visibility::Public,
176 is_exported: true,
177 is_async: false,
178 is_test: false,
179 decorators: vec![],
180 signature: None,
181 }
182 }
183
184 #[test]
185 fn embed_all_embeds_symbols() {
186 let (mut store, provider, vs) = setup();
187 store.insert_symbol(make_symbol("foo"));
188 store.insert_symbol(make_symbol("bar"));
189 let uc = EmbedUseCase::new(store, provider, vs);
190 let stats = uc
191 .embed_all(&EmbeddingConfig::default(), |_, _| {})
192 .unwrap();
193 assert_eq!(stats.total_symbols, 2);
194 assert_eq!(stats.embedded, 2);
195 assert_eq!(stats.skipped, 0);
196 }
197
198 #[test]
199 fn embed_incremental_skips_unchanged() {
200 let (mut store, provider, vs) = setup();
201 store.insert_symbol(make_symbol("foo"));
202 let uc = EmbedUseCase::new(store, provider, Arc::clone(&vs));
203 let stats1 = uc
205 .embed_all(&EmbeddingConfig::default(), |_, _| {})
206 .unwrap();
207 assert_eq!(stats1.embedded, 1);
208 let stats2 = uc
210 .embed_all(&EmbeddingConfig::default(), |_, _| {})
211 .unwrap();
212 assert_eq!(stats2.skipped, 1);
213 assert_eq!(stats2.embedded, 0);
214 }
215
216 #[test]
217 fn edge_change_triggers_reembed() {
218 let (mut store, provider, vs) = setup();
219 store.insert_symbol(make_symbol("foo"));
220 store.insert_symbol(make_symbol("bar"));
221
222 let uc = EmbedUseCase::new(store.clone(), provider.clone(), Arc::clone(&vs));
223 let _ = uc
224 .embed_all(&EmbeddingConfig::default(), |_, _| {})
225 .unwrap();
226
227 store.insert_edge(Edge {
229 kind: EdgeKind::Calls,
230 source: "test.rs::foo".into(),
231 target: "test.rs::bar".into(),
232 metadata: None,
233 });
234
235 let uc2 = EmbedUseCase::new(store, provider, Arc::clone(&vs));
237 let stats2 = uc2
238 .embed_all(&EmbeddingConfig::default(), |_, _| {})
239 .unwrap();
240 assert!(stats2.embedded > 0);
242 }
243
244 #[test]
245 fn cleanup_orphans_removes_stale() {
246 let (mut store, provider, vs) = setup();
247 store.insert_symbol(make_symbol("foo"));
248 store.insert_symbol(make_symbol("bar"));
249
250 let uc = EmbedUseCase::new(store, provider.clone(), Arc::clone(&vs));
251 uc.embed_all(&EmbeddingConfig::default(), |_, _| {})
252 .unwrap();
253 assert_eq!(vs.count().unwrap(), 2);
254
255 let mut store2 = InMemoryGraphStore::new();
257 store2.insert_symbol(make_symbol("foo"));
258
259 let uc2 = EmbedUseCase::new(store2, provider, Arc::clone(&vs));
260 let removed = uc2.cleanup_orphans().unwrap();
261 assert_eq!(removed, 1);
262 assert_eq!(vs.count().unwrap(), 1);
263 }
264
265 #[test]
266 fn embed_empty_store_returns_zero() {
267 let (store, provider, vs) = setup();
268 let uc = EmbedUseCase::new(store, provider, vs);
269 let stats = uc
270 .embed_all(&EmbeddingConfig::default(), |_, _| {})
271 .unwrap();
272 assert_eq!(stats.total_symbols, 0);
273 assert_eq!(stats.embedded, 0);
274 }
275}