sqlite_knowledge_graph/rag/embedder.rs
1//! Embedding abstraction for the RAG pipeline.
2//!
3//! Provides a trait-based interface so the engine is not coupled to any
4//! specific embedding backend. In production you'll typically wrap a
5//! Python subprocess or an HTTP API; in tests you can use `FixedEmbedder`.
6
7use crate::error::{Error, Result};
8use std::io::{BufRead, Write};
9use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
10
11/// Trait for converting text into dense float vectors.
12pub trait Embedder: Send + Sync {
13 /// Embed a single text query, returning a normalised float vector.
14 fn embed(&self, text: &str) -> Result<Vec<f32>>;
15}
16
17// ─────────────────────────────────────────────────────────────────────────────
18// SubprocessEmbedder
19// ─────────────────────────────────────────────────────────────────────────────
20
21/// Embedder that talks to a long-lived Python subprocess over stdin/stdout.
22///
23/// The subprocess must implement a simple line protocol:
24/// - stdin: one text per line
25/// - stdout: space-separated floats per line (same order)
26///
27/// Example Python server (`embed_server.py`):
28/// ```python
29/// import sys
30/// from sentence_transformers import SentenceTransformer
31/// model = SentenceTransformer("all-MiniLM-L6-v2")
32/// for line in sys.stdin:
33/// vec = model.encode(line.strip()).tolist()
34/// print(" ".join(map(str, vec)), flush=True)
35/// ```
36pub struct SubprocessEmbedder {
37 child: std::sync::Mutex<SubprocessState>,
38}
39
40struct SubprocessState {
41 _child: Child,
42 stdin: ChildStdin,
43 stdout: std::io::BufReader<ChildStdout>,
44}
45
46impl SubprocessEmbedder {
47 /// Spawn the subprocess. `program` is e.g. `"python3"`,
48 /// `args` is e.g. `&["embed_server.py"]`.
49 pub fn new(program: &str, args: &[&str]) -> Result<Self> {
50 let mut child = Command::new(program)
51 .args(args)
52 .stdin(Stdio::piped())
53 .stdout(Stdio::piped())
54 .stderr(Stdio::inherit())
55 .spawn()
56 .map_err(|e| Error::InvalidInput(format!("failed to spawn embedder: {e}")))?;
57
58 let stdin = child
59 .stdin
60 .take()
61 .ok_or_else(|| Error::InvalidInput("no stdin handle".into()))?;
62 let stdout = child
63 .stdout
64 .take()
65 .ok_or_else(|| Error::InvalidInput("no stdout handle".into()))?;
66
67 Ok(Self {
68 child: std::sync::Mutex::new(SubprocessState {
69 _child: child,
70 stdin,
71 stdout: std::io::BufReader::new(stdout),
72 }),
73 })
74 }
75}
76
77impl Embedder for SubprocessEmbedder {
78 fn embed(&self, text: &str) -> Result<Vec<f32>> {
79 let mut state = self
80 .child
81 .lock()
82 .map_err(|_| Error::InvalidInput("embedder mutex poisoned".into()))?;
83
84 // Send the text (replace newlines so the protocol stays line-based)
85 let sanitised = text.replace('\n', " ");
86 writeln!(state.stdin, "{sanitised}")
87 .map_err(|e| Error::InvalidInput(format!("write to embedder: {e}")))?;
88
89 // Read one line of floats back
90 let mut line = String::new();
91 state
92 .stdout
93 .read_line(&mut line)
94 .map_err(|e| Error::InvalidInput(format!("read from embedder: {e}")))?;
95
96 line.split_whitespace()
97 .map(|s| {
98 s.parse::<f32>()
99 .map_err(|e| Error::InvalidInput(format!("bad float from embedder: {e}")))
100 })
101 .collect()
102 }
103}
104
105// ─────────────────────────────────────────────────────────────────────────────
106// FixedEmbedder (testing)
107// ─────────────────────────────────────────────────────────────────────────────
108
109/// Deterministic embedder that always returns the same vector.
110/// Useful in unit tests that need an `Embedder` but don't care about the values.
111pub struct FixedEmbedder(pub Vec<f32>);
112
113impl Embedder for FixedEmbedder {
114 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
115 Ok(self.0.clone())
116 }
117}