1use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
54
55static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
61
62pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
66
67pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
71
72pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
74
75fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
83 let base = base.max(1);
84 (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
85}
86
87pub fn chunk_embed_batch_size() -> usize {
89 let dim = crate::constants::embedding_dim();
90 let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
91 tracing::debug!(
92 dim,
93 base = CHUNK_EMBED_BATCH_SIZE,
94 batch,
95 "adaptive chunk batch size (G44)"
96 );
97 batch
98}
99
100pub fn entity_embed_batch_size() -> usize {
102 let dim = crate::constants::embedding_dim();
103 let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
104 tracing::debug!(
105 dim,
106 base = ENTITY_EMBED_BATCH_SIZE,
107 batch,
108 "adaptive entity batch size (G44)"
109 );
110 batch
111}
112
113pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
115 if let Some(rt) = RUNTIME.get() {
116 return Ok(rt);
117 }
118 let rt = tokio::runtime::Builder::new_multi_thread()
119 .worker_threads(2)
120 .enable_all()
121 .build()
122 .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
123 let _ = RUNTIME.set(rt);
124 Ok(RUNTIME.get().expect("RUNTIME initialised above"))
125}
126
127pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
129 if let Some(e) = EMBEDDER.get() {
130 return Ok(e);
131 }
132 let backend = LlmEmbedding::detect_available()?;
133 let _ = EMBEDDER.set(Mutex::new(backend));
134 Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
135}
136
137fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
140 embedder.lock().clone()
141}
142
143pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
147 let client = clone_client(embedder);
148 let result = client.embed_passage(text)?;
149 validate_dim(result)
150}
151
152pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
156 let client = clone_client(embedder);
157 let result = client.embed_query(text)?;
158 validate_dim(result)
159}
160
161pub fn embed_passages_controlled(
166 embedder: &Mutex<LlmEmbedding>,
167 texts: &[&str],
168 _token_counts: &[usize],
169) -> Result<Vec<Vec<f32>>, AppError> {
170 if texts.is_empty() {
171 return Ok(Vec::new());
172 }
173 let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
174 embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
175}
176
177pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
178 let embedder = get_embedder(models_dir)?;
179 embed_passage(embedder, text)
180}
181
182pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
183 let embedder = get_embedder(models_dir)?;
184 embed_query(embedder, text)
185}
186#[derive(Debug, Clone, PartialEq)]
193pub enum FallbackReason {
194 EmbeddingFailed(String),
198 Cancelled,
200 Timeout {
203 operation: String,
204 duration_secs: u64,
205 },
206}
207
208impl std::fmt::Display for FallbackReason {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 Self::EmbeddingFailed(msg) => write!(f, "embedding failed: {msg}"),
212 Self::Cancelled => write!(f, "embedding cancelled by external signal"),
213 Self::Timeout {
214 operation,
215 duration_secs,
216 } => {
217 write!(
218 f,
219 "embedding timed out after {duration_secs}s during {operation}"
220 )
221 }
222 }
223 }
224}
225
226impl std::error::Error for FallbackReason {}
227
228pub fn try_embed_query_with_fallback(
236 models_dir: &Path,
237 query: &str,
238) -> Result<Vec<f32>, FallbackReason> {
239 match embed_query_local(models_dir, query) {
240 Ok(v) => Ok(v),
241 Err(AppError::Embedding(msg)) if msg.contains("cancelled") => {
242 Err(FallbackReason::Cancelled)
243 }
244 Err(AppError::Embedding(msg)) => Err(FallbackReason::EmbeddingFailed(msg)),
245 Err(AppError::Timeout {
246 operation,
247 duration_secs,
248 }) => Err(FallbackReason::Timeout {
249 operation,
250 duration_secs,
251 }),
252 Err(e) => Err(FallbackReason::EmbeddingFailed(e.to_string())),
253 }
254}
255
256pub fn embed_passages_controlled_local(
257 models_dir: &Path,
258 texts: &[&str],
259 token_counts: &[usize],
260) -> Result<Vec<Vec<f32>>, AppError> {
261 let embedder = get_embedder(models_dir)?;
262 embed_passages_controlled(embedder, texts, token_counts)
263}
264
265pub fn embed_passages_parallel_local(
268 models_dir: &Path,
269 texts: &[String],
270 parallelism: usize,
271 batch_size: usize,
272) -> Result<Vec<Vec<f32>>, AppError> {
273 let embedder = get_embedder(models_dir)?;
274 embed_texts_parallel(embedder, texts, parallelism, batch_size)
275}
276
277type EntityEmbedCacheMap = std::collections::HashMap<u64, Arc<Vec<f32>>>;
289
290static ENTITY_EMBED_CACHE: OnceLock<parking_lot::Mutex<EntityEmbedCacheMap>> = OnceLock::new();
291
292fn entity_embed_cache() -> &'static parking_lot::Mutex<EntityEmbedCacheMap> {
293 ENTITY_EMBED_CACHE.get_or_init(|| parking_lot::Mutex::new(std::collections::HashMap::new()))
294}
295
296fn entity_cache_key(model: &str, text: &str) -> u64 {
297 let mut hasher = blake3::Hasher::new();
298 hasher.update(model.as_bytes());
299 hasher.update(b"\0");
300 hasher.update(text.as_bytes());
301 let h = hasher.finalize();
302 let bytes = h.as_bytes();
303 u64::from_le_bytes([
304 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
305 ])
306}
307
308pub fn embed_entity_texts_cached(
318 models_dir: &Path,
319 texts: &[String],
320 parallelism: usize,
321) -> Result<(Vec<Vec<f32>>, EmbedCacheStats), AppError> {
322 if texts.is_empty() {
323 return Ok((Vec::new(), EmbedCacheStats::default()));
324 }
325 let embedder = get_embedder(models_dir)?;
326 let model = embedder.lock().model_label();
327 let cache = entity_embed_cache();
328 let mut hits: Vec<Option<Arc<Vec<f32>>>> = vec![None; texts.len()];
329 let mut miss_indices: Vec<usize> = Vec::with_capacity(texts.len());
330 {
331 let guard = cache.lock();
332 for (i, text) in texts.iter().enumerate() {
333 let key = entity_cache_key(&model, text);
334 if let Some(v) = guard.get(&key) {
335 hits[i] = Some(Arc::clone(v));
336 } else {
337 miss_indices.push(i);
338 }
339 }
340 }
341 let miss_count = miss_indices.len();
342 if miss_count > 0 {
343 let miss_texts: Vec<String> = miss_indices.iter().map(|&i| texts[i].clone()).collect();
344 let miss_vecs = embed_texts_parallel(
345 embedder,
346 &miss_texts,
347 parallelism,
348 entity_embed_batch_size(),
349 )?;
350 let mut guard = cache.lock();
351 for (slot, &orig_idx) in miss_indices.iter().enumerate() {
352 let vec = Arc::new(miss_vecs[slot].clone());
353 let key = entity_cache_key(&model, &texts[orig_idx]);
354 guard.insert(key, Arc::clone(&vec));
355 hits[orig_idx] = Some(vec);
356 }
357 }
358 let mut out = Vec::with_capacity(texts.len());
359 for hit in hits.into_iter() {
360 let v = hit.ok_or_else(|| {
361 AppError::Embedding("entity embed cache produced null result".to_string())
362 })?;
363 out.push((*v).clone());
364 }
365 Ok((
366 out,
367 EmbedCacheStats {
368 requested: texts.len(),
369 hits: texts.len() - miss_count,
370 misses: miss_count,
371 },
372 ))
373}
374
375#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, serde::Serialize)]
377pub struct EmbedCacheStats {
378 pub requested: usize,
379 pub hits: usize,
380 pub misses: usize,
381}
382
383impl EmbedCacheStats {
384 pub fn hit_rate(&self) -> f64 {
386 if self.requested == 0 {
387 0.0
388 } else {
389 self.hits as f64 / self.requested as f64
390 }
391 }
392}
393
394pub fn embed_texts_parallel(
407 embedder: &Mutex<LlmEmbedding>,
408 texts: &[String],
409 parallelism: usize,
410 batch_size: usize,
411) -> Result<Vec<Vec<f32>>, AppError> {
412 let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
413 embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
414 slots[idx] = Some(v.to_vec());
415 Ok(())
416 })?;
417 let mut out = Vec::with_capacity(slots.len());
418 for (idx, slot) in slots.into_iter().enumerate() {
419 out.push(slot.ok_or_else(|| {
420 AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
421 })?);
422 }
423 Ok(out)
424}
425
426pub fn embed_texts_parallel_with(
430 embedder: &Mutex<LlmEmbedding>,
431 texts: &[String],
432 parallelism: usize,
433 batch_size: usize,
434 mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
435) -> Result<(), AppError> {
436 if texts.is_empty() {
437 return Ok(());
438 }
439 let dim = crate::constants::embedding_dim();
440 if texts.len() == 1 {
441 let v = embed_passage(embedder, &texts[0])?;
442 return on_result(0, &v);
443 }
444
445 let client = clone_client(embedder);
446 let permits = effective_permits(parallelism);
447 let batches = build_batches(texts, batch_size.max(1));
448 let token = crate::cancel_token().clone();
449
450 let work = move |batch: Vec<(usize, String)>| {
451 let client = client.clone();
452 async move {
453 client
454 .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
455 .await
456 }
457 };
458
459 let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
460 match tokio::runtime::Handle::try_current() {
461 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
462 Err(_) => shared_runtime()?.block_on(fan_out),
463 }
464}
465
466fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
468 texts
469 .iter()
470 .cloned()
471 .enumerate()
472 .collect::<Vec<_>>()
473 .chunks(batch_size)
474 .map(|c| c.to_vec())
475 .collect()
476}
477
478pub fn effective_permits(requested: usize) -> usize {
483 let cpus = std::thread::available_parallelism()
484 .map(|n| n.get())
485 .unwrap_or(4);
486 let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
487 / crate::constants::LLM_WORKER_RSS_MB)
488 .max(1) as usize;
489 requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
490}
491
492async fn run_bounded<F, Fut>(
502 batches: Vec<Vec<(usize, String)>>,
503 permits: usize,
504 dim: usize,
505 token: CancellationToken,
506 work: F,
507 on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
508) -> Result<(), AppError>
509where
510 F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
511 Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
512{
513 let total_batches = batches.len();
514 let semaphore = Arc::new(Semaphore::new(permits));
515 let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
518 let mut set: JoinSet<()> = JoinSet::new();
519
520 for (batch_idx, batch) in batches.into_iter().enumerate() {
521 let sem = Arc::clone(&semaphore);
522 let token = token.clone();
523 let tx = tx.clone();
524 let work = work.clone();
525 set.spawn(async move {
526 let wait_start = std::time::Instant::now();
527 let Ok(_permit) = sem.acquire_owned().await else {
530 let _ = tx
531 .send(Err(AppError::Embedding("semaphore closed".to_string())))
532 .await;
533 return;
534 };
535 let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
536 let work_start = std::time::Instant::now();
537 let outcome = if crate::should_obey_shutdown() {
543 tokio::select! {
544 res = work(batch) => res,
545 _ = token.cancelled() => Err(AppError::Embedding(
546 "embedding cancelled by shutdown signal".to_string(),
547 )),
548 }
549 } else {
550 work(batch).await
551 };
552 tracing::debug!(
554 target: "embedding",
555 batch_idx,
556 permit_wait_ms,
557 work_ms = work_start.elapsed().as_millis() as u64,
558 ok = outcome.is_ok(),
559 "embedding batch finished"
560 );
561 let _ = tx.send(outcome).await;
562 });
563 }
564 drop(tx);
565
566 let mut completed = 0usize;
567 let mut failed = 0usize;
568 let mut cancelled = 0usize;
569 let mut first_error: Option<AppError> = None;
570
571 while let Some(message) = rx.recv().await {
572 match message {
573 Ok(items) => {
574 completed += 1;
575 if first_error.is_none() {
576 for (idx, v) in items {
577 if v.len() != dim {
578 first_error = Some(AppError::Embedding(format!(
579 "LLM returned {} dims for item {idx}, expected {dim}; \
580 refusing to truncate or pad silently (G42/C5)",
581 v.len()
582 )));
583 break;
584 }
585 if let Err(e) = on_result(idx, &v) {
586 first_error = Some(e);
587 break;
588 }
589 }
590 if first_error.is_some() {
591 set.shutdown().await;
594 }
595 }
596 }
597 Err(e) => {
598 if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
599 cancelled += 1;
600 } else {
601 failed += 1;
602 }
603 if first_error.is_none() {
604 first_error = Some(e);
605 set.shutdown().await;
606 }
607 }
608 }
609 }
610
611 while let Some(join_result) = set.join_next().await {
614 if let Err(join_err) = join_result {
615 if join_err.is_panic() {
616 failed += 1;
617 if first_error.is_none() {
618 first_error = Some(AppError::Embedding(format!(
619 "embedding task panicked: {join_err}"
620 )));
621 }
622 } else {
623 cancelled += 1;
624 }
625 }
626 }
627
628 tracing::info!(
631 target: "embedding",
632 total_batches,
633 completed,
634 failed,
635 cancelled,
636 available_permits = semaphore.available_permits(),
637 "embedding fan-out finished"
638 );
639
640 match first_error {
641 Some(e) => Err(e),
642 None => Ok(()),
643 }
644}
645
646pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
647 let mut out = Vec::with_capacity(v.len() * 4);
648 for f in v {
649 out.extend_from_slice(&f.to_le_bytes());
650 }
651 out
652}
653
654pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
655 let mut out = Vec::with_capacity(bytes.len() / 4);
656 for chunk in bytes.chunks_exact(4) {
657 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
658 }
659 out
660}
661
662pub fn embedding_dim() -> usize {
665 crate::constants::embedding_dim()
666}
667
668fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
672 let dim = crate::constants::embedding_dim();
673 if v.len() != dim {
674 return Err(AppError::Embedding(format!(
675 "embedding has {} dims, expected {dim}; \
676 refusing to truncate or pad silently (G42/C5)",
677 v.len()
678 )));
679 }
680 Ok(v)
681}
682
683#[cfg(test)]
684mod tests {
685 use super::*;
686 use std::sync::atomic::{AtomicUsize, Ordering};
687
688 #[test]
689 fn f32_to_bytes_roundtrip() {
690 let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
691 let bytes = f32_to_bytes(&input);
692 assert_eq!(bytes.len(), input.len() * 4);
693 let out = bytes_to_f32(&bytes);
694 assert_eq!(out, input);
695 }
696
697 #[test]
698 fn validate_dim_rejects_divergent_vectors() {
699 let dim = crate::constants::embedding_dim();
702 let long = vec![0.0; dim + 10];
703 assert!(validate_dim(long).is_err(), "longer vector must error");
704 let short = vec![0.0; dim.saturating_sub(1).max(1)];
705 assert!(validate_dim(short).is_err(), "shorter vector must error");
706 let exact = vec![0.0; dim];
707 assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
708 }
709
710 #[test]
711 fn embedding_dim_matches_constants_source() {
712 assert_eq!(embedding_dim(), crate::constants::embedding_dim());
713 }
714
715 #[test]
716 fn build_batches_preserves_global_indices() {
717 let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
718 let batches = build_batches(&texts, 4);
719 assert_eq!(batches.len(), 3);
720 assert_eq!(batches[0].len(), 4);
721 assert_eq!(batches[2].len(), 2);
722 assert_eq!(batches[2][1].0, 9);
723 assert_eq!(batches[2][1].1, "t9");
724 }
725
726 #[test]
727 fn effective_permits_clamps_to_bounds() {
728 assert!(effective_permits(0) >= 1);
729 assert!(effective_permits(1000) <= 32);
730 }
731
732 fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
733 (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
734 }
735
736 fn dummy_vec(dim: usize) -> Vec<f32> {
737 vec![0.0; dim]
738 }
739
740 #[test]
743 fn concurrency_peak_never_exceeds_permits() {
744 let permits = 4usize;
745 let batches = test_batches(permits * 10);
746 let dim = crate::constants::embedding_dim();
747 let current = Arc::new(AtomicUsize::new(0));
748 let peak = Arc::new(AtomicUsize::new(0));
749
750 let current_c = Arc::clone(¤t);
751 let peak_c = Arc::clone(&peak);
752 let work = move |batch: Vec<(usize, String)>| {
753 let current = Arc::clone(¤t_c);
754 let peak = Arc::clone(&peak_c);
755 async move {
756 let now = current.fetch_add(1, Ordering::SeqCst) + 1;
757 peak.fetch_max(now, Ordering::SeqCst);
758 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
759 current.fetch_sub(1, Ordering::SeqCst);
760 Ok(batch
761 .into_iter()
762 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
763 .collect())
764 }
765 };
766
767 let mut delivered = 0usize;
768 let rt = tokio::runtime::Builder::new_multi_thread()
769 .worker_threads(4)
770 .enable_all()
771 .build()
772 .expect("test runtime");
773 rt.block_on(run_bounded(
774 batches,
775 permits,
776 dim,
777 CancellationToken::new(),
778 work,
779 &mut |_idx, _v| {
780 delivered += 1;
781 Ok(())
782 },
783 ))
784 .expect("fan-out must succeed");
785
786 assert_eq!(delivered, permits * 10, "every item must be delivered");
787 assert!(
788 peak.load(Ordering::SeqCst) <= permits,
789 "peak concurrency {} exceeded permits {permits}",
790 peak.load(Ordering::SeqCst)
791 );
792 }
793
794 #[test]
797 fn panicking_task_returns_permit_and_surfaces_error() {
798 let permits = 2usize;
799 let batches = test_batches(4);
800 let dim = crate::constants::embedding_dim();
801
802 let work = move |batch: Vec<(usize, String)>| async move {
803 if batch[0].0 == 1 {
804 panic!("intentional test panic");
805 }
806 Ok(batch
807 .into_iter()
808 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
809 .collect())
810 };
811
812 let rt = tokio::runtime::Builder::new_multi_thread()
813 .worker_threads(2)
814 .enable_all()
815 .build()
816 .expect("test runtime");
817 let result = rt.block_on(run_bounded(
818 batches,
819 permits,
820 dim,
821 CancellationToken::new(),
822 work,
823 &mut |_idx, _v| Ok(()),
824 ));
825
826 let err = result.expect_err("panic must surface as an error");
827 assert!(
828 err.to_string().contains("panicked"),
829 "error must mention the panic: {err}"
830 );
831 }
832
833 #[test]
836 fn cancellation_terminates_fan_out_quickly() {
837 let permits = 2usize;
838 let batches = test_batches(8);
839 let dim = crate::constants::embedding_dim();
840 let token = CancellationToken::new();
841
842 let work = move |batch: Vec<(usize, String)>| async move {
843 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
845 Ok(batch
846 .into_iter()
847 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
848 .collect())
849 };
850
851 let rt = tokio::runtime::Builder::new_multi_thread()
852 .worker_threads(2)
853 .enable_all()
854 .build()
855 .expect("test runtime");
856 let cancel = token.clone();
857 let start = std::time::Instant::now();
858 let result = rt.block_on(async move {
859 tokio::spawn(async move {
860 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
861 cancel.cancel();
862 });
863 run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
864 });
865
866 assert!(result.is_err(), "cancelled fan-out must report an error");
867 assert!(
868 start.elapsed() < std::time::Duration::from_secs(10),
869 "graceful shutdown must finish well under the work duration"
870 );
871 }
872
873 #[test]
876 fn fan_out_rejects_divergent_dim() {
877 let permits = 2usize;
878 let batches = test_batches(2);
879 let dim = crate::constants::embedding_dim();
880
881 let work = move |batch: Vec<(usize, String)>| async move {
882 Ok(batch
883 .into_iter()
884 .map(|(i, _)| (i, vec![0.0f32; 3]))
885 .collect::<Vec<(usize, Vec<f32>)>>())
886 };
887
888 let rt = tokio::runtime::Builder::new_multi_thread()
889 .worker_threads(2)
890 .enable_all()
891 .build()
892 .expect("test runtime");
893 let result = rt.block_on(run_bounded(
894 batches,
895 permits,
896 dim,
897 CancellationToken::new(),
898 work,
899 &mut |_idx, _v| Ok(()),
900 ));
901
902 let err = result.expect_err("divergent dim must fail the fan-out");
903 assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
904 }
905
906 #[test]
908 fn adaptive_batch_dim64_keeps_calibrated_sizes() {
909 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
910 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
911 }
912
913 #[test]
915 fn adaptive_batch_dim384_shrinks() {
916 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
917 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
918 }
919
920 #[test]
922 fn adaptive_batch_intermediate_dims() {
923 assert_eq!(adaptive_batch_for_dim(8, 128), 4);
924 assert_eq!(adaptive_batch_for_dim(8, 256), 2);
925 }
926
927 #[test]
929 fn adaptive_batch_small_dim_clamps_to_base() {
930 assert_eq!(adaptive_batch_for_dim(8, 8), 8);
931 }
932
933 #[test]
935 fn adaptive_batch_total_function() {
936 assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
937 assert_eq!(adaptive_batch_for_dim(8, 0), 8);
938 assert_eq!(adaptive_batch_for_dim(0, 64), 1);
939 }
940
941 #[test]
943 #[serial_test::serial(env)]
944 fn adaptive_wrappers_follow_env_dim() {
945 std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
946 let chunk = chunk_embed_batch_size();
947 let entity = entity_embed_batch_size();
948 std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
949 crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
950 assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
951 assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
952 }
953
954 #[test]
960 fn fallback_reason_display_does_not_panic() {
961 let _ = FallbackReason::EmbeddingFailed("rate limit".into()).to_string();
962 let _ = FallbackReason::Cancelled.to_string();
963 let _ = FallbackReason::Timeout {
964 operation: "embed_query".into(),
965 duration_secs: 30,
966 }
967 .to_string();
968 }
969
970 #[test]
973 fn fallback_reason_is_partial_eq() {
974 assert_eq!(
975 FallbackReason::EmbeddingFailed("a".into()),
976 FallbackReason::EmbeddingFailed("a".into())
977 );
978 assert_eq!(FallbackReason::Cancelled, FallbackReason::Cancelled);
979 assert_ne!(
980 FallbackReason::EmbeddingFailed("a".into()),
981 FallbackReason::EmbeddingFailed("b".into())
982 );
983 assert_ne!(
984 FallbackReason::Cancelled,
985 FallbackReason::Timeout {
986 operation: "x".into(),
987 duration_secs: 1
988 }
989 );
990 }
991
992 #[test]
995 fn fallback_reason_timeout_preserves_fields() {
996 let r = FallbackReason::Timeout {
997 operation: "embed_query_local".into(),
998 duration_secs: 300,
999 };
1000 match r {
1001 FallbackReason::Timeout {
1002 operation,
1003 duration_secs,
1004 } => {
1005 assert_eq!(operation, "embed_query_local");
1006 assert_eq!(duration_secs, 300);
1007 }
1008 other => panic!("expected Timeout, got {other:?}"),
1009 }
1010 }
1011
1012 #[test]
1018 #[ignore = "G58 S1 stub: requires env without codex/claude on PATH; tracked as T5 of Fase 2"]
1019 fn try_embed_query_with_fallback_surfaces_embedding_failed_for_missing_binary() {
1020 let bogus = std::path::Path::new("/nonexistent-models-dir-for-g58-fallback-test");
1023 let result = try_embed_query_with_fallback(bogus, "hello world");
1024 match result {
1025 Err(FallbackReason::EmbeddingFailed(msg)) => {
1026 assert!(!msg.is_empty(), "fallback message must not be empty");
1028 }
1029 Err(FallbackReason::Cancelled) => {
1030 panic!("expected EmbeddingFailed, got Cancelled");
1031 }
1032 Err(FallbackReason::Timeout { .. }) => {
1033 panic!("expected EmbeddingFailed, got Timeout");
1034 }
1035 Ok(_) => {
1036 panic!("expected an error, got Ok — embedder must fail for bogus path");
1037 }
1038 }
1039 }
1040
1041 #[test]
1043 fn g56_entity_cache_key_is_stable_and_distinct() {
1044 let k1 = entity_cache_key("codex:default", "sqlite-graphrag");
1045 let k2 = entity_cache_key("codex:default", "sqlite-graphrag");
1046 let k3 = entity_cache_key("codex:default", "claude-code");
1047 let k4 = entity_cache_key("claude:default", "sqlite-graphrag");
1048 assert_eq!(k1, k2, "same model+text must hash identically");
1049 assert_ne!(k1, k3, "different text must hash differently");
1050 assert_ne!(k1, k4, "different model must hash differently");
1051 }
1052
1053 #[test]
1054 fn g56_entity_embed_cache_stats_hit_rate() {
1055 let zero = EmbedCacheStats::default();
1056 assert_eq!(zero.hit_rate(), 0.0);
1057 let half = EmbedCacheStats {
1058 requested: 4,
1059 hits: 2,
1060 misses: 2,
1061 };
1062 assert!((half.hit_rate() - 0.5).abs() < 1e-9);
1063 let all = EmbedCacheStats {
1064 requested: 7,
1065 hits: 7,
1066 misses: 0,
1067 };
1068 assert!((all.hit_rate() - 1.0).abs() < 1e-9);
1069 }
1070
1071 #[test]
1072 fn g56_entity_embed_cache_populates_and_hits() {
1073 let cache = entity_embed_cache();
1077 let model = "test-model";
1078 let text = "sqlite-graphrag";
1079 let key = entity_cache_key(model, text);
1080 let stored = Arc::new(vec![0.42_f32; crate::constants::embedding_dim()]);
1081 cache.lock().insert(key, Arc::clone(&stored));
1082 let guard = cache.lock();
1083 let hit = guard.get(&key).expect("cache must return stored value");
1084 assert_eq!(hit.len(), crate::constants::embedding_dim());
1085 assert!((hit[0] - 0.42).abs() < 1e-6);
1086 }
1087
1088 #[test]
1089 fn g56_empty_texts_short_circuits_with_zero_stats() {
1090 let stats = EmbedCacheStats::default();
1093 assert_eq!(stats.requested, 0);
1094 assert_eq!(stats.hits, 0);
1095 assert_eq!(stats.misses, 0);
1096 assert_eq!(stats.hit_rate(), 0.0);
1097 }
1098}