1use crate::embeddings_generator::EmbeddingsGenerator;
3use crate::text_chunker::{chunk_text, ChunkerConfig};
4use crate::token_cleaner::clean_and_redact;
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashSet;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::fs;
14use tokio::sync::{Semaphore};
15use tokio::sync::{mpsc, oneshot};
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::mpsc as std_mpsc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ChunkMetadata {
22 pub chunk_index: usize,
23 pub total_chunks: usize,
24 pub file_size: u64,
25 pub last_modified: Option<String>,
26 pub start_index: usize,
27 pub end_index: usize,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct EmbeddedChunk {
33 pub file_path: String,
34 pub content: String,
35 pub embedding: Vec<f32>,
36 pub metadata: ChunkMetadata,
37}
38
39#[derive(Debug, Clone)]
41struct PendingChunk {
42 file_path: String,
43 content: String,
44 metadata: ChunkMetadata,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
49pub struct EmbeddingsDatabase {
50 pub version: String,
51 pub generated_at: String,
52 pub model: String,
53 pub chunk_size: usize,
54 pub overlap_size: usize,
55 pub total_files: usize,
56 pub total_chunks: usize,
57 pub chunks: Vec<EmbeddedChunk>,
58}
59
60pub struct JsonDatabaseOptions {
62 pub dir: PathBuf,
63 pub output_file_path: PathBuf,
64 pub file_type_exclusions: HashSet<String>,
65 pub file_exclusions: Vec<String>,
66 pub verbose: bool,
67 pub chunker_config: ChunkerConfig,
68 pub max_concurrent_files: usize,
70 pub embedding_pool_size: usize,
72 pub embedding_batch_size: Option<usize>,
74}
75
76impl Default for JsonDatabaseOptions {
77 fn default() -> Self {
78 let cpu_count = std::thread::available_parallelism()
81 .map(|n| n.get())
82 .unwrap_or(4);
83 let default_pool = cpu_count.min(4).max(1);
84
85 Self {
86 dir: PathBuf::from("."),
87 output_file_path: PathBuf::from("embeddings.json"),
88 file_type_exclusions: Default::default(),
89 file_exclusions: Default::default(),
90 verbose: true,
91 chunker_config: ChunkerConfig::default(),
92 max_concurrent_files: 4,
93 embedding_pool_size: default_pool,
94 embedding_batch_size: None,
95 }
96 }
97}
98
99pub struct JsonDatabaseGenerator {
101 options: JsonDatabaseOptions,
102 embeddings_pool: EmbeddingPool,
103}
104
105impl JsonDatabaseGenerator {
106 pub fn new(options: JsonDatabaseOptions) -> Result<Self> {
108 let embeddings_pool = EmbeddingPool::new(options.embedding_pool_size)?;
111
112 Ok(Self {
113 options,
114 embeddings_pool,
115 })
116 }
117
118 async fn get_tracked_files(&self) -> Result<Vec<String>> {
120 self.get_tracked_files_internal().await
121 }
122
123 async fn get_tracked_files_internal(&self) -> Result<Vec<String>> {
124 let output = Command::new("git")
126 .arg("ls-files")
127 .current_dir(&self.options.dir)
128 .output()?;
129
130 if !output.status.success() {
131 return Err(anyhow::anyhow!("git ls-files failed"));
132 }
133
134 let output_str = String::from_utf8(output.stdout)?;
135 let tracked_files: Vec<String> = output_str
136 .lines()
137 .filter(|line| !line.trim().is_empty())
138 .map(|s| s.to_string())
139 .collect();
140
141 if self.options.verbose {
142 println!("Total tracked files: {}", tracked_files.len());
143 }
144
145 let total_files = tracked_files.len();
146
147 let filtered_files = tracked_files
149 .into_iter()
150 .filter(|file| {
151 let path = Path::new(file);
152 let ext = path
153 .extension()
154 .and_then(|e| e.to_str())
155 .map(|e| format!(".{}", e))
156 .unwrap_or_default();
157
158 if self.options.file_type_exclusions.contains(&ext) {
160 return false;
161 }
162
163 !self.matches_exclusion_patterns(file)
165 })
166 .collect::<Vec<_>>();
167
168 if self.options.verbose {
169 println!("Excluded files: {}", total_files - filtered_files.len());
170 println!(
171 "Files to process for embeddings: {}",
172 filtered_files.len()
173 );
174 }
175
176 Ok(filtered_files)
177 }
178
179 fn matches_exclusion_patterns(&self, file: &str) -> bool {
180 for pattern in &self.options.file_exclusions {
181 if self.glob_match(pattern, file) {
182 return true;
183 }
184 }
185 false
186 }
187
188 fn glob_match(&self, pattern: &str, path: &str) -> bool {
189 use regex::Regex;
190 let pattern = pattern
191 .replace("**", ".*")
192 .replace("*", "[^/]*")
193 .replace("?", "[^/]");
194 let pattern = format!("^{}$", pattern);
195
196 if let Ok(re) = Regex::new(&pattern) {
197 re.is_match(path)
198 } else {
199 false
200 }
201 }
202
203 pub async fn generate_database(&self) -> Result<JsonDatabaseResult> {
205 let overall_start = Instant::now();
206 let tracked_files = self.get_tracked_files().await?;
207
208 if self.options.verbose {
209 println!("Generating embeddings for {} files", tracked_files.len());
210 println!("Processing with max {} concurrent files", self.options.max_concurrent_files);
211 }
212
213 let semaphore = Arc::new(Semaphore::new(self.options.max_concurrent_files));
215
216 let stage_start = Instant::now();
218 let mut tasks = Vec::new();
219 for (file_idx, file) in tracked_files.iter().enumerate() {
220 let absolute_path = self.options.dir.join(file);
221 let file = file.clone();
222 let semaphore = semaphore.clone();
223 let chunker_config = self.options.chunker_config.clone();
224 let verbose = self.options.verbose;
225 let total_files = tracked_files.len();
226
227 let task = tokio::spawn(async move {
228 let _permit = semaphore.acquire().await.unwrap();
230
231 if verbose {
232 println!("Processing file {}/{}: {}", file_idx + 1, total_files, file);
233 }
234
235 match Self::process_file_stage_chunks(&absolute_path, &file, &chunker_config, verbose).await {
236 Ok(chunks) => Ok(chunks),
237 Err(e) => {
238 if verbose {
239 eprintln!("Error processing file {}: {}", file, e);
240 }
241 Err(e)
242 }
243 }
244 });
245
246 tasks.push(task);
247 }
248
249 let mut pending_chunks: Vec<PendingChunk> = Vec::new();
251 for task in tasks {
252 match task.await {
253 Ok(Ok(mut chunks)) => {
254 pending_chunks.append(&mut chunks);
255 }
256 Ok(Err(_)) => {
257 }
259 Err(e) => {
260 if self.options.verbose {
261 eprintln!("Task join error: {}", e);
262 }
263 }
264 }
265 }
266
267 let stage_elapsed = stage_start.elapsed();
268 let total_chunks_count = pending_chunks.len();
269 let staged_bytes: usize = pending_chunks.iter().map(|c| c.content.len()).sum();
270
271 if self.options.verbose {
272 let secs = stage_elapsed.as_secs_f64().max(1e-9);
273 let chunks_per_sec = total_chunks_count as f64 / secs;
274 let mb = staged_bytes as f64 / (1024.0 * 1024.0);
275 println!(
276 "[perf] Staging: files={}, chunks={}, bytes={:.2} MiB, time={:.3}s, throughput={:.1} chunks/s",
277 tracked_files.len(), total_chunks_count, mb, stage_elapsed.as_secs_f64(), chunks_per_sec
278 );
279 }
280
281 if total_chunks_count == 0 {
282 if self.options.verbose {
283 println!("No chunks produced; writing empty database.");
284 }
285 let database = EmbeddingsDatabase {
286 version: "1.0".to_string(),
287 generated_at: Utc::now().to_rfc3339(),
288 model: "EmbeddingGemma300M".to_string(),
289 chunk_size: self.options.chunker_config.chunk_size,
290 overlap_size: self.options.chunker_config.overlap_size,
291 total_files: tracked_files.len(),
292 total_chunks: 0,
293 chunks: vec![],
294 };
295 let json = serde_json::to_string_pretty(&database)?;
296 fs::write(&self.options.output_file_path, json).await?;
297 return Ok(JsonDatabaseResult { success: true, total_files: tracked_files.len(), total_chunks: 0 });
298 }
299
300 if self.options.verbose {
301 println!("Staged {} chunks; generating embeddings in global batches...", total_chunks_count);
302 }
303
304 let documents: Vec<String> = pending_chunks.iter().map(|pc| pc.content.clone()).collect();
306
307 let embed_start = Instant::now();
309 let backend_batch_size = self.options.embedding_batch_size;
310 let per_job_batch = 2048usize; if self.options.verbose {
312 println!(
313 "[perf] Embedding config: pool_size={}, per_job_batch={}, backend_batch_size={:?}",
314 self.options.embedding_pool_size, per_job_batch, backend_batch_size
315 );
316 }
317 let embeddings = self
318 .embeddings_pool
319 .embed_many_ordered(documents, Some(per_job_batch), backend_batch_size)
320 .await?;
321 let embed_elapsed = embed_start.elapsed();
322 if self.options.verbose {
323 let secs = embed_elapsed.as_secs_f64().max(1e-9);
324 let chunks_per_sec = total_chunks_count as f64 / secs;
325 println!(
326 "[perf] Embedding: chunks={}, time={:.3}s, throughput={:.1} chunks/s",
327 total_chunks_count, embed_elapsed.as_secs_f64(), chunks_per_sec
328 );
329 }
330
331 let mut all_chunks: Vec<EmbeddedChunk> = Vec::with_capacity(total_chunks_count);
333 for (i, pending) in pending_chunks.into_iter().enumerate() {
334 let embedding = embeddings.get(i)
335 .cloned()
336 .ok_or_else(|| anyhow::anyhow!("missing embedding for chunk {}", i))?;
337 all_chunks.push(EmbeddedChunk {
338 file_path: pending.file_path,
339 content: pending.content,
340 embedding,
341 metadata: pending.metadata,
342 });
343 }
344
345 if self.options.verbose {
346 println!("Total chunks generated: {}", all_chunks.len());
347 }
348
349 let database = EmbeddingsDatabase {
350 version: "1.0".to_string(),
351 generated_at: Utc::now().to_rfc3339(),
352 model: "EmbeddingGemma300M".to_string(),
353 chunk_size: self.options.chunker_config.chunk_size,
354 overlap_size: self.options.chunker_config.overlap_size,
355 total_files: tracked_files.len(),
356 total_chunks: all_chunks.len(),
357 chunks: all_chunks,
358 };
359
360 let write_start = Instant::now();
362 let json = serde_json::to_string_pretty(&database)?;
363 fs::write(&self.options.output_file_path, json).await?;
364 let write_elapsed = write_start.elapsed();
365
366 if self.options.verbose {
367 println!(
368 "JSON database created at {}",
369 self.options.output_file_path.display()
370 );
371 let total_elapsed = overall_start.elapsed();
372 let stage = stage_elapsed.as_secs_f64();
373 let embed = embed_elapsed.as_secs_f64();
374 let write = write_elapsed.as_secs_f64();
375 let total = total_elapsed.as_secs_f64();
376 println!(
377 "[perf] Totals: time={:.3}s (stage={:.3}s, embed={:.3}s, write={:.3}s)",
378 total, stage, embed, write
379 );
380 if total > 0.0 {
381 println!(
382 "[perf] Breakdown: stage={:.0}%, embed={:.0}%, write={:.0}%",
383 (stage / total * 100.0).round(),
384 (embed / total * 100.0).round(),
385 (write / total * 100.0).round()
386 );
387 }
388 }
389
390 Ok(JsonDatabaseResult {
391 success: true,
392 total_files: tracked_files.len(),
393 total_chunks: database.total_chunks,
394 })
395 }
396
397 async fn process_file_stage_chunks(
399 file_path: &Path,
400 relative_path: &str,
401 chunker_config: &ChunkerConfig,
402 verbose: bool,
403 ) -> Result<Vec<PendingChunk>> {
404 let content = fs::read_to_string(file_path).await?;
406 let content = clean_and_redact(&content);
407
408 if content.trim().is_empty() { return Ok(vec![]); }
409
410 let metadata = fs::metadata(file_path).await?;
412 let file_size = metadata.len();
413
414 let last_modified = metadata
415 .modified()
416 .ok()
417 .and_then(|time| {
418 let datetime: DateTime<Utc> = time.into();
419 Some(datetime.to_rfc3339())
420 });
421
422 let text_chunks = chunk_text(&content, chunker_config);
424 let total_chunks = text_chunks.len();
425
426 if text_chunks.is_empty() { return Ok(vec![]); }
427
428 if verbose { println!(" - Staged {} chunks", total_chunks); }
429
430 let pending: Vec<PendingChunk> = text_chunks
432 .into_iter()
433 .map(|text_chunk| PendingChunk {
434 file_path: relative_path.to_string(),
435 content: text_chunk.content,
436 metadata: ChunkMetadata {
437 chunk_index: text_chunk.chunk_index,
438 total_chunks,
439 file_size,
440 last_modified: last_modified.clone(),
441 start_index: text_chunk.start_index,
442 end_index: text_chunk.end_index,
443 },
444 })
445 .collect();
446
447 Ok(pending)
448 }
449}
450
451struct EmbeddingJob {
454 texts: Vec<String>,
455 batch_size: Option<usize>,
456 resp: oneshot::Sender<Result<Vec<Vec<f32>>>>,
457}
458
459#[derive(Clone)]
460struct EmbeddingPool(Arc<EmbeddingPoolInner>);
461
462struct EmbeddingPoolInner {
463 senders: Vec<mpsc::Sender<EmbeddingJob>>, next: AtomicUsize,
465}
466
467impl EmbeddingPool {
468 fn new(pool_size: usize) -> Result<Self> {
469 let size = pool_size.max(1);
470 let mut senders = Vec::with_capacity(size);
471 let mut readiness_rxs = Vec::with_capacity(size);
472
473 for worker_id in 0..size {
474 let (tx, mut rx) = mpsc::channel::<EmbeddingJob>(32);
476 let (ready_tx, ready_rx) = std_mpsc::channel::<Result<()>>();
478 std::thread::spawn(move || {
480 let mut generator = match EmbeddingsGenerator::new() {
482 Ok(g) => {
483 let _ = ready_tx.send(Ok(()));
485 g
486 }
487 Err(e) => {
488 let _ = ready_tx.send(Err(anyhow::anyhow!(format!(
490 "embedding worker {} init failed: {}",
491 worker_id, e
492 ))));
493 return;
494 }
495 };
496
497 while let Some(job) = rx.blocking_recv() {
499 let texts_refs: Vec<&str> = job.texts.iter().map(|s| s.as_str()).collect();
501 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
503 generator
504 .generate_embeddings(texts_refs, job.batch_size)
505 }))
506 .map_err(|_| anyhow::anyhow!("embedding worker {} panicked during generate", worker_id))
507 .and_then(|res| res.map_err(|e| anyhow::anyhow!(e)));
508
509 let _ = job.resp.send(result);
510 }
511 });
512
513 senders.push(tx);
514 readiness_rxs.push(ready_rx);
515 }
516
517 let init_timeout_secs: u64 = std::env::var("TOAK_EMBED_INIT_TIMEOUT_SECS")
519 .ok()
520 .and_then(|s| s.parse().ok())
521 .unwrap_or(20);
522 let start_wait = Instant::now();
523 for (idx, rx) in readiness_rxs.into_iter().enumerate() {
524 match rx.recv_timeout(std::time::Duration::from_secs(init_timeout_secs)) {
525 Ok(Ok(())) => { }
526 Ok(Err(e)) => {
527 return Err(anyhow::anyhow!(format!(
528 "embedding pool init failed: worker {} not ready: {}",
529 idx, e
530 )));
531 }
532 Err(_) => {
533 return Err(anyhow::anyhow!(format!(
534 "embedding pool init timed out after {}s waiting for worker {}",
535 init_timeout_secs, idx
536 )));
537 }
538 }
539 }
540 let _elapsed = start_wait.elapsed();
541
542 Ok(Self(Arc::new(EmbeddingPoolInner {
543 senders,
544 next: AtomicUsize::new(0),
545 })))
546 }
547
548 async fn embed(&self, texts: Vec<String>, batch_size: Option<usize>) -> Result<Vec<Vec<f32>>> {
549 let inner = &self.0;
550 let len = inner.senders.len();
551 let idx = inner.next.fetch_add(1, Ordering::Relaxed) % len;
552 let (resp_tx, resp_rx) = oneshot::channel();
553 let job = EmbeddingJob {
554 texts,
555 batch_size,
556 resp: resp_tx,
557 };
558 inner
559 .senders[idx]
560 .send(job)
561 .await
562 .map_err(|e| anyhow::anyhow!(
563 "failed to send embedding job: {}. hint: worker may have failed to initialize; try setting ORT_DISABLE_COREML=1 to force CPU or check startup logs.",
564 e
565 ))?;
566
567 let timeout_secs: u64 = std::env::var("TOAK_EMBED_TIMEOUT_SECS")
569 .ok()
570 .and_then(|s| s.parse().ok())
571 .unwrap_or(120);
572
573 match tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), resp_rx).await {
574 Ok(Ok(res)) => res,
575 Ok(Err(e)) => Err(anyhow::anyhow!("embedding worker dropped: {}", e)),
576 Err(_) => Err(anyhow::anyhow!(
577 "embedding job timed out after {}s; worker may be stalled",
578 timeout_secs
579 )),
580 }
581 }
582
583 async fn embed_many_ordered(
586 &self,
587 texts: Vec<String>,
588 per_job_batch: Option<usize>,
589 batch_size: Option<usize>,
590 ) -> Result<Vec<Vec<f32>>> {
591 let total = texts.len();
592 if total == 0 { return Ok(Vec::new()); }
593
594 let job_batch = per_job_batch.unwrap_or(2048).max(1);
595 let mut starts = Vec::new();
596 let mut futures = Vec::new();
597
598 let inner = &self.0;
599 let workers = inner.senders.len().max(1);
600 let mut rr = inner.next.fetch_add(0, Ordering::Relaxed) % workers; let mut i = 0;
604 while i < total {
605 let end = (i + job_batch).min(total);
606 let slice: Vec<String> = texts[i..end].to_vec();
607 let worker_idx = rr % workers;
608 rr = rr.wrapping_add(1);
609 let (resp_tx, resp_rx) = oneshot::channel();
611 let job = EmbeddingJob { texts: slice, batch_size, resp: resp_tx };
612 let sender = inner.senders[worker_idx].clone();
613 sender
614 .send(job)
615 .await
616 .map_err(|e| anyhow::anyhow!(
617 "failed to send embedding job to worker {}: {}. hint: worker may have failed to initialize; try ORT_DISABLE_COREML=1 or check initialization logs.",
618 worker_idx, e
619 ))?;
620 let rx = resp_rx;
621 starts.push(i);
622 futures.push(rx);
623 i = end;
624 }
625
626 let mut out: Vec<Vec<f32>> = (0..total).map(|_| Vec::new()).collect();
627
628 let timeout_secs: u64 = std::env::var("TOAK_EMBED_TIMEOUT_SECS")
631 .ok()
632 .and_then(|s| s.parse().ok())
633 .unwrap_or(120);
634
635 for (start, rx) in starts.into_iter().zip(futures.into_iter()) {
636 let batch = match tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), rx).await {
637 Ok(Ok(res)) => res?,
638 Ok(Err(e)) => return Err(anyhow::anyhow!("embedding worker dropped: {}", e)),
639 Err(_) => return Err(anyhow::anyhow!(
640 "embedding batch timed out after {}s; worker may be stalled",
641 timeout_secs
642 )),
643 };
644 for (offset, emb) in batch.into_iter().enumerate() {
645 out[start + offset] = emb;
646 }
647 }
648
649 Ok(out)
650 }
651}
652
653#[derive(Debug, Clone)]
655pub struct JsonDatabaseResult {
656 pub success: bool,
657 pub total_files: usize,
658 pub total_chunks: usize,
659}