Skip to main content

triplets_core/source/
mod.rs

1//! Data source interfaces and paging helpers.
2//!
3//! Ownership model:
4//! - `DataSource` is the sampler-facing interface that produces batches.
5//! - `IndexableSource` exposes stable, index-based access into a corpus.
6//! - `IndexablePager` owns the deterministic pseudo-random paging logic and
7//!   can page any indexable source without retaining per-record state.
8
9use chrono::{DateTime, Utc};
10use std::hash::Hash;
11use std::time::Instant;
12
13use crate::config::{SamplerConfig, TripletRecipe};
14use crate::data::DataRecord;
15use crate::errors::SamplerError;
16use crate::hash::stable_hash_with;
17use crate::types::SourceId;
18
19/// Source implementation modules.
20pub mod backends;
21/// Utility helpers used by source implementations.
22pub mod indexing;
23pub use backends::csv_source::{CsvSource, CsvSourceConfig};
24pub use backends::file_source::{
25    FileSource, FileSourceConfig, SectionBuilder, TaxonomyBuilder, anchor_context_sections,
26    taxonomy_from_path,
27};
28
29pub use backends::in_memory_source::InMemorySource;
30
31/// Source-owned incremental refresh position.
32///
33/// The sampler stores and returns this value between refresh calls.
34/// `revision` is opaque to the sampler and interpreted only by the source.
35#[derive(Clone, Debug)]
36pub struct SourceCursor {
37    /// Most recent observation timestamp produced by the source.
38    pub last_seen: DateTime<Utc>,
39    /// Opaque paging position token used to continue incremental refresh.
40    pub revision: u64,
41}
42
43/// Result of a single source refresh call.
44///
45/// Pass the returned `cursor` back into the next refresh to continue paging.
46#[derive(Clone, Debug)]
47pub struct SourceSnapshot {
48    /// Records returned by the refresh operation.
49    pub records: Vec<DataRecord>,
50    /// Next cursor to pass into a future refresh call.
51    pub cursor: SourceCursor,
52}
53
54/// Sampler-facing data source interface.
55///
56/// Implementations may be streaming or index-backed. For a fixed dataset state
57/// and cursor, refresh output should be deterministic.
58pub trait DataSource: Send + Sync {
59    /// Stable source identifier used in records, metrics, and persistence state.
60    fn id(&self) -> &str;
61    /// Fetch up to `limit` records starting from `cursor` state.
62    ///
63    /// Return the next cursor position in `SourceSnapshot.cursor`.
64    fn refresh(
65        &self,
66        config: &SamplerConfig,
67        cursor: Option<&SourceCursor>,
68        limit: Option<usize>,
69    ) -> Result<SourceSnapshot, SamplerError>;
70
71    /// Exact metadata record count reported by the source.
72    ///
73    /// This is intended for estimators that must avoid iterating records.
74    /// Implementations should return `Ok(count)` only when the count is
75    /// exact for the source scope. Return `Err` when exact counting is not
76    /// possible or the source is unavailable.
77    ///
78    /// Keep this consistent with `refresh` by using the same backend scope,
79    /// filtering, and logical corpus definition.
80    fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError>;
81
82    /// Optional source-provided default triplet recipes.
83    ///
84    /// Used when sampler config does not provide explicit recipes.
85    fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
86        Vec::new()
87    }
88}
89
90/// Index-addressable source interface used by deterministic pagers.
91///
92/// `len_hint` must be stable within an epoch, and `record_at` must return the
93/// record corresponding to the same index across runs.
94///
95/// Dense indexing is strongly recommended: implement indices as `0..len_hint`
96/// with minimal gaps. Sparse indexes (returning `None` for many positions)
97/// still work but waste paging capacity and reduce batch fill rates.
98pub trait IndexableSource: Send + Sync {
99    /// Stable source identifier.
100    fn id(&self) -> &str;
101    /// Current index domain size, typically `Some(total_records)`.
102    fn len_hint(&self) -> Option<usize>;
103    /// Return the record at index `idx`, or `None` for sparse/missing positions.
104    fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError>;
105}
106
107/// Deterministic pager for `IndexableSource`.
108///
109/// Encapsulates shuffle seed and cursor math so callers can reuse a stable
110/// paging algorithm without implementing permutation logic themselves.
111pub struct IndexablePager {
112    source_id: SourceId,
113}
114
115impl IndexablePager {
116    /// Create a new deterministic pager for `source_id`.
117    pub fn new(source_id: impl Into<SourceId>) -> Self {
118        Self {
119            source_id: source_id.into(),
120        }
121    }
122
123    /// Page records from an `IndexableSource` using the provided cursor.
124    pub fn refresh(
125        &self,
126        source: &dyn IndexableSource,
127        cursor: Option<&SourceCursor>,
128        limit: Option<usize>,
129    ) -> Result<SourceSnapshot, SamplerError> {
130        let total = source
131            .len_hint()
132            .ok_or_else(|| SamplerError::SourceInconsistent {
133                source_id: source.id().to_string(),
134                details: "indexable source did not provide len_hint".into(),
135            })?;
136        self.refresh_with(total, cursor, limit, |idx| source.record_at(idx))
137    }
138
139    /// Page records using a custom index fetcher.
140    ///
141    /// Useful when records are indexable but not exposed through `IndexableSource`
142    /// (for example, temporary index stores or precomputed path lists).
143    ///
144    /// The fetcher is called concurrently using rayon. It must be `Fn + Send + Sync`
145    /// (not merely `FnMut`). All callers that pass a closure over a shared
146    /// `&IndexableSource` satisfy this because `record_at` takes `&self`.
147    pub fn refresh_with<F>(
148        &self,
149        total: usize,
150        cursor: Option<&SourceCursor>,
151        limit: Option<usize>,
152        fetch: F,
153    ) -> Result<SourceSnapshot, SamplerError>
154    where
155        F: Fn(usize) -> Result<Option<DataRecord>, SamplerError> + Send + Sync,
156    {
157        if total == 0 {
158            return Ok(SourceSnapshot {
159                records: Vec::new(),
160                cursor: SourceCursor {
161                    last_seen: Utc::now(),
162                    revision: 0,
163                },
164            });
165        }
166        let mut start = cursor.map(|cursor| cursor.revision as usize).unwrap_or(0);
167        if start >= total {
168            start = 0;
169        }
170        let max = limit.unwrap_or(total);
171        let seed = Self::seed_for(&self.source_id, total);
172
173        // Pre-generate the full permuted index sequence with per-position cursor
174        // values. Pure integer arithmetic — negligible cost vs. record fetch.
175        let mut permutation = IndexPermutation::new(total, seed, start as u64);
176        let seq: Vec<(usize, usize)> = (0..total)
177            .map(|_| {
178                let idx = permutation.next();
179                (idx, permutation.cursor())
180            })
181            .collect();
182
183        let should_report = total >= 10_000 || max >= 1_024;
184        let refresh_start = Instant::now();
185        if should_report {
186            eprintln!(
187                "[triplets:source] refresh start source='{}' source_records={} ingestion_limit={}",
188                self.source_id, total, max
189            );
190        }
191
192        use rayon::prelude::*;
193        // The permutation `seq` covers all `total` positions.  We try up to
194        // `max` positions in parallel first (rayon's global pool keeps
195        // in-flight requests at `num_cpus`).  For dense sources (every index
196        // returns a record) this single parallel batch fills the quota and
197        // the rest of the loop body below is a no-op.
198        //
199        // If the source is sparse (some positions return None), the parallel
200        // batch may come up short.  We then walk the remaining positions
201        // sequentially to find enough records.  In practice this fallback
202        // rarely runs and only for a handful of positions.
203        let par_end = max.min(total);
204        let results: Vec<Result<Option<DataRecord>, SamplerError>> = seq[..par_end]
205            .par_iter()
206            .map(|&(idx, _)| fetch(idx))
207            .collect();
208        let mut records = Vec::with_capacity(max.min(total));
209        let mut final_cursor = start;
210        for (result, &(_, cursor_after)) in results.into_iter().zip(seq[..par_end].iter()) {
211            if records.len() >= max {
212                break;
213            }
214            if let Some(r) = result? {
215                records.push(r)
216            }
217            final_cursor = cursor_after;
218        }
219        // Sparse-source fallback: sequential walk through remaining positions.
220        for &(idx, cursor_after) in &seq[par_end..] {
221            if records.len() >= max {
222                break;
223            }
224            if let Some(r) = fetch(idx)? {
225                records.push(r);
226            }
227            final_cursor = cursor_after;
228        }
229
230        if should_report {
231            eprintln!(
232                "[triplets:source] refresh done source='{}' source_records={} ingested={} elapsed={:.2}s",
233                self.source_id,
234                total,
235                records.len(),
236                refresh_start.elapsed().as_secs_f64()
237            );
238        }
239        let last_seen = records
240            .iter()
241            .map(|record| record.updated_at)
242            .max()
243            .unwrap_or_else(Utc::now);
244        Ok(SourceSnapshot {
245            records,
246            cursor: SourceCursor {
247                last_seen,
248                revision: final_cursor as u64,
249            },
250        })
251    }
252
253    /// Build a deterministic seed for a source and total size.
254    pub(crate) fn seed_for(source_id: &SourceId, total: usize) -> u64 {
255        Self::stable_index_shuffle_key(source_id, 0)
256            ^ Self::stable_index_shuffle_key(source_id, total)
257    }
258
259    /// Build a deterministic seed for a source/total pair with explicit sampler seed.
260    pub fn seed_for_sampler(source_id: &SourceId, total: usize, sampler_seed: u64) -> u64 {
261        Self::seed_for(source_id, total)
262            ^ stable_hash_with(|hasher| {
263                "triplets_sampler_seed".hash(hasher);
264                source_id.hash(hasher);
265                total.hash(hasher);
266                sampler_seed.hash(hasher);
267            })
268    }
269
270    fn stable_index_shuffle_key(source_id: &SourceId, idx: usize) -> u64 {
271        stable_hash_with(|hasher| {
272            source_id.hash(hasher);
273            idx.hash(hasher);
274        })
275    }
276}
277
278/// DataSource adapter that pages an `IndexableSource` via `IndexablePager`.
279pub struct IndexableAdapter<T: IndexableSource> {
280    inner: T,
281}
282
283impl<T: IndexableSource> IndexableAdapter<T> {
284    /// Wrap an `IndexableSource` so it can be registered as a `DataSource`.
285    pub fn new(inner: T) -> Self {
286        Self { inner }
287    }
288}
289
290impl<T: IndexableSource> DataSource for IndexableAdapter<T> {
291    fn id(&self) -> &str {
292        self.inner.id()
293    }
294
295    fn refresh(
296        &self,
297        _config: &SamplerConfig,
298        cursor: Option<&SourceCursor>,
299        limit: Option<usize>,
300    ) -> Result<SourceSnapshot, SamplerError> {
301        let pager = IndexablePager::new(self.inner.id());
302        pager.refresh(&self.inner, cursor, limit)
303    }
304
305    fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
306        self.inner
307            .len_hint()
308            .map(|value| value as u128)
309            .ok_or_else(|| SamplerError::SourceInconsistent {
310                source_id: self.inner.id().to_string(),
311                details: "indexable source did not provide len_hint".into(),
312            })
313    }
314}
315
316/// Internal permutation used by `IndexablePager`.
317pub struct IndexPermutation {
318    total: u64,
319    domain_bits: u32,
320    domain_size: u64,
321    seed: u64,
322    counter: u64,
323}
324
325impl IndexPermutation {
326    /// Creates a new deterministic permutation over `[0, total)`.
327    pub fn new(total: usize, seed: u64, counter: u64) -> Self {
328        let total_u64 = total as u64;
329        let domain_bits = (64 - (total_u64 - 1).leading_zeros()).max(1);
330        let domain_size = 1u64 << domain_bits;
331        Self {
332            total: total_u64,
333            domain_bits,
334            domain_size,
335            seed,
336            counter,
337        }
338    }
339
340    /// Returns the next permuted index, staying within `[0, total)`.
341    ///
342    /// Each call advances the internal counter and returns a deterministic
343    /// pseudo-random index that is guaranteed to be less than `total`.
344    #[allow(clippy::should_implement_trait)]
345    pub fn next(&mut self) -> usize {
346        loop {
347            let v =
348                Self::permute_bits(self.counter % self.domain_size, self.domain_bits, self.seed);
349            self.counter = self.counter.wrapping_add(1);
350            if v < self.total {
351                return v as usize;
352            }
353        }
354    }
355
356    /// Current consumption cursor position.
357    pub fn cursor(&self) -> usize {
358        (self.counter as usize) % (self.total as usize)
359    }
360    fn permute_bits(value: u64, bits: u32, seed: u64) -> u64 {
361        if bits == 0 {
362            return 0;
363        }
364        let mask = if bits == 64 {
365            u64::MAX
366        } else {
367            (1u64 << bits) - 1
368        };
369        let mut a = (seed | 1) & mask;
370        if a == 0 {
371            a = 1;
372        }
373        let b = (seed >> 1) & mask;
374        a.wrapping_mul(value).wrapping_add(b) & mask
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::data::{QualityScore, RecordSection, SectionRole};
382    use crate::types::RecordId;
383    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
384    use std::thread;
385    use std::time::Duration as StdDuration;
386
387    /// Minimal `IndexableSource` test fixture.
388    struct IndexableStub {
389        id: SourceId,
390        count: usize,
391    }
392
393    struct NoLenHintStub {
394        id: SourceId,
395    }
396
397    impl IndexableStub {
398        fn new(id: &str, count: usize) -> Self {
399            Self {
400                id: id.to_string(),
401                count,
402            }
403        }
404    }
405
406    impl NoLenHintStub {
407        fn new(id: &str) -> Self {
408            Self { id: id.to_string() }
409        }
410    }
411
412    impl IndexableSource for IndexableStub {
413        fn id(&self) -> &str {
414            &self.id
415        }
416
417        fn len_hint(&self) -> Option<usize> {
418            Some(self.count)
419        }
420
421        fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError> {
422            if idx >= self.count {
423                return Ok(None);
424            }
425            let now = Utc::now();
426            Ok(Some(DataRecord {
427                id: format!("record_{idx}"),
428                source: self.id.clone(),
429                created_at: now,
430                updated_at: now,
431                quality: QualityScore { trust: 1.0 },
432                taxonomy: Vec::new(),
433                sections: vec![RecordSection {
434                    role: SectionRole::Anchor,
435                    heading: None,
436                    text: "stub".into(),
437                    sentences: vec!["stub".into()],
438                }],
439                meta_prefix: None,
440            }))
441        }
442    }
443
444    impl IndexableSource for NoLenHintStub {
445        fn id(&self) -> &str {
446            &self.id
447        }
448
449        fn len_hint(&self) -> Option<usize> {
450            None
451        }
452
453        fn record_at(&self, _idx: usize) -> Result<Option<DataRecord>, SamplerError> {
454            Ok(None)
455        }
456    }
457
458    #[test]
459    fn indexable_adapter_pages_in_stable_order() {
460        let adapter = IndexableAdapter::new(IndexableStub::new("stub", 6));
461        let config = SamplerConfig::default();
462        let full = adapter.refresh(&config, None, None).unwrap();
463        let full_ids: Vec<RecordId> = full.records.into_iter().map(|r| r.id).collect();
464
465        let mut cursor = None;
466        let mut paged = Vec::new();
467        for _ in 0..3 {
468            let snapshot = adapter.refresh(&config, cursor.as_ref(), Some(3)).unwrap();
469            cursor = Some(snapshot.cursor);
470            paged.extend(snapshot.records.into_iter().map(|r| r.id));
471            if paged.len() >= full_ids.len() {
472                break;
473            }
474        }
475        assert_eq!(paged, full_ids);
476    }
477
478    #[test]
479    fn indexable_paging_spans_multiple_regimes() {
480        // Use a source id whose permutation step is not 1 or -1 mod 2^k,
481        // otherwise the sequence would be a simple rotation/reversal.
482        let total = 256usize;
483        let mask = (1u64 << (64 - (total as u64 - 1).leading_zeros())) - 1;
484        let source_id = (0..512)
485            .map(|idx| format!("regime_test_{idx}"))
486            .find(|id| {
487                let seed = IndexablePager::seed_for(id, total);
488                let a = (seed | 1) & mask;
489                a != 1 && a != mask
490            })
491            .unwrap();
492
493        // Pull a single page and ensure the indices are spread across the space,
494        // which indicates the permutation isn't stuck in a narrow regime.
495        let adapter = IndexableAdapter::new(IndexableStub::new(&source_id, total));
496        let snapshot = adapter
497            .refresh(&SamplerConfig::default(), None, Some(64))
498            .unwrap();
499        let indices: Vec<usize> = snapshot
500            .records
501            .into_iter()
502            .map(|r| {
503                r.id.strip_prefix("record_")
504                    .unwrap()
505                    .parse::<usize>()
506                    .unwrap()
507            })
508            .collect();
509        let min_idx = *indices.iter().min().unwrap();
510        let max_idx = *indices.iter().max().unwrap();
511        assert!(
512            max_idx - min_idx >= total / 2,
513            "expected spread across the index space, got min={min_idx} max={max_idx}"
514        );
515    }
516
517    #[test]
518    fn indexable_pager_errors_when_len_hint_missing() {
519        let pager = IndexablePager::new("no_len_hint");
520        let source = NoLenHintStub::new("no_len_hint");
521        let result = pager.refresh(&source, None, Some(3));
522        assert!(result.is_err());
523    }
524
525    #[test]
526    fn indexable_adapter_reported_count_errors_when_len_hint_missing() {
527        let adapter = IndexableAdapter::new(NoLenHintStub::new("no_len_hint"));
528        let result = adapter.reported_record_count(&SamplerConfig::default());
529        assert!(result.is_err());
530    }
531
532    #[test]
533    fn indexable_pager_refresh_with_zero_total_returns_empty_snapshot() {
534        let pager = IndexablePager::new("empty");
535        let snapshot = pager
536            .refresh_with(0, None, Some(4), |_idx| Ok(None))
537            .unwrap();
538        assert!(snapshot.records.is_empty());
539        assert_eq!(snapshot.cursor.revision, 0);
540    }
541
542    #[test]
543    fn index_permutation_permute_bits_handles_zero_bits_and_zero_seed_path() {
544        assert_eq!(IndexPermutation::permute_bits(123, 0, 99), 0);
545
546        let bits = 1;
547        let value = 1;
548        let out = IndexPermutation::permute_bits(value, bits, 0);
549        assert!(out <= 1);
550    }
551
552    #[test]
553    fn index_permutation_next_stays_within_total_and_cursor_advances() {
554        let mut perm = IndexPermutation::new(3, 7, 0);
555        let mut seen = Vec::new();
556        for _ in 0..8 {
557            seen.push(perm.next());
558        }
559        assert!(seen.iter().all(|idx| *idx < 3));
560        assert!(perm.cursor() < 3);
561    }
562
563    #[test]
564    fn indexable_pager_large_refresh_triggers_reporting_branch_and_wraps_cursor() {
565        let pager = IndexablePager::new("reporting");
566        let cursor = SourceCursor {
567            last_seen: Utc::now(),
568            revision: 20_000,
569        };
570        let snapshot = pager
571            .refresh_with(10_000, Some(&cursor), Some(4), |idx| {
572                Ok(Some(DataRecord {
573                    id: format!("record_{idx}"),
574                    source: "reporting".to_string(),
575                    created_at: Utc::now(),
576                    updated_at: Utc::now(),
577                    quality: QualityScore { trust: 1.0 },
578                    taxonomy: Vec::new(),
579                    sections: vec![RecordSection {
580                        role: SectionRole::Anchor,
581                        heading: None,
582                        text: "t".to_string(),
583                        sentences: vec!["t".to_string()],
584                    }],
585                    meta_prefix: None,
586                }))
587            })
588            .unwrap();
589
590        assert_eq!(snapshot.records.len(), 4);
591        assert!(snapshot.cursor.revision < 10_000);
592    }
593
594    #[test]
595    fn indexable_pager_reporting_branch_emits_progress_when_refresh_is_slow() {
596        let pager = IndexablePager::new("slow_reporting");
597        let slept = AtomicBool::new(false);
598        let snapshot = pager
599            .refresh_with(2_000, None, Some(1_024), |_idx| {
600                if !slept.swap(true, Ordering::Relaxed) {
601                    thread::sleep(StdDuration::from_millis(800));
602                }
603                Ok(None)
604            })
605            .unwrap();
606
607        assert!(snapshot.records.is_empty());
608        assert!(snapshot.cursor.revision < 2_000);
609    }
610
611    #[test]
612    fn source_ids_and_reported_counts_are_exposed() {
613        let adapter = IndexableAdapter::new(IndexableStub::new("stub_id", 3));
614        assert_eq!(adapter.id(), "stub_id");
615        assert_eq!(
616            adapter
617                .reported_record_count(&SamplerConfig::default())
618                .unwrap(),
619            3
620        );
621    }
622
623    #[test]
624    fn indexable_pager_sequential_fallback_fills_quota_when_parallel_pass_yields_none() {
625        // Exercise the seq[par_end..] fallback loop: parallel pass entries all
626        // return None, so the sequential sweep has to supply the records.
627        // total=8, limit=4 -> par_end=4. First 4 calls (parallel) get None;
628        // next 4 calls (sequential fallback) get Some, filling the quota.
629        let pager = IndexablePager::new("fallback_fill");
630        let call_count = AtomicUsize::new(0);
631        let par_end = 4usize;
632        let snapshot = pager
633            .refresh_with(8, None, Some(par_end), |idx| {
634                let n = call_count.fetch_add(1, Ordering::Relaxed);
635                if n < par_end {
636                    Ok(None)
637                } else {
638                    Ok(Some(DataRecord {
639                        id: format!("r_{idx}"),
640                        source: "fallback_fill".to_string(),
641                        created_at: Utc::now(),
642                        updated_at: Utc::now(),
643                        quality: QualityScore { trust: 1.0 },
644                        taxonomy: Vec::new(),
645                        sections: vec![RecordSection {
646                            role: SectionRole::Anchor,
647                            heading: None,
648                            text: "t".to_string(),
649                            sentences: vec!["t".to_string()],
650                        }],
651                        meta_prefix: None,
652                    }))
653                }
654            })
655            .unwrap();
656        assert_eq!(snapshot.records.len(), par_end);
657    }
658
659    #[test]
660    fn indexable_pager_refresh_with_propagates_fetch_error() {
661        let pager = IndexablePager::new("err");
662        let err = pager
663            .refresh_with(8, None, Some(2), |_idx| {
664                Err(SamplerError::SourceUnavailable {
665                    source_id: "err".to_string(),
666                    reason: "fetch failed".to_string(),
667                })
668            })
669            .unwrap_err();
670        assert!(matches!(
671            err,
672            SamplerError::SourceUnavailable { ref reason, .. } if reason.contains("fetch failed")
673        ));
674    }
675
676    #[test]
677    fn seed_for_sampler_depends_on_sampler_seed() {
678        let source_id = "seeded".to_string();
679        let base = IndexablePager::seed_for(&source_id, 17);
680        let with_a = IndexablePager::seed_for_sampler(&source_id, 17, 1);
681        let with_b = IndexablePager::seed_for_sampler(&source_id, 17, 2);
682        assert_ne!(with_a, with_b);
683        assert_ne!(with_a, base);
684    }
685}