Skip to main content

sqlite_knowledge_graph/
embed.rs

1//! Vector embedding generation module for semantic search.
2
3use crate::error::{Error, Result};
4use crate::vector::VectorStore;
5use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7use std::io::Write;
8use std::process::{Command, Stdio};
9
10/// Embedding model configuration.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EmbeddingConfig {
13    pub model_name: String,
14    pub dimension: usize,
15}
16
17impl Default for EmbeddingConfig {
18    fn default() -> Self {
19        Self {
20            model_name: "all-MiniLM-L6-v2".to_string(),
21            dimension: 384,
22        }
23    }
24}
25
26/// Embedding generator using sentence-transformers.
27pub struct EmbeddingGenerator {
28    config: EmbeddingConfig,
29    /// If true, skip entities that already have real (non-zero) embeddings.
30    pub skip_existing: bool,
31}
32
33impl EmbeddingGenerator {
34    /// Create a new embedding generator with default configuration.
35    pub fn new() -> Self {
36        Self {
37            config: EmbeddingConfig::default(),
38            skip_existing: true,
39        }
40    }
41
42    /// Create a new embedding generator with custom configuration.
43    pub fn with_config(config: EmbeddingConfig) -> Self {
44        Self {
45            config,
46            skip_existing: true,
47        }
48    }
49
50    /// Set force mode: if true, regenerate embeddings even for entities that
51    /// already have real (non-zero) vectors.
52    pub fn with_force(mut self, force: bool) -> Self {
53        self.skip_existing = !force;
54        self
55    }
56
57    /// Generate embeddings for a list of texts.
58    pub fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
59        if texts.is_empty() {
60            return Ok(Vec::new());
61        }
62
63        let python_script = self.generate_python_script()?;
64
65        // Serialize texts to JSON
66        let texts_json = serde_json::to_string(&texts)
67            .map_err(|e| Error::Other(format!("Failed to serialize texts: {}", e)))?;
68
69        // Run Python script with stdin
70        let mut child = Command::new("python3")
71            .arg("-c")
72            .arg(&python_script)
73            .stdin(Stdio::piped())
74            .stdout(Stdio::piped())
75            .stderr(Stdio::piped())
76            .spawn()
77            .map_err(|e| Error::Other(format!("Failed to spawn Python: {}", e)))?;
78
79        // Write to stdin
80        if let Some(mut stdin) = child.stdin.take() {
81            stdin
82                .write_all(texts_json.as_bytes())
83                .map_err(|e| Error::Other(format!("Failed to write to stdin: {}", e)))?;
84        }
85
86        // Get the output
87        let output = child
88            .wait_with_output()
89            .map_err(|e| Error::Other(format!("Failed to read Python output: {}", e)))?;
90
91        if !output.status.success() {
92            let stderr = String::from_utf8_lossy(&output.stderr);
93            return Err(Error::Other(format!("Python script failed: {}", stderr)));
94        }
95
96        // Parse output
97        let stdout = String::from_utf8_lossy(&output.stdout);
98        self.parse_embeddings(&stdout)
99    }
100
101    /// Generate Python script for embedding generation.
102    fn generate_python_script(&self) -> Result<String> {
103        let script = format!(
104            r#"
105import sys
106import json
107import numpy as np
108
109try:
110    from sentence_transformers import SentenceTransformer
111
112    # Load model
113    model = SentenceTransformer('{}')
114
115    # Read texts from stdin
116    texts_json = sys.stdin.read()
117    texts = json.loads(texts_json)
118
119    # Generate embeddings
120    embeddings = model.encode(texts, convert_to_numpy=True)
121
122    # Convert to list and print as JSON
123    embeddings_list = embeddings.tolist()
124    print(json.dumps(embeddings_list))
125
126except ImportError:
127    print("{{\"error\": \"sentence-transformers not installed. Run: pip install sentence-transformers\"}}", file=sys.stderr)
128    sys.exit(1)
129except Exception as e:
130    print("{{\"error\": \"{{}}\"}}".format(str(e)), file=sys.stderr)
131    sys.exit(1)
132"#,
133            self.config.model_name
134        );
135
136        Ok(script)
137    }
138
139    /// Parse embeddings from Python output.
140    fn parse_embeddings(&self, output: &str) -> Result<Vec<Vec<f32>>> {
141        let embeddings: Vec<Vec<f32>> = serde_json::from_str(output)
142            .map_err(|e| Error::Other(format!("Failed to parse embeddings: {}", e)))?;
143
144        // Validate dimensions
145        for embedding in &embeddings {
146            if embedding.len() != self.config.dimension {
147                return Err(Error::InvalidVectorDimension {
148                    expected: self.config.dimension,
149                    actual: embedding.len(),
150                });
151            }
152        }
153
154        Ok(embeddings)
155    }
156
157    /// Generate embeddings for all paper entities in the database.
158    pub fn generate_for_papers(&self, conn: &Connection) -> Result<EmbeddingStats> {
159        let entities = get_entities_needing_embedding(conn, "paper", !self.skip_existing)?;
160        let total_count = count_entities(conn, "paper")?;
161        let skipped_count = total_count - entities.len() as i64;
162
163        self.generate_and_store(conn, entities, total_count, skipped_count, "paper")
164    }
165
166    /// Generate embeddings for all skill entities in the database.
167    pub fn generate_for_skills(&self, conn: &Connection) -> Result<EmbeddingStats> {
168        let entities = get_entities_needing_embedding(conn, "skill", !self.skip_existing)?;
169        let total_count = count_entities(conn, "skill")?;
170        let skipped_count = total_count - entities.len() as i64;
171
172        self.generate_and_store(conn, entities, total_count, skipped_count, "skill")
173    }
174
175    /// Generate embeddings for all entities in the database.
176    pub fn generate_for_all(&self, conn: &Connection) -> Result<EmbeddingStats> {
177        let papers_stats = self.generate_for_papers(conn)?;
178        let skills_stats = self.generate_for_skills(conn)?;
179
180        Ok(EmbeddingStats {
181            total_count: papers_stats.total_count + skills_stats.total_count,
182            processed_count: papers_stats.processed_count + skills_stats.processed_count,
183            skipped_count: papers_stats.skipped_count + skills_stats.skipped_count,
184            dimension: self.config.dimension,
185        })
186    }
187
188    /// Internal: batch-generate embeddings for a list of (id, text) pairs and store them.
189    fn generate_and_store(
190        &self,
191        conn: &Connection,
192        entities: Vec<(i64, String)>,
193        total_count: i64,
194        skipped_count: i64,
195        label: &str,
196    ) -> Result<EmbeddingStats> {
197        if entities.is_empty() {
198            println!(
199                "All {} entities already have real embeddings, skipping.",
200                label
201            );
202            return Ok(EmbeddingStats {
203                total_count,
204                processed_count: 0,
205                skipped_count,
206                dimension: self.config.dimension,
207            });
208        }
209
210        let (entity_ids, texts): (Vec<i64>, Vec<String>) = entities.into_iter().unzip();
211
212        println!(
213            "Generating embeddings for {} {} titles ({} already have real embeddings, skipping)...",
214            texts.len(),
215            label,
216            skipped_count
217        );
218
219        let batch_size = 100;
220        let mut processed_count = 0;
221
222        let store = VectorStore::new();
223        let tx = conn.unchecked_transaction()?;
224
225        for batch_start in (0..texts.len()).step_by(batch_size) {
226            let batch_end = (batch_start + batch_size).min(texts.len());
227            let batch_texts = texts[batch_start..batch_end].to_vec();
228            let batch_ids = entity_ids[batch_start..batch_end].to_vec();
229
230            println!(
231                "Processing batch: {}s {}-{}",
232                label,
233                batch_start + 1,
234                batch_end
235            );
236
237            let embeddings = self.generate_embeddings(batch_texts)?;
238
239            for (entity_id, embedding) in batch_ids.iter().zip(embeddings.iter()) {
240                store.insert_vector(&tx, *entity_id, embedding.clone())?;
241            }
242
243            processed_count += embeddings.len();
244            println!("  Generated {} embeddings", embeddings.len());
245        }
246
247        tx.commit()?;
248
249        println!("✓ Generated {} embeddings for {}s", processed_count, label);
250
251        Ok(EmbeddingStats {
252            total_count,
253            processed_count: processed_count as i64,
254            skipped_count,
255            dimension: self.config.dimension,
256        })
257    }
258}
259
260impl Default for EmbeddingGenerator {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266/// Get entity (id, name) pairs that need embeddings generated.
267///
268/// - If `force` is true, returns all entities of the given type.
269/// - Otherwise, returns only entities with missing or placeholder (all-zero) vectors.
270pub fn get_entities_needing_embedding(
271    conn: &Connection,
272    entity_type: &str,
273    force: bool,
274) -> Result<Vec<(i64, String)>> {
275    let mut stmt = conn.prepare(
276        r#"
277        SELECT e.id, e.name, v.vector
278        FROM kg_entities e
279        LEFT JOIN kg_vectors v ON e.id = v.entity_id
280        WHERE e.entity_type = ?1
281        ORDER BY e.id
282        "#,
283    )?;
284
285    let rows = stmt.query_map([entity_type], |row| {
286        Ok((
287            row.get::<_, i64>(0)?,
288            row.get::<_, String>(1)?,
289            row.get::<_, Option<Vec<u8>>>(2)?,
290        ))
291    })?;
292
293    let mut result = Vec::new();
294    for row in rows {
295        let (id, name, blob) = row?;
296        let needs_embedding = force || is_placeholder_or_missing(blob.as_deref());
297        if needs_embedding {
298            result.push((id, name));
299        }
300    }
301
302    Ok(result)
303}
304
305/// Returns true if the vector blob is missing (None) or all-zero bytes (placeholder).
306fn is_placeholder_or_missing(blob: Option<&[u8]>) -> bool {
307    match blob {
308        None => true,
309        Some(b) => b.iter().all(|&x| x == 0),
310    }
311}
312
313/// Count total entities of a given type.
314fn count_entities(conn: &Connection, entity_type: &str) -> Result<i64> {
315    let count: i64 = conn.query_row(
316        "SELECT COUNT(*) FROM kg_entities WHERE entity_type = ?1",
317        [entity_type],
318        |row| row.get(0),
319    )?;
320    Ok(count)
321}
322
323/// Statistics from embedding generation.
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct EmbeddingStats {
326    pub total_count: i64,
327    pub processed_count: i64,
328    pub skipped_count: i64,
329    pub dimension: usize,
330}
331
332/// Check if sentence-transformers is available.
333pub fn check_dependencies() -> Result<bool> {
334    let output = Command::new("python3")
335        .arg("-c")
336        .arg("import sentence_transformers")
337        .stdout(Stdio::piped())
338        .stderr(Stdio::piped())
339        .output()
340        .map_err(|e| Error::Other(format!("Failed to check Python dependencies: {}", e)))?;
341
342    Ok(output.status.success())
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::graph::insert_entity;
349    use crate::graph::Entity;
350    use crate::schema::create_schema;
351
352    fn make_in_memory_conn() -> Connection {
353        let conn = Connection::open_in_memory().unwrap();
354        create_schema(&conn).unwrap();
355        conn
356    }
357
358    // ── Config & constructor ──────────────────────────────────────────────
359
360    #[test]
361    fn test_embedding_config_default() {
362        let config = EmbeddingConfig::default();
363        assert_eq!(config.model_name, "all-MiniLM-L6-v2");
364        assert_eq!(config.dimension, 384);
365    }
366
367    #[test]
368    fn test_embedding_generator_new() {
369        let generator = EmbeddingGenerator::new();
370        assert_eq!(generator.config.model_name, "all-MiniLM-L6-v2");
371        assert_eq!(generator.config.dimension, 384);
372        assert!(generator.skip_existing);
373    }
374
375    #[test]
376    fn test_with_force_sets_skip_existing() {
377        let gen = EmbeddingGenerator::new().with_force(true);
378        assert!(!gen.skip_existing);
379
380        let gen2 = EmbeddingGenerator::new().with_force(false);
381        assert!(gen2.skip_existing);
382    }
383
384    // ── parse_embeddings ─────────────────────────────────────────────────
385
386    #[test]
387    fn test_parse_embeddings_dimension_mismatch() {
388        let generator = EmbeddingGenerator::new();
389        // 3-element vectors don't match expected 384
390        let result = generator.parse_embeddings("[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]");
391        assert!(result.is_err());
392    }
393
394    #[test]
395    fn test_parse_embeddings_valid_384() {
396        let generator = EmbeddingGenerator::new();
397        let vec384: Vec<f32> = (0..384).map(|i| i as f32 / 1000.0).collect();
398        let json = serde_json::to_string(&[&vec384]).unwrap();
399        let result = generator.parse_embeddings(&json).unwrap();
400        assert_eq!(result.len(), 1);
401        assert_eq!(result[0].len(), 384);
402        assert!((result[0][0] - 0.0).abs() < 1e-6);
403        assert!((result[0][1] - 0.001).abs() < 1e-6);
404    }
405
406    #[test]
407    fn test_parse_embeddings_batch_of_three() {
408        let generator = EmbeddingGenerator::new();
409        let vec384: Vec<f32> = vec![0.5f32; 384];
410        let batch = vec![vec384.clone(), vec384.clone(), vec384.clone()];
411        let json = serde_json::to_string(&batch).unwrap();
412        let result = generator.parse_embeddings(&json).unwrap();
413        assert_eq!(result.len(), 3);
414        for emb in &result {
415            assert_eq!(emb.len(), 384);
416        }
417    }
418
419    #[test]
420    fn test_parse_embeddings_invalid_json() {
421        let generator = EmbeddingGenerator::new();
422        let result = generator.parse_embeddings("not valid json");
423        assert!(result.is_err());
424    }
425
426    // ── is_placeholder_or_missing ─────────────────────────────────────────
427
428    #[test]
429    fn test_is_placeholder_missing() {
430        assert!(is_placeholder_or_missing(None));
431    }
432
433    #[test]
434    fn test_is_placeholder_zero_bytes() {
435        let blob = vec![0u8; 384 * 4];
436        assert!(is_placeholder_or_missing(Some(&blob)));
437    }
438
439    #[test]
440    fn test_is_placeholder_real_vector() {
441        // Non-zero blob (real embedding)
442        let v: Vec<f32> = vec![0.1f32; 384];
443        let mut blob = Vec::with_capacity(384 * 4);
444        for &val in &v {
445            blob.extend_from_slice(&val.to_le_bytes());
446        }
447        assert!(!is_placeholder_or_missing(Some(&blob)));
448    }
449
450    // ── get_entities_needing_embedding ────────────────────────────────────
451
452    #[test]
453    fn test_get_entities_needing_embedding_no_vector() {
454        let conn = make_in_memory_conn();
455
456        let e1 = Entity::new("paper", "Paper Without Vector");
457        let id1 = insert_entity(&conn, &e1).unwrap();
458        let _ = id1;
459
460        let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
461        assert_eq!(result.len(), 1);
462        assert_eq!(result[0].1, "Paper Without Vector");
463    }
464
465    #[test]
466    fn test_get_entities_needing_embedding_placeholder_vector() {
467        let conn = make_in_memory_conn();
468
469        let e1 = Entity::new("paper", "Paper With Placeholder");
470        let id1 = insert_entity(&conn, &e1).unwrap();
471
472        // Insert placeholder zero vector
473        let placeholder = vec![0.0f32; 384];
474        VectorStore::new()
475            .insert_vector(&conn, id1, placeholder)
476            .unwrap();
477
478        let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
479        assert_eq!(result.len(), 1);
480    }
481
482    #[test]
483    fn test_get_entities_needing_embedding_skip_real_vector() {
484        let conn = make_in_memory_conn();
485
486        let e1 = Entity::new("paper", "Paper With Real Embedding");
487        let id1 = insert_entity(&conn, &e1).unwrap();
488
489        // Insert real non-zero embedding
490        let real_embedding = vec![0.1f32; 384];
491        VectorStore::new()
492            .insert_vector(&conn, id1, real_embedding)
493            .unwrap();
494
495        let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
496        // Should be empty: the paper already has a real embedding
497        assert!(result.is_empty());
498    }
499
500    #[test]
501    fn test_get_entities_needing_embedding_force_returns_all() {
502        let conn = make_in_memory_conn();
503
504        let e1 = Entity::new("paper", "Paper With Real Embedding");
505        let id1 = insert_entity(&conn, &e1).unwrap();
506
507        let real_embedding = vec![0.1f32; 384];
508        VectorStore::new()
509            .insert_vector(&conn, id1, real_embedding)
510            .unwrap();
511
512        // force=true should return all entities regardless of existing vectors
513        let result = get_entities_needing_embedding(&conn, "paper", true).unwrap();
514        assert_eq!(result.len(), 1);
515    }
516
517    #[test]
518    fn test_get_entities_needing_embedding_mixed() {
519        let conn = make_in_memory_conn();
520
521        let e1 = Entity::new("paper", "Has Real Embedding");
522        let id1 = insert_entity(&conn, &e1).unwrap();
523        VectorStore::new()
524            .insert_vector(&conn, id1, vec![0.1f32; 384])
525            .unwrap();
526
527        let e2 = Entity::new("paper", "Has Placeholder");
528        let id2 = insert_entity(&conn, &e2).unwrap();
529        VectorStore::new()
530            .insert_vector(&conn, id2, vec![0.0f32; 384])
531            .unwrap();
532
533        let e3 = Entity::new("paper", "No Vector");
534        insert_entity(&conn, &e3).unwrap();
535
536        // Without force: only placeholder and missing should be returned
537        let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
538        assert_eq!(result.len(), 2);
539        let names: Vec<&str> = result.iter().map(|(_, n)| n.as_str()).collect();
540        assert!(names.contains(&"Has Placeholder"));
541        assert!(names.contains(&"No Vector"));
542        assert!(!names.contains(&"Has Real Embedding"));
543    }
544
545    // ── generate_for_papers / generate_for_skills empty DB ───────────────
546
547    #[test]
548    fn test_generate_for_papers_empty() {
549        let conn = make_in_memory_conn();
550        let generator = EmbeddingGenerator::new();
551        let stats = generator.generate_for_papers(&conn).unwrap();
552        assert_eq!(stats.total_count, 0);
553        assert_eq!(stats.processed_count, 0);
554        assert_eq!(stats.skipped_count, 0);
555    }
556
557    #[test]
558    fn test_generate_for_skills_empty() {
559        let conn = make_in_memory_conn();
560        let generator = EmbeddingGenerator::new();
561        let stats = generator.generate_for_skills(&conn).unwrap();
562        assert_eq!(stats.total_count, 0);
563        assert_eq!(stats.processed_count, 0);
564        assert_eq!(stats.skipped_count, 0);
565    }
566
567    #[test]
568    fn test_generate_for_papers_all_real_embeddings_are_skipped() {
569        let conn = make_in_memory_conn();
570
571        // Insert papers with real embeddings
572        for i in 0..3 {
573            let e = Entity::new("paper", format!("Paper {}", i));
574            let id = insert_entity(&conn, &e).unwrap();
575            VectorStore::new()
576                .insert_vector(&conn, id, vec![0.1f32; 384])
577                .unwrap();
578        }
579
580        let generator = EmbeddingGenerator::new(); // skip_existing = true
581        let stats = generator.generate_for_papers(&conn).unwrap();
582
583        assert_eq!(stats.total_count, 3);
584        assert_eq!(stats.processed_count, 0);
585        assert_eq!(stats.skipped_count, 3);
586    }
587
588    // ── batch boundary test ───────────────────────────────────────────────
589
590    #[test]
591    fn test_get_entities_batch_boundary() {
592        let conn = make_in_memory_conn();
593
594        // Insert 105 papers (crosses the 100-item batch boundary)
595        for i in 0..105 {
596            let e = Entity::new("paper", format!("Paper {}", i));
597            insert_entity(&conn, &e).unwrap();
598        }
599
600        let result = get_entities_needing_embedding(&conn, "paper", false).unwrap();
601        assert_eq!(result.len(), 105);
602    }
603
604    // ── EmbeddingStats ────────────────────────────────────────────────────
605
606    #[test]
607    fn test_embedding_stats_fields() {
608        let stats = EmbeddingStats {
609            total_count: 100,
610            processed_count: 80,
611            skipped_count: 20,
612            dimension: 384,
613        };
614        assert_eq!(stats.total_count, 100);
615        assert_eq!(stats.processed_count, 80);
616        assert_eq!(stats.skipped_count, 20);
617        assert_eq!(stats.dimension, 384);
618    }
619}