Skip to main content

sqlite_knowledge_graph/rag/
embedder.rs

1//! Embedding abstraction for the RAG pipeline.
2//!
3//! Provides a trait-based interface so the engine is not coupled to any
4//! specific embedding backend.  In production you'll typically wrap a
5//! Python subprocess or an HTTP API; in tests you can use `FixedEmbedder`.
6
7use crate::error::{Error, Result};
8use std::io::{BufRead, Write};
9use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
10
11/// Trait for converting text into dense float vectors.
12pub trait Embedder: Send + Sync {
13    /// Embed a single text query, returning a normalised float vector.
14    fn embed(&self, text: &str) -> Result<Vec<f32>>;
15}
16
17// ─────────────────────────────────────────────────────────────────────────────
18// SubprocessEmbedder
19// ─────────────────────────────────────────────────────────────────────────────
20
21/// Embedder that talks to a long-lived Python subprocess over stdin/stdout.
22///
23/// The subprocess must implement a simple line protocol:
24/// - stdin:  one text per line
25/// - stdout: space-separated floats per line (same order)
26///
27/// Example Python server (`embed_server.py`):
28/// ```python
29/// import sys
30/// from sentence_transformers import SentenceTransformer
31/// model = SentenceTransformer("all-MiniLM-L6-v2")
32/// for line in sys.stdin:
33///     vec = model.encode(line.strip()).tolist()
34///     print(" ".join(map(str, vec)), flush=True)
35/// ```
36pub struct SubprocessEmbedder {
37    child: std::sync::Mutex<SubprocessState>,
38}
39
40struct SubprocessState {
41    _child: Child,
42    stdin: ChildStdin,
43    stdout: std::io::BufReader<ChildStdout>,
44}
45
46impl SubprocessEmbedder {
47    /// Spawn the subprocess.  `program` is e.g. `"python3"`,
48    /// `args` is e.g. `&["embed_server.py"]`.
49    pub fn new(program: &str, args: &[&str]) -> Result<Self> {
50        let mut child = Command::new(program)
51            .args(args)
52            .stdin(Stdio::piped())
53            .stdout(Stdio::piped())
54            .stderr(Stdio::inherit())
55            .spawn()
56            .map_err(|e| Error::InvalidInput(format!("failed to spawn embedder: {e}")))?;
57
58        let stdin = child
59            .stdin
60            .take()
61            .ok_or_else(|| Error::InvalidInput("no stdin handle".into()))?;
62        let stdout = child
63            .stdout
64            .take()
65            .ok_or_else(|| Error::InvalidInput("no stdout handle".into()))?;
66
67        Ok(Self {
68            child: std::sync::Mutex::new(SubprocessState {
69                _child: child,
70                stdin,
71                stdout: std::io::BufReader::new(stdout),
72            }),
73        })
74    }
75}
76
77impl Embedder for SubprocessEmbedder {
78    fn embed(&self, text: &str) -> Result<Vec<f32>> {
79        let mut state = self
80            .child
81            .lock()
82            .map_err(|_| Error::InvalidInput("embedder mutex poisoned".into()))?;
83
84        // Send the text (replace newlines so the protocol stays line-based)
85        let sanitised = text.replace('\n', " ");
86        writeln!(state.stdin, "{sanitised}")
87            .map_err(|e| Error::InvalidInput(format!("write to embedder: {e}")))?;
88
89        // Read one line of floats back
90        let mut line = String::new();
91        state
92            .stdout
93            .read_line(&mut line)
94            .map_err(|e| Error::InvalidInput(format!("read from embedder: {e}")))?;
95
96        line.split_whitespace()
97            .map(|s| {
98                s.parse::<f32>()
99                    .map_err(|e| Error::InvalidInput(format!("bad float from embedder: {e}")))
100            })
101            .collect()
102    }
103}
104
105// ─────────────────────────────────────────────────────────────────────────────
106// FixedEmbedder (testing)
107// ─────────────────────────────────────────────────────────────────────────────
108
109/// Deterministic embedder that always returns the same vector.
110/// Useful in unit tests that need an `Embedder` but don't care about the values.
111pub struct FixedEmbedder(pub Vec<f32>);
112
113impl Embedder for FixedEmbedder {
114    fn embed(&self, _text: &str) -> Result<Vec<f32>> {
115        Ok(self.0.clone())
116    }
117}