1use chrono::{DateTime, Utc};
10use std::hash::Hash;
11use std::time::Instant;
12
13use crate::config::{SamplerConfig, TripletRecipe};
14use crate::data::DataRecord;
15use crate::errors::SamplerError;
16use crate::hash::stable_hash_with;
17use crate::types::SourceId;
18
19pub mod backends;
21pub mod indexing;
23pub use backends::csv_source::{CsvSource, CsvSourceConfig};
24pub use backends::file_source::{
25 FileSource, FileSourceConfig, SectionBuilder, TaxonomyBuilder, anchor_context_sections,
26 taxonomy_from_path,
27};
28
29pub use backends::in_memory_source::InMemorySource;
30
31#[derive(Clone, Debug)]
36pub struct SourceCursor {
37 pub last_seen: DateTime<Utc>,
39 pub revision: u64,
41}
42
43#[derive(Clone, Debug)]
47pub struct SourceSnapshot {
48 pub records: Vec<DataRecord>,
50 pub cursor: SourceCursor,
52}
53
54pub trait DataSource: Send + Sync {
59 fn id(&self) -> &str;
61 fn refresh(
65 &self,
66 config: &SamplerConfig,
67 cursor: Option<&SourceCursor>,
68 limit: Option<usize>,
69 ) -> Result<SourceSnapshot, SamplerError>;
70
71 fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError>;
81
82 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
86 Vec::new()
87 }
88}
89
90pub trait IndexableSource: Send + Sync {
99 fn id(&self) -> &str;
101 fn len_hint(&self) -> Option<usize>;
103 fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError>;
105}
106
107pub struct IndexablePager {
112 source_id: SourceId,
113}
114
115impl IndexablePager {
116 pub fn new(source_id: impl Into<SourceId>) -> Self {
118 Self {
119 source_id: source_id.into(),
120 }
121 }
122
123 pub fn refresh(
125 &self,
126 source: &dyn IndexableSource,
127 cursor: Option<&SourceCursor>,
128 limit: Option<usize>,
129 ) -> Result<SourceSnapshot, SamplerError> {
130 let total = source
131 .len_hint()
132 .ok_or_else(|| SamplerError::SourceInconsistent {
133 source_id: source.id().to_string(),
134 details: "indexable source did not provide len_hint".into(),
135 })?;
136 self.refresh_with(total, cursor, limit, |idx| source.record_at(idx))
137 }
138
139 pub fn refresh_with<F>(
148 &self,
149 total: usize,
150 cursor: Option<&SourceCursor>,
151 limit: Option<usize>,
152 fetch: F,
153 ) -> Result<SourceSnapshot, SamplerError>
154 where
155 F: Fn(usize) -> Result<Option<DataRecord>, SamplerError> + Send + Sync,
156 {
157 if total == 0 {
158 return Ok(SourceSnapshot {
159 records: Vec::new(),
160 cursor: SourceCursor {
161 last_seen: Utc::now(),
162 revision: 0,
163 },
164 });
165 }
166 let mut start = cursor.map(|cursor| cursor.revision as usize).unwrap_or(0);
167 if start >= total {
168 start = 0;
169 }
170 let max = limit.unwrap_or(total);
171 let seed = Self::seed_for(&self.source_id, total);
172
173 let mut permutation = IndexPermutation::new(total, seed, start as u64);
176 let seq: Vec<(usize, usize)> = (0..total)
177 .map(|_| {
178 let idx = permutation.next();
179 (idx, permutation.cursor())
180 })
181 .collect();
182
183 let should_report = total >= 10_000 || max >= 1_024;
184 let refresh_start = Instant::now();
185 if should_report {
186 eprintln!(
187 "[triplets:source] refresh start source='{}' source_records={} ingestion_limit={}",
188 self.source_id, total, max
189 );
190 }
191
192 use rayon::prelude::*;
193 let par_end = max.min(total);
204 let results: Vec<Result<Option<DataRecord>, SamplerError>> = seq[..par_end]
205 .par_iter()
206 .map(|&(idx, _)| fetch(idx))
207 .collect();
208 let mut records = Vec::with_capacity(max.min(total));
209 let mut final_cursor = start;
210 for (result, &(_, cursor_after)) in results.into_iter().zip(seq[..par_end].iter()) {
211 if records.len() >= max {
212 break;
213 }
214 if let Some(r) = result? {
215 records.push(r)
216 }
217 final_cursor = cursor_after;
218 }
219 for &(idx, cursor_after) in &seq[par_end..] {
221 if records.len() >= max {
222 break;
223 }
224 if let Some(r) = fetch(idx)? {
225 records.push(r);
226 }
227 final_cursor = cursor_after;
228 }
229
230 if should_report {
231 eprintln!(
232 "[triplets:source] refresh done source='{}' source_records={} ingested={} elapsed={:.2}s",
233 self.source_id,
234 total,
235 records.len(),
236 refresh_start.elapsed().as_secs_f64()
237 );
238 }
239 let last_seen = records
240 .iter()
241 .map(|record| record.updated_at)
242 .max()
243 .unwrap_or_else(Utc::now);
244 Ok(SourceSnapshot {
245 records,
246 cursor: SourceCursor {
247 last_seen,
248 revision: final_cursor as u64,
249 },
250 })
251 }
252
253 pub(crate) fn seed_for(source_id: &SourceId, total: usize) -> u64 {
255 Self::stable_index_shuffle_key(source_id, 0)
256 ^ Self::stable_index_shuffle_key(source_id, total)
257 }
258
259 pub fn seed_for_sampler(source_id: &SourceId, total: usize, sampler_seed: u64) -> u64 {
261 Self::seed_for(source_id, total)
262 ^ stable_hash_with(|hasher| {
263 "triplets_sampler_seed".hash(hasher);
264 source_id.hash(hasher);
265 total.hash(hasher);
266 sampler_seed.hash(hasher);
267 })
268 }
269
270 fn stable_index_shuffle_key(source_id: &SourceId, idx: usize) -> u64 {
271 stable_hash_with(|hasher| {
272 source_id.hash(hasher);
273 idx.hash(hasher);
274 })
275 }
276}
277
278pub struct IndexableAdapter<T: IndexableSource> {
280 inner: T,
281}
282
283impl<T: IndexableSource> IndexableAdapter<T> {
284 pub fn new(inner: T) -> Self {
286 Self { inner }
287 }
288}
289
290impl<T: IndexableSource> DataSource for IndexableAdapter<T> {
291 fn id(&self) -> &str {
292 self.inner.id()
293 }
294
295 fn refresh(
296 &self,
297 _config: &SamplerConfig,
298 cursor: Option<&SourceCursor>,
299 limit: Option<usize>,
300 ) -> Result<SourceSnapshot, SamplerError> {
301 let pager = IndexablePager::new(self.inner.id());
302 pager.refresh(&self.inner, cursor, limit)
303 }
304
305 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
306 self.inner
307 .len_hint()
308 .map(|value| value as u128)
309 .ok_or_else(|| SamplerError::SourceInconsistent {
310 source_id: self.inner.id().to_string(),
311 details: "indexable source did not provide len_hint".into(),
312 })
313 }
314}
315
316pub struct IndexPermutation {
318 total: u64,
319 domain_bits: u32,
320 domain_size: u64,
321 seed: u64,
322 counter: u64,
323}
324
325impl IndexPermutation {
326 pub fn new(total: usize, seed: u64, counter: u64) -> Self {
328 let total_u64 = total as u64;
329 let domain_bits = (64 - (total_u64 - 1).leading_zeros()).max(1);
330 let domain_size = 1u64 << domain_bits;
331 Self {
332 total: total_u64,
333 domain_bits,
334 domain_size,
335 seed,
336 counter,
337 }
338 }
339
340 #[allow(clippy::should_implement_trait)]
345 pub fn next(&mut self) -> usize {
346 loop {
347 let v =
348 Self::permute_bits(self.counter % self.domain_size, self.domain_bits, self.seed);
349 self.counter = self.counter.wrapping_add(1);
350 if v < self.total {
351 return v as usize;
352 }
353 }
354 }
355
356 pub fn cursor(&self) -> usize {
358 (self.counter as usize) % (self.total as usize)
359 }
360 fn permute_bits(value: u64, bits: u32, seed: u64) -> u64 {
361 if bits == 0 {
362 return 0;
363 }
364 let mask = if bits == 64 {
365 u64::MAX
366 } else {
367 (1u64 << bits) - 1
368 };
369 let mut a = (seed | 1) & mask;
370 if a == 0 {
371 a = 1;
372 }
373 let b = (seed >> 1) & mask;
374 a.wrapping_mul(value).wrapping_add(b) & mask
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::data::{QualityScore, RecordSection, SectionRole};
382 use crate::types::RecordId;
383 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
384 use std::thread;
385 use std::time::Duration as StdDuration;
386
387 struct IndexableStub {
389 id: SourceId,
390 count: usize,
391 }
392
393 struct NoLenHintStub {
394 id: SourceId,
395 }
396
397 impl IndexableStub {
398 fn new(id: &str, count: usize) -> Self {
399 Self {
400 id: id.to_string(),
401 count,
402 }
403 }
404 }
405
406 impl NoLenHintStub {
407 fn new(id: &str) -> Self {
408 Self { id: id.to_string() }
409 }
410 }
411
412 impl IndexableSource for IndexableStub {
413 fn id(&self) -> &str {
414 &self.id
415 }
416
417 fn len_hint(&self) -> Option<usize> {
418 Some(self.count)
419 }
420
421 fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError> {
422 if idx >= self.count {
423 return Ok(None);
424 }
425 let now = Utc::now();
426 Ok(Some(DataRecord {
427 id: format!("record_{idx}"),
428 source: self.id.clone(),
429 created_at: now,
430 updated_at: now,
431 quality: QualityScore { trust: 1.0 },
432 taxonomy: Vec::new(),
433 sections: vec![RecordSection {
434 role: SectionRole::Anchor,
435 heading: None,
436 text: "stub".into(),
437 sentences: vec!["stub".into()],
438 }],
439 meta_prefix: None,
440 }))
441 }
442 }
443
444 impl IndexableSource for NoLenHintStub {
445 fn id(&self) -> &str {
446 &self.id
447 }
448
449 fn len_hint(&self) -> Option<usize> {
450 None
451 }
452
453 fn record_at(&self, _idx: usize) -> Result<Option<DataRecord>, SamplerError> {
454 Ok(None)
455 }
456 }
457
458 #[test]
459 fn indexable_adapter_pages_in_stable_order() {
460 let adapter = IndexableAdapter::new(IndexableStub::new("stub", 6));
461 let config = SamplerConfig::default();
462 let full = adapter.refresh(&config, None, None).unwrap();
463 let full_ids: Vec<RecordId> = full.records.into_iter().map(|r| r.id).collect();
464
465 let mut cursor = None;
466 let mut paged = Vec::new();
467 for _ in 0..3 {
468 let snapshot = adapter.refresh(&config, cursor.as_ref(), Some(2)).unwrap();
469 cursor = Some(snapshot.cursor);
470 paged.extend(snapshot.records.into_iter().map(|r| r.id));
471 }
472 assert_eq!(paged, full_ids);
473 }
474
475 #[test]
476 fn indexable_paging_spans_multiple_regimes() {
477 let total = 256usize;
480 let mask = (1u64 << (64 - (total as u64 - 1).leading_zeros())) - 1;
481 let source_id = (0..512)
482 .map(|idx| format!("regime_test_{idx}"))
483 .find(|id| {
484 let seed = IndexablePager::seed_for(id, total);
485 let a = (seed | 1) & mask;
486 a != 1 && a != mask
487 })
488 .unwrap();
489
490 let adapter = IndexableAdapter::new(IndexableStub::new(&source_id, total));
493 let snapshot = adapter
494 .refresh(&SamplerConfig::default(), None, Some(64))
495 .unwrap();
496 let indices: Vec<usize> = snapshot
497 .records
498 .into_iter()
499 .map(|r| {
500 r.id.strip_prefix("record_")
501 .unwrap()
502 .parse::<usize>()
503 .unwrap()
504 })
505 .collect();
506 let min_idx = *indices.iter().min().unwrap();
507 let max_idx = *indices.iter().max().unwrap();
508 assert!(
509 max_idx - min_idx >= total / 2,
510 "expected spread across the index space, got min={min_idx} max={max_idx}"
511 );
512 }
513
514 #[test]
515 fn indexable_pager_errors_when_len_hint_missing() {
516 let pager = IndexablePager::new("no_len_hint");
517 let source = NoLenHintStub::new("no_len_hint");
518 let result = pager.refresh(&source, None, Some(3));
519 assert!(result.is_err());
520 }
521
522 #[test]
523 fn indexable_adapter_reported_count_errors_when_len_hint_missing() {
524 let adapter = IndexableAdapter::new(NoLenHintStub::new("no_len_hint"));
525 let result = adapter.reported_record_count(&SamplerConfig::default());
526 assert!(result.is_err());
527 }
528
529 #[test]
530 fn indexable_pager_refresh_with_zero_total_returns_empty_snapshot() {
531 let pager = IndexablePager::new("empty");
532 let snapshot = pager
533 .refresh_with(0, None, Some(4), |_idx| Ok(None))
534 .unwrap();
535 assert!(snapshot.records.is_empty());
536 assert_eq!(snapshot.cursor.revision, 0);
537 }
538
539 #[test]
540 fn index_permutation_permute_bits_handles_zero_bits_and_zero_seed_path() {
541 assert_eq!(IndexPermutation::permute_bits(123, 0, 99), 0);
542
543 let bits = 1;
544 let value = 1;
545 let out = IndexPermutation::permute_bits(value, bits, 0);
546 assert!(out <= 1);
547 }
548
549 #[test]
550 fn index_permutation_next_stays_within_total_and_cursor_advances() {
551 let mut perm = IndexPermutation::new(3, 7, 0);
552 let mut seen = Vec::new();
553 for _ in 0..8 {
554 seen.push(perm.next());
555 }
556 assert!(seen.iter().all(|idx| *idx < 3));
557 assert!(perm.cursor() < 3);
558 }
559
560 #[test]
561 fn indexable_pager_large_refresh_triggers_reporting_branch_and_wraps_cursor() {
562 let pager = IndexablePager::new("reporting");
563 let cursor = SourceCursor {
564 last_seen: Utc::now(),
565 revision: 20_000,
566 };
567 let snapshot = pager
568 .refresh_with(10_000, Some(&cursor), Some(4), |idx| {
569 Ok(Some(DataRecord {
570 id: format!("record_{idx}"),
571 source: "reporting".to_string(),
572 created_at: Utc::now(),
573 updated_at: Utc::now(),
574 quality: QualityScore { trust: 1.0 },
575 taxonomy: Vec::new(),
576 sections: vec![RecordSection {
577 role: SectionRole::Anchor,
578 heading: None,
579 text: "t".to_string(),
580 sentences: vec!["t".to_string()],
581 }],
582 meta_prefix: None,
583 }))
584 })
585 .unwrap();
586
587 assert_eq!(snapshot.records.len(), 4);
588 assert!(snapshot.cursor.revision < 10_000);
589 }
590
591 #[test]
592 fn indexable_pager_reporting_branch_emits_progress_when_refresh_is_slow() {
593 let pager = IndexablePager::new("slow_reporting");
594 let slept = AtomicBool::new(false);
595 let snapshot = pager
596 .refresh_with(2_000, None, Some(1_024), |_idx| {
597 if !slept.swap(true, Ordering::Relaxed) {
598 thread::sleep(StdDuration::from_millis(800));
599 }
600 Ok(None)
601 })
602 .unwrap();
603
604 assert!(snapshot.records.is_empty());
605 assert!(snapshot.cursor.revision < 2_000);
606 }
607
608 #[test]
609 fn source_ids_and_reported_counts_are_exposed() {
610 let adapter = IndexableAdapter::new(IndexableStub::new("stub_id", 3));
611 assert_eq!(adapter.id(), "stub_id");
612 assert_eq!(
613 adapter
614 .reported_record_count(&SamplerConfig::default())
615 .unwrap(),
616 3
617 );
618 }
619
620 #[test]
621 fn indexable_pager_sequential_fallback_fills_quota_when_parallel_pass_yields_none() {
622 let pager = IndexablePager::new("fallback_fill");
627 let call_count = AtomicUsize::new(0);
628 let par_end = 4usize;
629 let snapshot = pager
630 .refresh_with(8, None, Some(par_end), |idx| {
631 let n = call_count.fetch_add(1, Ordering::Relaxed);
632 if n < par_end {
633 Ok(None)
634 } else {
635 Ok(Some(DataRecord {
636 id: format!("r_{idx}"),
637 source: "fallback_fill".to_string(),
638 created_at: Utc::now(),
639 updated_at: Utc::now(),
640 quality: QualityScore { trust: 1.0 },
641 taxonomy: Vec::new(),
642 sections: vec![RecordSection {
643 role: SectionRole::Anchor,
644 heading: None,
645 text: "t".to_string(),
646 sentences: vec!["t".to_string()],
647 }],
648 meta_prefix: None,
649 }))
650 }
651 })
652 .unwrap();
653 assert_eq!(snapshot.records.len(), par_end);
654 }
655
656 #[test]
657 fn indexable_pager_refresh_with_propagates_fetch_error() {
658 let pager = IndexablePager::new("err");
659 let err = pager
660 .refresh_with(8, None, Some(2), |_idx| {
661 Err(SamplerError::SourceUnavailable {
662 source_id: "err".to_string(),
663 reason: "fetch failed".to_string(),
664 })
665 })
666 .unwrap_err();
667 assert!(matches!(
668 err,
669 SamplerError::SourceUnavailable { ref reason, .. } if reason.contains("fetch failed")
670 ));
671 }
672
673 #[test]
674 fn seed_for_sampler_depends_on_sampler_seed() {
675 let source_id = "seeded".to_string();
676 let base = IndexablePager::seed_for(&source_id, 17);
677 let with_a = IndexablePager::seed_for_sampler(&source_id, 17, 1);
678 let with_b = IndexablePager::seed_for_sampler(&source_id, 17, 2);
679 assert_ne!(with_a, with_b);
680 assert_ne!(with_a, base);
681 }
682}