1use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
54
55static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
61
62pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
66
67pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
71
72pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
74
75fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
83 let base = base.max(1);
84 (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
85}
86
87pub fn chunk_embed_batch_size() -> usize {
89 let dim = crate::constants::embedding_dim();
90 let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
91 tracing::debug!(
92 dim,
93 base = CHUNK_EMBED_BATCH_SIZE,
94 batch,
95 "adaptive chunk batch size (G44)"
96 );
97 batch
98}
99
100pub fn entity_embed_batch_size() -> usize {
102 let dim = crate::constants::embedding_dim();
103 let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
104 tracing::debug!(
105 dim,
106 base = ENTITY_EMBED_BATCH_SIZE,
107 batch,
108 "adaptive entity batch size (G44)"
109 );
110 batch
111}
112
113pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
115 if let Some(rt) = RUNTIME.get() {
116 return Ok(rt);
117 }
118 let rt = tokio::runtime::Builder::new_multi_thread()
119 .worker_threads(2)
120 .enable_all()
121 .build()
122 .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
123 let _ = RUNTIME.set(rt);
124 Ok(RUNTIME.get().expect("RUNTIME initialised above"))
125}
126
127pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
129 if let Some(e) = EMBEDDER.get() {
130 return Ok(e);
131 }
132 let backend = LlmEmbedding::detect_available()?;
133 let _ = EMBEDDER.set(Mutex::new(backend));
134 Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
135}
136
137fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
140 embedder.lock().clone()
141}
142
143pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
147 let client = clone_client(embedder);
148 let result = client.embed_passage(text)?;
149 validate_dim(result)
150}
151
152pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
156 let client = clone_client(embedder);
157 let result = client.embed_query(text)?;
158 validate_dim(result)
159}
160
161pub fn embed_passages_controlled(
166 embedder: &Mutex<LlmEmbedding>,
167 texts: &[&str],
168 _token_counts: &[usize],
169) -> Result<Vec<Vec<f32>>, AppError> {
170 if texts.is_empty() {
171 return Ok(Vec::new());
172 }
173 let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
174 embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
175}
176
177pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
178 let embedder = get_embedder(models_dir)?;
179 embed_passage(embedder, text)
180}
181
182pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
183 let embedder = get_embedder(models_dir)?;
184 embed_query(embedder, text)
185}
186
187pub fn embed_passages_controlled_local(
188 models_dir: &Path,
189 texts: &[&str],
190 token_counts: &[usize],
191) -> Result<Vec<Vec<f32>>, AppError> {
192 let embedder = get_embedder(models_dir)?;
193 embed_passages_controlled(embedder, texts, token_counts)
194}
195
196pub fn embed_passages_parallel_local(
199 models_dir: &Path,
200 texts: &[String],
201 parallelism: usize,
202 batch_size: usize,
203) -> Result<Vec<Vec<f32>>, AppError> {
204 let embedder = get_embedder(models_dir)?;
205 embed_texts_parallel(embedder, texts, parallelism, batch_size)
206}
207
208pub fn embed_texts_parallel(
221 embedder: &Mutex<LlmEmbedding>,
222 texts: &[String],
223 parallelism: usize,
224 batch_size: usize,
225) -> Result<Vec<Vec<f32>>, AppError> {
226 let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
227 embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
228 slots[idx] = Some(v.to_vec());
229 Ok(())
230 })?;
231 let mut out = Vec::with_capacity(slots.len());
232 for (idx, slot) in slots.into_iter().enumerate() {
233 out.push(slot.ok_or_else(|| {
234 AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
235 })?);
236 }
237 Ok(out)
238}
239
240pub fn embed_texts_parallel_with(
244 embedder: &Mutex<LlmEmbedding>,
245 texts: &[String],
246 parallelism: usize,
247 batch_size: usize,
248 mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
249) -> Result<(), AppError> {
250 if texts.is_empty() {
251 return Ok(());
252 }
253 let dim = crate::constants::embedding_dim();
254 if texts.len() == 1 {
255 let v = embed_passage(embedder, &texts[0])?;
256 return on_result(0, &v);
257 }
258
259 let client = clone_client(embedder);
260 let permits = effective_permits(parallelism);
261 let batches = build_batches(texts, batch_size.max(1));
262 let token = crate::cancel_token().clone();
263
264 let work = move |batch: Vec<(usize, String)>| {
265 let client = client.clone();
266 async move {
267 client
268 .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
269 .await
270 }
271 };
272
273 let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
274 match tokio::runtime::Handle::try_current() {
275 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
276 Err(_) => shared_runtime()?.block_on(fan_out),
277 }
278}
279
280fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
282 texts
283 .iter()
284 .cloned()
285 .enumerate()
286 .collect::<Vec<_>>()
287 .chunks(batch_size)
288 .map(|c| c.to_vec())
289 .collect()
290}
291
292pub fn effective_permits(requested: usize) -> usize {
297 let cpus = std::thread::available_parallelism()
298 .map(|n| n.get())
299 .unwrap_or(4);
300 let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
301 / crate::constants::LLM_WORKER_RSS_MB)
302 .max(1) as usize;
303 requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
304}
305
306async fn run_bounded<F, Fut>(
316 batches: Vec<Vec<(usize, String)>>,
317 permits: usize,
318 dim: usize,
319 token: CancellationToken,
320 work: F,
321 on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
322) -> Result<(), AppError>
323where
324 F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
325 Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
326{
327 let total_batches = batches.len();
328 let semaphore = Arc::new(Semaphore::new(permits));
329 let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
332 let mut set: JoinSet<()> = JoinSet::new();
333
334 for (batch_idx, batch) in batches.into_iter().enumerate() {
335 let sem = Arc::clone(&semaphore);
336 let token = token.clone();
337 let tx = tx.clone();
338 let work = work.clone();
339 set.spawn(async move {
340 let wait_start = std::time::Instant::now();
341 let Ok(_permit) = sem.acquire_owned().await else {
344 let _ = tx
345 .send(Err(AppError::Embedding("semaphore closed".to_string())))
346 .await;
347 return;
348 };
349 let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
350 let work_start = std::time::Instant::now();
351 let outcome = tokio::select! {
352 res = work(batch) => res,
353 _ = token.cancelled() => Err(AppError::Embedding(
354 "embedding cancelled by shutdown signal".to_string(),
355 )),
356 };
357 tracing::debug!(
359 target: "embedding",
360 batch_idx,
361 permit_wait_ms,
362 work_ms = work_start.elapsed().as_millis() as u64,
363 ok = outcome.is_ok(),
364 "embedding batch finished"
365 );
366 let _ = tx.send(outcome).await;
367 });
368 }
369 drop(tx);
370
371 let mut completed = 0usize;
372 let mut failed = 0usize;
373 let mut cancelled = 0usize;
374 let mut first_error: Option<AppError> = None;
375
376 while let Some(message) = rx.recv().await {
377 match message {
378 Ok(items) => {
379 completed += 1;
380 if first_error.is_none() {
381 for (idx, v) in items {
382 if v.len() != dim {
383 first_error = Some(AppError::Embedding(format!(
384 "LLM returned {} dims for item {idx}, expected {dim}; \
385 refusing to truncate or pad silently (G42/C5)",
386 v.len()
387 )));
388 break;
389 }
390 if let Err(e) = on_result(idx, &v) {
391 first_error = Some(e);
392 break;
393 }
394 }
395 if first_error.is_some() {
396 set.shutdown().await;
399 }
400 }
401 }
402 Err(e) => {
403 if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
404 cancelled += 1;
405 } else {
406 failed += 1;
407 }
408 if first_error.is_none() {
409 first_error = Some(e);
410 set.shutdown().await;
411 }
412 }
413 }
414 }
415
416 while let Some(join_result) = set.join_next().await {
419 if let Err(join_err) = join_result {
420 if join_err.is_panic() {
421 failed += 1;
422 if first_error.is_none() {
423 first_error = Some(AppError::Embedding(format!(
424 "embedding task panicked: {join_err}"
425 )));
426 }
427 } else {
428 cancelled += 1;
429 }
430 }
431 }
432
433 tracing::info!(
436 target: "embedding",
437 total_batches,
438 completed,
439 failed,
440 cancelled,
441 available_permits = semaphore.available_permits(),
442 "embedding fan-out finished"
443 );
444
445 match first_error {
446 Some(e) => Err(e),
447 None => Ok(()),
448 }
449}
450
451pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
452 let mut out = Vec::with_capacity(v.len() * 4);
453 for f in v {
454 out.extend_from_slice(&f.to_le_bytes());
455 }
456 out
457}
458
459pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
460 let mut out = Vec::with_capacity(bytes.len() / 4);
461 for chunk in bytes.chunks_exact(4) {
462 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
463 }
464 out
465}
466
467pub fn embedding_dim() -> usize {
470 crate::constants::embedding_dim()
471}
472
473fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
477 let dim = crate::constants::embedding_dim();
478 if v.len() != dim {
479 return Err(AppError::Embedding(format!(
480 "embedding has {} dims, expected {dim}; \
481 refusing to truncate or pad silently (G42/C5)",
482 v.len()
483 )));
484 }
485 Ok(v)
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use std::sync::atomic::{AtomicUsize, Ordering};
492
493 #[test]
494 fn f32_to_bytes_roundtrip() {
495 let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
496 let bytes = f32_to_bytes(&input);
497 assert_eq!(bytes.len(), input.len() * 4);
498 let out = bytes_to_f32(&bytes);
499 assert_eq!(out, input);
500 }
501
502 #[test]
503 fn validate_dim_rejects_divergent_vectors() {
504 let dim = crate::constants::embedding_dim();
507 let long = vec![0.0; dim + 10];
508 assert!(validate_dim(long).is_err(), "longer vector must error");
509 let short = vec![0.0; dim.saturating_sub(1).max(1)];
510 assert!(validate_dim(short).is_err(), "shorter vector must error");
511 let exact = vec![0.0; dim];
512 assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
513 }
514
515 #[test]
516 fn embedding_dim_matches_constants_source() {
517 assert_eq!(embedding_dim(), crate::constants::embedding_dim());
518 }
519
520 #[test]
521 fn build_batches_preserves_global_indices() {
522 let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
523 let batches = build_batches(&texts, 4);
524 assert_eq!(batches.len(), 3);
525 assert_eq!(batches[0].len(), 4);
526 assert_eq!(batches[2].len(), 2);
527 assert_eq!(batches[2][1].0, 9);
528 assert_eq!(batches[2][1].1, "t9");
529 }
530
531 #[test]
532 fn effective_permits_clamps_to_bounds() {
533 assert!(effective_permits(0) >= 1);
534 assert!(effective_permits(1000) <= 32);
535 }
536
537 fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
538 (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
539 }
540
541 fn dummy_vec(dim: usize) -> Vec<f32> {
542 vec![0.0; dim]
543 }
544
545 #[test]
548 fn concurrency_peak_never_exceeds_permits() {
549 let permits = 4usize;
550 let batches = test_batches(permits * 10);
551 let dim = crate::constants::embedding_dim();
552 let current = Arc::new(AtomicUsize::new(0));
553 let peak = Arc::new(AtomicUsize::new(0));
554
555 let current_c = Arc::clone(¤t);
556 let peak_c = Arc::clone(&peak);
557 let work = move |batch: Vec<(usize, String)>| {
558 let current = Arc::clone(¤t_c);
559 let peak = Arc::clone(&peak_c);
560 async move {
561 let now = current.fetch_add(1, Ordering::SeqCst) + 1;
562 peak.fetch_max(now, Ordering::SeqCst);
563 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
564 current.fetch_sub(1, Ordering::SeqCst);
565 Ok(batch
566 .into_iter()
567 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
568 .collect())
569 }
570 };
571
572 let mut delivered = 0usize;
573 let rt = tokio::runtime::Builder::new_multi_thread()
574 .worker_threads(4)
575 .enable_all()
576 .build()
577 .expect("test runtime");
578 rt.block_on(run_bounded(
579 batches,
580 permits,
581 dim,
582 CancellationToken::new(),
583 work,
584 &mut |_idx, _v| {
585 delivered += 1;
586 Ok(())
587 },
588 ))
589 .expect("fan-out must succeed");
590
591 assert_eq!(delivered, permits * 10, "every item must be delivered");
592 assert!(
593 peak.load(Ordering::SeqCst) <= permits,
594 "peak concurrency {} exceeded permits {permits}",
595 peak.load(Ordering::SeqCst)
596 );
597 }
598
599 #[test]
602 fn panicking_task_returns_permit_and_surfaces_error() {
603 let permits = 2usize;
604 let batches = test_batches(4);
605 let dim = crate::constants::embedding_dim();
606
607 let work = move |batch: Vec<(usize, String)>| async move {
608 if batch[0].0 == 1 {
609 panic!("intentional test panic");
610 }
611 Ok(batch
612 .into_iter()
613 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
614 .collect())
615 };
616
617 let rt = tokio::runtime::Builder::new_multi_thread()
618 .worker_threads(2)
619 .enable_all()
620 .build()
621 .expect("test runtime");
622 let result = rt.block_on(run_bounded(
623 batches,
624 permits,
625 dim,
626 CancellationToken::new(),
627 work,
628 &mut |_idx, _v| Ok(()),
629 ));
630
631 let err = result.expect_err("panic must surface as an error");
632 assert!(
633 err.to_string().contains("panicked"),
634 "error must mention the panic: {err}"
635 );
636 }
637
638 #[test]
641 fn cancellation_terminates_fan_out_quickly() {
642 let permits = 2usize;
643 let batches = test_batches(8);
644 let dim = crate::constants::embedding_dim();
645 let token = CancellationToken::new();
646
647 let work = move |batch: Vec<(usize, String)>| async move {
648 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
650 Ok(batch
651 .into_iter()
652 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
653 .collect())
654 };
655
656 let rt = tokio::runtime::Builder::new_multi_thread()
657 .worker_threads(2)
658 .enable_all()
659 .build()
660 .expect("test runtime");
661 let cancel = token.clone();
662 let start = std::time::Instant::now();
663 let result = rt.block_on(async move {
664 tokio::spawn(async move {
665 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
666 cancel.cancel();
667 });
668 run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
669 });
670
671 assert!(result.is_err(), "cancelled fan-out must report an error");
672 assert!(
673 start.elapsed() < std::time::Duration::from_secs(10),
674 "graceful shutdown must finish well under the work duration"
675 );
676 }
677
678 #[test]
681 fn fan_out_rejects_divergent_dim() {
682 let permits = 2usize;
683 let batches = test_batches(2);
684 let dim = crate::constants::embedding_dim();
685
686 let work = move |batch: Vec<(usize, String)>| async move {
687 Ok(batch
688 .into_iter()
689 .map(|(i, _)| (i, vec![0.0f32; 3]))
690 .collect::<Vec<(usize, Vec<f32>)>>())
691 };
692
693 let rt = tokio::runtime::Builder::new_multi_thread()
694 .worker_threads(2)
695 .enable_all()
696 .build()
697 .expect("test runtime");
698 let result = rt.block_on(run_bounded(
699 batches,
700 permits,
701 dim,
702 CancellationToken::new(),
703 work,
704 &mut |_idx, _v| Ok(()),
705 ));
706
707 let err = result.expect_err("divergent dim must fail the fan-out");
708 assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
709 }
710
711 #[test]
713 fn adaptive_batch_dim64_keeps_calibrated_sizes() {
714 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
715 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
716 }
717
718 #[test]
720 fn adaptive_batch_dim384_shrinks() {
721 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
722 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
723 }
724
725 #[test]
727 fn adaptive_batch_intermediate_dims() {
728 assert_eq!(adaptive_batch_for_dim(8, 128), 4);
729 assert_eq!(adaptive_batch_for_dim(8, 256), 2);
730 }
731
732 #[test]
734 fn adaptive_batch_small_dim_clamps_to_base() {
735 assert_eq!(adaptive_batch_for_dim(8, 8), 8);
736 }
737
738 #[test]
740 fn adaptive_batch_total_function() {
741 assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
742 assert_eq!(adaptive_batch_for_dim(8, 0), 8);
743 assert_eq!(adaptive_batch_for_dim(0, 64), 1);
744 }
745
746 #[test]
748 #[serial_test::serial(env)]
749 fn adaptive_wrappers_follow_env_dim() {
750 std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
751 let chunk = chunk_embed_batch_size();
752 let entity = entity_embed_batch_size();
753 std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
754 crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
755 assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
756 assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
757 }
758}