sqlite_graphrag/
embedder.rs1use crate::constants::EMBEDDING_DIM;
15use crate::errors::AppError;
16use crate::extract::llm_embedding::LlmEmbedding;
17use parking_lot::Mutex;
18use std::path::Path;
19use std::sync::OnceLock;
20
21static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
27
28pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
30 if let Some(e) = EMBEDDER.get() {
31 return Ok(e);
32 }
33 let backend = LlmEmbedding::detect_available()?;
34 let _ = EMBEDDER.set(Mutex::new(backend));
35 Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
36}
37
38pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
41 let mut guard = embedder.lock();
42 let result = guard.embed_passage(text)?;
43 Ok(normalise_dim(result))
44}
45
46pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
50 let mut guard = embedder.lock();
51 let result = guard.embed_query(text)?;
52 Ok(normalise_dim(result))
53}
54
55pub fn embed_passages_controlled(
60 embedder: &Mutex<LlmEmbedding>,
61 texts: &[&str],
62 token_counts: &[usize],
63) -> Result<Vec<Vec<f32>>, AppError> {
64 if texts.is_empty() {
65 return Ok(Vec::new());
66 }
67 let mut output: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
68 let mut group: Vec<&str> = Vec::new();
69 let mut current_padded = 0usize;
70 for (text, &tokens) in texts.iter().zip(token_counts.iter()) {
71 let padded = tokens.saturating_add(8);
72 if (current_padded + padded > crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS
73 || group.len() >= crate::constants::REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS)
74 && !group.is_empty()
75 {
76 flush_group(&mut output, &mut group, embedder)?;
77 current_padded = 0;
78 }
79 group.push(text);
80 current_padded += padded;
81 }
82 if !group.is_empty() {
83 flush_group(&mut output, &mut group, embedder)?;
84 }
85 Ok(output)
86}
87
88fn flush_group(
89 output: &mut Vec<Vec<f32>>,
90 group: &mut Vec<&str>,
91 embedder: &Mutex<LlmEmbedding>,
92) -> Result<(), AppError> {
93 let mut guard = embedder.lock();
94 for text in group.iter() {
95 let v = guard.embed_passage(text)?;
96 output.push(normalise_dim(v));
97 }
98 group.clear();
99 Ok(())
100}
101
102pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
103 let embedder = get_embedder(models_dir)?;
104 embed_passage(embedder, text)
105}
106
107pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
108 let embedder = get_embedder(models_dir)?;
109 embed_query(embedder, text)
110}
111
112pub fn embed_passages_controlled_local(
113 models_dir: &Path,
114 texts: &[&str],
115 token_counts: &[usize],
116) -> Result<Vec<Vec<f32>>, AppError> {
117 let embedder = get_embedder(models_dir)?;
118 embed_passages_controlled(embedder, texts, token_counts)
119}
120
121pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
122 let mut out = Vec::with_capacity(v.len() * 4);
123 for f in v {
124 out.extend_from_slice(&f.to_le_bytes());
125 }
126 out
127}
128
129pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
130 let mut out = Vec::with_capacity(bytes.len() / 4);
131 for chunk in bytes.chunks_exact(4) {
132 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
133 }
134 out
135}
136
137pub fn embedding_dim() -> usize {
140 EMBEDDING_DIM
141}
142
143fn normalise_dim(mut v: Vec<f32>) -> Vec<f32> {
144 if v.len() == EMBEDDING_DIM {
145 return v;
146 }
147 if v.len() > EMBEDDING_DIM {
148 v.truncate(EMBEDDING_DIM);
149 } else {
150 v.resize(EMBEDDING_DIM, 0.0);
151 }
152 v
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn f32_to_bytes_roundtrip() {
161 let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
162 let bytes = f32_to_bytes(&input);
163 assert_eq!(bytes.len(), input.len() * 4);
164 let out = bytes_to_f32(&bytes);
165 assert_eq!(out, input);
166 }
167
168 #[test]
169 fn normalise_dim_truncates_and_pads() {
170 let long = vec![0.0; EMBEDDING_DIM + 10];
171 assert_eq!(normalise_dim(long.clone()).len(), EMBEDDING_DIM);
172 let short = vec![0.0; 10];
173 assert_eq!(normalise_dim(short).len(), EMBEDDING_DIM);
174 let exact = vec![0.0; EMBEDDING_DIM];
175 assert_eq!(normalise_dim(exact.clone()).len(), EMBEDDING_DIM);
176 }
177
178 #[test]
179 fn embedding_dim_matches_constant() {
180 assert_eq!(embedding_dim(), crate::constants::EMBEDDING_DIM);
181 }
182}