1use std::collections::VecDeque;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Instant;
7
8use anyhow::{Result, anyhow};
9use chrono::Utc;
10use futures::future::BoxFuture;
11use futures::{FutureExt, StreamExt, stream::FuturesUnordered};
12use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError, mpsc};
13use tokio::task::JoinHandle;
14
15use crate::{
16 EmbeddingClient, EmbeddingConfig, IndexResult, RAGPipeline, SliceMode, StorageManager,
17};
18
19use super::contracts::{IndexControl, IndexEvent, IndexEventSink};
20
21type FileProcessor =
22 Arc<dyn Fn(usize, PathBuf, String) -> BoxFuture<'static, FileOutcome> + Send + Sync>;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25enum FileOutcome {
26 Indexed {
27 file_index: usize,
28 path: PathBuf,
29 chunks_indexed: usize,
30 content_hash: String,
31 duration_ms: u64,
32 embedder_ms: Option<u64>,
33 tokens_estimated: Option<usize>,
34 },
35 Skipped {
36 file_index: usize,
37 path: PathBuf,
38 reason: String,
39 content_hash: Option<String>,
40 },
41 Failed {
42 file_index: usize,
43 path: PathBuf,
44 error: String,
45 },
46}
47
48struct SchedulerState {
49 pending: VecDeque<(usize, PathBuf)>,
50 inflight: FuturesUnordered<JoinHandle<FileOutcome>>,
51 semaphore: Arc<Semaphore>,
52 resume_notify: Arc<Notify>,
53 namespace: String,
54 parallelism: usize,
55 paused: bool,
56 stop_requested: bool,
57 indexed: usize,
58 skipped: usize,
59 failed: usize,
60 total_chunks: usize,
61 started_at: Instant,
62 total: usize,
63}
64
65impl SchedulerState {
66 fn processed(&self) -> usize {
67 self.indexed + self.skipped + self.failed
68 }
69
70 fn in_flight(&self) -> usize {
71 self.inflight.len()
72 }
73
74 fn files_per_sec(&self) -> f64 {
75 let elapsed = self.started_at.elapsed().as_secs_f64();
76 if elapsed <= f64::EPSILON {
77 0.0
78 } else {
79 self.processed() as f64 / elapsed
80 }
81 }
82
83 fn eta_secs(&self) -> Option<f64> {
84 let rate = self.files_per_sec();
85 if rate <= f64::EPSILON {
86 None
87 } else {
88 Some(self.total.saturating_sub(self.processed()) as f64 / rate)
89 }
90 }
91}
92
93pub struct IndexingJob {
99 pub source_dir: PathBuf,
100 pub files: Vec<PathBuf>,
101 pub namespace: String,
102 pub embedding_config: EmbeddingConfig,
103 pub db_path: String,
104 pub initial_parallelism: usize,
105}
106
107pub fn start_indexing(
109 job: IndexingJob,
110 sink: Arc<dyn IndexEventSink>,
111 control_rx: mpsc::Receiver<IndexControl>,
112) -> JoinHandle<Result<()>> {
113 tokio::spawn(async move {
114 let IndexingJob {
115 source_dir,
116 files,
117 namespace,
118 embedding_config,
119 db_path,
120 initial_parallelism,
121 } = job;
122
123 let expanded_db_path = shellexpand::tilde(&db_path).to_string();
124 let storage = Arc::new(StorageManager::new_lance_only(&expanded_db_path).await?);
125 storage.ensure_collection().await?;
126
127 let embedding_client = Arc::new(tokio::sync::Mutex::new(
128 EmbeddingClient::new(&embedding_config).await?,
129 ));
130 let pipeline = Arc::new(RAGPipeline::new(embedding_client, storage).await?);
131
132 let processor: FileProcessor = Arc::new(move |file_index, path, namespace| {
133 let pipeline = pipeline.clone();
134 async move {
135 let started_at = Instant::now();
136 match pipeline
137 .index_document_with_dedup(&path, Some(&namespace), SliceMode::Onion)
138 .await
139 {
140 Ok(IndexResult::Indexed {
141 chunks_indexed,
142 content_hash,
143 embedder_ms,
144 tokens_estimated,
145 }) => FileOutcome::Indexed {
146 file_index,
147 path,
148 chunks_indexed,
149 content_hash,
150 duration_ms: started_at.elapsed().as_millis() as u64,
151 embedder_ms,
152 tokens_estimated,
153 },
154 Ok(IndexResult::Skipped {
155 reason,
156 content_hash,
157 }) => FileOutcome::Skipped {
158 file_index,
159 path,
160 reason,
161 content_hash: Some(content_hash),
162 },
163 Err(error) => FileOutcome::Failed {
164 file_index,
165 path,
166 error: error.to_string(),
167 },
168 }
169 }
170 .boxed()
171 });
172
173 run_scheduler_with_processor(
174 source_dir,
175 files,
176 namespace,
177 sink,
178 control_rx,
179 initial_parallelism,
180 processor,
181 )
182 .await
183 })
184}
185
186async fn run_scheduler_with_processor(
187 source_dir: PathBuf,
188 files: Vec<PathBuf>,
189 namespace: String,
190 sink: Arc<dyn IndexEventSink>,
191 mut control_rx: mpsc::Receiver<IndexControl>,
192 initial_parallelism: usize,
193 processor: FileProcessor,
194) -> Result<()> {
195 let parallelism = initial_parallelism.max(1);
196 let mut state = SchedulerState {
197 total: files.len(),
198 pending: files.into_iter().enumerate().collect(),
199 inflight: FuturesUnordered::new(),
200 semaphore: Arc::new(Semaphore::new(parallelism)),
201 resume_notify: Arc::new(Notify::new()),
202 namespace,
203 parallelism,
204 paused: false,
205 stop_requested: false,
206 indexed: 0,
207 skipped: 0,
208 failed: 0,
209 total_chunks: 0,
210 started_at: Instant::now(),
211 };
212
213 sink.on_event(&IndexEvent::RunStarted {
214 total_files: state.total,
215 namespace: state.namespace.clone(),
216 source_dir: source_dir.display().to_string(),
217 parallelism: state.parallelism,
218 started_at: Utc::now(),
219 });
220 emit_stats_tick(&state, &sink);
221
222 let mut stats_interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
223
224 loop {
225 drain_control_queue(&mut state, &sink, &mut control_rx);
226
227 if state.stop_requested {
228 if state.inflight.is_empty() {
229 break;
230 }
231
232 tokio::select! {
233 _ = stats_interval.tick() => {
234 emit_stats_tick(&state, &sink);
235 }
236 Some(control) = control_rx.recv() => {
237 handle_control(&mut state, &sink, control);
238 }
239 Some(join_result) = state.inflight.next() => {
240 apply_join_result(&mut state, &sink, join_result)?;
241 }
242 }
243 continue;
244 }
245
246 if state.paused {
247 let notify = state.resume_notify.clone();
248 let resume_wait = notify.notified();
249 tokio::pin!(resume_wait);
250
251 tokio::select! {
252 _ = stats_interval.tick() => {
253 emit_stats_tick(&state, &sink);
254 }
255 Some(control) = control_rx.recv() => {
256 handle_control(&mut state, &sink, control);
257 }
258 Some(join_result) = state.inflight.next(), if !state.inflight.is_empty() => {
259 apply_join_result(&mut state, &sink, join_result)?;
260 }
261 _ = &mut resume_wait => {}
262 }
263 continue;
264 }
265
266 spawn_ready_tasks(&mut state, &sink, processor.clone());
267
268 if state.pending.is_empty() && state.inflight.is_empty() {
269 break;
270 }
271
272 tokio::select! {
273 _ = stats_interval.tick() => {
274 emit_stats_tick(&state, &sink);
275 }
276 Some(control) = control_rx.recv() => {
277 handle_control(&mut state, &sink, control);
278 }
279 Some(join_result) = state.inflight.next(), if !state.inflight.is_empty() => {
280 apply_join_result(&mut state, &sink, join_result)?;
281 }
282 else => {
283 tokio::task::yield_now().await;
284 }
285 }
286 }
287
288 sink.on_event(&IndexEvent::RunCompleted {
289 processed: state.processed(),
290 indexed: state.indexed,
291 skipped: state.skipped,
292 failed: state.failed,
293 total_chunks: state.total_chunks,
294 elapsed: state.started_at.elapsed(),
295 stopped_early: state.stop_requested,
296 });
297
298 Ok(())
299}
300
301fn drain_control_queue(
302 state: &mut SchedulerState,
303 sink: &Arc<dyn IndexEventSink>,
304 control_rx: &mut mpsc::Receiver<IndexControl>,
305) {
306 while let Ok(control) = control_rx.try_recv() {
307 handle_control(state, sink, control);
308 }
309}
310
311fn spawn_ready_tasks(
312 state: &mut SchedulerState,
313 sink: &Arc<dyn IndexEventSink>,
314 processor: FileProcessor,
315) {
316 if state.paused || state.stop_requested {
317 return;
318 }
319
320 while state.in_flight() < state.parallelism && !state.pending.is_empty() {
321 let permit = match try_acquire_permit(&state.semaphore) {
322 Ok(Some(permit)) => permit,
323 Ok(None) => break,
324 Err(_) => break,
325 };
326
327 let Some((file_index, path)) = state.pending.pop_front() else {
328 break;
329 };
330 let size_bytes = std::fs::metadata(&path)
331 .map(|metadata| metadata.len())
332 .unwrap_or(0);
333 sink.on_event(&IndexEvent::FileStarted {
334 file_index,
335 path: path.display().to_string(),
336 size_bytes,
337 });
338
339 let work = processor.clone();
340 let namespace = state.namespace.clone();
341 let join_handle = tokio::spawn(async move {
342 let _permit = permit;
343 work(file_index, path, namespace).await
344 });
345 state.inflight.push(join_handle);
346 emit_stats_tick(state, sink);
347 }
348}
349
350fn try_acquire_permit(
351 semaphore: &Arc<Semaphore>,
352) -> Result<Option<OwnedSemaphorePermit>, TryAcquireError> {
353 semaphore
354 .clone()
355 .try_acquire_owned()
356 .map(Some)
357 .or_else(|error| {
358 if matches!(error, TryAcquireError::NoPermits) {
359 Ok(None)
360 } else {
361 Err(error)
362 }
363 })
364}
365
366fn handle_control(
367 state: &mut SchedulerState,
368 sink: &Arc<dyn IndexEventSink>,
369 control: IndexControl,
370) {
371 match control {
372 IndexControl::Pause => {
373 if !state.paused && !state.stop_requested {
374 state.paused = true;
375 sink.on_event(&IndexEvent::Paused);
376 emit_stats_tick(state, sink);
377 }
378 }
379 IndexControl::Resume => {
380 if state.paused && !state.stop_requested {
381 state.paused = false;
382 state.resume_notify.notify_waiters();
383 sink.on_event(&IndexEvent::Resumed);
384 emit_stats_tick(state, sink);
385 }
386 }
387 IndexControl::SetParallelism(level) => {
388 let next = level.max(1);
389 let previous = state.parallelism;
390 if next != previous {
391 adjust_parallelism(&state.semaphore, previous, next);
392 state.parallelism = next;
393 sink.on_event(&IndexEvent::ParallelismChanged {
394 previous,
395 current: next,
396 });
397 emit_stats_tick(state, sink);
398 }
399 }
400 IndexControl::Stop => {
401 if !state.stop_requested {
402 state.stop_requested = true;
403 state.paused = false;
404 state.resume_notify.notify_waiters();
405 sink.on_event(&IndexEvent::StopRequested);
406 emit_stats_tick(state, sink);
407 }
408 }
409 }
410}
411
412fn adjust_parallelism(semaphore: &Arc<Semaphore>, previous: usize, next: usize) {
413 if next > previous {
414 semaphore.add_permits(next - previous);
415 return;
416 }
417
418 for _ in 0..(previous - next) {
419 match semaphore.try_acquire() {
420 Ok(permit) => permit.forget(),
421 Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => break,
422 }
423 }
424}
425
426fn apply_join_result(
427 state: &mut SchedulerState,
428 sink: &Arc<dyn IndexEventSink>,
429 join_result: Result<FileOutcome, tokio::task::JoinError>,
430) -> Result<()> {
431 let outcome = match join_result {
432 Ok(outcome) => outcome,
433 Err(error) => {
434 let message = format!("indexing task join failed: {error}");
435 sink.on_event(&IndexEvent::RunFailed {
436 error: message.clone(),
437 processed_before_failure: state.processed(),
438 });
439 return Err(anyhow!(message));
440 }
441 };
442
443 apply_outcome(state, sink, outcome);
444 Ok(())
445}
446
447fn apply_outcome(state: &mut SchedulerState, sink: &Arc<dyn IndexEventSink>, outcome: FileOutcome) {
448 match outcome {
449 FileOutcome::Indexed {
450 file_index,
451 path,
452 chunks_indexed,
453 content_hash,
454 duration_ms,
455 embedder_ms,
456 tokens_estimated,
457 } => {
458 state.indexed += 1;
459 state.total_chunks += chunks_indexed;
460 sink.on_event(&IndexEvent::FileIndexed {
461 file_index,
462 path: path.display().to_string(),
463 chunks_indexed,
464 content_hash,
465 duration_ms,
466 embedder_ms,
467 tokens_estimated,
468 });
469 }
470 FileOutcome::Skipped {
471 file_index,
472 path,
473 reason,
474 content_hash,
475 } => {
476 state.skipped += 1;
477 sink.on_event(&IndexEvent::FileSkipped {
478 file_index,
479 path: path.display().to_string(),
480 reason,
481 content_hash,
482 });
483 }
484 FileOutcome::Failed {
485 file_index,
486 path,
487 error,
488 } => {
489 state.failed += 1;
490 sink.on_event(&IndexEvent::FileFailed {
491 file_index,
492 path: path.display().to_string(),
493 error,
494 });
495 }
496 }
497
498 emit_stats_tick(state, sink);
499}
500
501fn emit_stats_tick(state: &SchedulerState, sink: &Arc<dyn IndexEventSink>) {
502 sink.on_event(&IndexEvent::StatsTick {
503 processed: state.processed(),
504 indexed: state.indexed,
505 skipped: state.skipped,
506 failed: state.failed,
507 total: state.total,
508 files_per_sec: state.files_per_sec(),
509 eta_secs: state.eta_secs(),
510 total_chunks: state.total_chunks,
511 in_flight: state.in_flight(),
512 });
513}
514
515#[cfg(test)]
516mod tests {
517 use std::path::Path;
518 use std::sync::{Arc, Mutex as StdMutex};
519 use std::time::Duration;
520
521 use super::*;
522 use crate::tui::indexer::contracts::INDEX_CONTROL_CHANNEL_CAPACITY;
523
524 struct RecordingSink {
525 events: Arc<StdMutex<Vec<IndexEvent>>>,
526 }
527
528 impl RecordingSink {
529 fn new() -> Self {
530 Self {
531 events: Arc::new(StdMutex::new(Vec::new())),
532 }
533 }
534
535 fn events(&self) -> Vec<IndexEvent> {
536 self.events
537 .lock()
538 .unwrap_or_else(|poisoned| poisoned.into_inner())
539 .clone()
540 }
541 }
542
543 impl IndexEventSink for RecordingSink {
544 fn on_event(&self, event: &IndexEvent) {
545 self.events
546 .lock()
547 .unwrap_or_else(|poisoned| poisoned.into_inner())
548 .push(event.clone());
549 }
550 }
551
552 fn test_files(count: usize) -> Vec<PathBuf> {
553 (0..count)
554 .map(|index| Path::new("/tmp").join(format!("file-{index}.txt")))
555 .collect()
556 }
557
558 #[tokio::test]
559 async fn scheduler_pause_resume_blocks_new_starts_until_resumed() {
560 let sink = Arc::new(RecordingSink::new());
561 let (control_tx, control_rx) = mpsc::channel(INDEX_CONTROL_CHANNEL_CAPACITY);
562
563 let processor: FileProcessor = Arc::new(move |file_index, path, _namespace| {
564 async move {
565 tokio::time::sleep(Duration::from_millis(80)).await;
566 FileOutcome::Indexed {
567 file_index,
568 path,
569 chunks_indexed: 1,
570 content_hash: format!("hash-{file_index}"),
571 duration_ms: 5,
572 embedder_ms: Some(5),
573 tokens_estimated: Some(10),
574 }
575 }
576 .boxed()
577 });
578
579 let join = tokio::spawn(run_scheduler_with_processor(
580 PathBuf::from("/tmp"),
581 test_files(10),
582 "kb:test".to_string(),
583 sink.clone(),
584 control_rx,
585 2,
586 processor,
587 ));
588
589 tokio::time::sleep(Duration::from_millis(30)).await;
590 control_tx
591 .send(IndexControl::Pause)
592 .await
593 .expect("send pause");
594 tokio::time::sleep(Duration::from_millis(30)).await;
595
596 let events_after_pause = sink.events();
597 let started_before_resume = events_after_pause
598 .iter()
599 .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
600 .count();
601 assert!(
602 events_after_pause
603 .iter()
604 .any(|event| matches!(event, IndexEvent::Paused))
605 );
606
607 tokio::time::sleep(Duration::from_millis(60)).await;
608 let events_still_paused = sink.events();
609 let started_while_paused = events_still_paused
610 .iter()
611 .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
612 .count();
613 assert_eq!(started_while_paused, started_before_resume);
614
615 control_tx
616 .send(IndexControl::Resume)
617 .await
618 .expect("send resume");
619
620 join.await
621 .expect("scheduler join")
622 .expect("scheduler result");
623
624 let final_events = sink.events();
625 assert!(
626 final_events
627 .iter()
628 .any(|event| matches!(event, IndexEvent::Resumed))
629 );
630 let final_started = final_events
631 .iter()
632 .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
633 .count();
634 assert_eq!(final_started, 10);
635 }
636
637 #[tokio::test]
638 async fn scheduler_stop_drains_inflight_and_completes_cleanly() {
639 let sink = Arc::new(RecordingSink::new());
640 let (control_tx, control_rx) = mpsc::channel(INDEX_CONTROL_CHANNEL_CAPACITY);
641
642 let processor: FileProcessor = Arc::new(move |file_index, path, _namespace| {
643 async move {
644 tokio::time::sleep(Duration::from_millis(80)).await;
645 FileOutcome::Indexed {
646 file_index,
647 path,
648 chunks_indexed: 1,
649 content_hash: format!("hash-{file_index}"),
650 duration_ms: 5,
651 embedder_ms: None,
652 tokens_estimated: None,
653 }
654 }
655 .boxed()
656 });
657
658 let join = tokio::spawn(run_scheduler_with_processor(
659 PathBuf::from("/tmp"),
660 test_files(100),
661 "kb:test".to_string(),
662 sink.clone(),
663 control_rx,
664 4,
665 processor,
666 ));
667
668 tokio::time::sleep(Duration::from_millis(30)).await;
669 control_tx
670 .send(IndexControl::Stop)
671 .await
672 .expect("send stop");
673
674 join.await
675 .expect("scheduler join")
676 .expect("scheduler result");
677
678 let events = sink.events();
679 assert!(
680 events
681 .iter()
682 .any(|event| matches!(event, IndexEvent::StopRequested))
683 );
684 assert!(
685 events
686 .iter()
687 .any(|event| matches!(event, IndexEvent::RunCompleted { .. }))
688 );
689 }
690}