1use anyhow::{Result, anyhow};
27use reqwest::Client;
28use serde::{Deserialize, Serialize};
29use std::time::Duration;
30
31pub const DEFAULT_REQUIRED_DIMENSION: usize = 2560;
32pub const DEFAULT_OLLAMA_EMBEDDING_MODEL: &str = "qwen3-embedding:4b";
33const DEFAULT_MAX_BATCH_RETRIES: usize = 10;
34const DEFAULT_MAX_BATCH_BACKOFF_SECS: u64 = 30;
35
36#[derive(Debug, Serialize)]
41struct EmbeddingRequest {
42 input: Vec<String>,
43 model: String,
44}
45
46#[derive(Debug, Deserialize)]
47struct EmbeddingResponse {
48 data: Vec<EmbeddingData>,
49}
50
51#[derive(Debug, Deserialize)]
52struct EmbeddingData {
53 embedding: Vec<f32>,
54}
55
56#[derive(Debug, Serialize)]
57struct RerankRequest {
58 query: String,
59 documents: Vec<String>,
60 model: String,
61}
62
63#[derive(Debug, Deserialize)]
64struct RerankResponse {
65 results: Vec<RerankResult>,
66}
67
68#[derive(Debug, Deserialize)]
69struct RerankResult {
70 index: usize,
71 score: f32,
72}
73
74#[derive(Debug, Clone, Deserialize, Serialize, Default)]
80pub struct ProviderConfig {
81 #[serde(default)]
83 pub name: String,
84 #[serde(default)]
86 pub base_url: String,
87 #[serde(default)]
89 pub model: String,
90 #[serde(default = "default_priority")]
92 pub priority: u8,
93 #[serde(default = "default_embeddings_endpoint")]
95 pub endpoint: String,
96}
97
98fn default_priority() -> u8 {
99 10
100}
101
102fn default_embeddings_endpoint() -> String {
103 "/v1/embeddings".to_string()
104}
105
106fn env_usize(name: &str, default: usize) -> usize {
107 std::env::var(name)
108 .ok()
109 .and_then(|value| value.parse::<usize>().ok())
110 .filter(|value| *value > 0)
111 .unwrap_or(default)
112}
113
114fn env_u64(name: &str, default: u64) -> u64 {
115 std::env::var(name)
116 .ok()
117 .and_then(|value| value.parse::<u64>().ok())
118 .filter(|value| *value > 0)
119 .unwrap_or(default)
120}
121
122#[derive(Debug, Clone, Deserialize, Serialize, Default)]
124pub struct RerankerConfig {
125 pub base_url: Option<String>,
127 pub model: Option<String>,
129 #[serde(default = "default_rerank_endpoint")]
131 pub endpoint: String,
132}
133
134fn default_rerank_endpoint() -> String {
135 "/v1/rerank".to_string()
136}
137
138fn default_dimension() -> usize {
139 DEFAULT_REQUIRED_DIMENSION
140}
141
142fn default_max_batch_chars() -> usize {
143 128000 }
145
146fn default_max_batch_items() -> usize {
147 64 }
149
150fn build_provider_endpoint(base_url: &str, endpoint: &str) -> String {
151 let base_url = base_url.trim_end_matches('/');
152 let endpoint = endpoint.trim();
153 if endpoint.starts_with('/') {
154 format!("{}{}", base_url, endpoint)
155 } else {
156 format!("{}/{}", base_url, endpoint)
157 }
158}
159
160#[derive(Debug, Clone, Deserialize, Serialize)]
162pub struct EmbeddingConfig {
163 #[serde(default = "default_dimension")]
165 pub required_dimension: usize,
166 #[serde(default = "default_max_batch_chars")]
168 pub max_batch_chars: usize,
169 #[serde(default = "default_max_batch_items")]
171 pub max_batch_items: usize,
172 #[serde(default)]
174 pub providers: Vec<ProviderConfig>,
175 #[serde(default)]
177 pub reranker: RerankerConfig,
178}
179
180impl Default for EmbeddingConfig {
181 fn default() -> Self {
182 Self {
183 required_dimension: default_dimension(),
184 max_batch_chars: default_max_batch_chars(),
185 max_batch_items: default_max_batch_items(),
186 providers: vec![
187 ProviderConfig {
188 name: "ollama-local".to_string(),
189 base_url: "http://localhost:11434".to_string(),
190 model: DEFAULT_OLLAMA_EMBEDDING_MODEL.to_string(),
191 priority: 1,
192 endpoint: default_embeddings_endpoint(),
193 },
194 ProviderConfig {
195 name: "dragon".to_string(),
196 base_url: "http://dragon:12345".to_string(),
197 model: "Qwen/Qwen3-Embedding-4B".to_string(),
198 priority: 2,
199 endpoint: default_embeddings_endpoint(),
200 },
201 ],
202 reranker: RerankerConfig::default(),
203 }
204 }
205}
206
207impl EmbeddingConfig {
208 pub fn provider_name(&self) -> String {
210 self.providers
211 .first()
212 .map(|p| p.name.clone())
213 .unwrap_or_else(|| "none".to_string())
214 }
215
216 pub fn model_name(&self) -> String {
218 self.providers
219 .first()
220 .map(|p| p.model.clone())
221 .unwrap_or_else(|| "none".to_string())
222 }
223
224 pub fn dimension(&self) -> usize {
226 self.required_dimension
227 }
228}
229
230#[derive(Debug, Clone)]
236pub struct MlxConfig {
237 pub disabled: bool,
238 pub local_port: u16,
239 pub dragon_url: String,
240 pub dragon_port: u16,
241 pub embedder_model: String,
242 pub reranker_model: String,
243 pub reranker_port_offset: u16,
244 pub max_batch_chars: usize,
245 pub max_batch_items: usize,
246}
247
248#[derive(Debug, Clone, Default)]
250pub struct MlxMergeOptions {
251 pub disabled: Option<bool>,
252 pub local_port: Option<u16>,
253 pub dragon_url: Option<String>,
254 pub dragon_port: Option<u16>,
255 pub embedder_model: Option<String>,
256 pub reranker_model: Option<String>,
257 pub reranker_port_offset: Option<u16>,
258}
259
260impl Default for MlxConfig {
261 fn default() -> Self {
262 Self {
263 disabled: false,
264 local_port: 12345,
265 dragon_url: "http://dragon".to_string(),
266 dragon_port: 12345,
267 embedder_model: "Qwen/Qwen3-Embedding-4B".to_string(),
268 reranker_model: "Qwen/Qwen3-Reranker-4B".to_string(),
269 reranker_port_offset: 1,
270 max_batch_chars: default_max_batch_chars(),
271 max_batch_items: default_max_batch_items(),
272 }
273 }
274}
275
276impl MlxConfig {
277 pub fn from_env() -> Self {
279 let disabled = std::env::var("DISABLE_MLX")
280 .map(|v| v == "1" || v.to_lowercase() == "true")
281 .unwrap_or(false);
282
283 let local_port = std::env::var("EMBEDDER_PORT")
284 .ok()
285 .and_then(|s| s.parse().ok())
286 .unwrap_or(12345);
287
288 let dragon_url =
289 std::env::var("DRAGON_BASE_URL").unwrap_or_else(|_| "http://dragon".to_string());
290
291 let dragon_port = std::env::var("DRAGON_EMBEDDER_PORT")
292 .ok()
293 .and_then(|s| s.parse().ok())
294 .unwrap_or(local_port);
295
296 let reranker_port_offset = std::env::var("RERANKER_PORT")
297 .ok()
298 .and_then(|s| s.parse::<u16>().ok())
299 .map(|rp| rp.saturating_sub(local_port))
300 .unwrap_or(1);
301
302 let embedder_model = std::env::var("EMBEDDER_MODEL")
303 .unwrap_or_else(|_| "Qwen/Qwen3-Embedding-4B".to_string());
304
305 let reranker_model = std::env::var("RERANKER_MODEL")
306 .unwrap_or_else(|_| "Qwen/Qwen3-Reranker-4B".to_string());
307
308 let max_batch_chars = std::env::var("MLX_MAX_BATCH_CHARS")
309 .ok()
310 .and_then(|s| s.parse().ok())
311 .unwrap_or(32000);
312
313 let max_batch_items = std::env::var("MLX_MAX_BATCH_ITEMS")
314 .ok()
315 .and_then(|s| s.parse().ok())
316 .unwrap_or(16);
317
318 Self {
319 disabled,
320 local_port,
321 dragon_url,
322 dragon_port,
323 embedder_model,
324 reranker_model,
325 reranker_port_offset,
326 max_batch_chars,
327 max_batch_items,
328 }
329 }
330
331 pub fn merge_file_config(&mut self, opts: MlxMergeOptions) {
333 if let Some(v) = opts.disabled {
334 self.disabled = v;
335 }
336 if let Some(v) = opts.local_port {
337 self.local_port = v;
338 }
339 if let Some(v) = opts.dragon_url {
340 self.dragon_url = v;
341 }
342 if let Some(v) = opts.dragon_port {
343 self.dragon_port = v;
344 }
345 if let Some(v) = opts.embedder_model {
346 self.embedder_model = v;
347 }
348 if let Some(v) = opts.reranker_model {
349 self.reranker_model = v;
350 }
351 if let Some(v) = opts.reranker_port_offset {
352 self.reranker_port_offset = v;
353 }
354 }
355
356 pub fn to_embedding_config(&self) -> EmbeddingConfig {
358 let reranker_port = self.local_port + self.reranker_port_offset;
359 let required_dimension = DEFAULT_REQUIRED_DIMENSION;
360
361 EmbeddingConfig {
362 required_dimension,
363 max_batch_chars: self.max_batch_chars,
364 max_batch_items: self.max_batch_items,
365 providers: vec![
366 ProviderConfig {
367 name: "local".to_string(),
368 base_url: format!("http://localhost:{}", self.local_port),
369 model: self.embedder_model.clone(),
370 priority: 1,
371 endpoint: default_embeddings_endpoint(),
372 },
373 ProviderConfig {
374 name: "dragon".to_string(),
375 base_url: format!("{}:{}", self.dragon_url, self.dragon_port),
376 model: self.embedder_model.clone(),
377 priority: 2,
378 endpoint: default_embeddings_endpoint(),
379 },
380 ],
381 reranker: RerankerConfig {
382 base_url: Some(format!("{}:{}", self.dragon_url, reranker_port)),
383 model: Some(self.reranker_model.clone()),
384 endpoint: default_rerank_endpoint(),
385 },
386 }
387 }
388
389 pub fn with_batch_limits(mut self, max_chars: usize, max_items: usize) -> Self {
391 self.max_batch_chars = max_chars;
392 self.max_batch_items = max_items;
393 self
394 }
395}
396
397#[derive(Clone)]
403pub struct EmbeddingClient {
404 client: Client,
405 embedder_url: String,
406 embedder_model: String,
407 reranker_url: Option<String>,
408 reranker_model: Option<String>,
409 connected_to: String,
411 required_dimension: usize,
413 max_batch_chars: usize,
415 max_batch_items: usize,
417}
418
419pub type MLXBridge = EmbeddingClient;
421
422impl EmbeddingClient {
423 pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
425 if config.providers.is_empty() {
426 return Err(anyhow!(
427 "No embedding providers configured! Add providers to [embeddings.providers]"
428 ));
429 }
430
431 let client = Client::builder()
433 .timeout(Duration::from_secs(300))
434 .connect_timeout(Duration::from_secs(10))
435 .build()?;
436
437 let mut providers = config.providers.clone();
439 providers.sort_by_key(|p| p.priority);
440
441 let mut tried = Vec::new();
443 for provider in &providers {
444 let base_url = provider.base_url.trim_end_matches('/');
445 let provider_name = if provider.name.trim().is_empty() {
446 "<unnamed-provider>"
447 } else {
448 provider.name.as_str()
449 };
450 let model = provider.model.trim();
451 let embedder_url = build_provider_endpoint(base_url, &provider.endpoint);
452
453 match probe_provider_dimension(&client, provider).await {
454 Ok(actual_dim) if actual_dim == config.required_dimension => {
455 tracing::info!(
456 "Embedding: Connected to {} ({}) with model '{}' [{} dims]",
457 provider_name,
458 embedder_url,
459 model,
460 actual_dim
461 );
462
463 let (reranker_url, reranker_model) =
465 if let Some(ref rr_base) = config.reranker.base_url {
466 (
467 Some(format!(
468 "{}{}",
469 rr_base.trim_end_matches('/'),
470 config.reranker.endpoint
471 )),
472 config.reranker.model.clone(),
473 )
474 } else {
475 (None, None)
476 };
477
478 return Ok(Self {
479 client,
480 embedder_url,
481 embedder_model: provider.model.clone(),
482 reranker_url,
483 reranker_model,
484 connected_to: provider.name.clone(),
485 required_dimension: config.required_dimension,
486 max_batch_chars: config.max_batch_chars,
487 max_batch_items: config.max_batch_items,
488 });
489 }
490 Ok(actual_dim) => {
491 let failure = format!(
492 "- {} ({} model='{}'): the configured embedding endpoint returned {} dims, but config.required_dimension={}.\n Action: set [embeddings].required_dimension = {} or choose a {}-dim model.",
493 provider_name,
494 embedder_url,
495 model,
496 actual_dim,
497 config.required_dimension,
498 actual_dim,
499 config.required_dimension
500 );
501 tracing::error!("Embedding: validation failed: {}", failure);
502 tried.push(failure);
503 }
504 Err(e) => {
505 let failure = format!(
506 "- {} ({} model='{}'): {}",
507 provider_name, embedder_url, model, e
508 );
509 tracing::warn!("Embedding: provider probe failed: {}", failure);
510 tried.push(failure);
511 }
512 }
513 }
514
515 Err(anyhow!(
517 "No embedding provider passed validation for required_dimension={}. \
518 Each provider must succeed on its configured embedding endpoint before rust-memex will start.\nTried:\n{}",
519 config.required_dimension,
520 tried.join("\n")
521 ))
522 }
523
524 pub async fn from_legacy(config: &MlxConfig) -> Result<Self> {
526 if config.disabled {
527 return Err(anyhow!(
528 "Embedding disabled via config. No fallback available!"
529 ));
530 }
531 tracing::warn!("Using legacy [mlx] config - please migrate to [embeddings.providers]");
532 let embedding_config = config.to_embedding_config();
533 Self::new(&embedding_config).await
534 }
535
536 pub async fn from_env() -> Result<Self> {
538 let config = MlxConfig::from_env();
539 Self::from_legacy(&config).await
540 }
541
542 pub fn connected_to(&self) -> &str {
544 &self.connected_to
545 }
546
547 pub fn required_dimension(&self) -> usize {
549 self.required_dimension
550 }
551
552 pub fn batch_limits(&self) -> (usize, usize) {
554 (self.max_batch_chars, self.max_batch_items)
555 }
556
557 pub fn clone_with_batch_limits(&self, max_chars: usize, max_items: usize) -> Self {
559 let mut cloned = self.clone();
560 cloned.max_batch_chars = max_chars.max(1);
561 cloned.max_batch_items = max_items.max(1);
562 cloned
563 }
564
565 #[doc(hidden)]
569 pub fn stub_for_tests() -> Self {
570 Self {
571 client: reqwest::Client::new(),
572 embedder_url: "http://stub:0/v1/embeddings".to_string(),
573 embedder_model: "stub".to_string(),
574 reranker_url: None,
575 reranker_model: None,
576 connected_to: "stub-test".to_string(),
577 required_dimension: 4096,
578 max_batch_chars: 32000,
579 max_batch_items: 16,
580 }
581 }
582
583 pub async fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
584 let text_preview: String = text.chars().take(100).collect();
585 tracing::debug!(
586 "Embedding single text ({} chars): {}{}",
587 text.chars().count(),
588 text_preview,
589 if text.chars().count() > 100 {
590 "..."
591 } else {
592 ""
593 }
594 );
595
596 let request = EmbeddingRequest {
597 input: vec![text.to_string()],
598 model: self.embedder_model.clone(),
599 };
600
601 let response = match self
602 .client
603 .post(&self.embedder_url)
604 .json(&request)
605 .send()
606 .await
607 {
608 Ok(resp) => resp,
609 Err(e) => {
610 tracing::error!(
611 "Embedding request failed: {:?}\n URL: {}\n Model: {}",
612 e,
613 self.embedder_url,
614 self.embedder_model
615 );
616 return Err(anyhow!("Embedding request failed: {}", e));
617 }
618 };
619
620 let status = response.status();
621 let response_text = response.text().await.unwrap_or_else(|e| {
622 tracing::warn!("Failed to read response body: {:?}", e);
623 "<failed to read body>".to_string()
624 });
625
626 if !status.is_success() {
627 tracing::error!(
628 "Embedding API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
629 status,
630 self.embedder_url,
631 self.embedder_model,
632 response_text
633 );
634 return Err(anyhow!(
635 "Embedding API error (HTTP {}): {}",
636 status,
637 response_text
638 ));
639 }
640
641 let parsed: EmbeddingResponse = match serde_json::from_str(&response_text) {
642 Ok(r) => r,
643 Err(e) => {
644 tracing::error!(
645 "Failed to parse embedding response: {:?}\n Response body: {}",
646 e,
647 response_text
648 );
649 return Err(anyhow!("Failed to parse embedding response: {}", e));
650 }
651 };
652
653 let embedding = parsed
654 .data
655 .into_iter()
656 .next()
657 .map(|d| d.embedding)
658 .ok_or_else(|| {
659 tracing::error!("No embedding returned in response: {}", response_text);
660 anyhow!("No embedding returned")
661 })?;
662
663 if embedding.len() != self.required_dimension {
665 tracing::error!(
666 "Dimension mismatch! Expected {}, got {}. Model: {}",
667 self.required_dimension,
668 embedding.len(),
669 self.embedder_model
670 );
671 return Err(anyhow!(
672 "Dimension mismatch! Expected {}, got {}. This would corrupt the database!",
673 self.required_dimension,
674 embedding.len()
675 ));
676 }
677
678 tracing::debug!("Successfully embedded text ({} dims)", embedding.len());
679 Ok(embedding)
680 }
681
682 pub async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
688 if texts.is_empty() {
689 return Ok(vec![]);
690 }
691
692 let mut all_embeddings = Vec::with_capacity(texts.len());
693 let mut current_batch: Vec<String> = Vec::new();
694 let mut current_batch_indices: Vec<usize> = Vec::new();
695 let mut current_chars = 0;
696
697 let max_text_chars = self.max_batch_chars / 2;
699
700 let prepared_texts: Vec<String> = texts
702 .iter()
703 .map(|text| {
704 let char_count = text.chars().count();
705 if char_count > max_text_chars {
706 tracing::debug!(
707 "Text too large ({} chars), truncating to {} chars",
708 char_count,
709 max_text_chars
710 );
711 truncate_at_boundary(text, max_text_chars)
712 } else {
713 text.clone()
714 }
715 })
716 .collect();
717
718 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
720 let mut failed_indices: Vec<usize> = Vec::new();
721
722 for (idx, text_to_embed) in prepared_texts.iter().enumerate() {
723 let text_len = text_to_embed.chars().count();
724
725 if !current_batch.is_empty()
727 && (current_chars + text_len > self.max_batch_chars
728 || current_batch.len() >= self.max_batch_items)
729 {
730 match self.embed_batch_internal(¤t_batch).await {
732 Ok(batch_embeddings) => {
733 for (i, emb) in batch_embeddings.into_iter().enumerate() {
734 if let Some(orig_idx) = current_batch_indices.get(i) {
735 results[*orig_idx] = Some(emb);
736 }
737 }
738 }
739 Err(e) => {
740 tracing::warn!(
741 "Batch embedding failed for {} texts, will retry individually: {}",
742 current_batch.len(),
743 e
744 );
745 failed_indices.extend(current_batch_indices.iter().copied());
746 }
747 }
748 current_batch.clear();
749 current_batch_indices.clear();
750 current_chars = 0;
751 }
752
753 current_batch.push(text_to_embed.clone());
754 current_batch_indices.push(idx);
755 current_chars += text_len;
756 }
757
758 if !current_batch.is_empty() {
760 match self.embed_batch_internal(¤t_batch).await {
761 Ok(batch_embeddings) => {
762 for (i, emb) in batch_embeddings.into_iter().enumerate() {
763 if let Some(orig_idx) = current_batch_indices.get(i) {
764 results[*orig_idx] = Some(emb);
765 }
766 }
767 }
768 Err(e) => {
769 tracing::warn!(
770 "Batch embedding failed for {} texts, will retry individually: {}",
771 current_batch.len(),
772 e
773 );
774 failed_indices.extend(current_batch_indices.iter().copied());
775 }
776 }
777 }
778
779 const MAX_RETRIES: usize = 3;
781 for idx in failed_indices {
782 let text = &prepared_texts[idx];
783 let mut attempts = 0;
784 let mut last_error = String::new();
785
786 while attempts < MAX_RETRIES {
787 match self.embed(text).await {
788 Ok(embedding) => {
789 results[idx] = Some(embedding);
790 tracing::info!(
791 "Retry succeeded for chunk {} after {} attempts",
792 idx,
793 attempts + 1
794 );
795 break;
796 }
797 Err(e) => {
798 attempts += 1;
799 last_error = e.to_string();
800 tracing::warn!(
801 "Embed attempt {}/{} failed for chunk {}: {}",
802 attempts,
803 MAX_RETRIES,
804 idx,
805 e
806 );
807 if attempts < MAX_RETRIES {
808 let delay_ms = 100 * (1 << attempts);
810 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
811 }
812 }
813 }
814 }
815
816 if results[idx].is_none() {
817 tracing::error!(
818 "Chunk {} failed after {} retries: {}",
819 idx,
820 MAX_RETRIES,
821 last_error
822 );
823 return Err(anyhow!(
824 "Failed to embed chunk {} after {} retries: {}",
825 idx,
826 MAX_RETRIES,
827 last_error
828 ));
829 }
830 }
831
832 for (idx, opt) in results.iter().enumerate() {
834 match opt {
835 Some(emb) => all_embeddings.push(emb.clone()),
836 None => {
837 return Err(anyhow!(
838 "Internal error: missing embedding for chunk {}",
839 idx
840 ));
841 }
842 }
843 }
844
845 Ok(all_embeddings)
846 }
847
848 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
850 let total_chars: usize = texts.iter().map(|t| t.chars().count()).sum();
851
852 tracing::debug!(
853 "Embedding batch: {} texts, {} chars total",
854 texts.len(),
855 total_chars
856 );
857
858 for (i, text) in texts.iter().enumerate() {
860 let preview: String = text.chars().take(50).collect();
861 tracing::trace!(
862 " Batch[{}]: {} chars - {}{}",
863 i,
864 text.chars().count(),
865 preview,
866 if text.chars().count() > 50 { "..." } else { "" }
867 );
868 }
869
870 let request = EmbeddingRequest {
871 input: texts.to_vec(),
872 model: self.embedder_model.clone(),
873 };
874
875 let max_batch_retries = env_usize(
878 "RUST_MEMEX_EMBED_BATCH_MAX_RETRIES",
879 DEFAULT_MAX_BATCH_RETRIES,
880 );
881 let max_backoff_secs = env_u64(
882 "RUST_MEMEX_EMBED_BATCH_MAX_BACKOFF_SECS",
883 DEFAULT_MAX_BATCH_BACKOFF_SECS,
884 );
885 let mut attempt = 0;
886
887 loop {
888 attempt += 1;
889 let response = match self
890 .client
891 .post(&self.embedder_url)
892 .json(&request)
893 .send()
894 .await
895 {
896 Ok(resp) => resp,
897 Err(e) => {
898 if attempt >= max_batch_retries {
899 tracing::error!(
900 "Batch embedding failed after {} retries: {:?}\n URL: {}\n Model: {}",
901 max_batch_retries,
902 e,
903 self.embedder_url,
904 self.embedder_model
905 );
906 return Err(anyhow!(
907 "Embedding request failed after {} retries: {}",
908 max_batch_retries,
909 e
910 ));
911 }
912
913 let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
915 tracing::warn!(
916 "Embedding request failed (attempt {}/{}), retrying in {}s: {}",
917 attempt,
918 max_batch_retries,
919 backoff_secs,
920 e
921 );
922 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
923 continue;
924 }
925 };
926
927 if !response.status().is_success() {
929 let status = response.status();
930 let body = response.text().await.unwrap_or_default();
931
932 if attempt >= max_batch_retries {
933 tracing::error!(
934 "Embedding API error after {} retries: {} - {}",
935 max_batch_retries,
936 status,
937 body
938 );
939 return Err(anyhow!("Embedding API error: {} - {}", status, body));
940 }
941
942 let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
943 tracing::warn!(
944 "Embedding API error (attempt {}/{}), retrying in {}s: {} - {}",
945 attempt,
946 max_batch_retries,
947 backoff_secs,
948 status,
949 body
950 );
951 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
952 continue;
953 }
954
955 let embedding_response: EmbeddingResponse = match response.json().await {
957 Ok(r) => r,
958 Err(e) => {
959 if attempt >= max_batch_retries {
960 return Err(anyhow!("Failed to parse embedding response: {}", e));
961 }
962 let backoff_secs = (1u64 << attempt.min(5)).min(max_backoff_secs);
963 tracing::warn!(
964 "Failed to parse response (attempt {}/{}), retrying in {}s: {}",
965 attempt,
966 max_batch_retries,
967 backoff_secs,
968 e
969 );
970 tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
971 continue;
972 }
973 };
974
975 let embeddings: Vec<Vec<f32>> = embedding_response
977 .data
978 .into_iter()
979 .map(|d| d.embedding)
980 .collect();
981
982 if embeddings.len() != texts.len() {
983 return Err(anyhow!(
984 "Embedding count mismatch: got {} embeddings for {} texts",
985 embeddings.len(),
986 texts.len()
987 ));
988 }
989
990 if let Some(first) = embeddings.first()
991 && first.len() != self.required_dimension
992 {
993 return Err(anyhow!(
994 "Dimension mismatch: expected {}, got {}",
995 self.required_dimension,
996 first.len()
997 ));
998 }
999
1000 return Ok(embeddings);
1001 }
1002 }
1003
1004 pub async fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
1005 let reranker_url = self.reranker_url.as_ref().ok_or_else(|| {
1006 anyhow!("Reranker not configured. Add [embeddings.reranker] to config.")
1007 })?;
1008 let reranker_model = self
1009 .reranker_model
1010 .as_ref()
1011 .ok_or_else(|| anyhow!("Reranker model not configured."))?;
1012
1013 let query_preview: String = query.chars().take(100).collect();
1014 tracing::debug!(
1015 "Reranking {} documents for query: {}{}",
1016 documents.len(),
1017 query_preview,
1018 if query.chars().count() > 100 {
1019 "..."
1020 } else {
1021 ""
1022 }
1023 );
1024
1025 let request = RerankRequest {
1026 query: query.to_string(),
1027 documents: documents.to_vec(),
1028 model: reranker_model.clone(),
1029 };
1030
1031 let response = match self.client.post(reranker_url).json(&request).send().await {
1032 Ok(resp) => resp,
1033 Err(e) => {
1034 tracing::error!(
1035 "Rerank request failed: {:?}\n URL: {}\n Model: {}\n Query: {}\n Documents: {}",
1036 e,
1037 reranker_url,
1038 reranker_model,
1039 query_preview,
1040 documents.len()
1041 );
1042 return Err(anyhow!("Rerank request failed: {}", e));
1043 }
1044 };
1045
1046 let status = response.status();
1047 let response_text = response.text().await.unwrap_or_else(|e| {
1048 tracing::warn!("Failed to read rerank response body: {:?}", e);
1049 "<failed to read body>".to_string()
1050 });
1051
1052 if !status.is_success() {
1053 tracing::error!(
1054 "Rerank API error (HTTP {}):\n URL: {}\n Model: {}\n Response: {}",
1055 status,
1056 reranker_url,
1057 reranker_model,
1058 response_text
1059 );
1060 return Err(anyhow!(
1061 "Rerank API error (HTTP {}): {}",
1062 status,
1063 response_text
1064 ));
1065 }
1066
1067 let parsed: RerankResponse = match serde_json::from_str(&response_text) {
1068 Ok(r) => r,
1069 Err(e) => {
1070 tracing::error!(
1071 "Failed to parse rerank response: {:?}\n Response body: {}",
1072 e,
1073 response_text
1074 );
1075 return Err(anyhow!("Failed to parse rerank response: {}", e));
1076 }
1077 };
1078
1079 tracing::debug!("Rerank complete: {} documents scored", parsed.results.len());
1080
1081 Ok(parsed
1082 .results
1083 .into_iter()
1084 .map(|r| (r.index, r.score))
1085 .collect())
1086 }
1087}
1088
1089pub(crate) async fn probe_provider_dimension(
1090 client: &Client,
1091 provider: &ProviderConfig,
1092) -> Result<usize> {
1093 let base_url = provider.base_url.trim_end_matches('/');
1094 if base_url.is_empty() {
1095 return Err(anyhow!("provider base_url is empty"));
1096 }
1097
1098 let endpoint = provider.endpoint.trim();
1099 if endpoint.is_empty() {
1100 return Err(anyhow!("provider endpoint is empty"));
1101 }
1102
1103 let model = provider.model.trim();
1104 if model.is_empty() {
1105 return Err(anyhow!("provider model is empty"));
1106 }
1107
1108 let embedder_url = build_provider_endpoint(base_url, endpoint);
1109 let request = EmbeddingRequest {
1110 input: vec!["dimension probe".to_string()],
1111 model: model.to_string(),
1112 };
1113
1114 let response = client
1115 .post(&embedder_url)
1116 .json(&request)
1117 .timeout(Duration::from_secs(30))
1118 .send()
1119 .await
1120 .map_err(|e| anyhow!("POST {} failed: {}", embedder_url, e))?;
1121
1122 let status = response.status();
1123 let body = response.text().await.unwrap_or_default();
1124 if !status.is_success() {
1125 let hint = if status.as_u16() == 404 {
1126 " Check provider.endpoint; Ollama and OpenAI-compatible servers typically use /v1/embeddings."
1127 } else {
1128 ""
1129 };
1130 return Err(anyhow!(
1131 "POST {} returned {} for model '{}': {}{}",
1132 embedder_url,
1133 status,
1134 model,
1135 body.chars().take(300).collect::<String>(),
1136 hint
1137 ));
1138 }
1139
1140 let embed_response: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
1141 anyhow!(
1142 "POST {} returned non-embedding JSON for model '{}': {} (body: {})",
1143 embedder_url,
1144 model,
1145 e,
1146 body.chars().take(200).collect::<String>()
1147 )
1148 })?;
1149
1150 embed_response
1151 .data
1152 .first()
1153 .map(|d| d.embedding.len())
1154 .ok_or_else(|| {
1155 anyhow!(
1156 "POST {} returned no embeddings for model '{}'",
1157 embedder_url,
1158 model
1159 )
1160 })
1161}
1162
1163fn truncate_at_boundary(text: &str, max_chars: usize) -> String {
1165 let char_count = text.chars().count();
1166 if char_count <= max_chars {
1167 return text.to_string();
1168 }
1169
1170 let byte_idx = text
1172 .char_indices()
1173 .nth(max_chars)
1174 .map(|(idx, _)| idx)
1175 .unwrap_or(text.len());
1176
1177 let truncated = &text[..byte_idx];
1178
1179 let half_byte_idx = text
1181 .char_indices()
1182 .nth(max_chars / 2)
1183 .map(|(idx, _)| idx)
1184 .unwrap_or(0);
1185
1186 if let Some(pos) = truncated.rfind(['.', '!', '?', '\n'])
1187 && pos > half_byte_idx
1188 {
1189 return text[..=pos].to_string();
1190 }
1191
1192 if let Some(pos) = truncated.rfind([' ', '\t', '\n']) {
1194 return text[..pos].to_string();
1195 }
1196
1197 truncated.to_string()
1199}
1200
1201pub const DEFAULT_MAX_TOKENS: usize = 35_000;
1216
1217#[derive(Debug, Clone)]
1219pub struct TokenConfig {
1220 pub max_tokens: usize,
1222 pub chars_per_token: f32,
1225}
1226
1227impl Default for TokenConfig {
1228 fn default() -> Self {
1229 Self {
1230 max_tokens: DEFAULT_MAX_TOKENS,
1231 chars_per_token: 3.0,
1232 }
1233 }
1234}
1235
1236impl TokenConfig {
1237 pub fn english() -> Self {
1239 Self {
1240 max_tokens: DEFAULT_MAX_TOKENS,
1241 chars_per_token: 4.0,
1242 }
1243 }
1244
1245 pub fn for_multilingual_text() -> Self {
1247 Self {
1248 max_tokens: DEFAULT_MAX_TOKENS,
1249 chars_per_token: 2.5,
1250 }
1251 }
1252
1253 pub fn with_max_tokens(mut self, max: usize) -> Self {
1255 self.max_tokens = max;
1256 self
1257 }
1258}
1259
1260pub fn estimate_tokens(text: &str, config: &TokenConfig) -> usize {
1265 let char_count = text.chars().count();
1266 (char_count as f32 / config.chars_per_token).ceil() as usize
1267}
1268
1269pub fn validate_chunk_tokens(chunk: &str, config: &TokenConfig) -> Result<()> {
1273 let estimated = estimate_tokens(chunk, config);
1274
1275 if estimated > config.max_tokens {
1276 return Err(anyhow!(
1277 "Chunk exceeds token limit: ~{} tokens > {} max (text: {} chars). \
1278 Consider reducing chunk_size or enabling truncation.",
1279 estimated,
1280 config.max_tokens,
1281 chunk.chars().count()
1282 ));
1283 }
1284
1285 Ok(())
1286}
1287
1288pub fn safe_chunk_size(config: &TokenConfig) -> usize {
1290 let safe_tokens = (config.max_tokens as f32 * 0.8) as usize;
1292 (safe_tokens as f32 * config.chars_per_token) as usize
1293}
1294
1295pub fn truncate_to_token_limit(text: &str, config: &TokenConfig) -> String {
1297 let safe_chars = safe_chunk_size(config);
1298
1299 if text.chars().count() <= safe_chars {
1300 return text.to_string();
1301 }
1302
1303 truncate_at_boundary(text, safe_chars)
1304}
1305
1306pub fn validate_batch_tokens(texts: &[String], config: &TokenConfig) -> Vec<(usize, usize)> {
1308 texts
1309 .iter()
1310 .enumerate()
1311 .filter_map(|(idx, text)| {
1312 let estimated = estimate_tokens(text, config);
1313 if estimated > config.max_tokens {
1314 Some((idx, estimated))
1315 } else {
1316 None
1317 }
1318 })
1319 .collect()
1320}
1321
1322#[cfg(test)]
1323mod tests {
1324 use super::*;
1325 use axum::{Json, Router, extract::State, routing::post};
1326 use serde_json::json;
1327
1328 async fn mock_embeddings(State(dim): State<usize>) -> Json<serde_json::Value> {
1329 Json(json!({
1330 "data": [{
1331 "embedding": vec![0.25_f32; dim]
1332 }]
1333 }))
1334 }
1335
1336 async fn spawn_mock_embedding_server(dim: usize) -> String {
1337 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1338 let addr = listener.local_addr().unwrap();
1339 let app = Router::new()
1340 .route("/v1/embeddings", post(mock_embeddings))
1341 .with_state(dim);
1342
1343 tokio::spawn(async move {
1344 axum::serve(listener, app).await.unwrap();
1345 });
1346
1347 tokio::time::sleep(Duration::from_millis(10)).await;
1348
1349 format!("http://{}", addr)
1350 }
1351
1352 #[test]
1353 fn test_provider_sorting() {
1354 let mut providers = [
1355 ProviderConfig {
1356 name: "low".into(),
1357 base_url: "http://a".into(),
1358 model: "m".into(),
1359 priority: 10,
1360 endpoint: "/v1/embeddings".into(),
1361 },
1362 ProviderConfig {
1363 name: "high".into(),
1364 base_url: "http://b".into(),
1365 model: "m".into(),
1366 priority: 1,
1367 endpoint: "/v1/embeddings".into(),
1368 },
1369 ];
1370 providers.sort_by_key(|p| p.priority);
1371 assert_eq!(providers[0].name, "high");
1372 assert_eq!(providers[1].name, "low");
1373 }
1374
1375 #[test]
1376 fn test_legacy_conversion() {
1377 let legacy = MlxConfig {
1378 disabled: false,
1379 local_port: 12345,
1380 dragon_url: "http://dragon".into(),
1381 dragon_port: 12345,
1382 embedder_model: "test-model".into(),
1383 reranker_model: "rerank-model".into(),
1384 reranker_port_offset: 1,
1385 max_batch_chars: 32000,
1386 max_batch_items: 16,
1387 };
1388 let config = legacy.to_embedding_config();
1389 assert_eq!(config.providers.len(), 2);
1390 assert_eq!(config.providers[0].base_url, "http://localhost:12345");
1391 assert!(config.reranker.base_url.is_some());
1392 assert_eq!(config.max_batch_chars, 32000);
1393 assert_eq!(config.max_batch_items, 16);
1394 }
1395
1396 #[test]
1397 fn test_default_config() {
1398 let config = EmbeddingConfig::default();
1399 assert_eq!(config.required_dimension, DEFAULT_REQUIRED_DIMENSION);
1400 assert_eq!(config.max_batch_chars, 128000); assert_eq!(config.max_batch_items, 64); assert!(!config.providers.is_empty());
1403 assert_eq!(config.providers[0].model, DEFAULT_OLLAMA_EMBEDDING_MODEL);
1404 }
1405
1406 #[tokio::test]
1407 async fn test_probe_provider_dimension_reads_actual_dimension() {
1408 let base_url = spawn_mock_embedding_server(2560).await;
1409 let client = Client::new();
1410 let provider = ProviderConfig {
1411 name: "mock".into(),
1412 base_url,
1413 model: "mock-embedder".into(),
1414 priority: 1,
1415 endpoint: "/v1/embeddings".into(),
1416 };
1417
1418 let dim = probe_provider_dimension(&client, &provider).await.unwrap();
1419 assert_eq!(dim, 2560);
1420 }
1421
1422 #[tokio::test]
1423 async fn test_embedding_client_fails_fast_on_dimension_mismatch() {
1424 let base_url = spawn_mock_embedding_server(2560).await;
1425 let config = EmbeddingConfig {
1426 required_dimension: 1024,
1427 providers: vec![ProviderConfig {
1428 name: "mock".into(),
1429 base_url,
1430 model: "mock-embedder".into(),
1431 priority: 1,
1432 endpoint: "/v1/embeddings".into(),
1433 }],
1434 ..EmbeddingConfig::default()
1435 };
1436
1437 let err = EmbeddingClient::new(&config)
1438 .await
1439 .err()
1440 .expect("dimension mismatch should fail")
1441 .to_string();
1442 assert!(err.contains("returned 2560 dims"));
1443 assert!(err.contains("required_dimension=1024"));
1444 }
1445
1446 #[test]
1447 fn test_truncate_at_boundary() {
1448 let text = "Hello world. This is a test.";
1450 let truncated = truncate_at_boundary(text, 15);
1451 assert_eq!(truncated, "Hello world.");
1452
1453 let text = "Hello world this is a test";
1455 let truncated = truncate_at_boundary(text, 15);
1456 assert_eq!(truncated, "Hello world");
1457
1458 let text = "Short text";
1460 let truncated = truncate_at_boundary(text, 100);
1461 assert_eq!(truncated, "Short text");
1462 }
1463
1464 #[test]
1465 fn test_token_estimation() {
1466 let config = TokenConfig::default();
1467
1468 let text = "Hello world"; let tokens = estimate_tokens(text, &config);
1471 assert!((3..=5).contains(&tokens));
1472
1473 let english_config = TokenConfig::english();
1475 let tokens = estimate_tokens(text, &english_config);
1476 assert!((2..=4).contains(&tokens));
1477 }
1478
1479 #[test]
1480 fn default_token_ceiling_stays_above_long_transcript_floor() {
1481 let config = TokenConfig::default();
1482
1483 assert_eq!(DEFAULT_MAX_TOKENS, 35_000);
1484 assert_eq!(config.max_tokens, DEFAULT_MAX_TOKENS);
1485 assert!(config.max_tokens >= 35_000);
1486 }
1487
1488 #[test]
1489 fn test_chunk_validation() {
1490 let config = TokenConfig::default().with_max_tokens(100);
1491
1492 let short = "Hello world";
1494 assert!(validate_chunk_tokens(short, &config).is_ok());
1495
1496 let long = "a".repeat(1000); assert!(validate_chunk_tokens(&long, &config).is_err());
1499 }
1500
1501 #[test]
1502 fn test_safe_chunk_size() {
1503 let config = TokenConfig::default(); let safe = safe_chunk_size(&config);
1506 assert!(safe > 80_000 && safe < 90_000);
1508 }
1509
1510 #[test]
1511 fn test_batch_validation() {
1512 let config = TokenConfig::default().with_max_tokens(10);
1513
1514 let texts = vec![
1515 "short".to_string(), "a".repeat(100), "also short".to_string(), "b".repeat(200), ];
1520
1521 let failures = validate_batch_tokens(&texts, &config);
1522 assert_eq!(failures.len(), 2);
1523 assert_eq!(failures[0].0, 1); assert_eq!(failures[1].0, 3); }
1526}
1527
1528#[derive(Debug, Clone)]
1550pub struct DimensionAdapter {
1551 pub source_dim: usize,
1553 pub target_dim: usize,
1555}
1556
1557impl DimensionAdapter {
1558 pub fn new(source_dim: usize, target_dim: usize) -> Self {
1560 Self {
1561 source_dim,
1562 target_dim,
1563 }
1564 }
1565
1566 pub fn needs_adaptation(&self) -> bool {
1568 self.source_dim != self.target_dim
1569 }
1570
1571 pub fn adapt(&self, embedding: Vec<f32>) -> Vec<f32> {
1573 if embedding.len() == self.target_dim {
1574 return embedding;
1575 }
1576
1577 if embedding.len() < self.target_dim {
1578 self.expand(embedding)
1579 } else {
1580 self.contract(embedding)
1581 }
1582 }
1583
1584 pub fn expand(&self, embedding: Vec<f32>) -> Vec<f32> {
1588 if embedding.len() >= self.target_dim {
1589 return embedding[..self.target_dim].to_vec();
1590 }
1591
1592 let mut padded = embedding;
1593 padded.resize(self.target_dim, 0.0);
1594
1595 self.normalize(&mut padded);
1597 padded
1598 }
1599
1600 pub fn contract(&self, embedding: Vec<f32>) -> Vec<f32> {
1605 if embedding.len() <= self.target_dim {
1606 return embedding;
1607 }
1608
1609 if self.is_power_of_two_reduction(embedding.len()) {
1612 self.average_reduction(embedding)
1613 } else {
1614 embedding[..self.target_dim].to_vec()
1616 }
1617 }
1618
1619 fn is_power_of_two_reduction(&self, source_len: usize) -> bool {
1621 source_len > self.target_dim
1622 && source_len.is_power_of_two()
1623 && self.target_dim.is_power_of_two()
1624 && source_len.is_multiple_of(self.target_dim)
1625 }
1626
1627 fn average_reduction(&self, embedding: Vec<f32>) -> Vec<f32> {
1629 let factor = embedding.len() / self.target_dim;
1630 let mut result = Vec::with_capacity(self.target_dim);
1631
1632 for chunk in embedding.chunks(factor) {
1633 let sum: f32 = chunk.iter().sum();
1634 result.push(sum / factor as f32);
1635 }
1636
1637 self.normalize(&mut result);
1639 result
1640 }
1641
1642 fn normalize(&self, vec: &mut [f32]) {
1644 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
1645 if norm > 1e-10 {
1646 for v in vec.iter_mut() {
1647 *v /= norm;
1648 }
1649 }
1650 }
1651}
1652
1653pub fn cross_dimension_search_adapt(query_embedding: Vec<f32>, target_dim: usize) -> Vec<f32> {
1655 let adapter = DimensionAdapter::new(query_embedding.len(), target_dim);
1656 adapter.adapt(query_embedding)
1657}
1658
1659#[cfg(test)]
1660mod dimension_adapter_tests {
1661 use super::*;
1662
1663 #[test]
1664 fn test_expand_1024_to_4096() {
1665 let adapter = DimensionAdapter::new(1024, 4096);
1666 let small = vec![0.1f32; 1024];
1667 let expanded = adapter.expand(small);
1668
1669 assert_eq!(expanded.len(), 4096);
1670 assert!(expanded[0].abs() > 1e-10);
1672 assert!(expanded[4095].abs() < 1e-10);
1674 }
1675
1676 #[test]
1677 fn test_contract_4096_to_1024() {
1678 let adapter = DimensionAdapter::new(4096, 1024);
1679 let large = vec![0.1f32; 4096];
1680 let contracted = adapter.contract(large);
1681
1682 assert_eq!(contracted.len(), 1024);
1683 let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1685 assert!((norm - 1.0).abs() < 1e-5);
1686 }
1687
1688 #[test]
1689 fn test_adapt_auto_detect() {
1690 let adapter = DimensionAdapter::new(1024, 4096);
1691
1692 let small = vec![0.1f32; 1024];
1694 let result = adapter.adapt(small);
1695 assert_eq!(result.len(), 4096);
1696
1697 let adapter = DimensionAdapter::new(4096, 1024);
1699 let large = vec![0.1f32; 4096];
1700 let result = adapter.adapt(large);
1701 assert_eq!(result.len(), 1024);
1702 }
1703
1704 #[test]
1705 fn test_no_adaptation_needed() {
1706 let adapter = DimensionAdapter::new(4096, 4096);
1707 assert!(!adapter.needs_adaptation());
1708
1709 let embedding = vec![0.1f32; 4096];
1710 let result = adapter.adapt(embedding.clone());
1711 assert_eq!(result, embedding);
1712 }
1713
1714 #[test]
1715 fn test_average_reduction_preserves_info() {
1716 let adapter = DimensionAdapter::new(4096, 2048);
1717
1718 let large: Vec<f32> = (0..4096).map(|i| i as f32 / 4096.0).collect();
1720 let contracted = adapter.contract(large);
1721
1722 assert_eq!(contracted.len(), 2048);
1723 let norm: f32 = contracted.iter().map(|x| x * x).sum::<f32>().sqrt();
1726 assert!((norm - 1.0).abs() < 1e-5);
1727 }
1728}