Skip to main content

rust_memex/tui/indexer/
scheduler.rs

1//! Concurrent indexer scheduler with pause/resume/stop controls.
2
3use std::collections::VecDeque;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Instant;
7
8use anyhow::{Result, anyhow};
9use chrono::Utc;
10use futures::future::BoxFuture;
11use futures::{FutureExt, StreamExt, stream::FuturesUnordered};
12use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError, mpsc};
13use tokio::task::JoinHandle;
14
15use crate::{
16    EmbeddingClient, EmbeddingConfig, IndexResult, RAGPipeline, SliceMode, StorageManager,
17};
18
19use super::contracts::{IndexControl, IndexEvent, IndexEventSink};
20
21type FileProcessor =
22    Arc<dyn Fn(usize, PathBuf, String) -> BoxFuture<'static, FileOutcome> + Send + Sync>;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25enum FileOutcome {
26    Indexed {
27        file_index: usize,
28        path: PathBuf,
29        chunks_indexed: usize,
30        content_hash: String,
31        duration_ms: u64,
32        embedder_ms: Option<u64>,
33        tokens_estimated: Option<usize>,
34    },
35    Skipped {
36        file_index: usize,
37        path: PathBuf,
38        reason: String,
39        content_hash: Option<String>,
40    },
41    Failed {
42        file_index: usize,
43        path: PathBuf,
44        error: String,
45    },
46}
47
48struct SchedulerState {
49    pending: VecDeque<(usize, PathBuf)>,
50    inflight: FuturesUnordered<JoinHandle<FileOutcome>>,
51    semaphore: Arc<Semaphore>,
52    resume_notify: Arc<Notify>,
53    namespace: String,
54    parallelism: usize,
55    paused: bool,
56    stop_requested: bool,
57    indexed: usize,
58    skipped: usize,
59    failed: usize,
60    total_chunks: usize,
61    started_at: Instant,
62    total: usize,
63}
64
65impl SchedulerState {
66    fn processed(&self) -> usize {
67        self.indexed + self.skipped + self.failed
68    }
69
70    fn in_flight(&self) -> usize {
71        self.inflight.len()
72    }
73
74    fn files_per_sec(&self) -> f64 {
75        let elapsed = self.started_at.elapsed().as_secs_f64();
76        if elapsed <= f64::EPSILON {
77            0.0
78        } else {
79            self.processed() as f64 / elapsed
80        }
81    }
82
83    fn eta_secs(&self) -> Option<f64> {
84        let rate = self.files_per_sec();
85        if rate <= f64::EPSILON {
86            None
87        } else {
88            Some(self.total.saturating_sub(self.processed()) as f64 / rate)
89        }
90    }
91}
92
93/// Parameters describing *what* to index and *how* (data + tuning).
94///
95/// Runtime wiring (event sink + control channel) is kept as separate
96/// arguments to `start_indexing` because those are caller-owned ownership
97/// handles rather than job configuration.
98pub struct IndexingJob {
99    pub source_dir: PathBuf,
100    pub files: Vec<PathBuf>,
101    pub namespace: String,
102    pub embedding_config: EmbeddingConfig,
103    pub db_path: String,
104    pub initial_parallelism: usize,
105}
106
107/// Start the concurrent indexing scheduler.
108pub fn start_indexing(
109    job: IndexingJob,
110    sink: Arc<dyn IndexEventSink>,
111    control_rx: mpsc::Receiver<IndexControl>,
112) -> JoinHandle<Result<()>> {
113    tokio::spawn(async move {
114        let IndexingJob {
115            source_dir,
116            files,
117            namespace,
118            embedding_config,
119            db_path,
120            initial_parallelism,
121        } = job;
122
123        let expanded_db_path = shellexpand::tilde(&db_path).to_string();
124        let storage = Arc::new(StorageManager::new_lance_only(&expanded_db_path).await?);
125        storage.ensure_collection().await?;
126
127        let embedding_client = Arc::new(tokio::sync::Mutex::new(
128            EmbeddingClient::new(&embedding_config).await?,
129        ));
130        let pipeline = Arc::new(RAGPipeline::new(embedding_client, storage).await?);
131
132        let processor: FileProcessor = Arc::new(move |file_index, path, namespace| {
133            let pipeline = pipeline.clone();
134            async move {
135                let started_at = Instant::now();
136                match pipeline
137                    .index_document_with_dedup(&path, Some(&namespace), SliceMode::Onion)
138                    .await
139                {
140                    Ok(IndexResult::Indexed {
141                        chunks_indexed,
142                        content_hash,
143                        embedder_ms,
144                        tokens_estimated,
145                    }) => FileOutcome::Indexed {
146                        file_index,
147                        path,
148                        chunks_indexed,
149                        content_hash,
150                        duration_ms: started_at.elapsed().as_millis() as u64,
151                        embedder_ms,
152                        tokens_estimated,
153                    },
154                    Ok(IndexResult::Skipped {
155                        reason,
156                        content_hash,
157                    }) => FileOutcome::Skipped {
158                        file_index,
159                        path,
160                        reason,
161                        content_hash: Some(content_hash),
162                    },
163                    Err(error) => FileOutcome::Failed {
164                        file_index,
165                        path,
166                        error: error.to_string(),
167                    },
168                }
169            }
170            .boxed()
171        });
172
173        run_scheduler_with_processor(
174            source_dir,
175            files,
176            namespace,
177            sink,
178            control_rx,
179            initial_parallelism,
180            processor,
181        )
182        .await
183    })
184}
185
186async fn run_scheduler_with_processor(
187    source_dir: PathBuf,
188    files: Vec<PathBuf>,
189    namespace: String,
190    sink: Arc<dyn IndexEventSink>,
191    mut control_rx: mpsc::Receiver<IndexControl>,
192    initial_parallelism: usize,
193    processor: FileProcessor,
194) -> Result<()> {
195    let parallelism = initial_parallelism.max(1);
196    let mut state = SchedulerState {
197        total: files.len(),
198        pending: files.into_iter().enumerate().collect(),
199        inflight: FuturesUnordered::new(),
200        semaphore: Arc::new(Semaphore::new(parallelism)),
201        resume_notify: Arc::new(Notify::new()),
202        namespace,
203        parallelism,
204        paused: false,
205        stop_requested: false,
206        indexed: 0,
207        skipped: 0,
208        failed: 0,
209        total_chunks: 0,
210        started_at: Instant::now(),
211    };
212
213    sink.on_event(&IndexEvent::RunStarted {
214        total_files: state.total,
215        namespace: state.namespace.clone(),
216        source_dir: source_dir.display().to_string(),
217        parallelism: state.parallelism,
218        started_at: Utc::now(),
219    });
220    emit_stats_tick(&state, &sink);
221
222    let mut stats_interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
223
224    loop {
225        drain_control_queue(&mut state, &sink, &mut control_rx);
226
227        if state.stop_requested {
228            if state.inflight.is_empty() {
229                break;
230            }
231
232            tokio::select! {
233                _ = stats_interval.tick() => {
234                    emit_stats_tick(&state, &sink);
235                }
236                Some(control) = control_rx.recv() => {
237                    handle_control(&mut state, &sink, control);
238                }
239                Some(join_result) = state.inflight.next() => {
240                    apply_join_result(&mut state, &sink, join_result)?;
241                }
242            }
243            continue;
244        }
245
246        if state.paused {
247            let notify = state.resume_notify.clone();
248            let resume_wait = notify.notified();
249            tokio::pin!(resume_wait);
250
251            tokio::select! {
252                _ = stats_interval.tick() => {
253                    emit_stats_tick(&state, &sink);
254                }
255                Some(control) = control_rx.recv() => {
256                    handle_control(&mut state, &sink, control);
257                }
258                Some(join_result) = state.inflight.next(), if !state.inflight.is_empty() => {
259                    apply_join_result(&mut state, &sink, join_result)?;
260                }
261                _ = &mut resume_wait => {}
262            }
263            continue;
264        }
265
266        spawn_ready_tasks(&mut state, &sink, processor.clone());
267
268        if state.pending.is_empty() && state.inflight.is_empty() {
269            break;
270        }
271
272        tokio::select! {
273            _ = stats_interval.tick() => {
274                emit_stats_tick(&state, &sink);
275            }
276            Some(control) = control_rx.recv() => {
277                handle_control(&mut state, &sink, control);
278            }
279            Some(join_result) = state.inflight.next(), if !state.inflight.is_empty() => {
280                apply_join_result(&mut state, &sink, join_result)?;
281            }
282            else => {
283                tokio::task::yield_now().await;
284            }
285        }
286    }
287
288    sink.on_event(&IndexEvent::RunCompleted {
289        processed: state.processed(),
290        indexed: state.indexed,
291        skipped: state.skipped,
292        failed: state.failed,
293        total_chunks: state.total_chunks,
294        elapsed: state.started_at.elapsed(),
295        stopped_early: state.stop_requested,
296    });
297
298    Ok(())
299}
300
301fn drain_control_queue(
302    state: &mut SchedulerState,
303    sink: &Arc<dyn IndexEventSink>,
304    control_rx: &mut mpsc::Receiver<IndexControl>,
305) {
306    while let Ok(control) = control_rx.try_recv() {
307        handle_control(state, sink, control);
308    }
309}
310
311fn spawn_ready_tasks(
312    state: &mut SchedulerState,
313    sink: &Arc<dyn IndexEventSink>,
314    processor: FileProcessor,
315) {
316    if state.paused || state.stop_requested {
317        return;
318    }
319
320    while state.in_flight() < state.parallelism && !state.pending.is_empty() {
321        let permit = match try_acquire_permit(&state.semaphore) {
322            Ok(Some(permit)) => permit,
323            Ok(None) => break,
324            Err(_) => break,
325        };
326
327        let Some((file_index, path)) = state.pending.pop_front() else {
328            break;
329        };
330        let size_bytes = std::fs::metadata(&path)
331            .map(|metadata| metadata.len())
332            .unwrap_or(0);
333        sink.on_event(&IndexEvent::FileStarted {
334            file_index,
335            path: path.display().to_string(),
336            size_bytes,
337        });
338
339        let work = processor.clone();
340        let namespace = state.namespace.clone();
341        let join_handle = tokio::spawn(async move {
342            let _permit = permit;
343            work(file_index, path, namespace).await
344        });
345        state.inflight.push(join_handle);
346        emit_stats_tick(state, sink);
347    }
348}
349
350fn try_acquire_permit(
351    semaphore: &Arc<Semaphore>,
352) -> Result<Option<OwnedSemaphorePermit>, TryAcquireError> {
353    semaphore
354        .clone()
355        .try_acquire_owned()
356        .map(Some)
357        .or_else(|error| {
358            if matches!(error, TryAcquireError::NoPermits) {
359                Ok(None)
360            } else {
361                Err(error)
362            }
363        })
364}
365
366fn handle_control(
367    state: &mut SchedulerState,
368    sink: &Arc<dyn IndexEventSink>,
369    control: IndexControl,
370) {
371    match control {
372        IndexControl::Pause => {
373            if !state.paused && !state.stop_requested {
374                state.paused = true;
375                sink.on_event(&IndexEvent::Paused);
376                emit_stats_tick(state, sink);
377            }
378        }
379        IndexControl::Resume => {
380            if state.paused && !state.stop_requested {
381                state.paused = false;
382                state.resume_notify.notify_waiters();
383                sink.on_event(&IndexEvent::Resumed);
384                emit_stats_tick(state, sink);
385            }
386        }
387        IndexControl::SetParallelism(level) => {
388            let next = level.max(1);
389            let previous = state.parallelism;
390            if next != previous {
391                adjust_parallelism(&state.semaphore, previous, next);
392                state.parallelism = next;
393                sink.on_event(&IndexEvent::ParallelismChanged {
394                    previous,
395                    current: next,
396                });
397                emit_stats_tick(state, sink);
398            }
399        }
400        IndexControl::Stop => {
401            if !state.stop_requested {
402                state.stop_requested = true;
403                state.paused = false;
404                state.resume_notify.notify_waiters();
405                sink.on_event(&IndexEvent::StopRequested);
406                emit_stats_tick(state, sink);
407            }
408        }
409    }
410}
411
412fn adjust_parallelism(semaphore: &Arc<Semaphore>, previous: usize, next: usize) {
413    if next > previous {
414        semaphore.add_permits(next - previous);
415        return;
416    }
417
418    for _ in 0..(previous - next) {
419        match semaphore.try_acquire() {
420            Ok(permit) => permit.forget(),
421            Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => break,
422        }
423    }
424}
425
426fn apply_join_result(
427    state: &mut SchedulerState,
428    sink: &Arc<dyn IndexEventSink>,
429    join_result: Result<FileOutcome, tokio::task::JoinError>,
430) -> Result<()> {
431    let outcome = match join_result {
432        Ok(outcome) => outcome,
433        Err(error) => {
434            let message = format!("indexing task join failed: {error}");
435            sink.on_event(&IndexEvent::RunFailed {
436                error: message.clone(),
437                processed_before_failure: state.processed(),
438            });
439            return Err(anyhow!(message));
440        }
441    };
442
443    apply_outcome(state, sink, outcome);
444    Ok(())
445}
446
447fn apply_outcome(state: &mut SchedulerState, sink: &Arc<dyn IndexEventSink>, outcome: FileOutcome) {
448    match outcome {
449        FileOutcome::Indexed {
450            file_index,
451            path,
452            chunks_indexed,
453            content_hash,
454            duration_ms,
455            embedder_ms,
456            tokens_estimated,
457        } => {
458            state.indexed += 1;
459            state.total_chunks += chunks_indexed;
460            sink.on_event(&IndexEvent::FileIndexed {
461                file_index,
462                path: path.display().to_string(),
463                chunks_indexed,
464                content_hash,
465                duration_ms,
466                embedder_ms,
467                tokens_estimated,
468            });
469        }
470        FileOutcome::Skipped {
471            file_index,
472            path,
473            reason,
474            content_hash,
475        } => {
476            state.skipped += 1;
477            sink.on_event(&IndexEvent::FileSkipped {
478                file_index,
479                path: path.display().to_string(),
480                reason,
481                content_hash,
482            });
483        }
484        FileOutcome::Failed {
485            file_index,
486            path,
487            error,
488        } => {
489            state.failed += 1;
490            sink.on_event(&IndexEvent::FileFailed {
491                file_index,
492                path: path.display().to_string(),
493                error,
494            });
495        }
496    }
497
498    emit_stats_tick(state, sink);
499}
500
501fn emit_stats_tick(state: &SchedulerState, sink: &Arc<dyn IndexEventSink>) {
502    sink.on_event(&IndexEvent::StatsTick {
503        processed: state.processed(),
504        indexed: state.indexed,
505        skipped: state.skipped,
506        failed: state.failed,
507        total: state.total,
508        files_per_sec: state.files_per_sec(),
509        eta_secs: state.eta_secs(),
510        total_chunks: state.total_chunks,
511        in_flight: state.in_flight(),
512    });
513}
514
515#[cfg(test)]
516mod tests {
517    use std::path::Path;
518    use std::sync::{Arc, Mutex as StdMutex};
519    use std::time::Duration;
520
521    use super::*;
522    use crate::tui::indexer::contracts::INDEX_CONTROL_CHANNEL_CAPACITY;
523
524    struct RecordingSink {
525        events: Arc<StdMutex<Vec<IndexEvent>>>,
526    }
527
528    impl RecordingSink {
529        fn new() -> Self {
530            Self {
531                events: Arc::new(StdMutex::new(Vec::new())),
532            }
533        }
534
535        fn events(&self) -> Vec<IndexEvent> {
536            self.events
537                .lock()
538                .unwrap_or_else(|poisoned| poisoned.into_inner())
539                .clone()
540        }
541    }
542
543    impl IndexEventSink for RecordingSink {
544        fn on_event(&self, event: &IndexEvent) {
545            self.events
546                .lock()
547                .unwrap_or_else(|poisoned| poisoned.into_inner())
548                .push(event.clone());
549        }
550    }
551
552    fn test_files(count: usize) -> Vec<PathBuf> {
553        (0..count)
554            .map(|index| Path::new("/tmp").join(format!("file-{index}.txt")))
555            .collect()
556    }
557
558    #[tokio::test]
559    async fn scheduler_pause_resume_blocks_new_starts_until_resumed() {
560        let sink = Arc::new(RecordingSink::new());
561        let (control_tx, control_rx) = mpsc::channel(INDEX_CONTROL_CHANNEL_CAPACITY);
562
563        let processor: FileProcessor = Arc::new(move |file_index, path, _namespace| {
564            async move {
565                tokio::time::sleep(Duration::from_millis(80)).await;
566                FileOutcome::Indexed {
567                    file_index,
568                    path,
569                    chunks_indexed: 1,
570                    content_hash: format!("hash-{file_index}"),
571                    duration_ms: 5,
572                    embedder_ms: Some(5),
573                    tokens_estimated: Some(10),
574                }
575            }
576            .boxed()
577        });
578
579        let join = tokio::spawn(run_scheduler_with_processor(
580            PathBuf::from("/tmp"),
581            test_files(10),
582            "kb:test".to_string(),
583            sink.clone(),
584            control_rx,
585            2,
586            processor,
587        ));
588
589        tokio::time::sleep(Duration::from_millis(30)).await;
590        control_tx
591            .send(IndexControl::Pause)
592            .await
593            .expect("send pause");
594        tokio::time::sleep(Duration::from_millis(30)).await;
595
596        let events_after_pause = sink.events();
597        let started_before_resume = events_after_pause
598            .iter()
599            .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
600            .count();
601        assert!(
602            events_after_pause
603                .iter()
604                .any(|event| matches!(event, IndexEvent::Paused))
605        );
606
607        tokio::time::sleep(Duration::from_millis(60)).await;
608        let events_still_paused = sink.events();
609        let started_while_paused = events_still_paused
610            .iter()
611            .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
612            .count();
613        assert_eq!(started_while_paused, started_before_resume);
614
615        control_tx
616            .send(IndexControl::Resume)
617            .await
618            .expect("send resume");
619
620        join.await
621            .expect("scheduler join")
622            .expect("scheduler result");
623
624        let final_events = sink.events();
625        assert!(
626            final_events
627                .iter()
628                .any(|event| matches!(event, IndexEvent::Resumed))
629        );
630        let final_started = final_events
631            .iter()
632            .filter(|event| matches!(event, IndexEvent::FileStarted { .. }))
633            .count();
634        assert_eq!(final_started, 10);
635    }
636
637    #[tokio::test]
638    async fn scheduler_stop_drains_inflight_and_completes_cleanly() {
639        let sink = Arc::new(RecordingSink::new());
640        let (control_tx, control_rx) = mpsc::channel(INDEX_CONTROL_CHANNEL_CAPACITY);
641
642        let processor: FileProcessor = Arc::new(move |file_index, path, _namespace| {
643            async move {
644                tokio::time::sleep(Duration::from_millis(80)).await;
645                FileOutcome::Indexed {
646                    file_index,
647                    path,
648                    chunks_indexed: 1,
649                    content_hash: format!("hash-{file_index}"),
650                    duration_ms: 5,
651                    embedder_ms: None,
652                    tokens_estimated: None,
653                }
654            }
655            .boxed()
656        });
657
658        let join = tokio::spawn(run_scheduler_with_processor(
659            PathBuf::from("/tmp"),
660            test_files(100),
661            "kb:test".to_string(),
662            sink.clone(),
663            control_rx,
664            4,
665            processor,
666        ));
667
668        tokio::time::sleep(Duration::from_millis(30)).await;
669        control_tx
670            .send(IndexControl::Stop)
671            .await
672            .expect("send stop");
673
674        join.await
675            .expect("scheduler join")
676            .expect("scheduler result");
677
678        let events = sink.events();
679        assert!(
680            events
681                .iter()
682                .any(|event| matches!(event, IndexEvent::StopRequested))
683        );
684        assert!(
685            events
686                .iter()
687                .any(|event| matches!(event, IndexEvent::RunCompleted { .. }))
688        );
689    }
690}