1use crate::cli::EmbeddingsCommands;
10use crate::config::resolve_db_path;
11use crate::embeddings::{
12 chunk_text, create_embedding_provider, detect_available_providers, get_embedding_settings,
13 is_embeddings_enabled, prepare_item_text, reset_embedding_settings, save_embedding_settings,
14 ChunkConfig, EmbeddingProviderType, EmbeddingSettings,
15};
16use crate::error::{Error, Result};
17use crate::storage::SqliteStorage;
18use serde::Serialize;
19use std::path::PathBuf;
20
21#[derive(Serialize)]
23struct StatusOutput {
24 enabled: bool,
25 configured_provider: Option<String>,
26 available_providers: Vec<ProviderStatus>,
27 active_provider: Option<ActiveProviderInfo>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 stats: Option<EmbeddingStatsOutput>,
30}
31
32#[derive(Serialize)]
33struct EmbeddingStatsOutput {
34 items_with_embeddings: usize,
35 items_without_embeddings: usize,
36 total_items: usize,
37}
38
39#[derive(Serialize)]
40struct ProviderStatus {
41 name: String,
42 available: bool,
43 model: Option<String>,
44 dimensions: Option<usize>,
45}
46
47#[derive(Serialize)]
48struct ActiveProviderInfo {
49 name: String,
50 model: String,
51 dimensions: usize,
52 max_chars: usize,
53}
54
55#[derive(Serialize)]
57struct TestOutput {
58 success: bool,
59 provider: String,
60 model: String,
61 dimensions: usize,
62 input_text: String,
63 embedding_sample: Vec<f32>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 error: Option<String>,
66}
67
68#[derive(Serialize)]
70struct ConfigureOutput {
71 success: bool,
72 message: String,
73 settings: EmbeddingSettings,
74}
75
76#[derive(Serialize)]
78struct BackfillOutput {
79 processed: usize,
80 skipped: usize,
81 errors: usize,
82 provider: String,
83 model: String,
84}
85
86#[derive(Serialize)]
88struct UpgradeQualityOutput {
89 upgraded: usize,
90 skipped: usize,
91 errors: usize,
92 provider: String,
93 model: String,
94 total_eligible: usize,
95}
96
97pub fn execute(command: EmbeddingsCommands, db_path: Option<&PathBuf>, json: bool) -> Result<()> {
99 let rt = tokio::runtime::Runtime::new()
101 .map_err(|e| Error::Other(format!("Failed to create async runtime: {e}")))?;
102
103 rt.block_on(async { execute_async(command, db_path, json).await })
104}
105
106async fn execute_async(command: EmbeddingsCommands, db_path: Option<&PathBuf>, json: bool) -> Result<()> {
107 match command {
108 EmbeddingsCommands::Status => execute_status(db_path, json).await,
109 EmbeddingsCommands::Configure {
110 provider,
111 enable,
112 disable,
113 model,
114 endpoint,
115 token,
116 } => execute_configure(db_path, provider, enable, disable, model, endpoint, token, json).await,
117 EmbeddingsCommands::Backfill {
118 limit,
119 session,
120 force,
121 } => execute_backfill(db_path, limit, session, force, json).await,
122 EmbeddingsCommands::Test { text } => execute_test(&text, json).await,
123 EmbeddingsCommands::ProcessPending { limit, quiet } => {
124 execute_process_pending(db_path, limit, quiet).await
125 }
126 EmbeddingsCommands::UpgradeQuality { limit, session } => {
127 execute_upgrade_quality(db_path, limit, session, json).await
128 }
129 }
130}
131
132async fn execute_status(db_path: Option<&PathBuf>, json: bool) -> Result<()> {
134 let enabled = is_embeddings_enabled();
135 let settings = get_embedding_settings().unwrap_or_default();
136 let detection = detect_available_providers().await;
137
138 let stats = if let Some(path) = resolve_db_path(db_path.map(|p| p.as_path())) {
140 if path.exists() {
141 SqliteStorage::open(&path)
142 .ok()
143 .and_then(|storage| storage.count_embedding_status(None).ok())
144 .map(|s| EmbeddingStatsOutput {
145 items_with_embeddings: s.with_embeddings,
146 items_without_embeddings: s.without_embeddings,
147 total_items: s.with_embeddings + s.without_embeddings,
148 })
149 } else {
150 None
151 }
152 } else {
153 None
154 };
155
156 let active_provider = if enabled {
158 create_embedding_provider().await
159 } else {
160 None
161 };
162
163 let configured_provider = settings
164 .as_ref()
165 .and_then(|s| s.provider.as_ref())
166 .map(|p| p.to_string());
167
168 let mut providers = Vec::new();
170
171 let ollama_available = detection.available.contains(&"ollama".to_string());
173 providers.push(ProviderStatus {
174 name: "ollama".to_string(),
175 available: ollama_available,
176 model: if ollama_available {
177 Some(
178 settings
179 .as_ref()
180 .and_then(|s| s.OLLAMA_MODEL.clone())
181 .unwrap_or_else(|| "nomic-embed-text".to_string()),
182 )
183 } else {
184 None
185 },
186 dimensions: if ollama_available { Some(768) } else { None },
187 });
188
189 let hf_available = detection.available.contains(&"huggingface".to_string());
191 providers.push(ProviderStatus {
192 name: "huggingface".to_string(),
193 available: hf_available,
194 model: if hf_available {
195 Some(
196 settings
197 .as_ref()
198 .and_then(|s| s.HF_MODEL.clone())
199 .unwrap_or_else(|| "sentence-transformers/all-MiniLM-L6-v2".to_string()),
200 )
201 } else {
202 None
203 },
204 dimensions: if hf_available { Some(384) } else { None },
205 });
206
207 let active_info = active_provider.as_ref().map(|p| {
208 let info = p.info();
209 ActiveProviderInfo {
210 name: info.name,
211 model: info.model,
212 dimensions: info.dimensions,
213 max_chars: info.max_chars,
214 }
215 });
216
217 if json {
218 let output = StatusOutput {
219 enabled,
220 configured_provider,
221 available_providers: providers,
222 active_provider: active_info,
223 stats,
224 };
225 println!("{}", serde_json::to_string(&output)?);
226 } else {
227 println!("Embeddings Status");
228 println!("=================");
229 println!();
230 println!("Enabled: {}", if enabled { "yes" } else { "no" });
231 if let Some(ref p) = configured_provider {
232 println!("Configured Provider: {p}");
233 }
234 println!();
235
236 println!("Available Providers:");
237 for p in &providers {
238 let status = if p.available { "✓" } else { "✗" };
239 print!(" {status} {}", p.name);
240 if let Some(ref m) = p.model {
241 print!(" ({m})");
242 }
243 println!();
244 }
245 println!();
246
247 if let Some(ref active) = active_info {
248 println!("Active Provider:");
249 println!(" Name: {}", active.name);
250 println!(" Model: {}", active.model);
251 println!(" Dimensions: {}", active.dimensions);
252 println!(" Max Chars: {}", active.max_chars);
253 } else if enabled {
254 println!("No embedding provider available.");
255 println!();
256 println!("To enable embeddings:");
257 println!(" - Install Ollama: https://ollama.ai");
258 println!(" - Or set HF_TOKEN environment variable");
259 }
260
261 if let Some(ref s) = stats {
263 println!();
264 println!("Item Statistics:");
265 println!(" With embeddings: {}", s.items_with_embeddings);
266 println!(" Without embeddings: {}", s.items_without_embeddings);
267 println!(" Total items: {}", s.total_items);
268 if s.items_without_embeddings > 0 {
269 println!();
270 println!("Run 'sc embeddings backfill' to generate missing embeddings.");
271 }
272 }
273 }
274
275 Ok(())
276}
277
278#[allow(clippy::fn_params_excessive_bools)]
280async fn execute_configure(
281 db_path: Option<&PathBuf>,
282 provider: Option<String>,
283 enable: bool,
284 disable: bool,
285 model: Option<String>,
286 endpoint: Option<String>,
287 token: Option<String>,
288 json: bool,
289) -> Result<()> {
290 let mut settings = get_embedding_settings()
292 .unwrap_or_default()
293 .unwrap_or_default();
294
295 let mut changed = false;
296 let mut messages = Vec::new();
297
298 if enable && disable {
300 return Err(Error::InvalidArgument(
301 "Cannot specify both --enable and --disable".to_string(),
302 ));
303 }
304
305 if enable {
306 settings.enabled = Some(true);
307 messages.push("Embeddings enabled");
308 changed = true;
309 } else if disable {
310 settings.enabled = Some(false);
311 messages.push("Embeddings disabled");
312 changed = true;
313 }
314
315 if let Some(ref p) = provider {
317 let provider_type = match p.to_lowercase().as_str() {
318 "ollama" => EmbeddingProviderType::Ollama,
319 "huggingface" | "hf" => EmbeddingProviderType::Huggingface,
320 _ => {
321 return Err(Error::InvalidArgument(format!(
322 "Unknown provider: {p}. Valid options: ollama, huggingface"
323 )));
324 }
325 };
326 settings.provider = Some(provider_type);
327 messages.push("Provider configured");
328 changed = true;
329 }
330
331 if let Some(ref m) = model {
333 let provider_type = settings.provider.unwrap_or(EmbeddingProviderType::Ollama);
335 match provider_type {
336 EmbeddingProviderType::Ollama => {
337 settings.OLLAMA_MODEL = Some(m.clone());
338 }
339 EmbeddingProviderType::Huggingface => {
340 settings.HF_MODEL = Some(m.clone());
341 }
342 EmbeddingProviderType::Transformers => {
343 settings.TRANSFORMERS_MODEL = Some(m.clone());
344 }
345 EmbeddingProviderType::Model2vec => {
346 }
349 }
350 messages.push("Model configured");
351 changed = true;
352 }
353
354 if let Some(ref e) = endpoint {
356 let provider_type = settings.provider.unwrap_or(EmbeddingProviderType::Ollama);
357 match provider_type {
358 EmbeddingProviderType::Ollama => {
359 settings.OLLAMA_ENDPOINT = Some(e.clone());
360 }
361 EmbeddingProviderType::Huggingface => {
362 settings.HF_ENDPOINT = Some(e.clone());
363 }
364 _ => {}
365 }
366 messages.push("Endpoint configured");
367 changed = true;
368 }
369
370 if let Some(ref t) = token {
372 settings.HF_TOKEN = Some(t.clone());
373 messages.push("Token configured");
374 changed = true;
375 }
376
377 if !changed {
378 return execute_status(db_path, json).await;
380 }
381
382 save_embedding_settings(&settings)?;
384
385 let message = messages.join(", ");
386
387 if json {
388 let output = ConfigureOutput {
389 success: true,
390 message,
391 settings,
392 };
393 println!("{}", serde_json::to_string(&output)?);
394 } else {
395 println!("Configuration updated: {message}");
396 println!();
397 execute_status(db_path, false).await?;
398 }
399
400 Ok(())
401}
402
403async fn execute_backfill(
411 db_path: Option<&PathBuf>,
412 limit: Option<usize>,
413 session: Option<String>,
414 force: bool,
415 json: bool,
416) -> Result<()> {
417 let db_path = resolve_db_path(db_path.map(|p| p.as_path())).ok_or(Error::NotInitialized)?;
419
420 let provider = create_embedding_provider()
422 .await
423 .ok_or_else(|| Error::Embedding("No embedding provider available".to_string()))?;
424
425 let info = provider.info();
426 let provider_name = info.name.clone();
427 let model_name = info.model.clone();
428
429 let chunk_config = if provider_name.to_lowercase().contains("ollama") {
431 ChunkConfig::for_ollama()
432 } else {
433 ChunkConfig::for_minilm()
434 };
435
436 let mut storage = SqliteStorage::open(&db_path)?;
438
439 let items = if force {
441 storage.get_items_without_embeddings(session.as_deref(), Some(limit.unwrap_or(1000) as u32))?
443 } else {
444 storage.get_items_without_embeddings(session.as_deref(), Some(limit.unwrap_or(1000) as u32))?
446 };
447
448 if items.is_empty() {
449 if json {
450 let output = BackfillOutput {
451 processed: 0,
452 skipped: 0,
453 errors: 0,
454 provider: provider_name,
455 model: model_name,
456 };
457 println!("{}", serde_json::to_string(&output)?);
458 } else {
459 println!("No items to process.");
460 println!("All context items already have embeddings.");
461 }
462 return Ok(());
463 }
464
465 let total_items = items.len();
466 let mut processed = 0;
467 let mut skipped = 0;
468 let mut errors = 0;
469
470 if !json {
471 println!("Backfilling embeddings for {} items...", total_items);
472 println!("Provider: {} ({})", provider_name, model_name);
473 println!();
474 }
475
476 for item in items {
477 let text = prepare_item_text(&item.key, &item.value, Some(&item.category));
479
480 let chunks = chunk_text(&text, &chunk_config);
482
483 if chunks.is_empty() {
484 skipped += 1;
485 continue;
486 }
487
488 let mut chunk_errors = 0;
490 for (chunk_idx, chunk) in chunks.iter().enumerate() {
491 match provider.generate_embedding(&chunk.text).await {
492 Ok(embedding) => {
493 let chunk_id = format!("emb_{}_{}", item.id, chunk_idx);
495
496 if let Err(e) = storage.store_embedding_chunk(
498 &chunk_id,
499 &item.id,
500 chunk_idx as i32,
501 &chunk.text,
502 &embedding,
503 &provider_name,
504 &model_name,
505 ) {
506 if !json {
507 eprintln!(" Error storing chunk {}: {}", chunk_idx, e);
508 }
509 chunk_errors += 1;
510 }
511 }
512 Err(e) => {
513 if !json {
514 eprintln!(" Error generating embedding for {}: {}", item.key, e);
515 }
516 chunk_errors += 1;
517 }
518 }
519 }
520
521 if chunk_errors == 0 {
522 processed += 1;
523 if !json {
524 println!(" ✓ {} ({} chunks)", item.key, chunks.len());
525 }
526 } else if chunk_errors < chunks.len() {
527 processed += 1;
529 errors += chunk_errors;
530 if !json {
531 println!(" ⚠ {} ({}/{} chunks)", item.key, chunks.len() - chunk_errors, chunks.len());
532 }
533 } else {
534 errors += 1;
536 if !json {
537 println!(" ✗ {}", item.key);
538 }
539 }
540 }
541
542 if json {
543 let output = BackfillOutput {
544 processed,
545 skipped,
546 errors,
547 provider: provider_name,
548 model: model_name,
549 };
550 println!("{}", serde_json::to_string(&output)?);
551 } else {
552 println!();
553 println!("Complete!");
554 println!(" Processed: {}", processed);
555 println!(" Skipped: {}", skipped);
556 println!(" Errors: {}", errors);
557 }
558
559 Ok(())
560}
561
562async fn execute_test(text: &str, json: bool) -> Result<()> {
564 let provider = create_embedding_provider()
565 .await
566 .ok_or_else(|| Error::Embedding("No embedding provider available".to_string()))?;
567
568 let info = provider.info();
569
570 let result = provider.generate_embedding(text).await;
572
573 match result {
574 Ok(embedding) => {
575 let sample: Vec<f32> = embedding.iter().take(5).copied().collect();
576
577 if json {
578 let output = TestOutput {
579 success: true,
580 provider: info.name,
581 model: info.model,
582 dimensions: embedding.len(),
583 input_text: text.to_string(),
584 embedding_sample: sample,
585 error: None,
586 };
587 println!("{}", serde_json::to_string(&output)?);
588 } else {
589 println!("Embedding Test: SUCCESS");
590 println!();
591 println!("Provider: {}", info.name);
592 println!("Model: {}", info.model);
593 println!("Dimensions: {}", embedding.len());
594 println!("Input: \"{text}\"");
595 println!();
596 println!("Sample (first 5 values):");
597 for (i, v) in sample.iter().enumerate() {
598 println!(" [{i}] {v:.6}");
599 }
600 }
601 }
602 Err(e) => {
603 if json {
604 let output = TestOutput {
605 success: false,
606 provider: info.name,
607 model: info.model,
608 dimensions: 0,
609 input_text: text.to_string(),
610 embedding_sample: vec![],
611 error: Some(e.to_string()),
612 };
613 println!("{}", serde_json::to_string(&output)?);
614 } else {
615 println!("Embedding Test: FAILED");
616 println!();
617 println!("Provider: {}", info.name);
618 println!("Model: {}", info.model);
619 println!("Error: {e}");
620 }
621 return Err(e);
622 }
623 }
624
625 Ok(())
626}
627
628async fn execute_process_pending(
633 db_path: Option<&PathBuf>,
634 limit: usize,
635 quiet: bool,
636) -> Result<()> {
637 let db_path = resolve_db_path(db_path.map(|p| p.as_path())).ok_or(Error::NotInitialized)?;
639
640 if !db_path.exists() {
641 return Ok(()); }
643
644 if !is_embeddings_enabled() {
646 return Ok(());
647 }
648
649 let provider = match create_embedding_provider().await {
651 Some(p) => p,
652 None => return Ok(()), };
654
655 let info = provider.info();
656 let provider_name = info.name.clone();
657 let model_name = info.model.clone();
658
659 let chunk_config = if provider_name.to_lowercase().contains("ollama") {
661 ChunkConfig::for_ollama()
662 } else {
663 ChunkConfig::for_minilm()
664 };
665
666 let mut storage = SqliteStorage::open(&db_path)?;
668
669 let items = storage.get_items_without_embeddings(None, Some(limit as u32))?;
671
672 if items.is_empty() {
673 return Ok(());
674 }
675
676 let mut processed = 0;
677
678 for item in items {
679 let text = prepare_item_text(&item.key, &item.value, Some(&item.category));
681
682 let chunks = chunk_text(&text, &chunk_config);
684
685 if chunks.is_empty() {
686 continue;
687 }
688
689 let mut success = true;
691 for (chunk_idx, chunk) in chunks.iter().enumerate() {
692 match provider.generate_embedding(&chunk.text).await {
693 Ok(embedding) => {
694 let chunk_id = format!("emb_{}_{}", item.id, chunk_idx);
695 if storage
696 .store_embedding_chunk(
697 &chunk_id,
698 &item.id,
699 chunk_idx as i32,
700 &chunk.text,
701 &embedding,
702 &provider_name,
703 &model_name,
704 )
705 .is_err()
706 {
707 success = false;
708 break;
709 }
710 }
711 Err(_) => {
712 success = false;
713 break;
714 }
715 }
716 }
717
718 if success {
719 processed += 1;
720 if !quiet {
721 eprintln!("[bg] Embedded: {} ({} chunks)", item.key, chunks.len());
722 }
723 }
724 }
725
726 if !quiet && processed > 0 {
727 eprintln!("[bg] Processed {} pending embeddings", processed);
728 }
729
730 Ok(())
731}
732
733pub fn spawn_background_embedder() {
739 use std::process::{Command, Stdio};
740
741 if !is_embeddings_enabled() {
743 return;
744 }
745
746 let exe = match std::env::current_exe() {
748 Ok(path) => path,
749 Err(_) => return, };
751
752 let _ = Command::new(exe)
754 .args(["embeddings", "process-pending", "--quiet"])
755 .stdin(Stdio::null())
756 .stdout(Stdio::null())
757 .stderr(Stdio::null())
758 .spawn();
759 }
761
762#[allow(dead_code)]
764pub fn reset_to_defaults() -> Result<()> {
765 reset_embedding_settings()?;
766 println!("Embedding settings reset to defaults.");
767 Ok(())
768}
769
770async fn execute_upgrade_quality(
779 db_path: Option<&PathBuf>,
780 limit: Option<usize>,
781 session: Option<String>,
782 json: bool,
783) -> Result<()> {
784 let db_path = resolve_db_path(db_path.map(|p| p.as_path())).ok_or(Error::NotInitialized)?;
786
787 if !db_path.exists() {
788 return Err(Error::NotInitialized);
789 }
790
791 let provider = create_embedding_provider()
793 .await
794 .ok_or_else(|| Error::Embedding("No quality embedding provider available. Install Ollama or set HF_TOKEN.".to_string()))?;
795
796 let info = provider.info();
797 let provider_name = info.name.clone();
798 let model_name = info.model.clone();
799
800 let chunk_config = if provider_name.to_lowercase().contains("ollama") {
802 ChunkConfig::for_ollama()
803 } else {
804 ChunkConfig::for_minilm()
805 };
806
807 let mut storage = SqliteStorage::open(&db_path)?;
809
810 let items = storage.get_items_needing_quality_upgrade(
812 session.as_deref(),
813 limit.map(|l| l as u32),
814 )?;
815
816 let total_eligible = items.len();
817
818 if items.is_empty() {
819 if json {
820 let output = UpgradeQualityOutput {
821 upgraded: 0,
822 skipped: 0,
823 errors: 0,
824 provider: provider_name,
825 model: model_name,
826 total_eligible: 0,
827 };
828 println!("{}", serde_json::to_string(&output)?);
829 } else {
830 println!("No items need quality upgrade.");
831 println!("All items with fast embeddings already have quality embeddings.");
832 }
833 return Ok(());
834 }
835
836 if !json {
837 println!("Upgrading {} items to quality embeddings...", total_eligible);
838 println!("Provider: {} ({})", provider_name, model_name);
839 println!();
840 }
841
842 let mut upgraded = 0;
843 let mut skipped = 0;
844 let mut errors = 0;
845
846 for item in items {
847 let text = prepare_item_text(&item.key, &item.value, Some(&item.category));
849
850 let chunks = chunk_text(&text, &chunk_config);
852
853 if chunks.is_empty() {
854 skipped += 1;
855 if !json {
856 println!(" - {} (no content)", item.key);
857 }
858 continue;
859 }
860
861 let mut chunk_errors = 0;
863 for (chunk_idx, chunk) in chunks.iter().enumerate() {
864 match provider.generate_embedding(&chunk.text).await {
865 Ok(embedding) => {
866 let chunk_id = format!("emb_{}_{}", item.id, chunk_idx);
868
869 if let Err(e) = storage.store_embedding_chunk(
871 &chunk_id,
872 &item.id,
873 chunk_idx as i32,
874 &chunk.text,
875 &embedding,
876 &provider_name,
877 &model_name,
878 ) {
879 if !json {
880 eprintln!(" Error storing chunk {}: {}", chunk_idx, e);
881 }
882 chunk_errors += 1;
883 }
884 }
885 Err(e) => {
886 if !json {
887 eprintln!(" Error generating embedding for {}: {}", item.key, e);
888 }
889 chunk_errors += 1;
890 }
891 }
892 }
893
894 if chunk_errors == 0 {
895 upgraded += 1;
896 if !json {
897 println!(" ✓ {} ({} chunks)", item.key, chunks.len());
898 }
899 } else if chunk_errors < chunks.len() {
900 upgraded += 1;
902 errors += chunk_errors;
903 if !json {
904 println!(" ⚠ {} ({}/{} chunks)", item.key, chunks.len() - chunk_errors, chunks.len());
905 }
906 } else {
907 errors += 1;
909 if !json {
910 println!(" ✗ {}", item.key);
911 }
912 }
913 }
914
915 if json {
916 let output = UpgradeQualityOutput {
917 upgraded,
918 skipped,
919 errors,
920 provider: provider_name,
921 model: model_name,
922 total_eligible,
923 };
924 println!("{}", serde_json::to_string(&output)?);
925 } else {
926 println!();
927 println!("Quality upgrade complete!");
928 println!(" Upgraded: {}", upgraded);
929 println!(" Skipped: {}", skipped);
930 println!(" Errors: {}", errors);
931 println!();
932 println!("Items now have both fast (instant) and quality (accurate) embeddings.");
933 }
934
935 Ok(())
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941
942 #[test]
943 fn test_provider_status_serialization() {
944 let status = ProviderStatus {
945 name: "ollama".to_string(),
946 available: true,
947 model: Some("nomic-embed-text".to_string()),
948 dimensions: Some(768),
949 };
950 let json = serde_json::to_string(&status).unwrap();
951 assert!(json.contains("ollama"));
952 assert!(json.contains("768"));
953 }
954}