Skip to main content

the_code_graph_domain/use_cases/
embed.rs

1use std::collections::{HashMap, HashSet};
2
3use sha2::{Digest, Sha256};
4
5use crate::analysis::search::symbol_to_text;
6use crate::error::Result;
7use crate::model::{Edge, EmbedStats, EmbeddingConfig, EmbeddingEntry};
8use crate::ports::{EmbeddingProvider, GraphStore, VectorStore};
9
10// ---------------------------------------------------------------------------
11// EmbedUseCase
12// ---------------------------------------------------------------------------
13
14pub struct EmbedUseCase<S: GraphStore, E: EmbeddingProvider, V: VectorStore> {
15    store: S,
16    provider: E,
17    vector_store: V,
18}
19
20impl<S: GraphStore, E: EmbeddingProvider, V: VectorStore> EmbedUseCase<S, E, V> {
21    pub fn new(store: S, provider: E, vector_store: V) -> Self {
22        Self {
23            store,
24            provider,
25            vector_store,
26        }
27    }
28
29    /// Embed all symbols, skipping those whose text representation is unchanged.
30    /// Calls `on_batch(embedded_so_far, total_to_embed)` after each batch.
31    pub fn embed_all(
32        &self,
33        config: &EmbeddingConfig,
34        on_batch: impl Fn(usize, usize),
35    ) -> Result<EmbedStats> {
36        let symbols = self.store.all_symbols()?;
37        let edges = self.store.all_edges()?;
38        let edge_map = build_edge_map(&edges);
39
40        // Build a map of already-stored text hashes for incremental skipping.
41        let stored: HashMap<String, String> =
42            self.vector_store.get_stored_hashes()?.into_iter().collect();
43
44        let mut to_embed: Vec<(String, String, String)> = Vec::new(); // (qn, text, hash)
45        let mut skipped = 0usize;
46
47        for sym in &symbols {
48            let sym_edges = edge_map
49                .get(&sym.qualified_name)
50                .cloned()
51                .unwrap_or_default();
52            let text = symbol_to_text(sym, &sym_edges);
53            let hash = sha256_hex(&text);
54
55            if stored
56                .get(&sym.qualified_name)
57                .map(|h| h == &hash)
58                .unwrap_or(false)
59            {
60                skipped += 1;
61                continue;
62            }
63            to_embed.push((sym.qualified_name.clone(), text, hash));
64        }
65
66        let total_to_embed = to_embed.len();
67        let mut embedded = 0usize;
68        for chunk in to_embed.chunks(config.batch_size) {
69            let texts: Vec<String> = chunk.iter().map(|(_, t, _)| t.clone()).collect();
70            let vectors = self.provider.embed_batch(&texts)?;
71            let entries: Vec<EmbeddingEntry> = chunk
72                .iter()
73                .zip(vectors)
74                .map(|((qn, _, hash), vec)| EmbeddingEntry {
75                    qualified_name: qn.clone(),
76                    vector: vec,
77                    text_hash: hash.clone(),
78                })
79                .collect();
80            self.vector_store.store_embeddings(&entries)?;
81            embedded += entries.len();
82            on_batch(embedded, total_to_embed);
83        }
84
85        Ok(EmbedStats {
86            total_symbols: symbols.len(),
87            embedded,
88            skipped,
89            removed: 0,
90        })
91    }
92
93    /// Remove embeddings whose qualified name no longer exists in the graph.
94    pub fn cleanup_orphans(&self) -> Result<usize> {
95        let symbols: HashSet<String> = self
96            .store
97            .all_symbols()?
98            .iter()
99            .map(|s| s.qualified_name.clone())
100            .collect();
101
102        let stored = self.vector_store.get_stored_hashes()?;
103        let orphans: Vec<&str> = stored
104            .iter()
105            .filter(|(qn, _)| !symbols.contains(qn.as_str()))
106            .map(|(qn, _)| qn.as_str())
107            .collect();
108
109        let count = orphans.len();
110        if !orphans.is_empty() {
111            self.vector_store.remove_embeddings(&orphans)?;
112        }
113        Ok(count)
114    }
115}
116
117// ---------------------------------------------------------------------------
118// Helpers
119// ---------------------------------------------------------------------------
120
121fn build_edge_map(edges: &[Edge]) -> HashMap<String, Vec<Edge>> {
122    let mut map: HashMap<String, Vec<Edge>> = HashMap::new();
123    for edge in edges {
124        map.entry(edge.source.clone())
125            .or_default()
126            .push(edge.clone());
127        map.entry(edge.target.clone())
128            .or_default()
129            .push(edge.clone());
130    }
131    map
132}
133
134fn sha256_hex(text: &str) -> String {
135    let mut hasher = Sha256::new();
136    hasher.update(text.as_bytes());
137    format!("{:x}", hasher.finalize())
138}
139
140// ---------------------------------------------------------------------------
141// Tests
142// ---------------------------------------------------------------------------
143
144#[cfg(test)]
145mod tests {
146    use std::sync::Arc;
147
148    use super::*;
149    use crate::model::*;
150    use crate::test_support::*;
151
152    fn setup() -> (
153        InMemoryGraphStore,
154        InMemoryEmbeddingProvider,
155        Arc<InMemoryVectorStore>,
156    ) {
157        let store = InMemoryGraphStore::new();
158        let provider = InMemoryEmbeddingProvider::new(4);
159        let vs = Arc::new(InMemoryVectorStore::new());
160        (store, provider, vs)
161    }
162
163    fn make_symbol(name: &str) -> SymbolNode {
164        SymbolNode {
165            name: name.to_string(),
166            qualified_name: format!("test.rs::{name}"),
167            kind: SymbolKind::Function,
168            location: Location {
169                file: "test.rs".into(),
170                line_start: 1,
171                line_end: 5,
172                col_start: 0,
173                col_end: 0,
174            },
175            visibility: Visibility::Public,
176            is_exported: true,
177            is_async: false,
178            is_test: false,
179            decorators: vec![],
180            signature: None,
181        }
182    }
183
184    #[test]
185    fn embed_all_embeds_symbols() {
186        let (mut store, provider, vs) = setup();
187        store.insert_symbol(make_symbol("foo"));
188        store.insert_symbol(make_symbol("bar"));
189        let uc = EmbedUseCase::new(store, provider, vs);
190        let stats = uc
191            .embed_all(&EmbeddingConfig::default(), |_, _| {})
192            .unwrap();
193        assert_eq!(stats.total_symbols, 2);
194        assert_eq!(stats.embedded, 2);
195        assert_eq!(stats.skipped, 0);
196    }
197
198    #[test]
199    fn embed_incremental_skips_unchanged() {
200        let (mut store, provider, vs) = setup();
201        store.insert_symbol(make_symbol("foo"));
202        let uc = EmbedUseCase::new(store, provider, Arc::clone(&vs));
203        // First run
204        let stats1 = uc
205            .embed_all(&EmbeddingConfig::default(), |_, _| {})
206            .unwrap();
207        assert_eq!(stats1.embedded, 1);
208        // Second run — same symbols, same text → should skip
209        let stats2 = uc
210            .embed_all(&EmbeddingConfig::default(), |_, _| {})
211            .unwrap();
212        assert_eq!(stats2.skipped, 1);
213        assert_eq!(stats2.embedded, 0);
214    }
215
216    #[test]
217    fn edge_change_triggers_reembed() {
218        let (mut store, provider, vs) = setup();
219        store.insert_symbol(make_symbol("foo"));
220        store.insert_symbol(make_symbol("bar"));
221
222        let uc = EmbedUseCase::new(store.clone(), provider.clone(), Arc::clone(&vs));
223        let _ = uc
224            .embed_all(&EmbeddingConfig::default(), |_, _| {})
225            .unwrap();
226
227        // Add an edge: foo calls bar — changes foo's text representation
228        store.insert_edge(Edge {
229            kind: EdgeKind::Calls,
230            source: "test.rs::foo".into(),
231            target: "test.rs::bar".into(),
232            metadata: None,
233        });
234
235        // Re-embed with the updated store but same vector store
236        let uc2 = EmbedUseCase::new(store, provider, Arc::clone(&vs));
237        let stats2 = uc2
238            .embed_all(&EmbeddingConfig::default(), |_, _| {})
239            .unwrap();
240        // foo's text changed (now includes "calls bar"), bar's callers changed too
241        assert!(stats2.embedded > 0);
242    }
243
244    #[test]
245    fn cleanup_orphans_removes_stale() {
246        let (mut store, provider, vs) = setup();
247        store.insert_symbol(make_symbol("foo"));
248        store.insert_symbol(make_symbol("bar"));
249
250        let uc = EmbedUseCase::new(store, provider.clone(), Arc::clone(&vs));
251        uc.embed_all(&EmbeddingConfig::default(), |_, _| {})
252            .unwrap();
253        assert_eq!(vs.count().unwrap(), 2);
254
255        // Create a new store with only "foo" — "bar" becomes an orphan
256        let mut store2 = InMemoryGraphStore::new();
257        store2.insert_symbol(make_symbol("foo"));
258
259        let uc2 = EmbedUseCase::new(store2, provider, Arc::clone(&vs));
260        let removed = uc2.cleanup_orphans().unwrap();
261        assert_eq!(removed, 1);
262        assert_eq!(vs.count().unwrap(), 1);
263    }
264
265    #[test]
266    fn embed_empty_store_returns_zero() {
267        let (store, provider, vs) = setup();
268        let uc = EmbedUseCase::new(store, provider, vs);
269        let stats = uc
270            .embed_all(&EmbeddingConfig::default(), |_, _| {})
271            .unwrap();
272        assert_eq!(stats.total_symbols, 0);
273        assert_eq!(stats.embedded, 0);
274    }
275}