1use crate::config::SamplerConfig;
2use crate::constants::splits::{STEP_CURSOR_KEY, is_reserved_source_id};
3use crate::data::DataRecord;
4use crate::errors::SamplerError;
5use crate::hash::derive_epoch_seed;
6use crate::source::{DataSource, SourceCursor, SourceSnapshot};
7use crate::types::{RecordId, SourceId};
8use chrono::Utc;
9use indexmap::IndexMap;
10use std::collections::HashMap;
11use std::collections::VecDeque;
12use std::sync::{Arc, Condvar, Mutex, RwLock};
13use std::thread;
14use std::time::Duration;
15use tracing::debug;
16
17#[derive(Clone)]
19pub struct RecordCache {
20 inner: Arc<RwLock<RecordCacheInner>>,
21 notifier: Arc<(Mutex<CacheStats>, Condvar)>,
22}
23
24struct RecordCacheInner {
26 records: IndexMap<RecordId, CachedRecord>,
27 order: VecDeque<RecordId>,
28 max_records: usize,
29 next_version: u64,
30}
31
32struct CachedRecord {
34 record: DataRecord,
35 version: u64,
36}
37
38#[derive(Default)]
40struct CacheStats {
41 ingests: u64,
42}
43
44impl RecordCache {
45 pub fn new(max_records: usize) -> Self {
47 Self {
48 inner: Arc::new(RwLock::new(RecordCacheInner {
49 records: IndexMap::new(),
50 order: VecDeque::new(),
51 max_records,
52 next_version: 0,
53 })),
54 notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
55 }
56 }
57
58 pub fn ingest<I>(&self, records: I)
60 where
61 I: IntoIterator<Item = DataRecord>,
62 {
63 let mut batch: Vec<DataRecord> = records.into_iter().collect();
64 if batch.is_empty() {
65 return;
66 }
67 let mut inner = self.inner.write().expect("record cache poisoned");
68 inner.ingest_batch(&mut batch);
69 drop(inner);
70 let (lock, cvar) = &*self.notifier;
71 let mut stats = lock.lock().expect("record cache stats poisoned");
72 stats.ingests = stats.ingests.saturating_add(1);
73 cvar.notify_all();
74 }
75
76 pub fn clear(&self) {
78 let mut inner = self.inner.write().expect("record cache poisoned");
79 inner.records.clear();
80 inner.order.clear();
81 }
82
83 pub fn snapshot(&self) -> Vec<DataRecord> {
85 let inner = self.inner.read().expect("record cache poisoned");
86 inner
87 .records
88 .values()
89 .map(|entry| entry.record.clone())
90 .collect()
91 }
92
93 pub fn ingest_count(&self) -> u64 {
95 let (lock, _) = &*self.notifier;
96 lock.lock().expect("record cache stats poisoned").ingests
97 }
98
99 pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
101 let (lock, cvar) = &*self.notifier;
102 let mut stats = lock.lock().expect("record cache stats poisoned");
103 while stats.ingests <= last_seen {
104 let result = cvar
105 .wait_timeout(stats, timeout)
106 .expect("record cache stats poisoned");
107 stats = result.0;
108 if result.1.timed_out() {
109 break;
110 }
111 }
112 stats.ingests
113 }
114
115 pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
117 let (lock, cvar) = &*self.notifier;
118 let mut stats = lock.lock().expect("record cache stats poisoned");
119 while stats.ingests <= last_seen {
120 stats = cvar.wait(stats).expect("record cache stats poisoned");
121 }
122 stats.ingests
123 }
124
125 pub fn is_empty(&self) -> bool {
127 let inner = self.inner.read().expect("record cache poisoned");
128 inner.records.is_empty()
129 }
130
131 pub fn len(&self) -> usize {
133 let inner = self.inner.read().expect("record cache poisoned");
134 inner.records.len()
135 }
136}
137
138impl RecordCacheInner {
139 fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
140 for record in records.drain(..) {
141 self.next_version = self.next_version.saturating_add(1);
142 let record_id = record.id.clone();
143 if self.records.contains_key(&record_id) {
144 if let Some(entry) = self.records.get_mut(&record_id) {
145 entry.record = record;
146 entry.version = self.next_version;
147 }
148 Self::refresh_order(&mut self.order, &record_id);
149 self.order.push_back(record_id);
150 } else {
151 self.order.push_back(record_id.clone());
152 self.records.insert(
153 record_id,
154 CachedRecord {
155 record,
156 version: self.next_version,
157 },
158 );
159 }
160 self.enforce_limit();
161 }
162 }
163
164 fn enforce_limit(&mut self) {
165 if self.max_records == 0 {
166 self.records.clear();
167 self.order.clear();
168 return;
169 }
170 while self.records.len() > self.max_records {
171 if let Some(oldest) = self.order.pop_front() {
172 self.records.swap_remove(&oldest);
173 } else {
174 break;
175 }
176 }
177 }
178
179 fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
180 if order.is_empty() {
181 return;
182 }
183 if let Some(pos) = order.iter().position(|existing| existing == id) {
184 order.remove(pos);
185 }
186 }
187}
188
189pub struct IngestionManager {
191 sources: Vec<SourceState>,
192 max_records: usize,
193 sampler_config: SamplerConfig,
194 source_epoch: u64,
196 source_refresh_generation: u64,
198 last_refreshed_sources: Vec<SourceId>,
203 drain_start: usize,
208 step_counter: u64,
211}
212
213#[derive(Clone, Debug, Default)]
214pub struct SourceRefreshStats {
216 pub last_refresh_ms: u128,
218 pub last_record_count: usize,
220 pub last_records_per_sec: f64,
222 pub last_error: Option<String>,
224 pub error_count: u64,
226}
227
228impl IngestionManager {
229 pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
231 Self {
232 sources: Vec::new(),
233 max_records,
234 sampler_config,
235 source_epoch: 0,
236 source_refresh_generation: 0,
237 last_refreshed_sources: Vec::new(),
238 drain_start: 0,
239 step_counter: 0,
240 }
241 }
242
243 pub fn source_refresh_generation(&self) -> u64 {
245 self.source_refresh_generation
246 }
247
248 pub fn last_refreshed_sources(&self) -> &[SourceId] {
250 &self.last_refreshed_sources
251 }
252
253 pub(crate) fn reset_step_counter(&mut self) {
263 self.step_counter = 0;
264 }
265
266 #[cfg(test)]
268 pub fn step_counter(&self) -> u64 {
269 self.step_counter
270 }
271
272 pub(crate) fn set_source_epoch(&mut self, epoch: u64) {
273 self.source_epoch = epoch;
274 }
275
276 #[cfg(test)]
278 pub fn source_epoch(&self) -> u64 {
279 self.source_epoch
280 }
281
282 pub(crate) fn reset_stream_cursors(&mut self) {
288 for state in &mut self.sources {
289 state.cursor = None;
290 state.buffer.clear();
291 state.cache.clear();
292 }
293 }
294
295 pub fn register_source(
300 &mut self,
301 source: Box<dyn DataSource + 'static>,
302 ) -> Result<(), SamplerError> {
303 let source_id = source.id().to_string();
304 if is_reserved_source_id(&source_id) {
305 return Err(SamplerError::ReservedSourceId(source_id));
306 }
307 let cache = RecordCache::new(self.max_records);
308 self.sources.push(SourceState {
309 source,
310 cursor: None,
311 buffer: VecDeque::new(),
312 cache,
313 stats: SourceRefreshStats::default(),
314 });
315 Ok(())
316 }
317
318 pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
320 if cursors.is_empty() {
321 return;
322 }
323 let mut map = std::collections::HashMap::with_capacity(cursors.len());
324 for (id, revision) in cursors {
325 if id == STEP_CURSOR_KEY {
330 self.step_counter = *revision;
331 } else {
332 map.insert(id.as_str(), *revision);
333 }
334 }
335 for state in &mut self.sources {
336 if let Some(revision) = map.get(state.source.id()) {
337 state.cursor = Some(SourceCursor {
338 last_seen: Utc::now(),
339 revision: *revision,
340 });
341 }
342 }
343 }
344
345 pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
347 let mut out = Vec::new();
348 for state in &self.sources {
349 if let Some(cursor) = state.cursor.as_ref() {
350 out.push((state.source.id().to_string(), cursor.revision));
351 }
352 }
353 out.push((STEP_CURSOR_KEY.to_string(), self.step_counter));
358 out
359 }
360
361 pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
363 self.sources
364 .iter()
365 .map(|state| (state.source.id().to_string(), state.stats.clone()))
366 .collect()
367 }
368
369 pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
374 self.sources
375 .iter()
376 .flat_map(|s| s.cache.snapshot())
377 .collect()
378 }
379
380 pub fn all_caches_empty(&self) -> bool {
382 self.sources.iter().all(|s| s.cache.is_empty())
383 }
384
385 pub fn all_records_len(&self) -> usize {
387 self.sources.iter().map(|s| s.cache.len()).sum()
388 }
389
390 pub fn total_ingest_count(&self) -> u64 {
395 self.sources.iter().map(|s| s.cache.ingest_count()).sum()
396 }
397
398 pub fn refresh_all(&mut self) {
400 self.refresh_all_internal(false, None, None);
401 }
402
403 pub fn advance(&mut self, step: usize) {
405 self.refresh_all_internal(false, Some(step), None);
406 }
407
408 pub fn advance_with_weights(
413 &mut self,
414 step: usize,
415 weights: &HashMap<SourceId, f32>,
416 ) -> Result<(), SamplerError> {
417 self.validate_weights(weights)?;
418 self.refresh_all_internal(false, Some(step), Some(weights));
419 Ok(())
420 }
421
422 pub fn force_refresh_all(&mut self) {
424 self.refresh_all_internal(true, None, None);
425 }
426
427 pub fn refresh_all_with_weights(
432 &mut self,
433 weights: &HashMap<SourceId, f32>,
434 ) -> Result<(), SamplerError> {
435 self.validate_weights(weights)?;
436 self.refresh_all_internal(false, None, Some(weights));
437 Ok(())
438 }
439
440 pub fn force_refresh_all_with_weights(
445 &mut self,
446 weights: &HashMap<SourceId, f32>,
447 ) -> Result<(), SamplerError> {
448 self.validate_weights(weights)?;
449 self.refresh_all_internal(true, None, Some(weights));
450 Ok(())
451 }
452
453 fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
454 let known_ids: std::collections::HashSet<&str> =
455 self.sources.iter().map(|s| s.source.id()).collect();
456 for (id, &w) in weights {
457 if !known_ids.contains(id.as_str()) {
458 return Err(SamplerError::InvalidWeight {
459 source_id: id.clone(),
460 reason: "source is not registered".to_string(),
461 });
462 }
463 if w < 0.0 {
464 return Err(SamplerError::InvalidWeight {
465 source_id: id.clone(),
466 reason: format!("weight {w} is negative"),
467 });
468 }
469 }
470 Ok(())
471 }
472
473 fn refresh_all_internal(
480 &mut self,
481 force_refresh: bool,
482 step: Option<usize>,
483 weights: Option<&HashMap<SourceId, f32>>,
484 ) {
485 self.last_refreshed_sources.clear();
486 let mut refresh_plan = Vec::new();
487 for (idx, state) in self.sources.iter_mut().enumerate() {
488 if force_refresh {
489 state.buffer.clear();
490 }
491 if force_refresh || state.buffer.is_empty() {
492 refresh_plan.push((idx, state.cursor.clone()));
493 }
494 }
495
496 if !refresh_plan.is_empty() {
497 self.step_counter = self.step_counter.saturating_add(1);
498 self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
499 self.last_refreshed_sources = refresh_plan
500 .iter()
501 .map(|(idx, _)| self.sources[*idx].source.id().to_string())
502 .collect();
503 let mut results: Vec<
504 Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
505 > = Vec::with_capacity(self.sources.len());
506 results.resize_with(self.sources.len(), || None);
507 let fetch_limit = self.max_records;
508 let sampler_config = self.sampler_config.clone();
509 let step = self.step_counter;
510 thread::scope(|scope| {
511 let mut handles = Vec::with_capacity(refresh_plan.len());
512 for (idx, cursor) in &refresh_plan {
513 let source = &self.sources[*idx].source;
514 let cursor = cursor.clone();
515 let sampler_config = sampler_config.clone();
516 let source_epoch = self.source_epoch;
517 handles.push((
518 *idx,
519 scope.spawn(move || {
520 let start = std::time::Instant::now();
521 let epoch_config = SamplerConfig {
527 seed: derive_epoch_seed(sampler_config.seed, source_epoch) ^ step,
528 ..sampler_config
529 };
530 let result =
531 source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
532 let elapsed = start.elapsed();
533 (result, elapsed)
534 }),
535 ));
536 }
537 for (idx, handle) in handles {
538 let result = match handle.join() {
539 Ok((result, elapsed)) => {
540 debug!(
541 source_id = %self.sources[idx].source.id(),
542 refresh_ms = elapsed.as_millis(),
543 "source refresh completed"
544 );
545 (result, elapsed)
546 }
547 Err(_) => (
548 Err(SamplerError::SourceUnavailable {
549 source_id: self.sources[idx].source.id().to_string(),
550 reason: "source refresh thread panicked".into(),
551 }),
552 std::time::Duration::from_secs(0),
553 ),
554 };
555 results[idx] = Some(result);
556 }
557 });
558
559 for (idx, _) in refresh_plan {
560 let Some((result, elapsed)) = results[idx].take() else {
561 continue;
562 };
563 match result {
564 Ok(snapshot) => {
565 let SourceSnapshot {
566 records,
567 cursor: next_cursor,
568 } = snapshot;
569 let record_count = records.len();
570 let seconds = elapsed.as_secs_f64();
571 let per_sec = if seconds > 0.0 {
572 (record_count as f64) / seconds
573 } else {
574 0.0
575 };
576 let stats = &mut self.sources[idx].stats;
577 stats.last_refresh_ms = elapsed.as_millis();
578 stats.last_record_count = record_count;
579 stats.last_records_per_sec = per_sec;
580 stats.last_error = None;
581 debug!(
582 source_id = %self.sources[idx].source.id(),
583 record_count,
584 refresh_ms = elapsed.as_millis(),
585 records_per_sec = per_sec,
586 "source refresh ingested records"
587 );
588 let source_id = self.sources[idx].source.id().to_string();
589 let normalized = records
590 .into_iter()
591 .map(|mut record| {
592 record.source = source_id.clone();
593 record
594 })
595 .collect::<Vec<_>>();
596 self.sources[idx].buffer.extend(normalized);
597 self.sources[idx].cursor = Some(next_cursor);
598 }
599 Err(err) => {
600 let stats = &mut self.sources[idx].stats;
601 stats.last_refresh_ms = elapsed.as_millis();
602 stats.last_record_count = 0;
603 stats.last_records_per_sec = 0.0;
604 stats.last_error = Some(err.to_string());
605 stats.error_count = stats.error_count.saturating_add(1);
606 eprintln!(
607 "[data_sampler] source '{}' refresh failed: {}",
608 self.sources[idx].source.id(),
609 err
610 );
611 }
612 }
613 }
614 }
615
616 if step.is_none() {
620 for state in self.sources.iter_mut() {
621 state.cache.clear();
622 }
623 }
624 if self.max_records == 0 {
625 return;
626 }
627 let target_limit = step.unwrap_or(self.max_records);
628 if let Some(weights) = weights {
629 self.weighted_drain_into_caches(target_limit, weights);
630 } else {
631 let n = self.sources.len();
636 if n > 0 {
637 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
638 let mut total_drained = 0;
639 let mut any_remaining = true;
640 while total_drained < target_limit && any_remaining {
641 any_remaining = false;
642 for offset in 0..n {
643 if total_drained >= target_limit {
644 break;
645 }
646 let idx = (self.drain_start + offset) % n;
647 if let Some(record) = self.sources[idx].buffer.pop_front() {
648 per_source[idx].push(record);
649 total_drained += 1;
650 any_remaining = true;
651 }
652 }
653 }
654 if total_drained > 0 {
658 self.drain_start = (self.drain_start + 1) % n;
659 }
660 for (idx, batch) in per_source.into_iter().enumerate() {
661 if !batch.is_empty() {
662 self.sources[idx].cache.ingest(batch);
663 }
664 }
665 }
666 }
667 }
668
669 fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
670 let len = self.sources.len();
671 if len == 0 {
672 return;
673 }
674 let mut weight_values = Vec::with_capacity(len);
675 let mut any_positive = false;
676 for state in &self.sources {
677 let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
678 if weight > 0.0 {
679 any_positive = true;
680 }
681 weight_values.push(weight);
682 }
683 if !any_positive {
684 weight_values.fill(1.0);
685 }
686
687 let mut current = vec![0.0f32; len];
688 let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
689 let mut total = 0;
690 while total < limit {
691 let mut total_weight = 0.0f32;
692 for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
693 if weight <= 0.0 {
694 continue;
695 }
696 if self.sources[idx].buffer.is_empty() {
697 continue;
698 }
699 total_weight += weight;
700 }
701 if total_weight == 0.0 {
702 break;
703 }
704
705 let mut best_idx = None;
706 let mut best_score = f32::MIN;
707 let closer_to_start = |a: usize, b: usize| -> bool {
711 let da = (a + len - self.drain_start) % len;
712 let db = (b + len - self.drain_start) % len;
713 da < db
714 };
715 for idx in 0..len {
716 if weight_values[idx] <= 0.0 {
717 continue;
718 }
719 if self.sources[idx].buffer.is_empty() {
720 continue;
721 }
722 current[idx] += weight_values[idx];
723 let is_better = if current[idx] > best_score {
724 true
725 } else if current[idx] == best_score {
726 closer_to_start(idx, best_idx.unwrap_or(0))
727 } else {
728 false
729 };
730 if is_better {
731 best_score = current[idx];
732 best_idx = Some(idx);
733 }
734 }
735
736 let idx = match best_idx {
737 Some(idx) => idx,
738 None => break,
739 };
740 current[idx] -= total_weight;
741 if let Some(record) = self.sources[idx].buffer.pop_front() {
742 per_source[idx].push(record);
743 total += 1;
744 }
745 }
746
747 if total > 0 && len > 0 {
748 self.drain_start = (self.drain_start + 1) % len;
749 }
750
751 for (idx, batch) in per_source.into_iter().enumerate() {
752 if !batch.is_empty() {
753 self.sources[idx].cache.ingest(batch);
754 }
755 }
756 }
757
758 pub fn has_sources(&self) -> bool {
760 !self.sources.is_empty()
761 }
762}
763
764struct SourceState {
766 source: Box<dyn DataSource + 'static>,
767 cursor: Option<SourceCursor>,
768 buffer: VecDeque<DataRecord>,
769 cache: RecordCache,
771 stats: SourceRefreshStats,
772}
773
774#[cfg(test)]
775mod tests {
776 use super::*;
777 use crate::TripletSampler;
778 use crate::config::{Selector, TextRecipe, TripletRecipe};
779 use crate::data::{QualityScore, RecordSection, SectionRole};
780 use crate::sampler::Sampler;
781 use crate::splits::{DeterministicSplitStore, SamplerStateStore, SplitLabel, SplitRatios};
782 use chrono::Utc;
783 use std::collections::HashMap;
784 use std::collections::VecDeque;
785 use std::sync::atomic::{AtomicUsize, Ordering};
786 use std::sync::{Arc, Mutex};
787
788 fn make_record(id: &str, source: &str) -> DataRecord {
789 let now = Utc::now();
790 DataRecord {
791 id: id.to_string(),
792 source: source.to_string(),
793 created_at: now,
794 updated_at: now,
795 quality: QualityScore { trust: 1.0 },
796 taxonomy: Vec::new(),
797 sections: vec![RecordSection {
798 role: SectionRole::Anchor,
799 heading: None,
800 text: id.to_string(),
801 sentences: vec![id.to_string()],
802 }],
803 meta_prefix: None,
804 }
805 }
806
807 struct ScriptedSource {
808 id: String,
809 refreshes: Arc<AtomicUsize>,
810 script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
811 }
812
813 impl ScriptedSource {
814 fn new(
815 id: &str,
816 refreshes: Arc<AtomicUsize>,
817 script: Vec<Result<SourceSnapshot, SamplerError>>,
818 ) -> Self {
819 Self {
820 id: id.to_string(),
821 refreshes,
822 script: Arc::new(Mutex::new(script.into_iter().collect())),
823 }
824 }
825 }
826
827 impl DataSource for ScriptedSource {
828 fn id(&self) -> &str {
829 &self.id
830 }
831
832 fn refresh(
833 &self,
834 _config: &SamplerConfig,
835 _cursor: Option<&SourceCursor>,
836 _limit: Option<usize>,
837 ) -> Result<SourceSnapshot, SamplerError> {
838 self.refreshes.fetch_add(1, Ordering::SeqCst);
839 let mut guard = self.script.lock().expect("script lock poisoned");
840 guard.pop_front().unwrap_or_else(|| {
841 Ok(SourceSnapshot {
842 records: Vec::new(),
843 cursor: SourceCursor {
844 last_seen: Utc::now(),
845 revision: 0,
846 },
847 })
848 })
849 }
850
851 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
852 Ok(0)
853 }
854
855 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
856 Vec::new()
857 }
858 }
859
860 struct PanicSource {
861 id: String,
862 }
863
864 impl DataSource for PanicSource {
865 fn id(&self) -> &str {
866 &self.id
867 }
868
869 fn refresh(
870 &self,
871 _config: &SamplerConfig,
872 _cursor: Option<&SourceCursor>,
873 _limit: Option<usize>,
874 ) -> Result<SourceSnapshot, SamplerError> {
875 panic!("panic source refresh")
876 }
877
878 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
879 Ok(0)
880 }
881 }
882
883 #[test]
884 fn record_cache_waits_len_and_clear_paths_are_covered() {
885 let cache = RecordCache::new(2);
886 assert!(cache.is_empty());
887 assert_eq!(cache.len(), 0);
888 assert_eq!(cache.ingest_count(), 0);
889
890 cache.ingest(Vec::<DataRecord>::new());
891 assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
892
893 let cache_for_waiter = cache.clone();
894 let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
895 std::thread::sleep(Duration::from_millis(5));
896 cache.ingest(vec![make_record("r1", "s")]);
897 assert_eq!(handle.join().unwrap(), 1);
898 assert_eq!(cache.ingest_count(), 1);
899
900 cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
901 assert_eq!(cache.len(), 2);
902 let ids: Vec<String> = cache
903 .snapshot()
904 .into_iter()
905 .map(|record| record.id)
906 .collect();
907 assert!(ids.contains(&"r2".to_string()));
908 assert!(ids.contains(&"r3".to_string()));
909
910 cache.clear();
911 assert!(cache.is_empty());
912 }
913
914 #[test]
915 fn record_cache_zero_limit_discards_everything() {
916 let cache = RecordCache::new(0);
917 cache.ingest(vec![make_record("r1", "s")]);
918 assert!(cache.is_empty());
919 assert_eq!(cache.len(), 0);
920 }
921
922 #[test]
923 fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
924 let mut manager = IngestionManager::new(4, SamplerConfig::default());
925 assert!(!manager.has_sources());
926 manager.load_cursors(&[]);
927
928 let refreshes = Arc::new(AtomicUsize::new(0));
929 manager
930 .register_source(Box::new(ScriptedSource::new(
931 "cursor_source",
932 refreshes,
933 vec![Ok(SourceSnapshot {
934 records: vec![make_record("id_1", "original_source")],
935 cursor: SourceCursor {
936 last_seen: Utc::now(),
937 revision: 33,
938 },
939 })],
940 )))
941 .unwrap();
942 assert!(manager.has_sources());
943
944 manager.load_cursors(&[("cursor_source".to_string(), 7)]);
945 let cursors = manager.snapshot_cursors();
946 assert_eq!(cursors.len(), 2);
948 assert_eq!(cursors[0], ("cursor_source".to_string(), 7));
949 assert_eq!(cursors[1].0, STEP_CURSOR_KEY);
950
951 manager.refresh_all();
952 let updated = manager.snapshot_cursors();
953 assert_eq!(updated.len(), 2);
954 assert_eq!(updated[0], ("cursor_source".to_string(), 33));
955 assert_eq!(updated[1].0, STEP_CURSOR_KEY);
956 let records = manager.all_records_snapshot();
957 assert_eq!(records.len(), 1);
958 assert_eq!(records[0].source, "cursor_source");
959 }
960
961 #[test]
962 fn advance_uses_buffer_before_refreshing_again() {
963 let refreshes = Arc::new(AtomicUsize::new(0));
964 let mut manager = IngestionManager::new(5, SamplerConfig::default());
965 manager
966 .register_source(Box::new(ScriptedSource::new(
967 "buffered",
968 refreshes.clone(),
969 vec![Ok(SourceSnapshot {
970 records: vec![
971 make_record("a", "x"),
972 make_record("b", "x"),
973 make_record("c", "x"),
974 ],
975 cursor: SourceCursor {
976 last_seen: Utc::now(),
977 revision: 1,
978 },
979 })],
980 )))
981 .unwrap();
982
983 manager.advance(1);
984 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
985 assert_eq!(manager.all_records_len(), 1);
986
987 manager.advance(1);
988 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
989 assert_eq!(manager.all_records_len(), 2);
990 }
991
992 #[test]
993 fn force_refresh_clears_buffer_and_fetches_again() {
994 let refreshes = Arc::new(AtomicUsize::new(0));
995 let mut manager = IngestionManager::new(4, SamplerConfig::default());
996 manager
997 .register_source(Box::new(ScriptedSource::new(
998 "force",
999 refreshes.clone(),
1000 vec![
1001 Ok(SourceSnapshot {
1002 records: vec![
1003 make_record("r1", "x"),
1004 make_record("r2", "x"),
1005 make_record("r3", "x"),
1006 ],
1007 cursor: SourceCursor {
1008 last_seen: Utc::now(),
1009 revision: 10,
1010 },
1011 }),
1012 Ok(SourceSnapshot {
1013 records: vec![make_record("r4", "x")],
1014 cursor: SourceCursor {
1015 last_seen: Utc::now(),
1016 revision: 11,
1017 },
1018 }),
1019 ],
1020 )))
1021 .unwrap();
1022
1023 manager.advance(1);
1024 assert_eq!(manager.all_records_len(), 1);
1025 assert_eq!(refreshes.load(Ordering::SeqCst), 1);
1026
1027 manager.force_refresh_all();
1028 assert_eq!(refreshes.load(Ordering::SeqCst), 2);
1029 let records = manager.all_records_snapshot();
1030 assert_eq!(records.len(), 1);
1031 assert_eq!(records[0].id, "r4");
1032 }
1033
1034 #[test]
1035 fn weighted_drain_respects_zero_and_fallback_weights() {
1036 let mut manager = IngestionManager::new(6, SamplerConfig::default());
1037 manager
1038 .register_source(Box::new(ScriptedSource::new(
1039 "a",
1040 Arc::new(AtomicUsize::new(0)),
1041 vec![Ok(SourceSnapshot {
1042 records: vec![make_record("a1", "a"), make_record("a2", "a")],
1043 cursor: SourceCursor {
1044 last_seen: Utc::now(),
1045 revision: 1,
1046 },
1047 })],
1048 )))
1049 .unwrap();
1050 manager
1051 .register_source(Box::new(ScriptedSource::new(
1052 "b",
1053 Arc::new(AtomicUsize::new(0)),
1054 vec![Ok(SourceSnapshot {
1055 records: vec![make_record("b1", "b"), make_record("b2", "b")],
1056 cursor: SourceCursor {
1057 last_seen: Utc::now(),
1058 revision: 1,
1059 },
1060 })],
1061 )))
1062 .unwrap();
1063
1064 let mut only_b = HashMap::new();
1065 only_b.insert("a".to_string(), 0.0);
1066 only_b.insert("b".to_string(), 1.0);
1067 manager.refresh_all_with_weights(&only_b).unwrap();
1068 let ids: Vec<String> = manager
1069 .all_records_snapshot()
1070 .into_iter()
1071 .map(|record| record.id)
1072 .collect();
1073 assert!(ids.iter().all(|id| id.starts_with('b')));
1074
1075 let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1076 manager_fallback
1077 .register_source(Box::new(ScriptedSource::new(
1078 "a",
1079 Arc::new(AtomicUsize::new(0)),
1080 vec![Ok(SourceSnapshot {
1081 records: vec![make_record("a1", "a")],
1082 cursor: SourceCursor {
1083 last_seen: Utc::now(),
1084 revision: 2,
1085 },
1086 })],
1087 )))
1088 .unwrap();
1089 manager_fallback
1090 .register_source(Box::new(ScriptedSource::new(
1091 "b",
1092 Arc::new(AtomicUsize::new(0)),
1093 vec![Ok(SourceSnapshot {
1094 records: vec![make_record("b1", "b")],
1095 cursor: SourceCursor {
1096 last_seen: Utc::now(),
1097 revision: 2,
1098 },
1099 })],
1100 )))
1101 .unwrap();
1102
1103 let mut all_zero = HashMap::new();
1104 all_zero.insert("a".to_string(), 0.0);
1105 all_zero.insert("b".to_string(), 0.0);
1106 manager_fallback
1107 .refresh_all_with_weights(&all_zero)
1108 .unwrap();
1109 let ids: Vec<String> = manager_fallback
1110 .all_records_snapshot()
1111 .into_iter()
1112 .map(|record| record.id)
1113 .collect();
1114 assert!(ids.contains(&"a1".to_string()));
1115 assert!(ids.contains(&"b1".to_string()));
1116 }
1117
1118 #[test]
1119 fn refresh_errors_and_panics_update_source_stats() {
1120 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1121 manager
1122 .register_source(Box::new(ScriptedSource::new(
1123 "err_source",
1124 Arc::new(AtomicUsize::new(0)),
1125 vec![Err(SamplerError::SourceUnavailable {
1126 source_id: "err_source".to_string(),
1127 reason: "boom".to_string(),
1128 })],
1129 )))
1130 .unwrap();
1131 manager
1132 .register_source(Box::new(PanicSource {
1133 id: "panic_source".to_string(),
1134 }))
1135 .unwrap();
1136
1137 manager.refresh_all();
1138 let stats = manager.source_refresh_stats();
1139 let err_stats = stats
1140 .iter()
1141 .find(|(source, _)| source == "err_source")
1142 .map(|(_, stats)| stats)
1143 .unwrap();
1144 assert_eq!(err_stats.error_count, 1);
1145 assert!(
1146 err_stats
1147 .last_error
1148 .as_ref()
1149 .is_some_and(|msg| msg.contains("boom"))
1150 );
1151
1152 let panic_stats = stats
1153 .iter()
1154 .find(|(source, _)| source == "panic_source")
1155 .map(|(_, stats)| stats)
1156 .unwrap();
1157 assert_eq!(panic_stats.error_count, 1);
1158 assert!(
1159 panic_stats
1160 .last_error
1161 .as_ref()
1162 .is_some_and(|msg| msg.contains("panicked"))
1163 );
1164 }
1165
1166 #[test]
1167 fn force_refresh_with_weights_path_is_exercised() {
1168 let mut manager = IngestionManager::new(3, SamplerConfig::default());
1169 manager
1170 .register_source(Box::new(ScriptedSource::new(
1171 "w",
1172 Arc::new(AtomicUsize::new(0)),
1173 vec![Ok(SourceSnapshot {
1174 records: vec![make_record("w1", "w")],
1175 cursor: SourceCursor {
1176 last_seen: Utc::now(),
1177 revision: 3,
1178 },
1179 })],
1180 )))
1181 .unwrap();
1182
1183 let mut weights = HashMap::new();
1184 weights.insert("w".to_string(), 1.0);
1185 manager.force_refresh_all_with_weights(&weights).unwrap();
1186 assert_eq!(manager.all_records_len(), 1);
1187 }
1188
1189 #[test]
1190 fn advance_with_weights_rejects_unknown_source() {
1191 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1192 manager
1193 .register_source(Box::new(ScriptedSource::new(
1194 "known",
1195 Arc::new(AtomicUsize::new(0)),
1196 vec![],
1197 )))
1198 .unwrap();
1199
1200 let mut weights = HashMap::new();
1201 weights.insert("known".to_string(), 1.0);
1202 weights.insert("unknown".to_string(), 0.5);
1203
1204 let err = manager.advance_with_weights(1, &weights).unwrap_err();
1205 assert!(
1206 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1207 "expected InvalidWeight for unknown source, got {err:?}"
1208 );
1209 }
1210
1211 #[test]
1212 fn refresh_all_with_weights_rejects_negative_weight() {
1213 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1214 manager
1215 .register_source(Box::new(ScriptedSource::new(
1216 "src",
1217 Arc::new(AtomicUsize::new(0)),
1218 vec![],
1219 )))
1220 .unwrap();
1221
1222 let mut weights = HashMap::new();
1223 weights.insert("src".to_string(), -1.0);
1224
1225 let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1226 assert!(
1227 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1228 "expected InvalidWeight for negative weight, got {err:?}"
1229 );
1230 }
1231
1232 #[test]
1233 fn force_refresh_all_with_weights_rejects_unknown_source() {
1234 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1235 manager
1236 .register_source(Box::new(ScriptedSource::new(
1237 "real",
1238 Arc::new(AtomicUsize::new(0)),
1239 vec![],
1240 )))
1241 .unwrap();
1242
1243 let mut weights = HashMap::new();
1244 weights.insert("ghost".to_string(), 1.0);
1245
1246 let err = manager
1247 .force_refresh_all_with_weights(&weights)
1248 .unwrap_err();
1249 assert!(
1250 matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1251 "expected InvalidWeight for unknown source, got {err:?}"
1252 );
1253 }
1254
1255 struct SeedCapturingSource {
1257 id: String,
1258 received_seeds: Arc<Mutex<Vec<u64>>>,
1259 }
1260
1261 impl SeedCapturingSource {
1262 fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1263 Self {
1264 id: id.to_string(),
1265 received_seeds,
1266 }
1267 }
1268 }
1269
1270 impl DataSource for SeedCapturingSource {
1271 fn id(&self) -> &str {
1272 &self.id
1273 }
1274
1275 fn refresh(
1276 &self,
1277 config: &SamplerConfig,
1278 _cursor: Option<&SourceCursor>,
1279 _limit: Option<usize>,
1280 ) -> Result<SourceSnapshot, SamplerError> {
1281 self.received_seeds
1282 .lock()
1283 .expect("seed lock poisoned")
1284 .push(config.seed);
1285 Ok(SourceSnapshot {
1286 records: Vec::new(),
1287 cursor: SourceCursor {
1288 last_seen: Utc::now(),
1289 revision: 0,
1290 },
1291 })
1292 }
1293
1294 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1295 Ok(0)
1296 }
1297
1298 fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1299 Vec::new()
1300 }
1301 }
1302
1303 #[test]
1304 fn source_epoch_xor_changes_seed_received_by_source() {
1305 let base_seed = 0xDEAD_BEEF_u64;
1308 let config = SamplerConfig {
1309 seed: base_seed,
1310 ..SamplerConfig::default()
1311 };
1312
1313 let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1314 let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1315
1316 let mut manager = IngestionManager::new(4, config.clone());
1318 manager
1319 .register_source(Box::new(SeedCapturingSource::new(
1320 "src",
1321 Arc::clone(&seeds_epoch0),
1322 )))
1323 .unwrap();
1324 manager.refresh_all();
1326
1327 let mut manager2 = IngestionManager::new(4, config.clone());
1329 manager2
1330 .register_source(Box::new(SeedCapturingSource::new(
1331 "src",
1332 Arc::clone(&seeds_epoch1),
1333 )))
1334 .unwrap();
1335 manager2.set_source_epoch(1);
1336 manager2.refresh_all();
1337
1338 let received0 = seeds_epoch0.lock().unwrap();
1339 let received1 = seeds_epoch1.lock().unwrap();
1340
1341 assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1342 assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1343
1344 let seed_at_epoch0 = received0[0];
1345 let seed_at_epoch1 = received1[0];
1346
1347 assert_ne!(
1349 seed_at_epoch0, seed_at_epoch1,
1350 "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1351 derive_epoch_seed is not reaching the source"
1352 );
1353
1354 assert_eq!(
1358 seed_at_epoch0,
1359 derive_epoch_seed(base_seed, 0) ^ 1,
1360 "epoch-0 seed mismatch (step_counter=1)"
1361 );
1362 assert_eq!(
1363 seed_at_epoch1,
1364 derive_epoch_seed(base_seed, 1) ^ 1,
1365 "epoch-1 seed mismatch (step_counter=1)"
1366 );
1367 }
1368
1369 #[test]
1370 fn step_counter_resets_on_epoch_change() {
1371 let config = SamplerConfig::default();
1374 let seeds = Arc::new(Mutex::new(Vec::new()));
1375
1376 let mut manager = IngestionManager::new(4, config.clone());
1377 manager
1378 .register_source(Box::new(SeedCapturingSource::new(
1379 "src",
1380 Arc::clone(&seeds),
1381 )))
1382 .unwrap();
1383
1384 manager.refresh_all();
1386 let step1_seed = seeds.lock().unwrap()[0];
1387 assert_eq!(
1388 step1_seed,
1389 derive_epoch_seed(config.seed, 0) ^ 1,
1390 "epoch 0 step 1 seed"
1391 );
1392
1393 manager.set_source_epoch(1);
1395 manager.reset_step_counter();
1396 seeds.lock().unwrap().clear();
1397
1398 manager.refresh_all();
1401 let step1_epoch1 = seeds.lock().unwrap()[0];
1402 assert_eq!(
1403 step1_epoch1,
1404 derive_epoch_seed(config.seed, 1) ^ 1,
1405 "epoch 1 step 1 seed (must be ^1, not ^2)"
1406 );
1407 }
1408
1409 #[test]
1410 fn step_counter_survives_sampler_save_and_load_state() {
1411 let store = Arc::new(DeterministicSplitStore::new(SplitRatios::default(), 1).unwrap());
1417
1418 let sampler = TripletSampler::new(
1420 SamplerConfig {
1421 seed: 42,
1422 ..SamplerConfig::default()
1423 },
1424 Arc::clone(&store),
1425 );
1426 let refreshes = Arc::new(AtomicUsize::new(0));
1427 sampler
1428 .register_source(Box::new(ScriptedSource::new(
1429 "src",
1430 refreshes,
1431 vec![Ok(SourceSnapshot {
1432 records: vec![make_record("r1", "src")],
1433 cursor: SourceCursor {
1434 last_seen: Utc::now(),
1435 revision: 1,
1436 },
1437 })],
1438 )))
1439 .unwrap();
1440
1441 let _ = sampler.next_text_batch(SplitLabel::Train);
1445
1446 sampler.save_sampler_state(None).unwrap();
1448
1449 let loaded = store.load_sampler_state().unwrap().unwrap();
1451 let step_saved = loaded
1452 .source_stream_cursors
1453 .iter()
1454 .find(|(k, _)| k == STEP_CURSOR_KEY)
1455 .map(|(_, v)| *v)
1456 .expect("STEP_CURSOR_KEY must be in persisted source_stream_cursors");
1457 assert!(
1458 step_saved > 0,
1459 "step_counter must be >0 after batch call through TripletSampler"
1460 );
1461
1462 let mut manager2 = IngestionManager::new(4, SamplerConfig::default());
1464 manager2
1465 .register_source(Box::new(ScriptedSource::new(
1466 "src",
1467 Arc::new(AtomicUsize::new(0)),
1468 vec![Ok(SourceSnapshot {
1469 records: vec![make_record("r2", "src")],
1470 cursor: SourceCursor {
1471 last_seen: Utc::now(),
1472 revision: 2,
1473 },
1474 })],
1475 )))
1476 .unwrap();
1477 manager2.load_cursors(&loaded.source_stream_cursors);
1478
1479 let step_before = manager2
1481 .snapshot_cursors()
1482 .iter()
1483 .find(|(k, _)| k == STEP_CURSOR_KEY)
1484 .map(|(_, v)| *v)
1485 .unwrap();
1486 assert_eq!(
1487 step_before, step_saved,
1488 "load_cursors must restore __step__ to saved value"
1489 );
1490
1491 manager2.refresh_all();
1492 let step_after = manager2
1493 .snapshot_cursors()
1494 .iter()
1495 .find(|(k, _)| k == STEP_CURSOR_KEY)
1496 .map(|(_, v)| *v)
1497 .unwrap();
1498 assert_eq!(
1499 step_after,
1500 step_saved + 1,
1501 "step must continue: loaded {step_saved}, refresh incremented to {step_saved}+1"
1502 );
1503 }
1504
1505 #[test]
1506 fn scripted_and_panic_sources_cover_default_trait_paths() {
1507 let refreshes = Arc::new(AtomicUsize::new(0));
1508 let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1509
1510 let snapshot = scripted
1512 .refresh(&SamplerConfig::default(), None, None)
1513 .expect("fallback snapshot");
1514 assert!(snapshot.records.is_empty());
1515 assert_eq!(snapshot.cursor.revision, 0);
1516
1517 assert_eq!(
1518 scripted
1519 .reported_record_count(&SamplerConfig::default())
1520 .expect("record count"),
1521 0
1522 );
1523 assert!(scripted.default_triplet_recipes().is_empty());
1524
1525 let panic_source = PanicSource {
1526 id: "panic_count".to_string(),
1527 };
1528 assert_eq!(
1529 panic_source
1530 .reported_record_count(&SamplerConfig::default())
1531 .expect("record count"),
1532 0
1533 );
1534 }
1535
1536 #[test]
1537 fn seed_capturing_source_trait_defaults_are_exercised() {
1538 let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1539 assert_eq!(
1540 source
1541 .reported_record_count(&SamplerConfig::default())
1542 .expect("record count"),
1543 0
1544 );
1545 assert!(source.default_triplet_recipes().is_empty());
1546 }
1547
1548 #[test]
1549 fn refresh_paths_handle_zero_capacity_and_no_sources() {
1550 let mut manager = IngestionManager::new(0, SamplerConfig::default());
1551 manager
1552 .register_source(Box::new(ScriptedSource::new(
1553 "zero_capacity",
1554 Arc::new(AtomicUsize::new(0)),
1555 vec![Ok(SourceSnapshot {
1556 records: vec![make_record("r1", "zero_capacity")],
1557 cursor: SourceCursor {
1558 last_seen: Utc::now(),
1559 revision: 1,
1560 },
1561 })],
1562 )))
1563 .unwrap();
1564 manager.refresh_all();
1565 assert!(manager.all_caches_empty());
1566
1567 let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1569 let empty_weights = HashMap::new();
1570 empty_manager
1571 .refresh_all_with_weights(&empty_weights)
1572 .expect("no sources should not error");
1573 assert!(empty_manager.all_caches_empty());
1574 }
1575
1576 #[test]
1577 fn drain_start_rotates_fairly_across_sources() {
1578 struct FairSource {
1582 id: String,
1583 refresh_count: Arc<AtomicUsize>,
1584 }
1585
1586 impl DataSource for FairSource {
1587 fn id(&self) -> &str {
1588 &self.id
1589 }
1590 fn refresh(
1591 &self,
1592 _config: &SamplerConfig,
1593 _cursor: Option<&SourceCursor>,
1594 _limit: Option<usize>,
1595 ) -> Result<SourceSnapshot, SamplerError> {
1596 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1597 Ok(SourceSnapshot {
1598 records: (0..10)
1599 .map(|i| make_record(&format!("r{i}"), &self.id))
1600 .collect(),
1601 cursor: SourceCursor {
1602 last_seen: Utc::now(),
1603 revision: 1,
1604 },
1605 })
1606 }
1607 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1608 Ok(10)
1609 }
1610 }
1611
1612 let counts = (
1613 Arc::new(AtomicUsize::new(0)),
1614 Arc::new(AtomicUsize::new(0)),
1615 Arc::new(AtomicUsize::new(0)),
1616 );
1617
1618 let mut manager = IngestionManager::new(30, SamplerConfig::default());
1619 manager
1620 .register_source(Box::new(FairSource {
1621 id: "src_0".to_string(),
1622 refresh_count: Arc::clone(&counts.0),
1623 }))
1624 .unwrap();
1625 manager
1626 .register_source(Box::new(FairSource {
1627 id: "src_1".to_string(),
1628 refresh_count: Arc::clone(&counts.1),
1629 }))
1630 .unwrap();
1631 manager
1632 .register_source(Box::new(FairSource {
1633 id: "src_2".to_string(),
1634 refresh_count: Arc::clone(&counts.2),
1635 }))
1636 .unwrap();
1637
1638 manager.refresh_all();
1640 assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1642 assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1643 assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1644
1645 for _ in 0..33 {
1651 manager.advance(1);
1652 }
1653
1654 let r0 = counts.0.load(Ordering::SeqCst);
1655 let r1 = counts.1.load(Ordering::SeqCst);
1656 let r2 = counts.2.load(Ordering::SeqCst);
1657
1658 let min = r0.min(r1).min(r2);
1665 let max = r0.max(r1).max(r2);
1666 assert!(
1667 max <= min + 1,
1668 "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1669 );
1670 }
1671
1672 #[test]
1673 fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1674 struct SimpleSource {
1677 id: String,
1678 refresh_count: Arc<AtomicUsize>,
1679 }
1680
1681 impl DataSource for SimpleSource {
1682 fn id(&self) -> &str {
1683 &self.id
1684 }
1685 fn refresh(
1686 &self,
1687 _: &SamplerConfig,
1688 _: Option<&SourceCursor>,
1689 _: Option<usize>,
1690 ) -> Result<SourceSnapshot, SamplerError> {
1691 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1692 let now = Utc::now();
1693 let records: Vec<DataRecord> = (0..8)
1694 .map(|i| DataRecord {
1695 id: format!("{}_r{i}", self.id),
1696 source: self.id.clone(),
1697 created_at: now,
1698 updated_at: now,
1699 quality: QualityScore { trust: 1.0 },
1700 taxonomy: Vec::new(),
1701 sections: Vec::new(),
1702 meta_prefix: None,
1703 })
1704 .collect();
1705 Ok(SourceSnapshot {
1706 records,
1707 cursor: SourceCursor {
1708 last_seen: now,
1709 revision: 1,
1710 },
1711 })
1712 }
1713 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1714 Ok(8)
1715 }
1716 }
1717
1718 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1719 let mut manager = IngestionManager::new(40, SamplerConfig::default());
1720 for (i, count) in counts.iter().enumerate() {
1721 manager
1722 .register_source(Box::new(SimpleSource {
1723 id: format!("src_{i}"),
1724 refresh_count: Arc::clone(count),
1725 }))
1726 .unwrap();
1727 }
1728
1729 manager.refresh_all();
1730 for c in &counts {
1731 assert_eq!(c.load(Ordering::SeqCst), 1);
1732 }
1733
1734 for _ in 0..80 {
1735 manager.advance(2);
1736 }
1737
1738 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1739 assert_eq!(
1743 totals,
1744 vec![5, 5, 6, 6, 6],
1745 "direct manager: unexpected refresh distribution"
1746 );
1747 }
1748
1749 fn make_five_source_sampler(
1752 counts: &[Arc<AtomicUsize>],
1753 ) -> TripletSampler<DeterministicSplitStore> {
1754 struct Tracked {
1755 id: String,
1756 refresh_count: Arc<AtomicUsize>,
1757 }
1758 impl DataSource for Tracked {
1759 fn id(&self) -> &str {
1760 &self.id
1761 }
1762 fn refresh(
1763 &self,
1764 _: &SamplerConfig,
1765 _: Option<&SourceCursor>,
1766 _: Option<usize>,
1767 ) -> Result<SourceSnapshot, SamplerError> {
1768 self.refresh_count.fetch_add(1, Ordering::SeqCst);
1769 let now = Utc::now();
1770 let records: Vec<DataRecord> = (0..8)
1771 .map(|i| DataRecord {
1772 id: format!("{}_r{i}", self.id),
1773 source: self.id.clone(),
1774 created_at: now,
1775 updated_at: now,
1776 quality: QualityScore { trust: 1.0 },
1777 taxonomy: Vec::new(),
1778 sections: vec![RecordSection {
1779 role: SectionRole::Anchor,
1780 heading: None,
1781 text: format!("x{i}"),
1782 sentences: vec![format!("x{i}")],
1783 }],
1784 meta_prefix: None,
1785 })
1786 .collect();
1787 Ok(SourceSnapshot {
1788 records,
1789 cursor: SourceCursor {
1790 last_seen: now,
1791 revision: 1,
1792 },
1793 })
1794 }
1795 fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1796 Ok(8)
1797 }
1798 }
1799
1800 let config = SamplerConfig {
1801 batch_size: 2,
1802 text_recipes: vec![TextRecipe {
1803 name: "anchor".into(),
1804 selector: Selector::Role(SectionRole::Anchor),
1805 weight: 1.0,
1806 instruction: None,
1807 }],
1808 split: SplitRatios {
1809 train: 1.0,
1810 validation: 0.0,
1811 test: 0.0,
1812 },
1813 allowed_splits: vec![SplitLabel::Train],
1814 ingestion_max_records: 4,
1820 ..SamplerConfig::default()
1821 };
1822 let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1823 let sampler = TripletSampler::new(config, store);
1824
1825 for (i, count) in counts.iter().enumerate() {
1826 sampler
1827 .register_source(Box::new(Tracked {
1828 id: format!("src_{i}"),
1829 refresh_count: Arc::clone(count),
1830 }))
1831 .unwrap();
1832 }
1833 sampler
1834 }
1835
1836 #[test]
1837 fn sampler_unweighted_drain_distributes_evenly() {
1838 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1839 let sampler = make_five_source_sampler(&counts);
1840
1841 sampler.next_text_batch(SplitLabel::Train).unwrap();
1842 for _ in 0..80 {
1843 sampler.next_text_batch(SplitLabel::Train).unwrap();
1844 }
1845
1846 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1847 let min = *totals.iter().min().unwrap();
1848 let max = *totals.iter().max().unwrap();
1849 assert!(
1850 max <= min + 1,
1851 "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1852 );
1853 assert!(
1854 min >= 4,
1855 "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1856 );
1857 }
1858
1859 #[test]
1860 fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1861 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1865 let sampler = make_five_source_sampler(&counts);
1866
1867 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1868
1869 sampler
1870 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1871 .unwrap();
1872 for _ in 0..80 {
1873 sampler
1874 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1875 .unwrap();
1876 }
1877
1878 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1879 let min = *totals.iter().min().unwrap();
1880 let max = *totals.iter().max().unwrap();
1881 assert!(
1882 max <= min + 1,
1883 "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1884 );
1885 assert!(
1886 min >= 4,
1887 "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1888 );
1889 }
1890
1891 #[test]
1892 fn sampler_unweighted_and_weighted_match_distribution() {
1893 let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1895 let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1896 let usampler = make_five_source_sampler(&uc);
1897 let wsampler = make_five_source_sampler(&wc);
1898
1899 let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1900
1901 usampler.next_text_batch(SplitLabel::Train).unwrap();
1902 wsampler
1903 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1904 .unwrap();
1905 for _ in 0..80 {
1906 usampler.next_text_batch(SplitLabel::Train).unwrap();
1907 wsampler
1908 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1909 .unwrap();
1910 }
1911
1912 let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1913 let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1914 let umax = *ut.iter().max().unwrap();
1915 let umin = *ut.iter().min().unwrap();
1916 let wmax = *wt.iter().max().unwrap();
1917 let wmin = *wt.iter().min().unwrap();
1918 assert!(umax <= umin + 1, "unweighted: {ut:?}");
1919 assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1920 }
1921
1922 #[test]
1923 fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1924 let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1928 let sampler = make_five_source_sampler(&counts);
1929
1930 let mut weights = HashMap::new();
1931 for i in 0..5 {
1932 weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1933 }
1934
1935 sampler
1936 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1937 .unwrap();
1938 for _ in 0..200 {
1939 sampler
1940 .next_text_batch_with_weights(SplitLabel::Train, &weights)
1941 .unwrap();
1942 }
1943
1944 let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1945 assert!(
1953 totals
1954 .iter()
1955 .enumerate()
1956 .all(|(i, &t)| i == 3 || t < totals[3]),
1957 "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1958 );
1959 assert_eq!(
1961 totals,
1962 vec![6, 6, 6, 26, 11],
1963 "unequal-weights: unexpected refresh distribution"
1964 );
1965 }
1966
1967 #[test]
1968 fn register_source_rejects_reserved_id_pattern() {
1969 use crate::constants::splits::is_reserved_source_id;
1970
1971 let mut manager = IngestionManager::new(4, SamplerConfig::default());
1972
1973 assert!(is_reserved_source_id("__step__"));
1975 assert!(is_reserved_source_id("__meta__"));
1976 assert!(is_reserved_source_id("__anything__"));
1977 assert!(is_reserved_source_id("__x__"));
1978 assert!(!is_reserved_source_id(""));
1979 assert!(!is_reserved_source_id("__"));
1980 assert!(!is_reserved_source_id("___"));
1981 assert!(!is_reserved_source_id("normal_source"));
1982 assert!(!is_reserved_source_id("_prefix_suffix_"));
1983 assert!(!is_reserved_source_id("__unclosed"));
1984 assert!(!is_reserved_source_id("unopened__"));
1985
1986 let result = manager.register_source(Box::new(ScriptedSource::new(
1988 "__reserved__",
1989 Arc::new(AtomicUsize::new(0)),
1990 vec![],
1991 )));
1992 assert!(
1993 result.is_err(),
1994 "register_source should reject reserved source id"
1995 );
1996 let err = result.unwrap_err();
1997 assert!(
1998 matches!(&err, SamplerError::ReservedSourceId(id) if id == "__reserved__"),
1999 "expected ReservedSourceId error, got: {err}"
2000 );
2001
2002 assert!(!manager.has_sources());
2004
2005 manager
2007 .register_source(Box::new(ScriptedSource::new(
2008 "valid_source",
2009 Arc::new(AtomicUsize::new(0)),
2010 vec![],
2011 )))
2012 .unwrap();
2013 assert!(manager.has_sources());
2014 }
2015}