Skip to main content

triplets_core/
ingestion.rs

1use crate::config::SamplerConfig;
2use crate::constants::splits::is_reserved_source_id;
3use crate::data::DataRecord;
4use crate::errors::SamplerError;
5use crate::hash::derive_epoch_seed;
6use crate::source::{DataSource, SourceCursor, SourceSnapshot};
7use crate::types::{RecordId, SourceId};
8use chrono::Utc;
9use indexmap::IndexMap;
10use std::collections::HashMap;
11use std::collections::VecDeque;
12use std::sync::{Arc, Condvar, Mutex, RwLock};
13use std::thread;
14use std::time::Duration;
15use tracing::debug;
16
17/// Thread-safe in-memory cache of ingested records keyed by record id.
18#[derive(Clone)]
19pub struct RecordCache {
20    inner: Arc<RwLock<RecordCacheInner>>,
21    notifier: Arc<(Mutex<CacheStats>, Condvar)>,
22}
23
24/// Internal mutable cache storage behind `RecordCache` locks.
25struct RecordCacheInner {
26    records: IndexMap<RecordId, CachedRecord>,
27    order: VecDeque<RecordId>,
28    max_records: usize,
29    next_version: u64,
30}
31
32/// Internal cache entry plus monotonic version marker.
33struct CachedRecord {
34    record: DataRecord,
35    version: u64,
36}
37
38/// Internal ingest notification counters.
39#[derive(Default)]
40struct CacheStats {
41    ingests: u64,
42}
43
44impl RecordCache {
45    /// Create a cache capped to at most `max_records` live records.
46    pub fn new(max_records: usize) -> Self {
47        Self {
48            inner: Arc::new(RwLock::new(RecordCacheInner {
49                records: IndexMap::new(),
50                order: VecDeque::new(),
51                max_records,
52                next_version: 0,
53            })),
54            notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
55        }
56    }
57
58    /// Ingest a batch of records, replacing existing entries by id.
59    pub fn ingest<I>(&self, records: I)
60    where
61        I: IntoIterator<Item = DataRecord>,
62    {
63        let mut batch: Vec<DataRecord> = records.into_iter().collect();
64        if batch.is_empty() {
65            return;
66        }
67        let mut inner = self.inner.write().expect("record cache poisoned");
68        inner.ingest_batch(&mut batch);
69        drop(inner);
70        let (lock, cvar) = &*self.notifier;
71        let mut stats = lock.lock().expect("record cache stats poisoned");
72        stats.ingests = stats.ingests.saturating_add(1);
73        cvar.notify_all();
74    }
75
76    /// Remove all cached records.
77    pub fn clear(&self) {
78        let mut inner = self.inner.write().expect("record cache poisoned");
79        inner.records.clear();
80        inner.order.clear();
81    }
82
83    /// Return a cloned snapshot of current cached records.
84    pub fn snapshot(&self) -> Vec<DataRecord> {
85        let inner = self.inner.read().expect("record cache poisoned");
86        inner
87            .records
88            .values()
89            .map(|entry| entry.record.clone())
90            .collect()
91    }
92
93    /// Return the number of completed ingest operations.
94    pub fn ingest_count(&self) -> u64 {
95        let (lock, _) = &*self.notifier;
96        lock.lock().expect("record cache stats poisoned").ingests
97    }
98
99    /// Wait until ingest count exceeds `last_seen`, or until timeout elapses.
100    pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
101        let (lock, cvar) = &*self.notifier;
102        let mut stats = lock.lock().expect("record cache stats poisoned");
103        while stats.ingests <= last_seen {
104            let result = cvar
105                .wait_timeout(stats, timeout)
106                .expect("record cache stats poisoned");
107            stats = result.0;
108            if result.1.timed_out() {
109                break;
110            }
111        }
112        stats.ingests
113    }
114
115    /// Wait indefinitely until ingest count exceeds `last_seen`.
116    pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
117        let (lock, cvar) = &*self.notifier;
118        let mut stats = lock.lock().expect("record cache stats poisoned");
119        while stats.ingests <= last_seen {
120            stats = cvar.wait(stats).expect("record cache stats poisoned");
121        }
122        stats.ingests
123    }
124
125    /// Returns `true` when the cache has no records.
126    pub fn is_empty(&self) -> bool {
127        let inner = self.inner.read().expect("record cache poisoned");
128        inner.records.is_empty()
129    }
130
131    /// Return the number of records currently cached.
132    pub fn len(&self) -> usize {
133        let inner = self.inner.read().expect("record cache poisoned");
134        inner.records.len()
135    }
136}
137
138impl RecordCacheInner {
139    fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
140        for record in records.drain(..) {
141            self.next_version = self.next_version.saturating_add(1);
142            let record_id = record.id.clone();
143            if self.records.contains_key(&record_id) {
144                if let Some(entry) = self.records.get_mut(&record_id) {
145                    entry.record = record;
146                    entry.version = self.next_version;
147                }
148                Self::refresh_order(&mut self.order, &record_id);
149                self.order.push_back(record_id);
150            } else {
151                self.order.push_back(record_id.clone());
152                self.records.insert(
153                    record_id,
154                    CachedRecord {
155                        record,
156                        version: self.next_version,
157                    },
158                );
159            }
160            self.enforce_limit();
161        }
162    }
163
164    fn enforce_limit(&mut self) {
165        if self.max_records == 0 {
166            self.records.clear();
167            self.order.clear();
168            return;
169        }
170        while self.records.len() > self.max_records {
171            if let Some(oldest) = self.order.pop_front() {
172                self.records.swap_remove(&oldest);
173            } else {
174                break;
175            }
176        }
177    }
178
179    fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
180        if order.is_empty() {
181            return;
182        }
183        if let Some(pos) = order.iter().position(|existing| existing == id) {
184            order.remove(pos);
185        }
186    }
187}
188
189/// Coordinates on-demand source refresh and per-source-cache population.
190pub struct IngestionManager {
191    sources: Vec<SourceState>,
192    max_records: usize,
193    sampler_config: SamplerConfig,
194    /// Current epoch used to vary per-source permutation seeds across epochs.
195    epoch: u64,
196    /// Monotonic generation incremented whenever at least one source is refreshed.
197    source_refresh_generation: u64,
198    /// Source ids refreshed during the most recent `refresh_all_internal` call.
199    ///
200    /// This is updated even when cache ingest does not change, and is cleared when
201    /// no source refresh occurs in that cycle.
202    last_refreshed_sources: Vec<SourceId>,
203    /// Rotating start index for the round-robin buffer drain.  Instead of always
204    /// draining from source 0 (which starves high-index sources of refresh
205    /// opportunities), each cycle begins at this position and advances by one.
206    /// Over N cycles every source is drained first exactly once.
207    drain_start: usize,
208    /// Monotonic step counter incremented on every advance/refresh_all and set
209    /// on SourceCursor.step so all sources can observe per-call progress.
210    epoch_step: u64,
211}
212
213#[derive(Clone, Debug, Default)]
214/// Last-refresh telemetry captured per source.
215pub struct SourceRefreshStats {
216    /// Duration of the most recent refresh in milliseconds.
217    pub last_refresh_ms: u128,
218    /// Number of records returned by the most recent refresh.
219    pub last_record_count: usize,
220    /// Throughput estimate from the most recent refresh.
221    pub last_records_per_sec: f64,
222    /// Last refresh error message, if any.
223    pub last_error: Option<String>,
224    /// Total refresh failures seen for this source.
225    pub error_count: u64,
226}
227
228impl IngestionManager {
229    /// Create a new ingestion manager that ingests on demand.
230    pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
231        Self {
232            sources: Vec::new(),
233            max_records,
234            sampler_config,
235            epoch: 0,
236            source_refresh_generation: 0,
237            last_refreshed_sources: Vec::new(),
238            drain_start: 0,
239            epoch_step: 0,
240        }
241    }
242
243    /// Return a monotonic generation for source refresh cycles.
244    pub fn source_refresh_generation(&self) -> u64 {
245        self.source_refresh_generation
246    }
247
248    /// Return source ids refreshed by the most recent refresh cycle.
249    pub fn last_refreshed_sources(&self) -> &[SourceId] {
250        &self.last_refreshed_sources
251    }
252
253    /// Update the current epoch value so subsequent `refresh` calls pass
254    /// `seed ^ epoch` to sources, producing a different permutation.
255    /// Stream cursors are intentionally NOT reset here — the cursor is a raw
256    /// I/O offset into the source's stream and must continue advancing so
257    /// every record is eventually fetched (resetting it would repeat the
258    /// leading slice of the source on every epoch boundary).
259    pub(crate) fn set_epoch(&mut self, epoch: u64) {
260        self.epoch = epoch;
261    }
262
263    /// Reset epoch step counter to 0 (called at epoch boundaries).
264    /// The step counter is a sub-counter within an epoch: each epoch starts
265    /// with step=0 so that batch calls produce deterministic seeds.
266    pub(crate) fn reset_epoch_step(&mut self) {
267        self.epoch_step = 0;
268    }
269
270    /// Increment epoch step counter by 1. Called once per `next_*_batch` call so that
271    /// the epoch step tracks model training steps, not ingestion refresh events.
272    pub(crate) fn increment_epoch_step(&mut self) {
273        self.epoch_step = self.epoch_step.saturating_add(1);
274    }
275
276    /// Return the current epoch step counter.
277    pub fn epoch_step(&self) -> u64 {
278        self.epoch_step
279    }
280
281    /// Set the epoch step counter directly (used when restoring persisted state).
282    pub(crate) fn set_epoch_step(&mut self, step: u64) {
283        self.epoch_step = step;
284    }
285
286    /// Return the current epoch.
287    #[cfg(test)]
288    pub fn epoch(&self) -> u64 {
289        self.epoch
290    }
291
292    /// Reset all raw source stream cursors and drain per-source buffers.
293    ///
294    /// Use this only when starting a deterministic replay from a specific
295    /// epoch (e.g. explicit `set_epoch` calls). A clean-start reset ensures
296    /// the new permutation begins at position 0 of the permuted index space.
297    pub(crate) fn reset_stream_cursors(&mut self) {
298        for state in &mut self.sources {
299            state.cursor = None;
300            state.buffer.clear();
301            state.cache.clear();
302        }
303    }
304
305    /// Register a source for on-demand ingestion.
306    ///
307    /// Returns an error if the source's `id()` matches the reserved `__*__`
308    /// pattern used for internal synthetic/metadata source identifiers.
309    pub fn register_source(
310        &mut self,
311        source: Box<dyn DataSource + 'static>,
312    ) -> Result<(), SamplerError> {
313        let source_id = source.id().to_string();
314        if is_reserved_source_id(&source_id) {
315            return Err(SamplerError::ReservedSourceId(source_id));
316        }
317        let cache = RecordCache::new(self.max_records);
318        self.sources.push(SourceState {
319            source,
320            cursor: None,
321            buffer: VecDeque::new(),
322            cache,
323            stats: SourceRefreshStats::default(),
324        });
325        Ok(())
326    }
327
328    /// Load persisted per-source stream cursors.
329    pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
330        if cursors.is_empty() {
331            return;
332        }
333        let mut map = std::collections::HashMap::with_capacity(cursors.len());
334        for (id, revision) in cursors {
335            map.insert(id.as_str(), *revision);
336        }
337        for state in &mut self.sources {
338            if let Some(revision) = map.get(state.source.id()) {
339                state.cursor = Some(SourceCursor {
340                    last_seen: Utc::now(),
341                    revision: *revision,
342                });
343            }
344        }
345    }
346
347    /// Snapshot current per-source stream cursors.
348    pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
349        let mut out = Vec::new();
350        for state in &self.sources {
351            if let Some(cursor) = state.cursor.as_ref() {
352                out.push((state.source.id().to_string(), cursor.revision));
353            }
354        }
355        out
356    }
357
358    /// Return latest refresh telemetry for each registered source.
359    pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
360        self.sources
361            .iter()
362            .map(|state| (state.source.id().to_string(), state.stats.clone()))
363            .collect()
364    }
365
366    /// Return a flat snapshot of every record currently in all per-source caches.
367    ///
368    /// Records are cloned in source order; the `source` field is guaranteed
369    /// to be set (it is normalised in `refresh_all_internal`).
370    pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
371        self.sources
372            .iter()
373            .flat_map(|s| s.cache.snapshot())
374            .collect()
375    }
376
377    /// Returns `true` when ALL per-source caches are empty.
378    pub fn all_caches_empty(&self) -> bool {
379        self.sources.iter().all(|s| s.cache.is_empty())
380    }
381
382    /// Returns the total number of records across all per-source caches.
383    pub fn all_records_len(&self) -> usize {
384        self.sources.iter().map(|s| s.cache.len()).sum()
385    }
386
387    /// Returns the sum of ingest counts across all per-source caches.
388    ///
389    /// Used as a monotonic proxy to detect whether any cache has been updated
390    /// since the last sync.
391    pub fn total_ingest_count(&self) -> u64 {
392        self.sources.iter().map(|s| s.cache.ingest_count()).sum()
393    }
394
395    /// Refresh all registered sources once.
396    pub fn refresh_all(&mut self) {
397        self.refresh_all_internal(false, None, None);
398    }
399
400    /// Advance the ingestion window by ingesting `step` new records.
401    pub fn advance(&mut self, step: usize) {
402        self.refresh_all_internal(false, Some(step), None);
403    }
404
405    /// Advance the ingestion window by ingesting `step` new records with weights.
406    ///
407    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
408    /// source ID or a negative value.
409    pub fn advance_with_weights(
410        &mut self,
411        step: usize,
412        weights: &HashMap<SourceId, f32>,
413    ) -> Result<(), SamplerError> {
414        self.validate_weights(weights)?;
415        self.refresh_all_internal(false, Some(step), Some(weights));
416        Ok(())
417    }
418
419    /// Force refresh all registered sources, discarding buffered records.
420    pub fn force_refresh_all(&mut self) {
421        self.refresh_all_internal(true, None, None);
422    }
423
424    /// Refresh all registered sources once with per-call source weights.
425    ///
426    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
427    /// source ID or a negative value.
428    pub fn refresh_all_with_weights(
429        &mut self,
430        weights: &HashMap<SourceId, f32>,
431    ) -> Result<(), SamplerError> {
432        self.validate_weights(weights)?;
433        self.refresh_all_internal(false, None, Some(weights));
434        Ok(())
435    }
436
437    /// Force refresh all registered sources with per-call source weights.
438    ///
439    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
440    /// source ID or a negative value.
441    pub fn force_refresh_all_with_weights(
442        &mut self,
443        weights: &HashMap<SourceId, f32>,
444    ) -> Result<(), SamplerError> {
445        self.validate_weights(weights)?;
446        self.refresh_all_internal(true, None, Some(weights));
447        Ok(())
448    }
449
450    fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
451        let known_ids: std::collections::HashSet<&str> =
452            self.sources.iter().map(|s| s.source.id()).collect();
453        for (id, &w) in weights {
454            if !known_ids.contains(id.as_str()) {
455                return Err(SamplerError::InvalidWeight {
456                    source_id: id.clone(),
457                    reason: "source is not registered".to_string(),
458                });
459            }
460            if w < 0.0 {
461                return Err(SamplerError::InvalidWeight {
462                    source_id: id.clone(),
463                    reason: format!("weight {w} is negative"),
464                });
465            }
466        }
467        Ok(())
468    }
469
470    /// Rebuild the shared cache by round-robin draining per-source buffers.
471    ///
472    /// When `force_refresh` is false, each source only refreshes when its buffer
473    /// is empty; when true, all buffers are cleared and all sources refresh.
474    /// If `step` is provided, performs a rolling update of `step` records (no clear).
475    /// If `step` is None, clears the cache and fills up to max capacity.
476    fn refresh_all_internal(
477        &mut self,
478        force_refresh: bool,
479        step: Option<usize>,
480        weights: Option<&HashMap<SourceId, f32>>,
481    ) {
482        self.last_refreshed_sources.clear();
483        let mut refresh_plan = Vec::new();
484        for (idx, state) in self.sources.iter_mut().enumerate() {
485            if force_refresh {
486                state.buffer.clear();
487            }
488            if force_refresh || state.buffer.is_empty() {
489                refresh_plan.push((idx, state.cursor.clone()));
490            }
491        }
492
493        if !refresh_plan.is_empty() {
494            self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
495            self.last_refreshed_sources = refresh_plan
496                .iter()
497                .map(|(idx, _)| self.sources[*idx].source.id().to_string())
498                .collect();
499            let mut results: Vec<
500                Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
501            > = Vec::with_capacity(self.sources.len());
502            results.resize_with(self.sources.len(), || None);
503            let fetch_limit = self.max_records;
504            let sampler_config = self.sampler_config.clone();
505            let step = self.epoch_step;
506            thread::scope(|scope| {
507                let mut handles = Vec::with_capacity(refresh_plan.len());
508                for (idx, cursor) in &refresh_plan {
509                    let source = &self.sources[*idx].source;
510                    let cursor = cursor.clone();
511                    let sampler_config = sampler_config.clone();
512                    let epoch = self.epoch;
513                    handles.push((
514                        *idx,
515                        scope.spawn(move || {
516                            let start = std::time::Instant::now();
517                            // XOR the source epoch into the seed so each epoch
518                            // produces a distinct permutation within the source.
519                            // XOR the step counter into the seed so every
520                            // advance/refresh call gets a distinct seed, which
521                            // sources can use for e.g. shard ordering.
522                            let epoch_config = SamplerConfig {
523                                seed: derive_epoch_seed(sampler_config.seed, epoch) ^ step,
524                                ..sampler_config
525                            };
526                            let result =
527                                source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
528                            let elapsed = start.elapsed();
529                            (result, elapsed)
530                        }),
531                    ));
532                }
533                for (idx, handle) in handles {
534                    let result = match handle.join() {
535                        Ok((result, elapsed)) => {
536                            debug!(
537                                source_id = %self.sources[idx].source.id(),
538                                refresh_ms = elapsed.as_millis(),
539                                "source refresh completed"
540                            );
541                            (result, elapsed)
542                        }
543                        Err(_) => (
544                            Err(SamplerError::SourceUnavailable {
545                                source_id: self.sources[idx].source.id().to_string(),
546                                reason: "source refresh thread panicked".into(),
547                            }),
548                            std::time::Duration::from_secs(0),
549                        ),
550                    };
551                    results[idx] = Some(result);
552                }
553            });
554
555            for (idx, _) in refresh_plan {
556                let Some((result, elapsed)) = results[idx].take() else {
557                    continue;
558                };
559                match result {
560                    Ok(snapshot) => {
561                        let SourceSnapshot {
562                            records,
563                            cursor: next_cursor,
564                        } = snapshot;
565                        let record_count = records.len();
566                        let seconds = elapsed.as_secs_f64();
567                        let per_sec = if seconds > 0.0 {
568                            (record_count as f64) / seconds
569                        } else {
570                            0.0
571                        };
572                        let stats = &mut self.sources[idx].stats;
573                        stats.last_refresh_ms = elapsed.as_millis();
574                        stats.last_record_count = record_count;
575                        stats.last_records_per_sec = per_sec;
576                        stats.last_error = None;
577                        debug!(
578                            source_id = %self.sources[idx].source.id(),
579                            record_count,
580                            refresh_ms = elapsed.as_millis(),
581                            records_per_sec = per_sec,
582                            "source refresh ingested records"
583                        );
584                        let source_id = self.sources[idx].source.id().to_string();
585                        let normalized = records
586                            .into_iter()
587                            .map(|mut record| {
588                                record.source = source_id.clone();
589                                record
590                            })
591                            .collect::<Vec<_>>();
592                        self.sources[idx].buffer.extend(normalized);
593                        self.sources[idx].cursor = Some(next_cursor);
594                    }
595                    Err(err) => {
596                        let stats = &mut self.sources[idx].stats;
597                        stats.last_refresh_ms = elapsed.as_millis();
598                        stats.last_record_count = 0;
599                        stats.last_records_per_sec = 0.0;
600                        stats.last_error = Some(err.to_string());
601                        stats.error_count = stats.error_count.saturating_add(1);
602                        eprintln!(
603                            "[data_sampler] source '{}' refresh failed: {}",
604                            self.sources[idx].source.id(),
605                            err
606                        );
607                    }
608                }
609            }
610        }
611
612        // On a full refresh (step=None) clear every per-source cache so that the
613        // snapshot always reflects the newest window, matching the previous
614        // shared-cache clear semantics.
615        if step.is_none() {
616            for state in self.sources.iter_mut() {
617                state.cache.clear();
618            }
619        }
620        if self.max_records == 0 {
621            return;
622        }
623        let target_limit = step.unwrap_or(self.max_records);
624        if let Some(weights) = weights {
625            self.weighted_drain_into_caches(target_limit, weights);
626        } else {
627            // Fair round-robin drain: start from `drain_start` instead of 0 so
628            // that the drain cursor rotates across cycles.  This prevents head
629            // sources (low indices) from always draining faster than tail sources,
630            // which was starving tail sources of refresh opportunities.
631            let n = self.sources.len();
632            if n > 0 {
633                let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
634                let mut total_drained = 0;
635                let mut any_remaining = true;
636                while total_drained < target_limit && any_remaining {
637                    any_remaining = false;
638                    for offset in 0..n {
639                        if total_drained >= target_limit {
640                            break;
641                        }
642                        let idx = (self.drain_start + offset) % n;
643                        if let Some(record) = self.sources[idx].buffer.pop_front() {
644                            per_source[idx].push(record);
645                            total_drained += 1;
646                            any_remaining = true;
647                        }
648                    }
649                }
650                // Advance the drain cursor so the next cycle starts from a different
651                // position.  Only advance when at least one record was drained, so a
652                // burst of drain-noop cycles on an empty source list doesn't rotate.
653                if total_drained > 0 {
654                    self.drain_start = (self.drain_start + 1) % n;
655                }
656                for (idx, batch) in per_source.into_iter().enumerate() {
657                    if !batch.is_empty() {
658                        self.sources[idx].cache.ingest(batch);
659                    }
660                }
661            }
662        }
663    }
664
665    fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
666        let len = self.sources.len();
667        if len == 0 {
668            return;
669        }
670        let mut weight_values = Vec::with_capacity(len);
671        let mut any_positive = false;
672        for state in &self.sources {
673            let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
674            if weight > 0.0 {
675                any_positive = true;
676            }
677            weight_values.push(weight);
678        }
679        if !any_positive {
680            weight_values.fill(1.0);
681        }
682
683        let mut current = vec![0.0f32; len];
684        let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
685        let mut total = 0;
686        while total < limit {
687            let mut total_weight = 0.0f32;
688            for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
689                if weight <= 0.0 {
690                    continue;
691                }
692                if self.sources[idx].buffer.is_empty() {
693                    continue;
694                }
695                total_weight += weight;
696            }
697            if total_weight == 0.0 {
698                break;
699            }
700
701            let mut best_idx = None;
702            let mut best_score = f32::MIN;
703            // Rotating tie-breaker: when scores are equal, prefer the source
704            // just PAST drain_start in rotation order (i.e. the source whose
705            // turn is coming up in the round-robin).
706            let closer_to_start = |a: usize, b: usize| -> bool {
707                let da = (a + len - self.drain_start) % len;
708                let db = (b + len - self.drain_start) % len;
709                da < db
710            };
711            for idx in 0..len {
712                if weight_values[idx] <= 0.0 {
713                    continue;
714                }
715                if self.sources[idx].buffer.is_empty() {
716                    continue;
717                }
718                current[idx] += weight_values[idx];
719                let is_better = if current[idx] > best_score {
720                    true
721                } else if current[idx] == best_score {
722                    closer_to_start(idx, best_idx.unwrap_or(0))
723                } else {
724                    false
725                };
726                if is_better {
727                    best_score = current[idx];
728                    best_idx = Some(idx);
729                }
730            }
731
732            let idx = match best_idx {
733                Some(idx) => idx,
734                None => break,
735            };
736            current[idx] -= total_weight;
737            if let Some(record) = self.sources[idx].buffer.pop_front() {
738                per_source[idx].push(record);
739                total += 1;
740            }
741        }
742
743        if total > 0 && len > 0 {
744            self.drain_start = (self.drain_start + 1) % len;
745        }
746
747        for (idx, batch) in per_source.into_iter().enumerate() {
748            if !batch.is_empty() {
749                self.sources[idx].cache.ingest(batch);
750            }
751        }
752    }
753
754    /// Returns `true` when at least one source is registered.
755    pub fn has_sources(&self) -> bool {
756        !self.sources.is_empty()
757    }
758}
759
760/// Per-source ingestion runtime state.
761struct SourceState {
762    source: Box<dyn DataSource + 'static>,
763    cursor: Option<SourceCursor>,
764    buffer: VecDeque<DataRecord>,
765    /// Per-source LRU record cache capped at `max_records`.
766    cache: RecordCache,
767    stats: SourceRefreshStats,
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use crate::TripletSampler;
774    use crate::config::{Selector, TextRecipe, TripletRecipe};
775    use crate::data::{QualityScore, RecordSection, SectionRole};
776    use crate::sampler::Sampler;
777    use crate::splits::{DeterministicSplitStore, SamplerStateStore, SplitLabel, SplitRatios};
778    use chrono::Utc;
779    use std::collections::HashMap;
780    use std::collections::VecDeque;
781    use std::sync::atomic::{AtomicUsize, Ordering};
782    use std::sync::{Arc, Mutex};
783
784    fn make_record(id: &str, source: &str) -> DataRecord {
785        let now = Utc::now();
786        DataRecord {
787            id: id.to_string(),
788            source: source.to_string(),
789            created_at: now,
790            updated_at: now,
791            quality: QualityScore { trust: 1.0 },
792            taxonomy: Vec::new(),
793            sections: vec![RecordSection {
794                role: SectionRole::Anchor,
795                heading: None,
796                text: id.to_string(),
797                sentences: vec![id.to_string()],
798            }],
799            meta_prefix: None,
800        }
801    }
802
803    struct ScriptedSource {
804        id: String,
805        refreshes: Arc<AtomicUsize>,
806        script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
807    }
808
809    impl ScriptedSource {
810        fn new(
811            id: &str,
812            refreshes: Arc<AtomicUsize>,
813            script: Vec<Result<SourceSnapshot, SamplerError>>,
814        ) -> Self {
815            Self {
816                id: id.to_string(),
817                refreshes,
818                script: Arc::new(Mutex::new(script.into_iter().collect())),
819            }
820        }
821    }
822
823    impl DataSource for ScriptedSource {
824        fn id(&self) -> &str {
825            &self.id
826        }
827
828        fn refresh(
829            &self,
830            _config: &SamplerConfig,
831            _cursor: Option<&SourceCursor>,
832            _limit: Option<usize>,
833        ) -> Result<SourceSnapshot, SamplerError> {
834            self.refreshes.fetch_add(1, Ordering::SeqCst);
835            let mut guard = self.script.lock().expect("script lock poisoned");
836            guard.pop_front().unwrap_or_else(|| {
837                Ok(SourceSnapshot {
838                    records: Vec::new(),
839                    cursor: SourceCursor {
840                        last_seen: Utc::now(),
841                        revision: 0,
842                    },
843                })
844            })
845        }
846
847        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
848            Ok(0)
849        }
850
851        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
852            Vec::new()
853        }
854    }
855
856    struct PanicSource {
857        id: String,
858    }
859
860    impl DataSource for PanicSource {
861        fn id(&self) -> &str {
862            &self.id
863        }
864
865        fn refresh(
866            &self,
867            _config: &SamplerConfig,
868            _cursor: Option<&SourceCursor>,
869            _limit: Option<usize>,
870        ) -> Result<SourceSnapshot, SamplerError> {
871            panic!("panic source refresh")
872        }
873
874        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
875            Ok(0)
876        }
877    }
878
879    #[test]
880    fn record_cache_waits_len_and_clear_paths_are_covered() {
881        let cache = RecordCache::new(2);
882        assert!(cache.is_empty());
883        assert_eq!(cache.len(), 0);
884        assert_eq!(cache.ingest_count(), 0);
885
886        cache.ingest(Vec::<DataRecord>::new());
887        assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
888
889        let cache_for_waiter = cache.clone();
890        let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
891        std::thread::sleep(Duration::from_millis(5));
892        cache.ingest(vec![make_record("r1", "s")]);
893        assert_eq!(handle.join().unwrap(), 1);
894        assert_eq!(cache.ingest_count(), 1);
895
896        cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
897        assert_eq!(cache.len(), 2);
898        let ids: Vec<String> = cache
899            .snapshot()
900            .into_iter()
901            .map(|record| record.id)
902            .collect();
903        assert!(ids.contains(&"r2".to_string()));
904        assert!(ids.contains(&"r3".to_string()));
905
906        cache.clear();
907        assert!(cache.is_empty());
908    }
909
910    #[test]
911    fn record_cache_zero_limit_discards_everything() {
912        let cache = RecordCache::new(0);
913        cache.ingest(vec![make_record("r1", "s")]);
914        assert!(cache.is_empty());
915        assert_eq!(cache.len(), 0);
916    }
917
918    #[test]
919    fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
920        let mut manager = IngestionManager::new(4, SamplerConfig::default());
921        assert!(!manager.has_sources());
922        manager.load_cursors(&[]);
923
924        let refreshes = Arc::new(AtomicUsize::new(0));
925        manager
926            .register_source(Box::new(ScriptedSource::new(
927                "cursor_source",
928                refreshes,
929                vec![Ok(SourceSnapshot {
930                    records: vec![make_record("id_1", "original_source")],
931                    cursor: SourceCursor {
932                        last_seen: Utc::now(),
933                        revision: 33,
934                    },
935                })],
936            )))
937            .unwrap();
938        assert!(manager.has_sources());
939
940        manager.load_cursors(&[("cursor_source".to_string(), 7)]);
941        let cursors = manager.snapshot_cursors();
942        assert_eq!(cursors.len(), 1);
943        assert_eq!(cursors[0], ("cursor_source".to_string(), 7));
944
945        manager.refresh_all();
946        let updated = manager.snapshot_cursors();
947        assert_eq!(updated.len(), 1);
948        assert_eq!(updated[0], ("cursor_source".to_string(), 33));
949        let records = manager.all_records_snapshot();
950        assert_eq!(records.len(), 1);
951        assert_eq!(records[0].source, "cursor_source");
952    }
953
954    #[test]
955    fn advance_uses_buffer_before_refreshing_again() {
956        let refreshes = Arc::new(AtomicUsize::new(0));
957        let mut manager = IngestionManager::new(5, SamplerConfig::default());
958        manager
959            .register_source(Box::new(ScriptedSource::new(
960                "buffered",
961                refreshes.clone(),
962                vec![Ok(SourceSnapshot {
963                    records: vec![
964                        make_record("a", "x"),
965                        make_record("b", "x"),
966                        make_record("c", "x"),
967                    ],
968                    cursor: SourceCursor {
969                        last_seen: Utc::now(),
970                        revision: 1,
971                    },
972                })],
973            )))
974            .unwrap();
975
976        manager.advance(1);
977        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
978        assert_eq!(manager.all_records_len(), 1);
979
980        manager.advance(1);
981        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
982        assert_eq!(manager.all_records_len(), 2);
983    }
984
985    #[test]
986    fn force_refresh_clears_buffer_and_fetches_again() {
987        let refreshes = Arc::new(AtomicUsize::new(0));
988        let mut manager = IngestionManager::new(4, SamplerConfig::default());
989        manager
990            .register_source(Box::new(ScriptedSource::new(
991                "force",
992                refreshes.clone(),
993                vec![
994                    Ok(SourceSnapshot {
995                        records: vec![
996                            make_record("r1", "x"),
997                            make_record("r2", "x"),
998                            make_record("r3", "x"),
999                        ],
1000                        cursor: SourceCursor {
1001                            last_seen: Utc::now(),
1002                            revision: 10,
1003                        },
1004                    }),
1005                    Ok(SourceSnapshot {
1006                        records: vec![make_record("r4", "x")],
1007                        cursor: SourceCursor {
1008                            last_seen: Utc::now(),
1009                            revision: 11,
1010                        },
1011                    }),
1012                ],
1013            )))
1014            .unwrap();
1015
1016        manager.advance(1);
1017        assert_eq!(manager.all_records_len(), 1);
1018        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
1019
1020        manager.force_refresh_all();
1021        assert_eq!(refreshes.load(Ordering::SeqCst), 2);
1022        let records = manager.all_records_snapshot();
1023        assert_eq!(records.len(), 1);
1024        assert_eq!(records[0].id, "r4");
1025    }
1026
1027    #[test]
1028    fn weighted_drain_respects_zero_and_fallback_weights() {
1029        let mut manager = IngestionManager::new(6, SamplerConfig::default());
1030        manager
1031            .register_source(Box::new(ScriptedSource::new(
1032                "a",
1033                Arc::new(AtomicUsize::new(0)),
1034                vec![Ok(SourceSnapshot {
1035                    records: vec![make_record("a1", "a"), make_record("a2", "a")],
1036                    cursor: SourceCursor {
1037                        last_seen: Utc::now(),
1038                        revision: 1,
1039                    },
1040                })],
1041            )))
1042            .unwrap();
1043        manager
1044            .register_source(Box::new(ScriptedSource::new(
1045                "b",
1046                Arc::new(AtomicUsize::new(0)),
1047                vec![Ok(SourceSnapshot {
1048                    records: vec![make_record("b1", "b"), make_record("b2", "b")],
1049                    cursor: SourceCursor {
1050                        last_seen: Utc::now(),
1051                        revision: 1,
1052                    },
1053                })],
1054            )))
1055            .unwrap();
1056
1057        let mut only_b = HashMap::new();
1058        only_b.insert("a".to_string(), 0.0);
1059        only_b.insert("b".to_string(), 1.0);
1060        manager.refresh_all_with_weights(&only_b).unwrap();
1061        let ids: Vec<String> = manager
1062            .all_records_snapshot()
1063            .into_iter()
1064            .map(|record| record.id)
1065            .collect();
1066        assert!(ids.iter().all(|id| id.starts_with('b')));
1067
1068        let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1069        manager_fallback
1070            .register_source(Box::new(ScriptedSource::new(
1071                "a",
1072                Arc::new(AtomicUsize::new(0)),
1073                vec![Ok(SourceSnapshot {
1074                    records: vec![make_record("a1", "a")],
1075                    cursor: SourceCursor {
1076                        last_seen: Utc::now(),
1077                        revision: 2,
1078                    },
1079                })],
1080            )))
1081            .unwrap();
1082        manager_fallback
1083            .register_source(Box::new(ScriptedSource::new(
1084                "b",
1085                Arc::new(AtomicUsize::new(0)),
1086                vec![Ok(SourceSnapshot {
1087                    records: vec![make_record("b1", "b")],
1088                    cursor: SourceCursor {
1089                        last_seen: Utc::now(),
1090                        revision: 2,
1091                    },
1092                })],
1093            )))
1094            .unwrap();
1095
1096        let mut all_zero = HashMap::new();
1097        all_zero.insert("a".to_string(), 0.0);
1098        all_zero.insert("b".to_string(), 0.0);
1099        manager_fallback
1100            .refresh_all_with_weights(&all_zero)
1101            .unwrap();
1102        let ids: Vec<String> = manager_fallback
1103            .all_records_snapshot()
1104            .into_iter()
1105            .map(|record| record.id)
1106            .collect();
1107        assert!(ids.contains(&"a1".to_string()));
1108        assert!(ids.contains(&"b1".to_string()));
1109    }
1110
1111    #[test]
1112    fn refresh_errors_and_panics_update_source_stats() {
1113        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1114        manager
1115            .register_source(Box::new(ScriptedSource::new(
1116                "err_source",
1117                Arc::new(AtomicUsize::new(0)),
1118                vec![Err(SamplerError::SourceUnavailable {
1119                    source_id: "err_source".to_string(),
1120                    reason: "boom".to_string(),
1121                })],
1122            )))
1123            .unwrap();
1124        manager
1125            .register_source(Box::new(PanicSource {
1126                id: "panic_source".to_string(),
1127            }))
1128            .unwrap();
1129
1130        manager.refresh_all();
1131        let stats = manager.source_refresh_stats();
1132        let err_stats = stats
1133            .iter()
1134            .find(|(source, _)| source == "err_source")
1135            .map(|(_, stats)| stats)
1136            .unwrap();
1137        assert_eq!(err_stats.error_count, 1);
1138        assert!(
1139            err_stats
1140                .last_error
1141                .as_ref()
1142                .is_some_and(|msg| msg.contains("boom"))
1143        );
1144
1145        let panic_stats = stats
1146            .iter()
1147            .find(|(source, _)| source == "panic_source")
1148            .map(|(_, stats)| stats)
1149            .unwrap();
1150        assert_eq!(panic_stats.error_count, 1);
1151        assert!(
1152            panic_stats
1153                .last_error
1154                .as_ref()
1155                .is_some_and(|msg| msg.contains("panicked"))
1156        );
1157    }
1158
1159    #[test]
1160    fn force_refresh_with_weights_path_is_exercised() {
1161        let mut manager = IngestionManager::new(3, SamplerConfig::default());
1162        manager
1163            .register_source(Box::new(ScriptedSource::new(
1164                "w",
1165                Arc::new(AtomicUsize::new(0)),
1166                vec![Ok(SourceSnapshot {
1167                    records: vec![make_record("w1", "w")],
1168                    cursor: SourceCursor {
1169                        last_seen: Utc::now(),
1170                        revision: 3,
1171                    },
1172                })],
1173            )))
1174            .unwrap();
1175
1176        let mut weights = HashMap::new();
1177        weights.insert("w".to_string(), 1.0);
1178        manager.force_refresh_all_with_weights(&weights).unwrap();
1179        assert_eq!(manager.all_records_len(), 1);
1180    }
1181
1182    #[test]
1183    fn advance_with_weights_rejects_unknown_source() {
1184        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1185        manager
1186            .register_source(Box::new(ScriptedSource::new(
1187                "known",
1188                Arc::new(AtomicUsize::new(0)),
1189                vec![],
1190            )))
1191            .unwrap();
1192
1193        let mut weights = HashMap::new();
1194        weights.insert("known".to_string(), 1.0);
1195        weights.insert("unknown".to_string(), 0.5);
1196
1197        let err = manager.advance_with_weights(1, &weights).unwrap_err();
1198        assert!(
1199            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1200            "expected InvalidWeight for unknown source, got {err:?}"
1201        );
1202    }
1203
1204    #[test]
1205    fn refresh_all_with_weights_rejects_negative_weight() {
1206        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1207        manager
1208            .register_source(Box::new(ScriptedSource::new(
1209                "src",
1210                Arc::new(AtomicUsize::new(0)),
1211                vec![],
1212            )))
1213            .unwrap();
1214
1215        let mut weights = HashMap::new();
1216        weights.insert("src".to_string(), -1.0);
1217
1218        let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1219        assert!(
1220            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1221            "expected InvalidWeight for negative weight, got {err:?}"
1222        );
1223    }
1224
1225    #[test]
1226    fn force_refresh_all_with_weights_rejects_unknown_source() {
1227        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1228        manager
1229            .register_source(Box::new(ScriptedSource::new(
1230                "real",
1231                Arc::new(AtomicUsize::new(0)),
1232                vec![],
1233            )))
1234            .unwrap();
1235
1236        let mut weights = HashMap::new();
1237        weights.insert("ghost".to_string(), 1.0);
1238
1239        let err = manager
1240            .force_refresh_all_with_weights(&weights)
1241            .unwrap_err();
1242        assert!(
1243            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1244            "expected InvalidWeight for unknown source, got {err:?}"
1245        );
1246    }
1247
1248    /// A source that records the `config.seed` value it receives on each `refresh()` call.
1249    struct SeedCapturingSource {
1250        id: String,
1251        received_seeds: Arc<Mutex<Vec<u64>>>,
1252    }
1253
1254    impl SeedCapturingSource {
1255        fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1256            Self {
1257                id: id.to_string(),
1258                received_seeds,
1259            }
1260        }
1261    }
1262
1263    impl DataSource for SeedCapturingSource {
1264        fn id(&self) -> &str {
1265            &self.id
1266        }
1267
1268        fn refresh(
1269            &self,
1270            config: &SamplerConfig,
1271            _cursor: Option<&SourceCursor>,
1272            _limit: Option<usize>,
1273        ) -> Result<SourceSnapshot, SamplerError> {
1274            self.received_seeds
1275                .lock()
1276                .expect("seed lock poisoned")
1277                .push(config.seed);
1278            Ok(SourceSnapshot {
1279                records: Vec::new(),
1280                cursor: SourceCursor {
1281                    last_seen: Utc::now(),
1282                    revision: 0,
1283                },
1284            })
1285        }
1286
1287        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1288            Ok(0)
1289        }
1290
1291        fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1292            Vec::new()
1293        }
1294    }
1295
1296    #[test]
1297    fn epoch_xor_changes_seed_received_by_source() {
1298        // Verify that derive_epoch_seed(base, epoch) is actually threaded through to
1299        // the source's refresh() call, and that different epochs produce different seeds.
1300        let base_seed = 0xDEAD_BEEF_u64;
1301        let config = SamplerConfig {
1302            seed: base_seed,
1303            ..SamplerConfig::default()
1304        };
1305
1306        let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1307        let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1308
1309        // --- epoch 0 ---
1310        let mut manager = IngestionManager::new(4, config.clone());
1311        manager
1312            .register_source(Box::new(SeedCapturingSource::new(
1313                "src",
1314                Arc::clone(&seeds_epoch0),
1315            )))
1316            .unwrap();
1317        // epoch defaults to 0; refresh_all passes derive_epoch_seed(base, 0)
1318        manager.refresh_all();
1319
1320        // --- epoch 1 ---
1321        let mut manager2 = IngestionManager::new(4, config.clone());
1322        manager2
1323            .register_source(Box::new(SeedCapturingSource::new(
1324                "src",
1325                Arc::clone(&seeds_epoch1),
1326            )))
1327            .unwrap();
1328        manager2.set_epoch(1);
1329        manager2.refresh_all();
1330
1331        let received0 = seeds_epoch0.lock().unwrap();
1332        let received1 = seeds_epoch1.lock().unwrap();
1333
1334        assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1335        assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1336
1337        let seed_at_epoch0 = received0[0];
1338        let seed_at_epoch1 = received1[0];
1339
1340        // The seeds must differ — epoch XOR has a real effect.
1341        assert_ne!(
1342            seed_at_epoch0, seed_at_epoch1,
1343            "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1344             derive_epoch_seed is not reaching the source"
1345        );
1346
1347        // They must match the expected derive_epoch_seed values.
1348        // epoch_step stays at 0 for direct refresh_all calls
1349        // (it's only bumped by next_*_batch calls).
1350        assert_eq!(
1351            seed_at_epoch0,
1352            derive_epoch_seed(base_seed, 0),
1353            "epoch-0 seed mismatch (epoch_step=0)"
1354        );
1355        assert_eq!(
1356            seed_at_epoch1,
1357            derive_epoch_seed(base_seed, 1),
1358            "epoch-1 seed mismatch (epoch_step=0)"
1359        );
1360    }
1361
1362    #[test]
1363    fn epoch_step_resets_on_epoch_change() {
1364        // Proves that set_epoch + reset_epoch_step gives epoch_step=0 so each
1365        // epoch produces the same step sequence starting from step 1.
1366        let config = SamplerConfig::default();
1367        let seeds = Arc::new(Mutex::new(Vec::new()));
1368
1369        let mut manager = IngestionManager::new(4, config.clone());
1370        manager
1371            .register_source(Box::new(SeedCapturingSource::new(
1372                "src",
1373                Arc::clone(&seeds),
1374            )))
1375            .unwrap();
1376
1377        // Epoch 0, first refresh: epoch_step stays at 0 (no longer
1378        // incremented by refresh_all — it's bumped per batch call instead).
1379        manager.refresh_all();
1380        let step1_seed = seeds.lock().unwrap()[0];
1381        assert_eq!(
1382            step1_seed,
1383            derive_epoch_seed(config.seed, 0),
1384            "epoch 0 step 0 seed"
1385        );
1386
1387        // Advance epoch — epoch_step must stay 0 (no batch calls yet).
1388        manager.set_epoch(1);
1389        assert_eq!(manager.epoch_step(), 0);
1390        seeds.lock().unwrap().clear();
1391
1392        // Epoch 1, first refresh: step stays at 0 (no batch calls in this test).
1393        manager.refresh_all();
1394        let step1_epoch1 = seeds.lock().unwrap()[0];
1395        assert_eq!(
1396            step1_epoch1,
1397            derive_epoch_seed(config.seed, 1),
1398            "epoch 1 step 0 seed (must be ^0 since refresh_all no longer bumps step)"
1399        );
1400    }
1401
1402    #[test]
1403    fn epoch_step_survives_sampler_save_and_load_state() {
1404        // Proves the epoch step survives through the REAL API path:
1405        //   TripletSampler::save_sampler_state → persist_source_state
1406        //   → self.ingestion.snapshot_cursors() internally
1407        //   → DeterministicSplitStore.save_sampler_state
1408        //   → load_sampler_state → load_cursors
1409        let store = Arc::new(DeterministicSplitStore::new(SplitRatios::default(), 1).unwrap());
1410
1411        // Sampler with ScriptedSource (provides empty recipes but that's fine)
1412        let sampler = TripletSampler::new(
1413            SamplerConfig {
1414                seed: 42,
1415                ..SamplerConfig::default()
1416            },
1417            Arc::clone(&store),
1418        );
1419        let refreshes = Arc::new(AtomicUsize::new(0));
1420        sampler
1421            .register_source(Box::new(ScriptedSource::new(
1422                "src",
1423                refreshes,
1424                vec![Ok(SourceSnapshot {
1425                    records: vec![make_record("r1", "src")],
1426                    cursor: SourceCursor {
1427                        last_seen: Utc::now(),
1428                        revision: 1,
1429                    },
1430                })],
1431            )))
1432            .unwrap();
1433
1434        // Call a batch method to trigger ensure_source_state() so
1435        // save_sampler_state actually writes.  The batch itself may
1436        // exhaust (no recipes) but the state is loaded regardless.
1437        let _ = sampler.next_text_batch(SplitLabel::Train);
1438
1439        // Save through the REAL sampler API
1440        sampler.save_sampler_state(None).unwrap();
1441
1442        // Load back what the sampler saved
1443        let loaded = store.load_sampler_state().unwrap().unwrap();
1444        let step_saved = loaded.epoch_step;
1445        assert!(
1446            step_saved > 0,
1447            "epoch_step must be >0 after batch call through TripletSampler"
1448        );
1449
1450        // Feed the REAL loaded cursors to a new manager
1451        let mut manager2 = IngestionManager::new(4, SamplerConfig::default());
1452        manager2
1453            .register_source(Box::new(ScriptedSource::new(
1454                "src",
1455                Arc::new(AtomicUsize::new(0)),
1456                vec![Ok(SourceSnapshot {
1457                    records: vec![make_record("r2", "src")],
1458                    cursor: SourceCursor {
1459                        last_seen: Utc::now(),
1460                        revision: 2,
1461                    },
1462                })],
1463            )))
1464            .unwrap();
1465        manager2.load_cursors(&loaded.source_stream_cursors);
1466        manager2.set_epoch_step(loaded.epoch_step);
1467
1468        // epoch_step restored to saved value; verify it's stable (no batch
1469        // call was made, so step does not change).
1470        let step_before = manager2.epoch_step();
1471        assert_eq!(
1472            step_before, step_saved,
1473            "load_cursors must restore epoch_step to saved value"
1474        );
1475
1476        // Verify the step is still present after a refresh (no increment).
1477        manager2.refresh_all();
1478        let step_after = manager2.epoch_step();
1479        assert_eq!(
1480            step_after, step_saved,
1481            "epoch_step must survive refresh_all without increment (step is per-batch,
1482             not per-refresh): loaded {step_saved}, got {step_after}"
1483        );
1484    }
1485
1486    #[test]
1487    fn scripted_and_panic_sources_cover_default_trait_paths() {
1488        let refreshes = Arc::new(AtomicUsize::new(0));
1489        let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1490
1491        // Empty script falls back to an empty snapshot.
1492        let snapshot = scripted
1493            .refresh(&SamplerConfig::default(), None, None)
1494            .expect("fallback snapshot");
1495        assert!(snapshot.records.is_empty());
1496        assert_eq!(snapshot.cursor.revision, 0);
1497
1498        assert_eq!(
1499            scripted
1500                .reported_record_count(&SamplerConfig::default())
1501                .expect("record count"),
1502            0
1503        );
1504        assert!(scripted.default_triplet_recipes().is_empty());
1505
1506        let panic_source = PanicSource {
1507            id: "panic_count".to_string(),
1508        };
1509        assert_eq!(
1510            panic_source
1511                .reported_record_count(&SamplerConfig::default())
1512                .expect("record count"),
1513            0
1514        );
1515    }
1516
1517    #[test]
1518    fn seed_capturing_source_trait_defaults_are_exercised() {
1519        let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1520        assert_eq!(
1521            source
1522                .reported_record_count(&SamplerConfig::default())
1523                .expect("record count"),
1524            0
1525        );
1526        assert!(source.default_triplet_recipes().is_empty());
1527    }
1528
1529    #[test]
1530    fn refresh_paths_handle_zero_capacity_and_no_sources() {
1531        let mut manager = IngestionManager::new(0, SamplerConfig::default());
1532        manager
1533            .register_source(Box::new(ScriptedSource::new(
1534                "zero_capacity",
1535                Arc::new(AtomicUsize::new(0)),
1536                vec![Ok(SourceSnapshot {
1537                    records: vec![make_record("r1", "zero_capacity")],
1538                    cursor: SourceCursor {
1539                        last_seen: Utc::now(),
1540                        revision: 1,
1541                    },
1542                })],
1543            )))
1544            .unwrap();
1545        manager.refresh_all();
1546        assert!(manager.all_caches_empty());
1547
1548        // Weighted refresh with no sources should be a no-op.
1549        let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1550        let empty_weights = HashMap::new();
1551        empty_manager
1552            .refresh_all_with_weights(&empty_weights)
1553            .expect("no sources should not error");
1554        assert!(empty_manager.all_caches_empty());
1555    }
1556
1557    #[test]
1558    fn drain_start_rotates_fairly_across_sources() {
1559        // Create 3 sources, each with 10 records in their buffer after refresh.
1560        // The fair round-robin should ensure all 3 drain at the same rate
1561        // over multiple advance cycles.
1562        struct FairSource {
1563            id: String,
1564            refresh_count: Arc<AtomicUsize>,
1565        }
1566
1567        impl DataSource for FairSource {
1568            fn id(&self) -> &str {
1569                &self.id
1570            }
1571            fn refresh(
1572                &self,
1573                _config: &SamplerConfig,
1574                _cursor: Option<&SourceCursor>,
1575                _limit: Option<usize>,
1576            ) -> Result<SourceSnapshot, SamplerError> {
1577                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1578                Ok(SourceSnapshot {
1579                    records: (0..10)
1580                        .map(|i| make_record(&format!("r{i}"), &self.id))
1581                        .collect(),
1582                    cursor: SourceCursor {
1583                        last_seen: Utc::now(),
1584                        revision: 1,
1585                    },
1586                })
1587            }
1588            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1589                Ok(10)
1590            }
1591        }
1592
1593        let counts = (
1594            Arc::new(AtomicUsize::new(0)),
1595            Arc::new(AtomicUsize::new(0)),
1596            Arc::new(AtomicUsize::new(0)),
1597        );
1598
1599        let mut manager = IngestionManager::new(30, SamplerConfig::default());
1600        manager
1601            .register_source(Box::new(FairSource {
1602                id: "src_0".to_string(),
1603                refresh_count: Arc::clone(&counts.0),
1604            }))
1605            .unwrap();
1606        manager
1607            .register_source(Box::new(FairSource {
1608                id: "src_1".to_string(),
1609                refresh_count: Arc::clone(&counts.1),
1610            }))
1611            .unwrap();
1612        manager
1613            .register_source(Box::new(FairSource {
1614                id: "src_2".to_string(),
1615                refresh_count: Arc::clone(&counts.2),
1616            }))
1617            .unwrap();
1618
1619        // First refresh_all fills all buffers.
1620        manager.refresh_all();
1621        // All 3 refreshed once.
1622        assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1623        assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1624        assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1625
1626        // Each advance(1) drains 1 record from 1 source, rotating via
1627        // drain_start.  With 3 sources, after 3 advances each source
1628        // loses 1 record.  After 30 advances each source loses 10 records
1629        // and triggers a refresh.  Run 33 advances and check all 3 refreshed
1630        // roughly the same number of times.
1631        for _ in 0..33 {
1632            manager.advance(1);
1633        }
1634
1635        let r0 = counts.0.load(Ordering::SeqCst);
1636        let r1 = counts.1.load(Ordering::SeqCst);
1637        let r2 = counts.2.load(Ordering::SeqCst);
1638
1639        // Each had 10 records after initial refresh_all.
1640        // 10 records / (1 drained per 3 cycles) = 30 cycles to drain each buffer.
1641        // After 33 more advances each buffer emptied ~1 time and re-filled, so
1642        // each source should have refreshed ~2 times total (initial + 1 drain).
1643        // The exact count can vary by 1 due to timing, but all 3 must be within
1644        // 1 of each other — no source can be starved.
1645        let min = r0.min(r1).min(r2);
1646        let max = r0.max(r1).max(r2);
1647        assert!(
1648            max <= min + 1,
1649            "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1650        );
1651    }
1652
1653    #[test]
1654    fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1655        // Direct IngestionManager test with batch_size=2, 5 sources.
1656        // Isolates the drain_start rotation from the sampler pipeline.
1657        struct SimpleSource {
1658            id: String,
1659            refresh_count: Arc<AtomicUsize>,
1660        }
1661
1662        impl DataSource for SimpleSource {
1663            fn id(&self) -> &str {
1664                &self.id
1665            }
1666            fn refresh(
1667                &self,
1668                _: &SamplerConfig,
1669                _: Option<&SourceCursor>,
1670                _: Option<usize>,
1671            ) -> Result<SourceSnapshot, SamplerError> {
1672                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1673                let now = Utc::now();
1674                let records: Vec<DataRecord> = (0..8)
1675                    .map(|i| DataRecord {
1676                        id: format!("{}_r{i}", self.id),
1677                        source: self.id.clone(),
1678                        created_at: now,
1679                        updated_at: now,
1680                        quality: QualityScore { trust: 1.0 },
1681                        taxonomy: Vec::new(),
1682                        sections: Vec::new(),
1683                        meta_prefix: None,
1684                    })
1685                    .collect();
1686                Ok(SourceSnapshot {
1687                    records,
1688                    cursor: SourceCursor {
1689                        last_seen: now,
1690                        revision: 1,
1691                    },
1692                })
1693            }
1694            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1695                Ok(8)
1696            }
1697        }
1698
1699        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1700        let mut manager = IngestionManager::new(40, SamplerConfig::default());
1701        for (i, count) in counts.iter().enumerate() {
1702            manager
1703                .register_source(Box::new(SimpleSource {
1704                    id: format!("src_{i}"),
1705                    refresh_count: Arc::clone(count),
1706                }))
1707                .unwrap();
1708        }
1709
1710        manager.refresh_all();
1711        for c in &counts {
1712            assert_eq!(c.load(Ordering::SeqCst), 1);
1713        }
1714
1715        for _ in 0..80 {
1716            manager.advance(2);
1717        }
1718
1719        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1720        // Round-robin drain with 5 sources, 8 records each, batch_size=2, 80
1721        // advance calls.  Most sources refresh 5×; sources 2-4 get one extra
1722        // cycle (6) due to the rotating drain_start tie-breaker positioning.
1723        assert_eq!(
1724            totals,
1725            vec![5, 5, 6, 6, 6],
1726            "direct manager: unexpected refresh distribution"
1727        );
1728    }
1729
1730    /// Helper: create 5 sources with 8 records each and a sampler configured for
1731    /// text batches with batch_size=2.  Returns the sampler and refresh counters.
1732    fn make_five_source_sampler(
1733        counts: &[Arc<AtomicUsize>],
1734    ) -> TripletSampler<DeterministicSplitStore> {
1735        struct Tracked {
1736            id: String,
1737            refresh_count: Arc<AtomicUsize>,
1738        }
1739        impl DataSource for Tracked {
1740            fn id(&self) -> &str {
1741                &self.id
1742            }
1743            fn refresh(
1744                &self,
1745                _: &SamplerConfig,
1746                _: Option<&SourceCursor>,
1747                _: Option<usize>,
1748            ) -> Result<SourceSnapshot, SamplerError> {
1749                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1750                let now = Utc::now();
1751                let records: Vec<DataRecord> = (0..8)
1752                    .map(|i| DataRecord {
1753                        id: format!("{}_r{i}", self.id),
1754                        source: self.id.clone(),
1755                        created_at: now,
1756                        updated_at: now,
1757                        quality: QualityScore { trust: 1.0 },
1758                        taxonomy: Vec::new(),
1759                        sections: vec![RecordSection {
1760                            role: SectionRole::Anchor,
1761                            heading: None,
1762                            text: format!("x{i}"),
1763                            sentences: vec![format!("x{i}")],
1764                        }],
1765                        meta_prefix: None,
1766                    })
1767                    .collect();
1768                Ok(SourceSnapshot {
1769                    records,
1770                    cursor: SourceCursor {
1771                        last_seen: now,
1772                        revision: 1,
1773                    },
1774                })
1775            }
1776            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1777                Ok(8)
1778            }
1779        }
1780
1781        let config = SamplerConfig {
1782            batch_size: 2,
1783            text_recipes: vec![TextRecipe {
1784                name: "anchor".into(),
1785                selector: Selector::Role(SectionRole::Anchor),
1786                weight: 1.0,
1787                instruction: None,
1788            }],
1789            split: SplitRatios {
1790                train: 1.0,
1791                validation: 0.0,
1792                test: 0.0,
1793            },
1794            allowed_splits: vec![SplitLabel::Train],
1795            // Small enough that the cache slides on every advance, changing
1796            // the record pool each batch.  Without this, all 40 records fit in
1797            // the cache permanently and the cross-batch text dedup would never
1798            // clear `emitted_texts`, causing early Exhausted after only a few
1799            // batches (only 8 unique texts across 40 records).
1800            ingestion_max_records: 4,
1801            ..SamplerConfig::default()
1802        };
1803        let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1804        let sampler = TripletSampler::new(config, store);
1805
1806        for (i, count) in counts.iter().enumerate() {
1807            sampler
1808                .register_source(Box::new(Tracked {
1809                    id: format!("src_{i}"),
1810                    refresh_count: Arc::clone(count),
1811                }))
1812                .unwrap();
1813        }
1814        sampler
1815    }
1816
1817    #[test]
1818    fn sampler_unweighted_drain_distributes_evenly() {
1819        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1820        let sampler = make_five_source_sampler(&counts);
1821
1822        sampler.next_text_batch(SplitLabel::Train).unwrap();
1823        for _ in 0..80 {
1824            sampler.next_text_batch(SplitLabel::Train).unwrap();
1825        }
1826
1827        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1828        let min = *totals.iter().min().unwrap();
1829        let max = *totals.iter().max().unwrap();
1830        assert!(
1831            max <= min + 1,
1832            "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1833        );
1834        assert!(
1835            min >= 4,
1836            "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1837        );
1838    }
1839
1840    #[test]
1841    fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1842        // Same as unweighted but goes through the public weighted API with equal
1843        // weights for all 5 sources — verifies the weighted drain also rotates
1844        // fairly via drain_start bias.
1845        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1846        let sampler = make_five_source_sampler(&counts);
1847
1848        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1849
1850        sampler
1851            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1852            .unwrap();
1853        for _ in 0..80 {
1854            sampler
1855                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1856                .unwrap();
1857        }
1858
1859        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1860        let min = *totals.iter().min().unwrap();
1861        let max = *totals.iter().max().unwrap();
1862        assert!(
1863            max <= min + 1,
1864            "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1865        );
1866        assert!(
1867            min >= 4,
1868            "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1869        );
1870    }
1871
1872    #[test]
1873    fn sampler_unweighted_and_weighted_match_distribution() {
1874        // Verify both paths produce similar refresh distributions.
1875        let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1876        let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1877        let usampler = make_five_source_sampler(&uc);
1878        let wsampler = make_five_source_sampler(&wc);
1879
1880        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1881
1882        usampler.next_text_batch(SplitLabel::Train).unwrap();
1883        wsampler
1884            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1885            .unwrap();
1886        for _ in 0..80 {
1887            usampler.next_text_batch(SplitLabel::Train).unwrap();
1888            wsampler
1889                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1890                .unwrap();
1891        }
1892
1893        let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1894        let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1895        let umax = *ut.iter().max().unwrap();
1896        let umin = *ut.iter().min().unwrap();
1897        let wmax = *wt.iter().max().unwrap();
1898        let wmin = *wt.iter().min().unwrap();
1899        assert!(umax <= umin + 1, "unweighted: {ut:?}");
1900        assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1901    }
1902
1903    #[test]
1904    fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1905        // Source 3 gets weight 2.0, others get 1.0.  The weighted
1906        // proportional-fair drain must give source 3 proportionally more
1907        // refresh cycles than any weight-1.0 source.
1908        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1909        let sampler = make_five_source_sampler(&counts);
1910
1911        let mut weights = HashMap::new();
1912        for i in 0..5 {
1913            weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1914        }
1915
1916        sampler
1917            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1918            .unwrap();
1919        for _ in 0..200 {
1920            sampler
1921                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1922                .unwrap();
1923        }
1924
1925        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1926        // The highest-weight source (src_3, w=2.0) must have strictly more
1927        // refreshes than EVERY weight-1.0 source.
1928        // Note: src_4 (index 4, w=1.0) gets 12 refreshes while the other
1929        // weight-1.0 sources get 6-7.  This happens because the proportional-
1930        // fair scheduler uses drain_start as a rotating tie-breaker — being
1931        // positioned right after the high-weight source (src_3) gives src_4
1932        // extra wins when src_3's buffer empties and refills more frequently.
1933        assert!(
1934            totals
1935                .iter()
1936                .enumerate()
1937                .all(|(i, &t)| i == 3 || t < totals[3]),
1938            "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1939        );
1940        // Lock in the exact deterministic distribution.
1941        // (Distribution changed after the cross-batch dedup fix: with proper
1942        // per-record hash tracking, the sampler exhausts and force-refreshes
1943        // more often, which increases the total number of refresh cycles but
1944        // preserves the weighted distribution pattern. The key invariant
1945        // (src_3 with w=2.0 outpaces all w=1.0 sources) is unchanged.)
1946        assert_eq!(
1947            totals,
1948            vec![11, 11, 11, 31, 17],
1949            "unequal-weights: unexpected refresh distribution"
1950        );
1951    }
1952
1953    #[test]
1954    fn register_source_rejects_reserved_id_pattern() {
1955        use crate::constants::splits::is_reserved_source_id;
1956
1957        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1958
1959        // Verify the utility function catches common patterns
1960        assert!(is_reserved_source_id("__meta__"));
1961        assert!(is_reserved_source_id("__anything__"));
1962        assert!(is_reserved_source_id("__x__"));
1963        assert!(!is_reserved_source_id(""));
1964        assert!(!is_reserved_source_id("__"));
1965        assert!(!is_reserved_source_id("___"));
1966        assert!(!is_reserved_source_id("normal_source"));
1967        assert!(!is_reserved_source_id("_prefix_suffix_"));
1968        assert!(!is_reserved_source_id("__unclosed"));
1969        assert!(!is_reserved_source_id("unopened__"));
1970
1971        // Registering with a `__*__` id should fail
1972        let result = manager.register_source(Box::new(ScriptedSource::new(
1973            "__reserved__",
1974            Arc::new(AtomicUsize::new(0)),
1975            vec![],
1976        )));
1977        assert!(
1978            result.is_err(),
1979            "register_source should reject reserved source id"
1980        );
1981        let err = result.unwrap_err();
1982        assert!(
1983            matches!(&err, SamplerError::ReservedSourceId(id) if id == "__reserved__"),
1984            "expected ReservedSourceId error, got: {err}"
1985        );
1986
1987        // Verify source was NOT registered (still zero sources)
1988        assert!(!manager.has_sources());
1989
1990        // Normal source IDs still work
1991        manager
1992            .register_source(Box::new(ScriptedSource::new(
1993                "valid_source",
1994                Arc::new(AtomicUsize::new(0)),
1995                vec![],
1996            )))
1997            .unwrap();
1998        assert!(manager.has_sources());
1999    }
2000}