1use chrono::{DateTime, Utc};
10use std::hash::Hash;
11use std::time::Instant;
12
13use crate::config::{SamplerConfig, TripletRecipe};
14use crate::data::DataRecord;
15use crate::errors::SamplerError;
16use crate::hash::stable_hash_with;
17use crate::types::SourceId;
18
19pub mod backends;
21pub mod indexing;
23pub use backends::csv_source::{CsvSource, CsvSourceConfig};
24pub use backends::file_source::{
25 FileSource, FileSourceConfig, SectionBuilder, TaxonomyBuilder, anchor_context_sections,
26 taxonomy_from_path,
27};
28
29pub use backends::in_memory_source::InMemorySource;
30
31#[derive(Clone, Debug)]
36pub struct SourceCursor {
37 pub last_seen: DateTime<Utc>,
39 pub revision: u64,
41}
42
43#[derive(Clone, Debug)]
47pub struct SourceSnapshot {
48 pub records: Vec<DataRecord>,
50 pub cursor: SourceCursor,
52}
53
54pub trait DataSource: Send + Sync {
59 fn id(&self) -> &str;
61 fn refresh(
65 &self,
66 config: &SamplerConfig,
67 cursor: Option<&SourceCursor>,
68 limit: Option<usize>,
69 ) -> Result<SourceSnapshot, SamplerError>;
70
71 fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError>;
81
82 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
86 Vec::new()
87 }
88}
89
90pub trait IndexableSource: Send + Sync {
99 fn id(&self) -> &str;
101 fn len_hint(&self) -> Option<usize>;
103 fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError>;
105}
106
107pub struct IndexablePager {
112 source_id: SourceId,
113}
114
115impl IndexablePager {
116 pub fn new(source_id: impl Into<SourceId>) -> Self {
118 Self {
119 source_id: source_id.into(),
120 }
121 }
122
123 pub fn refresh(
125 &self,
126 source: &dyn IndexableSource,
127 cursor: Option<&SourceCursor>,
128 limit: Option<usize>,
129 ) -> Result<SourceSnapshot, SamplerError> {
130 let total = source
131 .len_hint()
132 .ok_or_else(|| SamplerError::SourceInconsistent {
133 source_id: source.id().to_string(),
134 details: "indexable source did not provide len_hint".into(),
135 })?;
136 self.refresh_with(total, cursor, limit, |idx| source.record_at(idx))
137 }
138
139 pub fn refresh_with<F>(
148 &self,
149 total: usize,
150 cursor: Option<&SourceCursor>,
151 limit: Option<usize>,
152 fetch: F,
153 ) -> Result<SourceSnapshot, SamplerError>
154 where
155 F: Fn(usize) -> Result<Option<DataRecord>, SamplerError> + Send + Sync,
156 {
157 if total == 0 {
158 return Ok(SourceSnapshot {
159 records: Vec::new(),
160 cursor: SourceCursor {
161 last_seen: Utc::now(),
162 revision: 0,
163 },
164 });
165 }
166 let mut start = cursor.map(|cursor| cursor.revision as usize).unwrap_or(0);
167 if start >= total {
168 start = 0;
169 }
170 let max = limit.unwrap_or(total);
171 let seed = Self::seed_for(&self.source_id, total);
172
173 let mut permutation = IndexPermutation::new(total, seed, start as u64);
176 let seq: Vec<(usize, usize)> = (0..total)
177 .map(|_| {
178 let idx = permutation.next();
179 (idx, permutation.cursor())
180 })
181 .collect();
182
183 let should_report = total >= 10_000 || max >= 1_024;
184 let refresh_start = Instant::now();
185 if should_report {
186 eprintln!(
187 "[triplets:source] refresh start source='{}' source_records={} ingestion_limit={}",
188 self.source_id, total, max
189 );
190 }
191
192 use rayon::prelude::*;
193 let par_end = max.min(total);
204 let results: Vec<Result<Option<DataRecord>, SamplerError>> = seq[..par_end]
205 .par_iter()
206 .map(|&(idx, _)| fetch(idx))
207 .collect();
208 let mut records = Vec::with_capacity(max.min(total));
209 let mut final_cursor = start;
210 for (result, &(_, cursor_after)) in results.into_iter().zip(seq[..par_end].iter()) {
211 if records.len() >= max {
212 break;
213 }
214 if let Some(r) = result? {
215 records.push(r)
216 }
217 final_cursor = cursor_after;
218 }
219 for &(idx, cursor_after) in &seq[par_end..] {
221 if records.len() >= max {
222 break;
223 }
224 if let Some(r) = fetch(idx)? {
225 records.push(r);
226 }
227 final_cursor = cursor_after;
228 }
229
230 if should_report {
231 eprintln!(
232 "[triplets:source] refresh done source='{}' source_records={} ingested={} elapsed={:.2}s",
233 self.source_id,
234 total,
235 records.len(),
236 refresh_start.elapsed().as_secs_f64()
237 );
238 }
239 let last_seen = records
240 .iter()
241 .map(|record| record.updated_at)
242 .max()
243 .unwrap_or_else(Utc::now);
244 Ok(SourceSnapshot {
245 records,
246 cursor: SourceCursor {
247 last_seen,
248 revision: final_cursor as u64,
249 },
250 })
251 }
252
253 pub(crate) fn seed_for(source_id: &SourceId, total: usize) -> u64 {
255 Self::stable_index_shuffle_key(source_id, 0)
256 ^ Self::stable_index_shuffle_key(source_id, total)
257 }
258
259 pub fn seed_for_sampler(source_id: &SourceId, total: usize, sampler_seed: u64) -> u64 {
261 Self::seed_for(source_id, total)
262 ^ stable_hash_with(|hasher| {
263 "triplets_sampler_seed".hash(hasher);
264 source_id.hash(hasher);
265 total.hash(hasher);
266 sampler_seed.hash(hasher);
267 })
268 }
269
270 fn stable_index_shuffle_key(source_id: &SourceId, idx: usize) -> u64 {
271 stable_hash_with(|hasher| {
272 source_id.hash(hasher);
273 idx.hash(hasher);
274 })
275 }
276}
277
278pub struct IndexableAdapter<T: IndexableSource> {
280 inner: T,
281}
282
283impl<T: IndexableSource> IndexableAdapter<T> {
284 pub fn new(inner: T) -> Self {
286 Self { inner }
287 }
288}
289
290impl<T: IndexableSource> DataSource for IndexableAdapter<T> {
291 fn id(&self) -> &str {
292 self.inner.id()
293 }
294
295 fn refresh(
296 &self,
297 _config: &SamplerConfig,
298 cursor: Option<&SourceCursor>,
299 limit: Option<usize>,
300 ) -> Result<SourceSnapshot, SamplerError> {
301 let pager = IndexablePager::new(self.inner.id());
302 pager.refresh(&self.inner, cursor, limit)
303 }
304
305 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
306 self.inner
307 .len_hint()
308 .map(|value| value as u128)
309 .ok_or_else(|| SamplerError::SourceInconsistent {
310 source_id: self.inner.id().to_string(),
311 details: "indexable source did not provide len_hint".into(),
312 })
313 }
314}
315
316pub struct IndexPermutation {
318 total: u64,
319 domain_bits: u32,
320 domain_size: u64,
321 seed: u64,
322 counter: u64,
323}
324
325impl IndexPermutation {
326 pub fn new(total: usize, seed: u64, counter: u64) -> Self {
328 let total_u64 = total as u64;
329 let domain_bits = (64 - (total_u64 - 1).leading_zeros()).max(1);
330 let domain_size = 1u64 << domain_bits;
331 Self {
332 total: total_u64,
333 domain_bits,
334 domain_size,
335 seed,
336 counter,
337 }
338 }
339
340 #[allow(clippy::should_implement_trait)]
345 pub fn next(&mut self) -> usize {
346 loop {
347 let v =
348 Self::permute_bits(self.counter % self.domain_size, self.domain_bits, self.seed);
349 self.counter = self.counter.wrapping_add(1);
350 if v < self.total {
351 return v as usize;
352 }
353 }
354 }
355
356 pub fn cursor(&self) -> usize {
358 (self.counter as usize) % (self.total as usize)
359 }
360 fn permute_bits(value: u64, bits: u32, seed: u64) -> u64 {
361 if bits == 0 {
362 return 0;
363 }
364 let mask = if bits == 64 {
365 u64::MAX
366 } else {
367 (1u64 << bits) - 1
368 };
369 let mut a = (seed | 1) & mask;
370 if a == 0 {
371 a = 1;
372 }
373 let b = (seed >> 1) & mask;
374 a.wrapping_mul(value).wrapping_add(b) & mask
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::data::{QualityScore, RecordSection, SectionRole};
382 use crate::types::RecordId;
383 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
384 use std::thread;
385 use std::time::Duration as StdDuration;
386
387 struct IndexableStub {
389 id: SourceId,
390 count: usize,
391 }
392
393 struct NoLenHintStub {
394 id: SourceId,
395 }
396
397 impl IndexableStub {
398 fn new(id: &str, count: usize) -> Self {
399 Self {
400 id: id.to_string(),
401 count,
402 }
403 }
404 }
405
406 impl NoLenHintStub {
407 fn new(id: &str) -> Self {
408 Self { id: id.to_string() }
409 }
410 }
411
412 impl IndexableSource for IndexableStub {
413 fn id(&self) -> &str {
414 &self.id
415 }
416
417 fn len_hint(&self) -> Option<usize> {
418 Some(self.count)
419 }
420
421 fn record_at(&self, idx: usize) -> Result<Option<DataRecord>, SamplerError> {
422 if idx >= self.count {
423 return Ok(None);
424 }
425 let now = Utc::now();
426 Ok(Some(DataRecord {
427 id: format!("record_{idx}"),
428 source: self.id.clone(),
429 created_at: now,
430 updated_at: now,
431 quality: QualityScore { trust: 1.0 },
432 taxonomy: Vec::new(),
433 sections: vec![RecordSection {
434 role: SectionRole::Anchor,
435 heading: None,
436 text: "stub".into(),
437 sentences: vec!["stub".into()],
438 }],
439 meta_prefix: None,
440 }))
441 }
442 }
443
444 impl IndexableSource for NoLenHintStub {
445 fn id(&self) -> &str {
446 &self.id
447 }
448
449 fn len_hint(&self) -> Option<usize> {
450 None
451 }
452
453 fn record_at(&self, _idx: usize) -> Result<Option<DataRecord>, SamplerError> {
454 Ok(None)
455 }
456 }
457
458 #[test]
459 fn indexable_adapter_pages_in_stable_order() {
460 let adapter = IndexableAdapter::new(IndexableStub::new("stub", 6));
461 let config = SamplerConfig::default();
462 let full = adapter.refresh(&config, None, None).unwrap();
463 let full_ids: Vec<RecordId> = full.records.into_iter().map(|r| r.id).collect();
464
465 let mut cursor = None;
466 let mut paged = Vec::new();
467 for _ in 0..3 {
468 let snapshot = adapter.refresh(&config, cursor.as_ref(), Some(3)).unwrap();
469 cursor = Some(snapshot.cursor);
470 paged.extend(snapshot.records.into_iter().map(|r| r.id));
471 if paged.len() >= full_ids.len() {
472 break;
473 }
474 }
475 assert_eq!(paged, full_ids);
476 }
477
478 #[test]
479 fn indexable_paging_spans_multiple_regimes() {
480 let total = 256usize;
483 let mask = (1u64 << (64 - (total as u64 - 1).leading_zeros())) - 1;
484 let source_id = (0..512)
485 .map(|idx| format!("regime_test_{idx}"))
486 .find(|id| {
487 let seed = IndexablePager::seed_for(id, total);
488 let a = (seed | 1) & mask;
489 a != 1 && a != mask
490 })
491 .unwrap();
492
493 let adapter = IndexableAdapter::new(IndexableStub::new(&source_id, total));
496 let snapshot = adapter
497 .refresh(&SamplerConfig::default(), None, Some(64))
498 .unwrap();
499 let indices: Vec<usize> = snapshot
500 .records
501 .into_iter()
502 .map(|r| {
503 r.id.strip_prefix("record_")
504 .unwrap()
505 .parse::<usize>()
506 .unwrap()
507 })
508 .collect();
509 let min_idx = *indices.iter().min().unwrap();
510 let max_idx = *indices.iter().max().unwrap();
511 assert!(
512 max_idx - min_idx >= total / 2,
513 "expected spread across the index space, got min={min_idx} max={max_idx}"
514 );
515 }
516
517 #[test]
518 fn indexable_pager_errors_when_len_hint_missing() {
519 let pager = IndexablePager::new("no_len_hint");
520 let source = NoLenHintStub::new("no_len_hint");
521 let result = pager.refresh(&source, None, Some(3));
522 assert!(result.is_err());
523 }
524
525 #[test]
526 fn indexable_adapter_reported_count_errors_when_len_hint_missing() {
527 let adapter = IndexableAdapter::new(NoLenHintStub::new("no_len_hint"));
528 let result = adapter.reported_record_count(&SamplerConfig::default());
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn indexable_pager_refresh_with_zero_total_returns_empty_snapshot() {
534 let pager = IndexablePager::new("empty");
535 let snapshot = pager
536 .refresh_with(0, None, Some(4), |_idx| Ok(None))
537 .unwrap();
538 assert!(snapshot.records.is_empty());
539 assert_eq!(snapshot.cursor.revision, 0);
540 }
541
542 #[test]
543 fn index_permutation_permute_bits_handles_zero_bits_and_zero_seed_path() {
544 assert_eq!(IndexPermutation::permute_bits(123, 0, 99), 0);
545
546 let bits = 1;
547 let value = 1;
548 let out = IndexPermutation::permute_bits(value, bits, 0);
549 assert!(out <= 1);
550 }
551
552 #[test]
553 fn index_permutation_next_stays_within_total_and_cursor_advances() {
554 let mut perm = IndexPermutation::new(3, 7, 0);
555 let mut seen = Vec::new();
556 for _ in 0..8 {
557 seen.push(perm.next());
558 }
559 assert!(seen.iter().all(|idx| *idx < 3));
560 assert!(perm.cursor() < 3);
561 }
562
563 #[test]
564 fn indexable_pager_large_refresh_triggers_reporting_branch_and_wraps_cursor() {
565 let pager = IndexablePager::new("reporting");
566 let cursor = SourceCursor {
567 last_seen: Utc::now(),
568 revision: 20_000,
569 };
570 let snapshot = pager
571 .refresh_with(10_000, Some(&cursor), Some(4), |idx| {
572 Ok(Some(DataRecord {
573 id: format!("record_{idx}"),
574 source: "reporting".to_string(),
575 created_at: Utc::now(),
576 updated_at: Utc::now(),
577 quality: QualityScore { trust: 1.0 },
578 taxonomy: Vec::new(),
579 sections: vec![RecordSection {
580 role: SectionRole::Anchor,
581 heading: None,
582 text: "t".to_string(),
583 sentences: vec!["t".to_string()],
584 }],
585 meta_prefix: None,
586 }))
587 })
588 .unwrap();
589
590 assert_eq!(snapshot.records.len(), 4);
591 assert!(snapshot.cursor.revision < 10_000);
592 }
593
594 #[test]
595 fn indexable_pager_reporting_branch_emits_progress_when_refresh_is_slow() {
596 let pager = IndexablePager::new("slow_reporting");
597 let slept = AtomicBool::new(false);
598 let snapshot = pager
599 .refresh_with(2_000, None, Some(1_024), |_idx| {
600 if !slept.swap(true, Ordering::Relaxed) {
601 thread::sleep(StdDuration::from_millis(800));
602 }
603 Ok(None)
604 })
605 .unwrap();
606
607 assert!(snapshot.records.is_empty());
608 assert!(snapshot.cursor.revision < 2_000);
609 }
610
611 #[test]
612 fn source_ids_and_reported_counts_are_exposed() {
613 let adapter = IndexableAdapter::new(IndexableStub::new("stub_id", 3));
614 assert_eq!(adapter.id(), "stub_id");
615 assert_eq!(
616 adapter
617 .reported_record_count(&SamplerConfig::default())
618 .unwrap(),
619 3
620 );
621 }
622
623 #[test]
624 fn indexable_pager_sequential_fallback_fills_quota_when_parallel_pass_yields_none() {
625 let pager = IndexablePager::new("fallback_fill");
630 let call_count = AtomicUsize::new(0);
631 let par_end = 4usize;
632 let snapshot = pager
633 .refresh_with(8, None, Some(par_end), |idx| {
634 let n = call_count.fetch_add(1, Ordering::Relaxed);
635 if n < par_end {
636 Ok(None)
637 } else {
638 Ok(Some(DataRecord {
639 id: format!("r_{idx}"),
640 source: "fallback_fill".to_string(),
641 created_at: Utc::now(),
642 updated_at: Utc::now(),
643 quality: QualityScore { trust: 1.0 },
644 taxonomy: Vec::new(),
645 sections: vec![RecordSection {
646 role: SectionRole::Anchor,
647 heading: None,
648 text: "t".to_string(),
649 sentences: vec!["t".to_string()],
650 }],
651 meta_prefix: None,
652 }))
653 }
654 })
655 .unwrap();
656 assert_eq!(snapshot.records.len(), par_end);
657 }
658
659 #[test]
660 fn indexable_pager_refresh_with_propagates_fetch_error() {
661 let pager = IndexablePager::new("err");
662 let err = pager
663 .refresh_with(8, None, Some(2), |_idx| {
664 Err(SamplerError::SourceUnavailable {
665 source_id: "err".to_string(),
666 reason: "fetch failed".to_string(),
667 })
668 })
669 .unwrap_err();
670 assert!(matches!(
671 err,
672 SamplerError::SourceUnavailable { ref reason, .. } if reason.contains("fetch failed")
673 ));
674 }
675
676 #[test]
677 fn seed_for_sampler_depends_on_sampler_seed() {
678 let source_id = "seeded".to_string();
679 let base = IndexablePager::seed_for(&source_id, 17);
680 let with_a = IndexablePager::seed_for_sampler(&source_id, 17, 1);
681 let with_b = IndexablePager::seed_for_sampler(&source_id, 17, 2);
682 assert_ne!(with_a, with_b);
683 assert_ne!(with_a, base);
684 }
685}