1use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49static CLAUDE_EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
59static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
60
61static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
67
68pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
72
73pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
77
78pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
80
81fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
89 let base = base.max(1);
90 (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
91}
92
93pub fn chunk_embed_batch_size() -> usize {
95 let dim = crate::constants::embedding_dim();
96 let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
97 tracing::debug!(
98 dim,
99 base = CHUNK_EMBED_BATCH_SIZE,
100 batch,
101 "adaptive chunk batch size (G44)"
102 );
103 batch
104}
105
106pub fn entity_embed_batch_size() -> usize {
108 let dim = crate::constants::embedding_dim();
109 let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
110 tracing::debug!(
111 dim,
112 base = ENTITY_EMBED_BATCH_SIZE,
113 batch,
114 "adaptive entity batch size (G44)"
115 );
116 batch
117}
118
119pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
121 if let Some(rt) = RUNTIME.get() {
122 return Ok(rt);
123 }
124 let rt = tokio::runtime::Builder::new_multi_thread()
125 .worker_threads(2)
126 .enable_all()
127 .build()
128 .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
129 let _ = RUNTIME.set(rt);
130 Ok(RUNTIME.get().expect("RUNTIME initialised above"))
131}
132
133pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
135 if let Some(e) = EMBEDDER.get() {
136 return Ok(e);
137 }
138 let backend = LlmEmbedding::detect_available()?;
139 let _ = EMBEDDER.set(Mutex::new(backend));
140 Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
141}
142
143pub fn get_claude_embedder(
148 claude_binary: Option<&Path>,
149 claude_model: Option<&str>,
150) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
151 if let Some(e) = CLAUDE_EMBEDDER.get() {
152 return Ok(e);
153 }
154 let mut builder = LlmEmbedding::with_claude_builder();
155 if let Some(b) = claude_binary {
156 builder = builder.override_binary(b.to_path_buf());
157 }
158 if let Some(m) = claude_model {
159 builder = builder.override_model(m.to_string());
160 }
161 let backend = builder.build()?;
162 let _ = CLAUDE_EMBEDDER.set(Mutex::new(backend));
163 Ok(CLAUDE_EMBEDDER
164 .get()
165 .expect("CLAUDE_EMBEDDER initialised above"))
166}
167
168pub fn embed_via_claude_local(
172 _models_dir: &Path,
173 text: &str,
174 claude_binary: Option<&Path>,
175 claude_model: Option<&str>,
176) -> Result<Vec<f32>, AppError> {
177 let _slot_guard = acquire_llm_slot_for_embedding()?;
178 let embedder = get_claude_embedder(claude_binary, claude_model)?;
179 embed_passage(embedder, text)
180}
181
182pub fn embed_via_claude_local_resolved(
187 _models_dir: &Path,
188 text: &str,
189 claude_binary: Option<&Path>,
190 claude_model: Option<&str>,
191) -> Result<(Vec<f32>, LlmBackendKind), AppError> {
192 let _slot_guard = acquire_llm_slot_for_embedding()?;
193 let embedder = get_claude_embedder(claude_binary, claude_model)?;
194 let v = embed_passage(embedder, text)?;
195 Ok((v, LlmBackendKind::Claude))
196}
197fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
200 embedder.lock().clone()
201}
202
203pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
207 let client = clone_client(embedder);
208 let result = client.embed_passage(text)?;
209 validate_dim(result)
210}
211
212pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
216 let client = clone_client(embedder);
217 let result = client.embed_query(text)?;
218 validate_dim(result)
219}
220
221pub fn embed_passages_controlled(
226 embedder: &Mutex<LlmEmbedding>,
227 texts: &[&str],
228 _token_counts: &[usize],
229) -> Result<Vec<Vec<f32>>, AppError> {
230 if texts.is_empty() {
231 return Ok(Vec::new());
232 }
233 let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
234 embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
235}
236
237pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
238 let _slot_guard = acquire_llm_slot_for_embedding()?;
239 let embedder = get_embedder(models_dir)?;
240 embed_passage(embedder, text)
241}
242
243pub fn embed_passage_local_resolved(
249 models_dir: &Path,
250 text: &str,
251) -> Result<(Vec<f32>, LlmBackendKind), AppError> {
252 let _slot_guard = acquire_llm_slot_for_embedding()?;
253 let embedder = get_embedder(models_dir)?;
254 let v = embed_passage(embedder, text)?;
255 let kind = match embedder.lock().flavour() {
256 crate::extract::llm_embedding::EmbeddingFlavour::Codex => LlmBackendKind::Codex,
257 crate::extract::llm_embedding::EmbeddingFlavour::Claude => LlmBackendKind::Claude,
258 };
259 Ok((v, kind))
260}
261
262pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
263 let _slot_guard = acquire_llm_slot_for_embedding()?;
264 let embedder = get_embedder(models_dir)?;
265 embed_query(embedder, text)
266}
267
268pub fn embed_passage_with_choice(
285 models_dir: &Path,
286 text: &str,
287 choice: Option<crate::cli::LlmBackendChoice>,
288) -> Result<(Vec<f32>, LlmBackendKind), AppError> {
289 let _slot_guard = acquire_llm_slot_for_embedding()?;
290 match choice {
291 None => {
292 let embedder = get_embedder(models_dir)?;
293 embed_passage(embedder, text).map(|v| (v, LlmBackendKind::None))
294 }
295 Some(choice) => embed_with_fallback(models_dir, text, &choice.to_chain(), false),
296 }
297}
298pub fn try_embed_query_with_choice(
304 models_dir: &Path,
305 text: &str,
306 choice: Option<crate::cli::LlmBackendChoice>,
307) -> Result<(Vec<f32>, LlmBackendKind), FallbackReason> {
308 match embed_passage_with_choice(models_dir, text, choice) {
309 Ok((v, _backend)) if v.is_empty() => Err(FallbackReason::DimZero),
322 Ok((v, backend)) => Ok((v, backend)),
323 Err(e) => Err(classify_embedding_error(e)),
324 }
325}
326fn acquire_llm_slot_for_embedding() -> Result<crate::llm_slots::LlmSlotGuard, AppError> {
338 use crate::constants::{CLI_LOCK_DEFAULT_WAIT_SECS, LLM_WORKER_RSS_MB};
339 let max = std::env::var("SQLITE_GRAPHRAG_LLM_MAX_HOST_CONCURRENCY")
340 .ok()
341 .and_then(|s| s.parse::<u32>().ok())
342 .filter(|n| *n >= 1)
343 .unwrap_or_else(crate::llm_slots::default_max_concurrency);
344 let wait_secs = if std::env::var("SQLITE_GRAPHRAG_LLM_SLOT_NO_WAIT").is_ok() {
345 0
346 } else {
347 std::env::var("SQLITE_GRAPHRAG_LLM_SLOT_WAIT_SECS")
348 .ok()
349 .and_then(|s| s.parse::<u64>().ok())
350 .unwrap_or(CLI_LOCK_DEFAULT_WAIT_SECS)
351 };
352 let _ = LLM_WORKER_RSS_MB; match crate::llm_slots::acquire_llm_slot(max, wait_secs) {
360 Ok(guard) => Ok(guard),
361 Err(e @ AppError::LockBusy { .. }) if wait_secs > 0 => Err(AppError::Embedding(format!(
362 "slot exhausted: {e} (fall back to FTS5)"
363 ))),
364 Err(e) => Err(e),
365 }
366}
367#[derive(Debug, Clone, PartialEq)]
374pub enum FallbackReason {
375 EmbeddingFailed(String),
379 SlotExhausted,
384 OAuthQuota { backend: &'static str },
388 BackendMismatch {
392 requested: &'static str,
393 resolved: &'static str,
394 },
395 DimZero,
400 Cancelled,
402 Timeout {
405 operation: String,
406 duration_secs: u64,
407 },
408}
409
410impl FallbackReason {
411 pub fn reason_code(&self) -> &'static str {
415 match self {
416 Self::EmbeddingFailed(_) => "embedding_failed",
417 Self::SlotExhausted => "slot_exhausted",
418 Self::OAuthQuota { .. } => "oauth_quota",
419 Self::BackendMismatch { .. } => "backend_mismatch",
420 Self::DimZero => "dim_zero",
421 Self::Cancelled => "cancelled",
422 Self::Timeout { .. } => "timeout",
423 }
424 }
425}
426
427impl std::fmt::Display for FallbackReason {
428 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
429 match self {
430 Self::EmbeddingFailed(msg) => write!(f, "embedding failed: {msg}"),
431 Self::SlotExhausted => write!(
432 f,
433 "slot exhausted: failed to acquire LLM slot after backoff window (max=8 concurrent, total backoff=750ms)"
434 ),
435 Self::OAuthQuota { backend } => {
436 write!(f, "OAuth usage quota exhausted on backend '{backend}'")
437 }
438 Self::BackendMismatch {
439 requested,
440 resolved,
441 } => {
442 write!(
443 f,
444 "backend mismatch: user requested '{requested}' but '{resolved}' was invoked"
445 )
446 }
447 Self::DimZero => write!(f, "embedding returned zero-dimensional vector"),
448 Self::Cancelled => write!(f, "embedding cancelled by external signal"),
449 Self::Timeout {
450 operation,
451 duration_secs,
452 } => {
453 write!(
454 f,
455 "embedding timed out after {duration_secs}s during {operation}"
456 )
457 }
458 }
459 }
460}
461
462impl std::error::Error for FallbackReason {}
463
464pub fn try_embed_query_with_fallback(
472 models_dir: &Path,
473 query: &str,
474) -> Result<(Vec<f32>, LlmBackendKind), FallbackReason> {
475 match embed_query_local(models_dir, query) {
476 Ok(v) => Ok((v, LlmBackendKind::None)),
477 Err(e) => Err(classify_embedding_error(e)),
478 }
479}
480
481pub fn try_embed_query_with_deterministic_fallback(
490 models_dir: &Path,
491 query: &str,
492 choice: Option<crate::cli::LlmBackendChoice>,
493) -> Result<(Vec<f32>, LlmBackendKind), FallbackReason> {
494 match try_embed_query_with_choice(models_dir, query, choice) {
495 Ok(t) => Ok(t),
496 Err(reason @ FallbackReason::OAuthQuota { backend }) => {
497 let alt = match backend {
498 "codex" => Some(crate::cli::LlmBackendChoice::Claude),
499 "claude" => Some(crate::cli::LlmBackendChoice::Codex),
500 _ => None,
501 };
502 if let Some(alt_choice) = alt {
503 try_embed_query_with_choice(models_dir, query, Some(alt_choice))
504 } else {
505 Err(reason)
506 }
507 }
508 Err(reason @ FallbackReason::SlotExhausted) => {
509 std::thread::sleep(std::time::Duration::from_millis(750));
510 try_embed_query_with_choice(models_dir, query, choice).or(Err(reason))
511 }
512 Err(other) => Err(other),
513 }
514}
515
516pub fn classify_embedding_error(err: AppError) -> FallbackReason {
524 match err {
525 AppError::Embedding(msg) if msg.contains("cancelled") => FallbackReason::Cancelled,
526 AppError::Embedding(msg) if msg.contains("slot exhausted") => FallbackReason::SlotExhausted,
527 AppError::Embedding(msg) if msg.contains("OAuth") || msg.contains("quota") => {
528 let backend = if msg.contains("codex") {
529 "codex"
530 } else if msg.contains("claude") || msg.contains("anthropic-ratelimit") {
531 "claude"
536 } else {
537 "unknown"
538 };
539 FallbackReason::OAuthQuota { backend }
540 }
541 AppError::Embedding(msg) if msg.contains("backend mismatch") => {
542 let (requested, resolved) =
548 if msg.contains("requested claude") && msg.contains("but codex") {
549 ("claude", "codex")
550 } else if msg.contains("requested codex") && msg.contains("but claude") {
551 ("codex", "claude")
552 } else if msg.contains("requested claude") {
553 ("claude", "unknown")
554 } else if msg.contains("requested codex") {
555 ("codex", "unknown")
556 } else {
557 ("unknown", "unknown")
558 };
559 FallbackReason::BackendMismatch {
560 requested,
561 resolved,
562 }
563 }
564 AppError::Embedding(msg) if msg.contains("dim") && msg.contains("zero") => {
565 FallbackReason::DimZero
566 }
567 AppError::Timeout {
568 operation,
569 duration_secs,
570 } => FallbackReason::Timeout {
571 operation,
572 duration_secs,
573 },
574 AppError::Embedding(msg) => FallbackReason::EmbeddingFailed(msg),
575 e => FallbackReason::EmbeddingFailed(e.to_string()),
576 }
577}
578pub fn embed_with_fallback(
597 models_dir: &Path,
598 text: &str,
599 chain: &[LlmBackendKind],
600 skip_on_failure: bool,
601) -> Result<(Vec<f32>, LlmBackendKind), AppError> {
602 use crate::llm::exit_code_hints::LlmBackendError;
603 let effective: Vec<LlmBackendKind> = if chain.is_empty() {
604 vec![
605 LlmBackendKind::Codex,
606 LlmBackendKind::Claude,
607 LlmBackendKind::None,
608 ]
609 } else {
610 chain.to_vec()
611 };
612
613 let mut last_err: Option<AppError> = None;
614 for backend in &effective {
615 match embed_via_backend(models_dir, text, backend) {
621 Ok((v, resolved_kind)) => return Ok((v, resolved_kind)),
622 Err(e) => {
623 tracing::warn!(
624 target: "embedding",
625 backend = ?backend,
626 error = %e,
627 "embed_with_fallback: backend failed, trying next"
628 );
629 last_err = Some(e);
630 }
631 }
632 }
633 if skip_on_failure {
634 return Ok((Vec::new(), LlmBackendKind::None));
637 }
638 Err(last_err
639 .unwrap_or_else(|| AppError::Embedding(LlmBackendError::NoBackendsAvailable.to_string())))
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
646pub enum LlmBackendKind {
647 Codex,
649 Claude,
651 None,
653}
654
655impl LlmBackendKind {
656 pub fn as_str(self) -> &'static str {
659 match self {
660 Self::Codex => "codex",
661 Self::Claude => "claude",
662 Self::None => "none",
663 }
664 }
665}
666
667pub fn embed_via_backend(
682 models_dir: &Path,
683 text: &str,
684 backend: &LlmBackendKind,
685) -> Result<(Vec<f32>, LlmBackendKind), AppError> {
686 match backend {
687 LlmBackendKind::None => Ok((Vec::new(), LlmBackendKind::None)),
688 LlmBackendKind::Codex => embed_passage_local_resolved(models_dir, text),
689 LlmBackendKind::Claude => {
690 tracing::debug!(
694 target: "embedder",
695 backend = "claude",
696 "embed_via_backend: forcing claude (ADR-0042 / GAP-002 fix)"
697 );
698 embed_via_claude_local_resolved(models_dir, text, None, None)
699 }
700 }
701}
702
703pub fn embed_via_backend_legacy(
708 models_dir: &Path,
709 text: &str,
710 backend: &LlmBackendKind,
711) -> Result<Vec<f32>, AppError> {
712 embed_via_backend(models_dir, text, backend).map(|(v, _)| v)
713}
714
715pub fn embed_passages_controlled_local(
716 models_dir: &Path,
717 texts: &[&str],
718 token_counts: &[usize],
719) -> Result<Vec<Vec<f32>>, AppError> {
720 let embedder = get_embedder(models_dir)?;
721 embed_passages_controlled(embedder, texts, token_counts)
722}
723
724pub fn embed_passages_parallel_local(
727 models_dir: &Path,
728 texts: &[String],
729 parallelism: usize,
730 batch_size: usize,
731) -> Result<Vec<Vec<f32>>, AppError> {
732 let embedder = get_embedder(models_dir)?;
733 embed_texts_parallel(embedder, texts, parallelism, batch_size)
734}
735
736type EntityEmbedCacheMap = std::collections::HashMap<u64, Arc<Vec<f32>>>;
748
749static ENTITY_EMBED_CACHE: OnceLock<parking_lot::Mutex<EntityEmbedCacheMap>> = OnceLock::new();
750
751fn entity_embed_cache() -> &'static parking_lot::Mutex<EntityEmbedCacheMap> {
752 ENTITY_EMBED_CACHE.get_or_init(|| parking_lot::Mutex::new(std::collections::HashMap::new()))
753}
754
755fn entity_cache_key(model: &str, text: &str) -> u64 {
756 let mut hasher = blake3::Hasher::new();
757 hasher.update(model.as_bytes());
758 hasher.update(b"\0");
759 hasher.update(text.as_bytes());
760 let h = hasher.finalize();
761 let bytes = h.as_bytes();
762 u64::from_le_bytes([
763 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
764 ])
765}
766
767pub fn embed_entity_texts_cached(
777 models_dir: &Path,
778 texts: &[String],
779 parallelism: usize,
780) -> Result<(Vec<Vec<f32>>, EmbedCacheStats), AppError> {
781 if texts.is_empty() {
782 return Ok((Vec::new(), EmbedCacheStats::default()));
783 }
784 let embedder = get_embedder(models_dir)?;
785 let model = embedder.lock().model_label();
786 let cache = entity_embed_cache();
787 let mut hits: Vec<Option<Arc<Vec<f32>>>> = vec![None; texts.len()];
788 let mut miss_indices: Vec<usize> = Vec::with_capacity(texts.len());
789 {
790 let guard = cache.lock();
791 for (i, text) in texts.iter().enumerate() {
792 let key = entity_cache_key(&model, text);
793 if let Some(v) = guard.get(&key) {
794 hits[i] = Some(Arc::clone(v));
795 } else {
796 miss_indices.push(i);
797 }
798 }
799 }
800 let miss_count = miss_indices.len();
801 if miss_count > 0 {
802 let miss_texts: Vec<String> = miss_indices.iter().map(|&i| texts[i].clone()).collect();
803 let miss_vecs = embed_texts_parallel(
804 embedder,
805 &miss_texts,
806 parallelism,
807 entity_embed_batch_size(),
808 )?;
809 let mut guard = cache.lock();
810 for (slot, &orig_idx) in miss_indices.iter().enumerate() {
811 let vec = Arc::new(miss_vecs[slot].clone());
812 let key = entity_cache_key(&model, &texts[orig_idx]);
813 guard.insert(key, Arc::clone(&vec));
814 hits[orig_idx] = Some(vec);
815 }
816 }
817 let mut out = Vec::with_capacity(texts.len());
818 for hit in hits.into_iter() {
819 let v = hit.ok_or_else(|| {
820 AppError::Embedding("entity embed cache produced null result".to_string())
821 })?;
822 out.push((*v).clone());
823 }
824 Ok((
825 out,
826 EmbedCacheStats {
827 requested: texts.len(),
828 hits: texts.len() - miss_count,
829 misses: miss_count,
830 },
831 ))
832}
833
834#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, serde::Serialize)]
836pub struct EmbedCacheStats {
837 pub requested: usize,
838 pub hits: usize,
839 pub misses: usize,
840}
841
842impl EmbedCacheStats {
843 pub fn hit_rate(&self) -> f64 {
845 if self.requested == 0 {
846 0.0
847 } else {
848 self.hits as f64 / self.requested as f64
849 }
850 }
851}
852
853pub fn embed_texts_parallel(
866 embedder: &Mutex<LlmEmbedding>,
867 texts: &[String],
868 parallelism: usize,
869 batch_size: usize,
870) -> Result<Vec<Vec<f32>>, AppError> {
871 let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
872 embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
873 slots[idx] = Some(v.to_vec());
874 Ok(())
875 })?;
876 let mut out = Vec::with_capacity(slots.len());
877 for (idx, slot) in slots.into_iter().enumerate() {
878 out.push(slot.ok_or_else(|| {
879 AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
880 })?);
881 }
882 Ok(out)
883}
884
885pub fn embed_texts_parallel_with(
889 embedder: &Mutex<LlmEmbedding>,
890 texts: &[String],
891 parallelism: usize,
892 batch_size: usize,
893 mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
894) -> Result<(), AppError> {
895 if texts.is_empty() {
896 return Ok(());
897 }
898 let dim = crate::constants::embedding_dim();
899 if texts.len() == 1 {
900 let v = embed_passage(embedder, &texts[0])?;
901 return on_result(0, &v);
902 }
903
904 let client = clone_client(embedder);
905 let permits = effective_permits(parallelism);
906 let batches = build_batches(texts, batch_size.max(1));
907 let token = crate::cancel_token().clone();
908
909 let work = move |batch: Vec<(usize, String)>| {
910 let client = client.clone();
911 async move {
912 client
913 .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
914 .await
915 }
916 };
917
918 let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
919 match tokio::runtime::Handle::try_current() {
920 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
921 Err(_) => shared_runtime()?.block_on(fan_out),
922 }
923}
924
925fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
927 texts
928 .iter()
929 .cloned()
930 .enumerate()
931 .collect::<Vec<_>>()
932 .chunks(batch_size)
933 .map(|c| c.to_vec())
934 .collect()
935}
936
937pub fn effective_permits(requested: usize) -> usize {
942 let cpus = std::thread::available_parallelism()
943 .map(|n| n.get())
944 .unwrap_or(4);
945 let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
946 / crate::constants::LLM_WORKER_RSS_MB)
947 .max(1) as usize;
948 requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
949}
950
951async fn run_bounded<F, Fut>(
961 batches: Vec<Vec<(usize, String)>>,
962 permits: usize,
963 dim: usize,
964 token: CancellationToken,
965 work: F,
966 on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
967) -> Result<(), AppError>
968where
969 F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
970 Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
971{
972 let total_batches = batches.len();
973 let semaphore = Arc::new(Semaphore::new(permits));
974 let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
977 let mut set: JoinSet<()> = JoinSet::new();
978
979 for (batch_idx, batch) in batches.into_iter().enumerate() {
980 let sem = Arc::clone(&semaphore);
981 let token = token.clone();
982 let tx = tx.clone();
983 let work = work.clone();
984 set.spawn(async move {
985 let wait_start = std::time::Instant::now();
986 let Ok(_permit) = sem.acquire_owned().await else {
989 let _ = tx
990 .send(Err(AppError::Embedding("semaphore closed".to_string())))
991 .await;
992 return;
993 };
994 let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
995 let work_start = std::time::Instant::now();
996 let outcome = if crate::should_obey_shutdown() {
1002 tokio::select! {
1003 res = work(batch) => res,
1004 _ = token.cancelled() => Err(AppError::Embedding(
1005 "embedding cancelled by shutdown signal".to_string(),
1006 )),
1007 }
1008 } else {
1009 work(batch).await
1010 };
1011 tracing::debug!(
1013 target: "embedding",
1014 batch_idx,
1015 permit_wait_ms,
1016 work_ms = work_start.elapsed().as_millis() as u64,
1017 ok = outcome.is_ok(),
1018 "embedding batch finished"
1019 );
1020 let _ = tx.send(outcome).await;
1021 });
1022 }
1023 drop(tx);
1024
1025 let mut completed = 0usize;
1026 let mut failed = 0usize;
1027 let mut cancelled = 0usize;
1028 let mut first_error: Option<AppError> = None;
1029
1030 while let Some(message) = rx.recv().await {
1031 match message {
1032 Ok(items) => {
1033 completed += 1;
1034 if first_error.is_none() {
1035 for (idx, v) in items {
1036 if v.len() != dim {
1037 first_error = Some(AppError::Embedding(format!(
1038 "LLM returned {} dims for item {idx}, expected {dim}; \
1039 refusing to truncate or pad silently (G42/C5)",
1040 v.len()
1041 )));
1042 break;
1043 }
1044 if let Err(e) = on_result(idx, &v) {
1045 first_error = Some(e);
1046 break;
1047 }
1048 }
1049 if first_error.is_some() {
1050 set.shutdown().await;
1053 }
1054 }
1055 }
1056 Err(e) => {
1057 if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
1058 cancelled += 1;
1059 } else {
1060 failed += 1;
1061 }
1062 if first_error.is_none() {
1063 first_error = Some(e);
1064 set.shutdown().await;
1065 }
1066 }
1067 }
1068 }
1069
1070 while let Some(join_result) = set.join_next().await {
1073 if let Err(join_err) = join_result {
1074 if join_err.is_panic() {
1075 failed += 1;
1076 if first_error.is_none() {
1077 first_error = Some(AppError::Embedding(format!(
1078 "embedding task panicked: {join_err}"
1079 )));
1080 }
1081 } else {
1082 cancelled += 1;
1083 }
1084 }
1085 }
1086
1087 tracing::debug!(
1097 target: "embedding",
1098 total_batches,
1099 completed,
1100 failed,
1101 cancelled,
1102 "embedding fan-out finished"
1103 );
1104
1105 match first_error {
1106 Some(e) => Err(e),
1107 None => Ok(()),
1108 }
1109}
1110
1111pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
1112 let mut out = Vec::with_capacity(v.len() * 4);
1113 for f in v {
1114 out.extend_from_slice(&f.to_le_bytes());
1115 }
1116 out
1117}
1118
1119pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
1120 let mut out = Vec::with_capacity(bytes.len() / 4);
1121 for chunk in bytes.chunks_exact(4) {
1122 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
1123 }
1124 out
1125}
1126
1127pub fn embedding_dim() -> usize {
1130 crate::constants::embedding_dim()
1131}
1132
1133fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
1137 let dim = crate::constants::embedding_dim();
1138 if v.len() != dim {
1139 return Err(AppError::Embedding(format!(
1140 "embedding has {} dims, expected {dim}; \
1141 refusing to truncate or pad silently (G42/C5)",
1142 v.len()
1143 )));
1144 }
1145 Ok(v)
1146}
1147
1148#[cfg(test)]
1149mod tests {
1150 use super::*;
1151 use std::sync::atomic::{AtomicUsize, Ordering};
1152
1153 #[test]
1154 fn f32_to_bytes_roundtrip() {
1155 let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
1156 let bytes = f32_to_bytes(&input);
1157 assert_eq!(bytes.len(), input.len() * 4);
1158 let out = bytes_to_f32(&bytes);
1159 assert_eq!(out, input);
1160 }
1161
1162 #[test]
1163 fn validate_dim_rejects_divergent_vectors() {
1164 let dim = crate::constants::embedding_dim();
1167 let long = vec![0.0; dim + 10];
1168 assert!(validate_dim(long).is_err(), "longer vector must error");
1169 let short = vec![0.0; dim.saturating_sub(1).max(1)];
1170 assert!(validate_dim(short).is_err(), "shorter vector must error");
1171 let exact = vec![0.0; dim];
1172 assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
1173 }
1174
1175 #[test]
1176 fn embedding_dim_matches_constants_source() {
1177 assert_eq!(embedding_dim(), crate::constants::embedding_dim());
1178 }
1179
1180 #[test]
1181 fn build_batches_preserves_global_indices() {
1182 let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
1183 let batches = build_batches(&texts, 4);
1184 assert_eq!(batches.len(), 3);
1185 assert_eq!(batches[0].len(), 4);
1186 assert_eq!(batches[2].len(), 2);
1187 assert_eq!(batches[2][1].0, 9);
1188 assert_eq!(batches[2][1].1, "t9");
1189 }
1190
1191 #[test]
1192 fn effective_permits_clamps_to_bounds() {
1193 assert!(effective_permits(0) >= 1);
1194 assert!(effective_permits(1000) <= 32);
1195 }
1196
1197 fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
1198 (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
1199 }
1200
1201 fn dummy_vec(dim: usize) -> Vec<f32> {
1202 vec![0.0; dim]
1203 }
1204
1205 #[test]
1208 fn concurrency_peak_never_exceeds_permits() {
1209 let permits = 4usize;
1210 let batches = test_batches(permits * 10);
1211 let dim = crate::constants::embedding_dim();
1212 let current = Arc::new(AtomicUsize::new(0));
1213 let peak = Arc::new(AtomicUsize::new(0));
1214
1215 let current_c = Arc::clone(¤t);
1216 let peak_c = Arc::clone(&peak);
1217 let work = move |batch: Vec<(usize, String)>| {
1218 let current = Arc::clone(¤t_c);
1219 let peak = Arc::clone(&peak_c);
1220 async move {
1221 let now = current.fetch_add(1, Ordering::SeqCst) + 1;
1222 peak.fetch_max(now, Ordering::SeqCst);
1223 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1224 current.fetch_sub(1, Ordering::SeqCst);
1225 Ok(batch
1226 .into_iter()
1227 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
1228 .collect())
1229 }
1230 };
1231
1232 let mut delivered = 0usize;
1233 let rt = tokio::runtime::Builder::new_multi_thread()
1234 .worker_threads(4)
1235 .enable_all()
1236 .build()
1237 .expect("test runtime");
1238 rt.block_on(run_bounded(
1239 batches,
1240 permits,
1241 dim,
1242 CancellationToken::new(),
1243 work,
1244 &mut |_idx, _v| {
1245 delivered += 1;
1246 Ok(())
1247 },
1248 ))
1249 .expect("fan-out must succeed");
1250
1251 assert_eq!(delivered, permits * 10, "every item must be delivered");
1252 assert!(
1253 peak.load(Ordering::SeqCst) <= permits,
1254 "peak concurrency {} exceeded permits {permits}",
1255 peak.load(Ordering::SeqCst)
1256 );
1257 }
1258
1259 #[test]
1262 fn panicking_task_returns_permit_and_surfaces_error() {
1263 let permits = 2usize;
1264 let batches = test_batches(4);
1265 let dim = crate::constants::embedding_dim();
1266
1267 let work = move |batch: Vec<(usize, String)>| async move {
1268 if batch[0].0 == 1 {
1269 panic!("intentional test panic");
1270 }
1271 Ok(batch
1272 .into_iter()
1273 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
1274 .collect())
1275 };
1276
1277 let rt = tokio::runtime::Builder::new_multi_thread()
1278 .worker_threads(2)
1279 .enable_all()
1280 .build()
1281 .expect("test runtime");
1282 let result = rt.block_on(run_bounded(
1283 batches,
1284 permits,
1285 dim,
1286 CancellationToken::new(),
1287 work,
1288 &mut |_idx, _v| Ok(()),
1289 ));
1290
1291 let err = result.expect_err("panic must surface as an error");
1292 assert!(
1293 err.to_string().contains("panicked"),
1294 "error must mention the panic: {err}"
1295 );
1296 }
1297
1298 #[test]
1301 fn cancellation_terminates_fan_out_quickly() {
1302 let permits = 2usize;
1303 let batches = test_batches(8);
1304 let dim = crate::constants::embedding_dim();
1305 let token = CancellationToken::new();
1306
1307 let work = move |batch: Vec<(usize, String)>| async move {
1308 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
1310 Ok(batch
1311 .into_iter()
1312 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
1313 .collect())
1314 };
1315
1316 let rt = tokio::runtime::Builder::new_multi_thread()
1317 .worker_threads(2)
1318 .enable_all()
1319 .build()
1320 .expect("test runtime");
1321 let cancel = token.clone();
1322 let start = std::time::Instant::now();
1323 let result = rt.block_on(async move {
1324 tokio::spawn(async move {
1325 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1326 cancel.cancel();
1327 });
1328 run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
1329 });
1330
1331 assert!(result.is_err(), "cancelled fan-out must report an error");
1332 assert!(
1333 start.elapsed() < std::time::Duration::from_secs(10),
1334 "graceful shutdown must finish well under the work duration"
1335 );
1336 }
1337
1338 #[test]
1341 fn fan_out_rejects_divergent_dim() {
1342 let permits = 2usize;
1343 let batches = test_batches(2);
1344 let dim = crate::constants::embedding_dim();
1345
1346 let work = move |batch: Vec<(usize, String)>| async move {
1347 Ok(batch
1348 .into_iter()
1349 .map(|(i, _)| (i, vec![0.0f32; 3]))
1350 .collect::<Vec<(usize, Vec<f32>)>>())
1351 };
1352
1353 let rt = tokio::runtime::Builder::new_multi_thread()
1354 .worker_threads(2)
1355 .enable_all()
1356 .build()
1357 .expect("test runtime");
1358 let result = rt.block_on(run_bounded(
1359 batches,
1360 permits,
1361 dim,
1362 CancellationToken::new(),
1363 work,
1364 &mut |_idx, _v| Ok(()),
1365 ));
1366
1367 let err = result.expect_err("divergent dim must fail the fan-out");
1368 assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
1369 }
1370
1371 #[test]
1373 fn adaptive_batch_dim64_keeps_calibrated_sizes() {
1374 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
1375 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
1376 }
1377
1378 #[test]
1380 fn adaptive_batch_dim384_shrinks() {
1381 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
1382 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
1383 }
1384
1385 #[test]
1387 fn adaptive_batch_intermediate_dims() {
1388 assert_eq!(adaptive_batch_for_dim(8, 128), 4);
1389 assert_eq!(adaptive_batch_for_dim(8, 256), 2);
1390 }
1391
1392 #[test]
1394 fn adaptive_batch_small_dim_clamps_to_base() {
1395 assert_eq!(adaptive_batch_for_dim(8, 8), 8);
1396 }
1397
1398 #[test]
1400 fn adaptive_batch_total_function() {
1401 assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
1402 assert_eq!(adaptive_batch_for_dim(8, 0), 8);
1403 assert_eq!(adaptive_batch_for_dim(0, 64), 1);
1404 }
1405
1406 #[test]
1408 #[serial_test::serial(env)]
1409 fn adaptive_wrappers_follow_env_dim() {
1410 std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
1411 let chunk = chunk_embed_batch_size();
1412 let entity = entity_embed_batch_size();
1413 std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
1414 crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
1415 assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
1416 assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
1417 }
1418
1419 #[test]
1425 fn fallback_reason_display_does_not_panic() {
1426 let _ = FallbackReason::EmbeddingFailed("rate limit".into()).to_string();
1427 let _ = FallbackReason::Cancelled.to_string();
1428 let _ = FallbackReason::Timeout {
1429 operation: "embed_query".into(),
1430 duration_secs: 30,
1431 }
1432 .to_string();
1433 }
1434
1435 #[test]
1438 fn fallback_reason_is_partial_eq() {
1439 assert_eq!(
1440 FallbackReason::EmbeddingFailed("a".into()),
1441 FallbackReason::EmbeddingFailed("a".into())
1442 );
1443 assert_eq!(FallbackReason::Cancelled, FallbackReason::Cancelled);
1444 assert_ne!(
1445 FallbackReason::EmbeddingFailed("a".into()),
1446 FallbackReason::EmbeddingFailed("b".into())
1447 );
1448 assert_ne!(
1449 FallbackReason::Cancelled,
1450 FallbackReason::Timeout {
1451 operation: "x".into(),
1452 duration_secs: 1
1453 }
1454 );
1455 }
1456
1457 #[test]
1460 fn fallback_reason_timeout_preserves_fields() {
1461 let r = FallbackReason::Timeout {
1462 operation: "embed_query_local".into(),
1463 duration_secs: 300,
1464 };
1465 match r {
1466 FallbackReason::Timeout {
1467 operation,
1468 duration_secs,
1469 } => {
1470 assert_eq!(operation, "embed_query_local");
1471 assert_eq!(duration_secs, 300);
1472 }
1473 other => panic!("expected Timeout, got {other:?}"),
1474 }
1475 }
1476
1477 #[test]
1483 #[ignore = "G58 S1 stub: requires env without codex/claude on PATH; tracked as T5 of Fase 2"]
1484 fn try_embed_query_with_fallback_surfaces_embedding_failed_for_missing_binary() {
1485 let bogus = std::path::Path::new("/nonexistent-models-dir-for-g58-fallback-test");
1488 let result = try_embed_query_with_fallback(bogus, "hello world");
1489 match result {
1490 Err(FallbackReason::EmbeddingFailed(msg)) => {
1491 assert!(!msg.is_empty(), "fallback message must not be empty");
1493 }
1494 Err(FallbackReason::Cancelled) => {
1495 panic!("expected EmbeddingFailed, got Cancelled");
1496 }
1497 Err(FallbackReason::Timeout { .. }) => {
1498 panic!("expected EmbeddingFailed, got Timeout");
1499 }
1500 Err(FallbackReason::SlotExhausted) => {
1501 panic!("expected EmbeddingFailed, got SlotExhausted");
1502 }
1503 Err(FallbackReason::OAuthQuota { .. }) => {
1504 panic!("expected EmbeddingFailed, got OAuthQuota");
1505 }
1506 Err(FallbackReason::BackendMismatch { .. }) => {
1507 panic!("expected EmbeddingFailed, got BackendMismatch");
1508 }
1509 Err(FallbackReason::DimZero) => {
1510 panic!("expected EmbeddingFailed, got DimZero");
1511 }
1512 Ok(_) => {
1513 panic!("expected an error, got Ok — embedder must fail for bogus path");
1514 }
1515 }
1516 }
1517
1518 #[test]
1520 fn g56_entity_cache_key_is_stable_and_distinct() {
1521 let k1 = entity_cache_key("codex:default", "sqlite-graphrag");
1522 let k2 = entity_cache_key("codex:default", "sqlite-graphrag");
1523 let k3 = entity_cache_key("codex:default", "claude-code");
1524 let k4 = entity_cache_key("claude:default", "sqlite-graphrag");
1525 assert_eq!(k1, k2, "same model+text must hash identically");
1526 assert_ne!(k1, k3, "different text must hash differently");
1527 assert_ne!(k1, k4, "different model must hash differently");
1528 }
1529
1530 #[test]
1531 fn g56_entity_embed_cache_stats_hit_rate() {
1532 let zero = EmbedCacheStats::default();
1533 assert_eq!(zero.hit_rate(), 0.0);
1534 let half = EmbedCacheStats {
1535 requested: 4,
1536 hits: 2,
1537 misses: 2,
1538 };
1539 assert!((half.hit_rate() - 0.5).abs() < 1e-9);
1540 let all = EmbedCacheStats {
1541 requested: 7,
1542 hits: 7,
1543 misses: 0,
1544 };
1545 assert!((all.hit_rate() - 1.0).abs() < 1e-9);
1546 }
1547
1548 #[test]
1549 fn g56_entity_embed_cache_populates_and_hits() {
1550 let cache = entity_embed_cache();
1554 let model = "test-model";
1555 let text = "sqlite-graphrag";
1556 let key = entity_cache_key(model, text);
1557 let stored = Arc::new(vec![0.42_f32; crate::constants::embedding_dim()]);
1558 cache.lock().insert(key, Arc::clone(&stored));
1559 let guard = cache.lock();
1560 let hit = guard.get(&key).expect("cache must return stored value");
1561 assert_eq!(hit.len(), crate::constants::embedding_dim());
1562 assert!((hit[0] - 0.42).abs() < 1e-6);
1563 }
1564
1565 #[test]
1566 fn g56_empty_texts_short_circuits_with_zero_stats() {
1567 let stats = EmbedCacheStats::default();
1570 assert_eq!(stats.requested, 0);
1571 assert_eq!(stats.hits, 0);
1572 assert_eq!(stats.misses, 0);
1573 assert_eq!(stats.hit_rate(), 0.0);
1574 }
1575}
1576
1577#[cfg(test)]
1581mod embed_with_fallback_tests {
1582 use super::*;
1583 use crate::llm::exit_code_hints::LlmBackendError;
1584
1585 #[test]
1586 fn none_backend_returns_empty_vector_without_calling_llm() {
1587 let (v, kind) = embed_via_backend(
1591 std::path::Path::new("/nonexistent"),
1592 "any text",
1593 &LlmBackendKind::None,
1594 )
1595 .expect("None backend never fails");
1596 assert!(v.is_empty());
1597 assert_eq!(kind, LlmBackendKind::None, "None backend must report None");
1598 }
1599
1600 #[test]
1601 fn empty_chain_defaults_to_codex_claude_none() {
1602 let defaults = [
1606 LlmBackendKind::Codex,
1607 LlmBackendKind::Claude,
1608 LlmBackendKind::None,
1609 ];
1610
1611 #[allow(dead_code)]
1616 fn llm_backend_kind_as_str_is_stable() {
1617 assert_eq!(LlmBackendKind::Codex.as_str(), "codex");
1618 assert_eq!(LlmBackendKind::Claude.as_str(), "claude");
1619 assert_eq!(LlmBackendKind::None.as_str(), "none");
1620 }
1621
1622 #[allow(dead_code)]
1623 fn fallback_reason_reason_code_is_stable() {
1624 assert_eq!(
1625 FallbackReason::EmbeddingFailed("any".into()).reason_code(),
1626 "embedding_failed"
1627 );
1628 assert_eq!(FallbackReason::Cancelled.reason_code(), "cancelled");
1629 assert_eq!(
1630 FallbackReason::Timeout {
1631 operation: "embed_query".into(),
1632 duration_secs: 30
1633 }
1634 .reason_code(),
1635 "timeout"
1636 );
1637 }
1638 assert_eq!(defaults.len(), 3);
1639 }
1640
1641 #[test]
1642 fn embed_with_fallback_succeeds_via_none_when_chain_exhausts() {
1643 let chain = vec![LlmBackendKind::None];
1657 let v = embed_with_fallback(
1658 std::path::Path::new("/nonexistent-models-dir-for-gap005-test"),
1659 "hello",
1660 &chain,
1661 false,
1662 )
1663 .expect("chain ending in None must always succeed");
1664 assert!(v.0.is_empty(), "vector must be empty");
1665 assert_eq!(v.1, LlmBackendKind::None);
1666 }
1667 #[test]
1668 fn embed_with_fallback_skip_on_failure_with_only_none_returns_empty() {
1669 let chain = vec![LlmBackendKind::None];
1674 let v = embed_with_fallback(
1675 std::path::Path::new("/nonexistent-models-dir-for-gap005-test"),
1676 "hello",
1677 &chain,
1678 true,
1679 )
1680 .expect("None chain is always Ok");
1681 assert!(v.0.is_empty(), "vector must be empty");
1682 assert_eq!(v.1, LlmBackendKind::None);
1683 }
1684 #[allow(dead_code)]
1685 fn llm_backend_error_no_backends_default_message() {
1686 let e = LlmBackendError::NoBackendsAvailable;
1689 let h = e.hint();
1690 assert!(h.contains("--llm-fallback"));
1691 }
1692
1693 #[test]
1694 fn llm_backend_error_nonzero_exit_carries_stderr_tail() {
1695 let e = LlmBackendError::NonZeroExit {
1696 exit_code: Some(137),
1697 signal: Some(9),
1698 stdout_tail: "out".into(),
1699 stderr_tail: "OOM killed".into(),
1700 binary: "codex".into(),
1701 hint: "OOM".into(),
1702 };
1703 let s = e.to_string();
1704 assert!(s.contains("codex"));
1705 assert!(s.contains("OOM killed"));
1706 assert!(s.contains("signal 9") || s.contains("exit 137"));
1707 }
1708}