Skip to main content

triplets_core/
ingestion.rs

1use crate::config::SamplerConfig;
2use crate::constants::splits::{STEP_CURSOR_KEY, is_reserved_source_id};
3use crate::data::DataRecord;
4use crate::errors::SamplerError;
5use crate::hash::derive_epoch_seed;
6use crate::source::{DataSource, SourceCursor, SourceSnapshot};
7use crate::types::{RecordId, SourceId};
8use chrono::Utc;
9use indexmap::IndexMap;
10use std::collections::HashMap;
11use std::collections::VecDeque;
12use std::sync::{Arc, Condvar, Mutex, RwLock};
13use std::thread;
14use std::time::Duration;
15use tracing::debug;
16
17/// Thread-safe in-memory cache of ingested records keyed by record id.
18#[derive(Clone)]
19pub struct RecordCache {
20    inner: Arc<RwLock<RecordCacheInner>>,
21    notifier: Arc<(Mutex<CacheStats>, Condvar)>,
22}
23
24/// Internal mutable cache storage behind `RecordCache` locks.
25struct RecordCacheInner {
26    records: IndexMap<RecordId, CachedRecord>,
27    order: VecDeque<RecordId>,
28    max_records: usize,
29    next_version: u64,
30}
31
32/// Internal cache entry plus monotonic version marker.
33struct CachedRecord {
34    record: DataRecord,
35    version: u64,
36}
37
38/// Internal ingest notification counters.
39#[derive(Default)]
40struct CacheStats {
41    ingests: u64,
42}
43
44impl RecordCache {
45    /// Create a cache capped to at most `max_records` live records.
46    pub fn new(max_records: usize) -> Self {
47        Self {
48            inner: Arc::new(RwLock::new(RecordCacheInner {
49                records: IndexMap::new(),
50                order: VecDeque::new(),
51                max_records,
52                next_version: 0,
53            })),
54            notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
55        }
56    }
57
58    /// Ingest a batch of records, replacing existing entries by id.
59    pub fn ingest<I>(&self, records: I)
60    where
61        I: IntoIterator<Item = DataRecord>,
62    {
63        let mut batch: Vec<DataRecord> = records.into_iter().collect();
64        if batch.is_empty() {
65            return;
66        }
67        let mut inner = self.inner.write().expect("record cache poisoned");
68        inner.ingest_batch(&mut batch);
69        drop(inner);
70        let (lock, cvar) = &*self.notifier;
71        let mut stats = lock.lock().expect("record cache stats poisoned");
72        stats.ingests = stats.ingests.saturating_add(1);
73        cvar.notify_all();
74    }
75
76    /// Remove all cached records.
77    pub fn clear(&self) {
78        let mut inner = self.inner.write().expect("record cache poisoned");
79        inner.records.clear();
80        inner.order.clear();
81    }
82
83    /// Return a cloned snapshot of current cached records.
84    pub fn snapshot(&self) -> Vec<DataRecord> {
85        let inner = self.inner.read().expect("record cache poisoned");
86        inner
87            .records
88            .values()
89            .map(|entry| entry.record.clone())
90            .collect()
91    }
92
93    /// Return the number of completed ingest operations.
94    pub fn ingest_count(&self) -> u64 {
95        let (lock, _) = &*self.notifier;
96        lock.lock().expect("record cache stats poisoned").ingests
97    }
98
99    /// Wait until ingest count exceeds `last_seen`, or until timeout elapses.
100    pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
101        let (lock, cvar) = &*self.notifier;
102        let mut stats = lock.lock().expect("record cache stats poisoned");
103        while stats.ingests <= last_seen {
104            let result = cvar
105                .wait_timeout(stats, timeout)
106                .expect("record cache stats poisoned");
107            stats = result.0;
108            if result.1.timed_out() {
109                break;
110            }
111        }
112        stats.ingests
113    }
114
115    /// Wait indefinitely until ingest count exceeds `last_seen`.
116    pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
117        let (lock, cvar) = &*self.notifier;
118        let mut stats = lock.lock().expect("record cache stats poisoned");
119        while stats.ingests <= last_seen {
120            stats = cvar.wait(stats).expect("record cache stats poisoned");
121        }
122        stats.ingests
123    }
124
125    /// Returns `true` when the cache has no records.
126    pub fn is_empty(&self) -> bool {
127        let inner = self.inner.read().expect("record cache poisoned");
128        inner.records.is_empty()
129    }
130
131    /// Return the number of records currently cached.
132    pub fn len(&self) -> usize {
133        let inner = self.inner.read().expect("record cache poisoned");
134        inner.records.len()
135    }
136}
137
138impl RecordCacheInner {
139    fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
140        for record in records.drain(..) {
141            self.next_version = self.next_version.saturating_add(1);
142            let record_id = record.id.clone();
143            if self.records.contains_key(&record_id) {
144                if let Some(entry) = self.records.get_mut(&record_id) {
145                    entry.record = record;
146                    entry.version = self.next_version;
147                }
148                Self::refresh_order(&mut self.order, &record_id);
149                self.order.push_back(record_id);
150            } else {
151                self.order.push_back(record_id.clone());
152                self.records.insert(
153                    record_id,
154                    CachedRecord {
155                        record,
156                        version: self.next_version,
157                    },
158                );
159            }
160            self.enforce_limit();
161        }
162    }
163
164    fn enforce_limit(&mut self) {
165        if self.max_records == 0 {
166            self.records.clear();
167            self.order.clear();
168            return;
169        }
170        while self.records.len() > self.max_records {
171            if let Some(oldest) = self.order.pop_front() {
172                self.records.swap_remove(&oldest);
173            } else {
174                break;
175            }
176        }
177    }
178
179    fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
180        if order.is_empty() {
181            return;
182        }
183        if let Some(pos) = order.iter().position(|existing| existing == id) {
184            order.remove(pos);
185        }
186    }
187}
188
189/// Coordinates on-demand source refresh and per-source-cache population.
190pub struct IngestionManager {
191    sources: Vec<SourceState>,
192    max_records: usize,
193    sampler_config: SamplerConfig,
194    /// Current source epoch used to vary per-source permutation seeds across epochs.
195    source_epoch: u64,
196    /// Monotonic generation incremented whenever at least one source is refreshed.
197    source_refresh_generation: u64,
198    /// Source ids refreshed during the most recent `refresh_all_internal` call.
199    ///
200    /// This is updated even when cache ingest does not change, and is cleared when
201    /// no source refresh occurs in that cycle.
202    last_refreshed_sources: Vec<SourceId>,
203    /// Rotating start index for the round-robin buffer drain.  Instead of always
204    /// draining from source 0 (which starves high-index sources of refresh
205    /// opportunities), each cycle begins at this position and advances by one.
206    /// Over N cycles every source is drained first exactly once.
207    drain_start: usize,
208    /// Monotonic step counter incremented on every advance/refresh_all and set
209    /// on SourceCursor.step so all sources can observe per-call progress.
210    step_counter: u64,
211}
212
213#[derive(Clone, Debug, Default)]
214/// Last-refresh telemetry captured per source.
215pub struct SourceRefreshStats {
216    /// Duration of the most recent refresh in milliseconds.
217    pub last_refresh_ms: u128,
218    /// Number of records returned by the most recent refresh.
219    pub last_record_count: usize,
220    /// Throughput estimate from the most recent refresh.
221    pub last_records_per_sec: f64,
222    /// Last refresh error message, if any.
223    pub last_error: Option<String>,
224    /// Total refresh failures seen for this source.
225    pub error_count: u64,
226}
227
228impl IngestionManager {
229    /// Create a new ingestion manager that ingests on demand.
230    pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
231        Self {
232            sources: Vec::new(),
233            max_records,
234            sampler_config,
235            source_epoch: 0,
236            source_refresh_generation: 0,
237            last_refreshed_sources: Vec::new(),
238            drain_start: 0,
239            step_counter: 0,
240        }
241    }
242
243    /// Return a monotonic generation for source refresh cycles.
244    pub fn source_refresh_generation(&self) -> u64 {
245        self.source_refresh_generation
246    }
247
248    /// Return source ids refreshed by the most recent refresh cycle.
249    pub fn last_refreshed_sources(&self) -> &[SourceId] {
250        &self.last_refreshed_sources
251    }
252
253    /// Update the current source epoch.
254    ///
255    /// Only changes the epoch value so subsequent `refresh` calls pass
256    /// `seed ^ epoch` to sources, producing a different permutation.
257    /// Stream cursors are intentionally NOT reset here — the cursor is a raw
258    /// I/O offset into the source's stream and must continue advancing so
259    /// every record is eventually fetched (resetting it would repeat the
260    /// leading slice of the source on every epoch boundary).
261    /// Reset step counter to 0 (called at epoch boundaries).
262    pub(crate) fn reset_step_counter(&mut self) {
263        self.step_counter = 0;
264    }
265
266    /// Return the current step counter.
267    #[cfg(test)]
268    pub fn step_counter(&self) -> u64 {
269        self.step_counter
270    }
271
272    pub(crate) fn set_source_epoch(&mut self, epoch: u64) {
273        self.source_epoch = epoch;
274    }
275
276    /// Return the current source epoch.
277    #[cfg(test)]
278    pub fn source_epoch(&self) -> u64 {
279        self.source_epoch
280    }
281
282    /// Reset all raw source stream cursors and drain per-source buffers.
283    ///
284    /// Use this only when starting a deterministic replay from a specific
285    /// epoch (e.g. explicit `set_epoch` calls). A clean-start reset ensures
286    /// the new permutation begins at position 0 of the permuted index space.
287    pub(crate) fn reset_stream_cursors(&mut self) {
288        for state in &mut self.sources {
289            state.cursor = None;
290            state.buffer.clear();
291            state.cache.clear();
292        }
293    }
294
295    /// Register a source for on-demand ingestion.
296    ///
297    /// Returns an error if the source's `id()` matches the reserved `__*__`
298    /// pattern used for internal synthetic/metadata source identifiers.
299    pub fn register_source(
300        &mut self,
301        source: Box<dyn DataSource + 'static>,
302    ) -> Result<(), SamplerError> {
303        let source_id = source.id().to_string();
304        if is_reserved_source_id(&source_id) {
305            return Err(SamplerError::ReservedSourceId(source_id));
306        }
307        let cache = RecordCache::new(self.max_records);
308        self.sources.push(SourceState {
309            source,
310            cursor: None,
311            buffer: VecDeque::new(),
312            cache,
313            stats: SourceRefreshStats::default(),
314        });
315        Ok(())
316    }
317
318    /// Load persisted per-source stream cursors.
319    pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
320        if cursors.is_empty() {
321            return;
322        }
323        let mut map = std::collections::HashMap::with_capacity(cursors.len());
324        for (id, revision) in cursors {
325            // The step counter is stored as a (STEP_CURSOR_KEY, revision)
326            // tuple inside the same source_stream_cursors Vec.  Bitcode
327            // serialises it identically to any ("source_a", 14) cursor;
328            // the special meaning is applied here at read time.
329            if id == STEP_CURSOR_KEY {
330                self.step_counter = *revision;
331            } else {
332                map.insert(id.as_str(), *revision);
333            }
334        }
335        for state in &mut self.sources {
336            if let Some(revision) = map.get(state.source.id()) {
337                state.cursor = Some(SourceCursor {
338                    last_seen: Utc::now(),
339                    revision: *revision,
340                });
341            }
342        }
343    }
344
345    /// Snapshot current per-source stream cursors.
346    pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
347        let mut out = Vec::new();
348        for state in &self.sources {
349            if let Some(cursor) = state.cursor.as_ref() {
350                out.push((state.source.id().to_string(), cursor.revision));
351            }
352        }
353        // Append the step counter as a (STEP_CURSOR_KEY, step) tuple so it
354        // survives restarts inside the same Vec that bitcode already encodes.
355        // Bitcode sees only Vec<(String, u64)> — this tuple is no different
356        // from any source cursor.  load_cursors reverses the interpretation.
357        out.push((STEP_CURSOR_KEY.to_string(), self.step_counter));
358        out
359    }
360
361    /// Return latest refresh telemetry for each registered source.
362    pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
363        self.sources
364            .iter()
365            .map(|state| (state.source.id().to_string(), state.stats.clone()))
366            .collect()
367    }
368
369    /// Return a flat snapshot of every record currently in all per-source caches.
370    ///
371    /// Records are cloned in source order; the `source` field is guaranteed
372    /// to be set (it is normalised in `refresh_all_internal`).
373    pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
374        self.sources
375            .iter()
376            .flat_map(|s| s.cache.snapshot())
377            .collect()
378    }
379
380    /// Returns `true` when ALL per-source caches are empty.
381    pub fn all_caches_empty(&self) -> bool {
382        self.sources.iter().all(|s| s.cache.is_empty())
383    }
384
385    /// Returns the total number of records across all per-source caches.
386    pub fn all_records_len(&self) -> usize {
387        self.sources.iter().map(|s| s.cache.len()).sum()
388    }
389
390    /// Returns the sum of ingest counts across all per-source caches.
391    ///
392    /// Used as a monotonic proxy to detect whether any cache has been updated
393    /// since the last sync.
394    pub fn total_ingest_count(&self) -> u64 {
395        self.sources.iter().map(|s| s.cache.ingest_count()).sum()
396    }
397
398    /// Refresh all registered sources once.
399    pub fn refresh_all(&mut self) {
400        self.refresh_all_internal(false, None, None);
401    }
402
403    /// Advance the ingestion window by ingesting `step` new records.
404    pub fn advance(&mut self, step: usize) {
405        self.refresh_all_internal(false, Some(step), None);
406    }
407
408    /// Advance the ingestion window by ingesting `step` new records with weights.
409    ///
410    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
411    /// source ID or a negative value.
412    pub fn advance_with_weights(
413        &mut self,
414        step: usize,
415        weights: &HashMap<SourceId, f32>,
416    ) -> Result<(), SamplerError> {
417        self.validate_weights(weights)?;
418        self.refresh_all_internal(false, Some(step), Some(weights));
419        Ok(())
420    }
421
422    /// Force refresh all registered sources, discarding buffered records.
423    pub fn force_refresh_all(&mut self) {
424        self.refresh_all_internal(true, None, None);
425    }
426
427    /// Refresh all registered sources once with per-call source weights.
428    ///
429    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
430    /// source ID or a negative value.
431    pub fn refresh_all_with_weights(
432        &mut self,
433        weights: &HashMap<SourceId, f32>,
434    ) -> Result<(), SamplerError> {
435        self.validate_weights(weights)?;
436        self.refresh_all_internal(false, None, Some(weights));
437        Ok(())
438    }
439
440    /// Force refresh all registered sources with per-call source weights.
441    ///
442    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
443    /// source ID or a negative value.
444    pub fn force_refresh_all_with_weights(
445        &mut self,
446        weights: &HashMap<SourceId, f32>,
447    ) -> Result<(), SamplerError> {
448        self.validate_weights(weights)?;
449        self.refresh_all_internal(true, None, Some(weights));
450        Ok(())
451    }
452
453    fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
454        let known_ids: std::collections::HashSet<&str> =
455            self.sources.iter().map(|s| s.source.id()).collect();
456        for (id, &w) in weights {
457            if !known_ids.contains(id.as_str()) {
458                return Err(SamplerError::InvalidWeight {
459                    source_id: id.clone(),
460                    reason: "source is not registered".to_string(),
461                });
462            }
463            if w < 0.0 {
464                return Err(SamplerError::InvalidWeight {
465                    source_id: id.clone(),
466                    reason: format!("weight {w} is negative"),
467                });
468            }
469        }
470        Ok(())
471    }
472
473    /// Rebuild the shared cache by round-robin draining per-source buffers.
474    ///
475    /// When `force_refresh` is false, each source only refreshes when its buffer
476    /// is empty; when true, all buffers are cleared and all sources refresh.
477    /// If `step` is provided, performs a rolling update of `step` records (no clear).
478    /// If `step` is None, clears the cache and fills up to max capacity.
479    fn refresh_all_internal(
480        &mut self,
481        force_refresh: bool,
482        step: Option<usize>,
483        weights: Option<&HashMap<SourceId, f32>>,
484    ) {
485        self.last_refreshed_sources.clear();
486        let mut refresh_plan = Vec::new();
487        for (idx, state) in self.sources.iter_mut().enumerate() {
488            if force_refresh {
489                state.buffer.clear();
490            }
491            if force_refresh || state.buffer.is_empty() {
492                refresh_plan.push((idx, state.cursor.clone()));
493            }
494        }
495
496        if !refresh_plan.is_empty() {
497            self.step_counter = self.step_counter.saturating_add(1);
498            self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
499            self.last_refreshed_sources = refresh_plan
500                .iter()
501                .map(|(idx, _)| self.sources[*idx].source.id().to_string())
502                .collect();
503            let mut results: Vec<
504                Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
505            > = Vec::with_capacity(self.sources.len());
506            results.resize_with(self.sources.len(), || None);
507            let fetch_limit = self.max_records;
508            let sampler_config = self.sampler_config.clone();
509            let step = self.step_counter;
510            thread::scope(|scope| {
511                let mut handles = Vec::with_capacity(refresh_plan.len());
512                for (idx, cursor) in &refresh_plan {
513                    let source = &self.sources[*idx].source;
514                    let cursor = cursor.clone();
515                    let sampler_config = sampler_config.clone();
516                    let source_epoch = self.source_epoch;
517                    handles.push((
518                        *idx,
519                        scope.spawn(move || {
520                            let start = std::time::Instant::now();
521                            // XOR the source epoch into the seed so each epoch
522                            // produces a distinct permutation within the source.
523                            // XOR the step counter into the seed so every
524                            // advance/refresh call gets a distinct seed, which
525                            // sources can use for e.g. shard ordering.
526                            let epoch_config = SamplerConfig {
527                                seed: derive_epoch_seed(sampler_config.seed, source_epoch) ^ step,
528                                ..sampler_config
529                            };
530                            let result =
531                                source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
532                            let elapsed = start.elapsed();
533                            (result, elapsed)
534                        }),
535                    ));
536                }
537                for (idx, handle) in handles {
538                    let result = match handle.join() {
539                        Ok((result, elapsed)) => {
540                            debug!(
541                                source_id = %self.sources[idx].source.id(),
542                                refresh_ms = elapsed.as_millis(),
543                                "source refresh completed"
544                            );
545                            (result, elapsed)
546                        }
547                        Err(_) => (
548                            Err(SamplerError::SourceUnavailable {
549                                source_id: self.sources[idx].source.id().to_string(),
550                                reason: "source refresh thread panicked".into(),
551                            }),
552                            std::time::Duration::from_secs(0),
553                        ),
554                    };
555                    results[idx] = Some(result);
556                }
557            });
558
559            for (idx, _) in refresh_plan {
560                let Some((result, elapsed)) = results[idx].take() else {
561                    continue;
562                };
563                match result {
564                    Ok(snapshot) => {
565                        let SourceSnapshot {
566                            records,
567                            cursor: next_cursor,
568                        } = snapshot;
569                        let record_count = records.len();
570                        let seconds = elapsed.as_secs_f64();
571                        let per_sec = if seconds > 0.0 {
572                            (record_count as f64) / seconds
573                        } else {
574                            0.0
575                        };
576                        let stats = &mut self.sources[idx].stats;
577                        stats.last_refresh_ms = elapsed.as_millis();
578                        stats.last_record_count = record_count;
579                        stats.last_records_per_sec = per_sec;
580                        stats.last_error = None;
581                        debug!(
582                            source_id = %self.sources[idx].source.id(),
583                            record_count,
584                            refresh_ms = elapsed.as_millis(),
585                            records_per_sec = per_sec,
586                            "source refresh ingested records"
587                        );
588                        let source_id = self.sources[idx].source.id().to_string();
589                        let normalized = records
590                            .into_iter()
591                            .map(|mut record| {
592                                record.source = source_id.clone();
593                                record
594                            })
595                            .collect::<Vec<_>>();
596                        self.sources[idx].buffer.extend(normalized);
597                        self.sources[idx].cursor = Some(next_cursor);
598                    }
599                    Err(err) => {
600                        let stats = &mut self.sources[idx].stats;
601                        stats.last_refresh_ms = elapsed.as_millis();
602                        stats.last_record_count = 0;
603                        stats.last_records_per_sec = 0.0;
604                        stats.last_error = Some(err.to_string());
605                        stats.error_count = stats.error_count.saturating_add(1);
606                        eprintln!(
607                            "[data_sampler] source '{}' refresh failed: {}",
608                            self.sources[idx].source.id(),
609                            err
610                        );
611                    }
612                }
613            }
614        }
615
616        // On a full refresh (step=None) clear every per-source cache so that the
617        // snapshot always reflects the newest window, matching the previous
618        // shared-cache clear semantics.
619        if step.is_none() {
620            for state in self.sources.iter_mut() {
621                state.cache.clear();
622            }
623        }
624        if self.max_records == 0 {
625            return;
626        }
627        let target_limit = step.unwrap_or(self.max_records);
628        if let Some(weights) = weights {
629            self.weighted_drain_into_caches(target_limit, weights);
630        } else {
631            // Fair round-robin drain: start from `drain_start` instead of 0 so
632            // that the drain cursor rotates across cycles.  This prevents head
633            // sources (low indices) from always draining faster than tail sources,
634            // which was starving tail sources of refresh opportunities.
635            let n = self.sources.len();
636            if n > 0 {
637                let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
638                let mut total_drained = 0;
639                let mut any_remaining = true;
640                while total_drained < target_limit && any_remaining {
641                    any_remaining = false;
642                    for offset in 0..n {
643                        if total_drained >= target_limit {
644                            break;
645                        }
646                        let idx = (self.drain_start + offset) % n;
647                        if let Some(record) = self.sources[idx].buffer.pop_front() {
648                            per_source[idx].push(record);
649                            total_drained += 1;
650                            any_remaining = true;
651                        }
652                    }
653                }
654                // Advance the drain cursor so the next cycle starts from a different
655                // position.  Only advance when at least one record was drained, so a
656                // burst of drain-noop cycles on an empty source list doesn't rotate.
657                if total_drained > 0 {
658                    self.drain_start = (self.drain_start + 1) % n;
659                }
660                for (idx, batch) in per_source.into_iter().enumerate() {
661                    if !batch.is_empty() {
662                        self.sources[idx].cache.ingest(batch);
663                    }
664                }
665            }
666        }
667    }
668
669    fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
670        let len = self.sources.len();
671        if len == 0 {
672            return;
673        }
674        let mut weight_values = Vec::with_capacity(len);
675        let mut any_positive = false;
676        for state in &self.sources {
677            let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
678            if weight > 0.0 {
679                any_positive = true;
680            }
681            weight_values.push(weight);
682        }
683        if !any_positive {
684            weight_values.fill(1.0);
685        }
686
687        let mut current = vec![0.0f32; len];
688        let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
689        let mut total = 0;
690        while total < limit {
691            let mut total_weight = 0.0f32;
692            for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
693                if weight <= 0.0 {
694                    continue;
695                }
696                if self.sources[idx].buffer.is_empty() {
697                    continue;
698                }
699                total_weight += weight;
700            }
701            if total_weight == 0.0 {
702                break;
703            }
704
705            let mut best_idx = None;
706            let mut best_score = f32::MIN;
707            // Rotating tie-breaker: when scores are equal, prefer the source
708            // just PAST drain_start in rotation order (i.e. the source whose
709            // turn is coming up in the round-robin).
710            let closer_to_start = |a: usize, b: usize| -> bool {
711                let da = (a + len - self.drain_start) % len;
712                let db = (b + len - self.drain_start) % len;
713                da < db
714            };
715            for idx in 0..len {
716                if weight_values[idx] <= 0.0 {
717                    continue;
718                }
719                if self.sources[idx].buffer.is_empty() {
720                    continue;
721                }
722                current[idx] += weight_values[idx];
723                let is_better = if current[idx] > best_score {
724                    true
725                } else if current[idx] == best_score {
726                    closer_to_start(idx, best_idx.unwrap_or(0))
727                } else {
728                    false
729                };
730                if is_better {
731                    best_score = current[idx];
732                    best_idx = Some(idx);
733                }
734            }
735
736            let idx = match best_idx {
737                Some(idx) => idx,
738                None => break,
739            };
740            current[idx] -= total_weight;
741            if let Some(record) = self.sources[idx].buffer.pop_front() {
742                per_source[idx].push(record);
743                total += 1;
744            }
745        }
746
747        if total > 0 && len > 0 {
748            self.drain_start = (self.drain_start + 1) % len;
749        }
750
751        for (idx, batch) in per_source.into_iter().enumerate() {
752            if !batch.is_empty() {
753                self.sources[idx].cache.ingest(batch);
754            }
755        }
756    }
757
758    /// Returns `true` when at least one source is registered.
759    pub fn has_sources(&self) -> bool {
760        !self.sources.is_empty()
761    }
762}
763
764/// Per-source ingestion runtime state.
765struct SourceState {
766    source: Box<dyn DataSource + 'static>,
767    cursor: Option<SourceCursor>,
768    buffer: VecDeque<DataRecord>,
769    /// Per-source LRU record cache capped at `max_records`.
770    cache: RecordCache,
771    stats: SourceRefreshStats,
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use crate::TripletSampler;
778    use crate::config::{Selector, TextRecipe, TripletRecipe};
779    use crate::data::{QualityScore, RecordSection, SectionRole};
780    use crate::sampler::Sampler;
781    use crate::splits::{DeterministicSplitStore, SamplerStateStore, SplitLabel, SplitRatios};
782    use chrono::Utc;
783    use std::collections::HashMap;
784    use std::collections::VecDeque;
785    use std::sync::atomic::{AtomicUsize, Ordering};
786    use std::sync::{Arc, Mutex};
787
788    fn make_record(id: &str, source: &str) -> DataRecord {
789        let now = Utc::now();
790        DataRecord {
791            id: id.to_string(),
792            source: source.to_string(),
793            created_at: now,
794            updated_at: now,
795            quality: QualityScore { trust: 1.0 },
796            taxonomy: Vec::new(),
797            sections: vec![RecordSection {
798                role: SectionRole::Anchor,
799                heading: None,
800                text: id.to_string(),
801                sentences: vec![id.to_string()],
802            }],
803            meta_prefix: None,
804        }
805    }
806
807    struct ScriptedSource {
808        id: String,
809        refreshes: Arc<AtomicUsize>,
810        script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
811    }
812
813    impl ScriptedSource {
814        fn new(
815            id: &str,
816            refreshes: Arc<AtomicUsize>,
817            script: Vec<Result<SourceSnapshot, SamplerError>>,
818        ) -> Self {
819            Self {
820                id: id.to_string(),
821                refreshes,
822                script: Arc::new(Mutex::new(script.into_iter().collect())),
823            }
824        }
825    }
826
827    impl DataSource for ScriptedSource {
828        fn id(&self) -> &str {
829            &self.id
830        }
831
832        fn refresh(
833            &self,
834            _config: &SamplerConfig,
835            _cursor: Option<&SourceCursor>,
836            _limit: Option<usize>,
837        ) -> Result<SourceSnapshot, SamplerError> {
838            self.refreshes.fetch_add(1, Ordering::SeqCst);
839            let mut guard = self.script.lock().expect("script lock poisoned");
840            guard.pop_front().unwrap_or_else(|| {
841                Ok(SourceSnapshot {
842                    records: Vec::new(),
843                    cursor: SourceCursor {
844                        last_seen: Utc::now(),
845                        revision: 0,
846                    },
847                })
848            })
849        }
850
851        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
852            Ok(0)
853        }
854
855        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
856            Vec::new()
857        }
858    }
859
860    struct PanicSource {
861        id: String,
862    }
863
864    impl DataSource for PanicSource {
865        fn id(&self) -> &str {
866            &self.id
867        }
868
869        fn refresh(
870            &self,
871            _config: &SamplerConfig,
872            _cursor: Option<&SourceCursor>,
873            _limit: Option<usize>,
874        ) -> Result<SourceSnapshot, SamplerError> {
875            panic!("panic source refresh")
876        }
877
878        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
879            Ok(0)
880        }
881    }
882
883    #[test]
884    fn record_cache_waits_len_and_clear_paths_are_covered() {
885        let cache = RecordCache::new(2);
886        assert!(cache.is_empty());
887        assert_eq!(cache.len(), 0);
888        assert_eq!(cache.ingest_count(), 0);
889
890        cache.ingest(Vec::<DataRecord>::new());
891        assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
892
893        let cache_for_waiter = cache.clone();
894        let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
895        std::thread::sleep(Duration::from_millis(5));
896        cache.ingest(vec![make_record("r1", "s")]);
897        assert_eq!(handle.join().unwrap(), 1);
898        assert_eq!(cache.ingest_count(), 1);
899
900        cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
901        assert_eq!(cache.len(), 2);
902        let ids: Vec<String> = cache
903            .snapshot()
904            .into_iter()
905            .map(|record| record.id)
906            .collect();
907        assert!(ids.contains(&"r2".to_string()));
908        assert!(ids.contains(&"r3".to_string()));
909
910        cache.clear();
911        assert!(cache.is_empty());
912    }
913
914    #[test]
915    fn record_cache_zero_limit_discards_everything() {
916        let cache = RecordCache::new(0);
917        cache.ingest(vec![make_record("r1", "s")]);
918        assert!(cache.is_empty());
919        assert_eq!(cache.len(), 0);
920    }
921
922    #[test]
923    fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
924        let mut manager = IngestionManager::new(4, SamplerConfig::default());
925        assert!(!manager.has_sources());
926        manager.load_cursors(&[]);
927
928        let refreshes = Arc::new(AtomicUsize::new(0));
929        manager
930            .register_source(Box::new(ScriptedSource::new(
931                "cursor_source",
932                refreshes,
933                vec![Ok(SourceSnapshot {
934                    records: vec![make_record("id_1", "original_source")],
935                    cursor: SourceCursor {
936                        last_seen: Utc::now(),
937                        revision: 33,
938                    },
939                })],
940            )))
941            .unwrap();
942        assert!(manager.has_sources());
943
944        manager.load_cursors(&[("cursor_source".to_string(), 7)]);
945        let cursors = manager.snapshot_cursors();
946        // snapshot_cursors now includes a __step__ entry.
947        assert_eq!(cursors.len(), 2);
948        assert_eq!(cursors[0], ("cursor_source".to_string(), 7));
949        assert_eq!(cursors[1].0, STEP_CURSOR_KEY);
950
951        manager.refresh_all();
952        let updated = manager.snapshot_cursors();
953        assert_eq!(updated.len(), 2);
954        assert_eq!(updated[0], ("cursor_source".to_string(), 33));
955        assert_eq!(updated[1].0, STEP_CURSOR_KEY);
956        let records = manager.all_records_snapshot();
957        assert_eq!(records.len(), 1);
958        assert_eq!(records[0].source, "cursor_source");
959    }
960
961    #[test]
962    fn advance_uses_buffer_before_refreshing_again() {
963        let refreshes = Arc::new(AtomicUsize::new(0));
964        let mut manager = IngestionManager::new(5, SamplerConfig::default());
965        manager
966            .register_source(Box::new(ScriptedSource::new(
967                "buffered",
968                refreshes.clone(),
969                vec![Ok(SourceSnapshot {
970                    records: vec![
971                        make_record("a", "x"),
972                        make_record("b", "x"),
973                        make_record("c", "x"),
974                    ],
975                    cursor: SourceCursor {
976                        last_seen: Utc::now(),
977                        revision: 1,
978                    },
979                })],
980            )))
981            .unwrap();
982
983        manager.advance(1);
984        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
985        assert_eq!(manager.all_records_len(), 1);
986
987        manager.advance(1);
988        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
989        assert_eq!(manager.all_records_len(), 2);
990    }
991
992    #[test]
993    fn force_refresh_clears_buffer_and_fetches_again() {
994        let refreshes = Arc::new(AtomicUsize::new(0));
995        let mut manager = IngestionManager::new(4, SamplerConfig::default());
996        manager
997            .register_source(Box::new(ScriptedSource::new(
998                "force",
999                refreshes.clone(),
1000                vec![
1001                    Ok(SourceSnapshot {
1002                        records: vec![
1003                            make_record("r1", "x"),
1004                            make_record("r2", "x"),
1005                            make_record("r3", "x"),
1006                        ],
1007                        cursor: SourceCursor {
1008                            last_seen: Utc::now(),
1009                            revision: 10,
1010                        },
1011                    }),
1012                    Ok(SourceSnapshot {
1013                        records: vec![make_record("r4", "x")],
1014                        cursor: SourceCursor {
1015                            last_seen: Utc::now(),
1016                            revision: 11,
1017                        },
1018                    }),
1019                ],
1020            )))
1021            .unwrap();
1022
1023        manager.advance(1);
1024        assert_eq!(manager.all_records_len(), 1);
1025        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
1026
1027        manager.force_refresh_all();
1028        assert_eq!(refreshes.load(Ordering::SeqCst), 2);
1029        let records = manager.all_records_snapshot();
1030        assert_eq!(records.len(), 1);
1031        assert_eq!(records[0].id, "r4");
1032    }
1033
1034    #[test]
1035    fn weighted_drain_respects_zero_and_fallback_weights() {
1036        let mut manager = IngestionManager::new(6, SamplerConfig::default());
1037        manager
1038            .register_source(Box::new(ScriptedSource::new(
1039                "a",
1040                Arc::new(AtomicUsize::new(0)),
1041                vec![Ok(SourceSnapshot {
1042                    records: vec![make_record("a1", "a"), make_record("a2", "a")],
1043                    cursor: SourceCursor {
1044                        last_seen: Utc::now(),
1045                        revision: 1,
1046                    },
1047                })],
1048            )))
1049            .unwrap();
1050        manager
1051            .register_source(Box::new(ScriptedSource::new(
1052                "b",
1053                Arc::new(AtomicUsize::new(0)),
1054                vec![Ok(SourceSnapshot {
1055                    records: vec![make_record("b1", "b"), make_record("b2", "b")],
1056                    cursor: SourceCursor {
1057                        last_seen: Utc::now(),
1058                        revision: 1,
1059                    },
1060                })],
1061            )))
1062            .unwrap();
1063
1064        let mut only_b = HashMap::new();
1065        only_b.insert("a".to_string(), 0.0);
1066        only_b.insert("b".to_string(), 1.0);
1067        manager.refresh_all_with_weights(&only_b).unwrap();
1068        let ids: Vec<String> = manager
1069            .all_records_snapshot()
1070            .into_iter()
1071            .map(|record| record.id)
1072            .collect();
1073        assert!(ids.iter().all(|id| id.starts_with('b')));
1074
1075        let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1076        manager_fallback
1077            .register_source(Box::new(ScriptedSource::new(
1078                "a",
1079                Arc::new(AtomicUsize::new(0)),
1080                vec![Ok(SourceSnapshot {
1081                    records: vec![make_record("a1", "a")],
1082                    cursor: SourceCursor {
1083                        last_seen: Utc::now(),
1084                        revision: 2,
1085                    },
1086                })],
1087            )))
1088            .unwrap();
1089        manager_fallback
1090            .register_source(Box::new(ScriptedSource::new(
1091                "b",
1092                Arc::new(AtomicUsize::new(0)),
1093                vec![Ok(SourceSnapshot {
1094                    records: vec![make_record("b1", "b")],
1095                    cursor: SourceCursor {
1096                        last_seen: Utc::now(),
1097                        revision: 2,
1098                    },
1099                })],
1100            )))
1101            .unwrap();
1102
1103        let mut all_zero = HashMap::new();
1104        all_zero.insert("a".to_string(), 0.0);
1105        all_zero.insert("b".to_string(), 0.0);
1106        manager_fallback
1107            .refresh_all_with_weights(&all_zero)
1108            .unwrap();
1109        let ids: Vec<String> = manager_fallback
1110            .all_records_snapshot()
1111            .into_iter()
1112            .map(|record| record.id)
1113            .collect();
1114        assert!(ids.contains(&"a1".to_string()));
1115        assert!(ids.contains(&"b1".to_string()));
1116    }
1117
1118    #[test]
1119    fn refresh_errors_and_panics_update_source_stats() {
1120        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1121        manager
1122            .register_source(Box::new(ScriptedSource::new(
1123                "err_source",
1124                Arc::new(AtomicUsize::new(0)),
1125                vec![Err(SamplerError::SourceUnavailable {
1126                    source_id: "err_source".to_string(),
1127                    reason: "boom".to_string(),
1128                })],
1129            )))
1130            .unwrap();
1131        manager
1132            .register_source(Box::new(PanicSource {
1133                id: "panic_source".to_string(),
1134            }))
1135            .unwrap();
1136
1137        manager.refresh_all();
1138        let stats = manager.source_refresh_stats();
1139        let err_stats = stats
1140            .iter()
1141            .find(|(source, _)| source == "err_source")
1142            .map(|(_, stats)| stats)
1143            .unwrap();
1144        assert_eq!(err_stats.error_count, 1);
1145        assert!(
1146            err_stats
1147                .last_error
1148                .as_ref()
1149                .is_some_and(|msg| msg.contains("boom"))
1150        );
1151
1152        let panic_stats = stats
1153            .iter()
1154            .find(|(source, _)| source == "panic_source")
1155            .map(|(_, stats)| stats)
1156            .unwrap();
1157        assert_eq!(panic_stats.error_count, 1);
1158        assert!(
1159            panic_stats
1160                .last_error
1161                .as_ref()
1162                .is_some_and(|msg| msg.contains("panicked"))
1163        );
1164    }
1165
1166    #[test]
1167    fn force_refresh_with_weights_path_is_exercised() {
1168        let mut manager = IngestionManager::new(3, SamplerConfig::default());
1169        manager
1170            .register_source(Box::new(ScriptedSource::new(
1171                "w",
1172                Arc::new(AtomicUsize::new(0)),
1173                vec![Ok(SourceSnapshot {
1174                    records: vec![make_record("w1", "w")],
1175                    cursor: SourceCursor {
1176                        last_seen: Utc::now(),
1177                        revision: 3,
1178                    },
1179                })],
1180            )))
1181            .unwrap();
1182
1183        let mut weights = HashMap::new();
1184        weights.insert("w".to_string(), 1.0);
1185        manager.force_refresh_all_with_weights(&weights).unwrap();
1186        assert_eq!(manager.all_records_len(), 1);
1187    }
1188
1189    #[test]
1190    fn advance_with_weights_rejects_unknown_source() {
1191        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1192        manager
1193            .register_source(Box::new(ScriptedSource::new(
1194                "known",
1195                Arc::new(AtomicUsize::new(0)),
1196                vec![],
1197            )))
1198            .unwrap();
1199
1200        let mut weights = HashMap::new();
1201        weights.insert("known".to_string(), 1.0);
1202        weights.insert("unknown".to_string(), 0.5);
1203
1204        let err = manager.advance_with_weights(1, &weights).unwrap_err();
1205        assert!(
1206            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1207            "expected InvalidWeight for unknown source, got {err:?}"
1208        );
1209    }
1210
1211    #[test]
1212    fn refresh_all_with_weights_rejects_negative_weight() {
1213        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1214        manager
1215            .register_source(Box::new(ScriptedSource::new(
1216                "src",
1217                Arc::new(AtomicUsize::new(0)),
1218                vec![],
1219            )))
1220            .unwrap();
1221
1222        let mut weights = HashMap::new();
1223        weights.insert("src".to_string(), -1.0);
1224
1225        let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1226        assert!(
1227            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1228            "expected InvalidWeight for negative weight, got {err:?}"
1229        );
1230    }
1231
1232    #[test]
1233    fn force_refresh_all_with_weights_rejects_unknown_source() {
1234        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1235        manager
1236            .register_source(Box::new(ScriptedSource::new(
1237                "real",
1238                Arc::new(AtomicUsize::new(0)),
1239                vec![],
1240            )))
1241            .unwrap();
1242
1243        let mut weights = HashMap::new();
1244        weights.insert("ghost".to_string(), 1.0);
1245
1246        let err = manager
1247            .force_refresh_all_with_weights(&weights)
1248            .unwrap_err();
1249        assert!(
1250            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1251            "expected InvalidWeight for unknown source, got {err:?}"
1252        );
1253    }
1254
1255    /// A source that records the `config.seed` value it receives on each `refresh()` call.
1256    struct SeedCapturingSource {
1257        id: String,
1258        received_seeds: Arc<Mutex<Vec<u64>>>,
1259    }
1260
1261    impl SeedCapturingSource {
1262        fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1263            Self {
1264                id: id.to_string(),
1265                received_seeds,
1266            }
1267        }
1268    }
1269
1270    impl DataSource for SeedCapturingSource {
1271        fn id(&self) -> &str {
1272            &self.id
1273        }
1274
1275        fn refresh(
1276            &self,
1277            config: &SamplerConfig,
1278            _cursor: Option<&SourceCursor>,
1279            _limit: Option<usize>,
1280        ) -> Result<SourceSnapshot, SamplerError> {
1281            self.received_seeds
1282                .lock()
1283                .expect("seed lock poisoned")
1284                .push(config.seed);
1285            Ok(SourceSnapshot {
1286                records: Vec::new(),
1287                cursor: SourceCursor {
1288                    last_seen: Utc::now(),
1289                    revision: 0,
1290                },
1291            })
1292        }
1293
1294        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1295            Ok(0)
1296        }
1297
1298        fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1299            Vec::new()
1300        }
1301    }
1302
1303    #[test]
1304    fn source_epoch_xor_changes_seed_received_by_source() {
1305        // Verify that derive_epoch_seed(base, epoch) is actually threaded through to
1306        // the source's refresh() call, and that different epochs produce different seeds.
1307        let base_seed = 0xDEAD_BEEF_u64;
1308        let config = SamplerConfig {
1309            seed: base_seed,
1310            ..SamplerConfig::default()
1311        };
1312
1313        let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1314        let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1315
1316        // --- epoch 0 ---
1317        let mut manager = IngestionManager::new(4, config.clone());
1318        manager
1319            .register_source(Box::new(SeedCapturingSource::new(
1320                "src",
1321                Arc::clone(&seeds_epoch0),
1322            )))
1323            .unwrap();
1324        // source_epoch defaults to 0; refresh_all passes derive_epoch_seed(base, 0)
1325        manager.refresh_all();
1326
1327        // --- epoch 1 ---
1328        let mut manager2 = IngestionManager::new(4, config.clone());
1329        manager2
1330            .register_source(Box::new(SeedCapturingSource::new(
1331                "src",
1332                Arc::clone(&seeds_epoch1),
1333            )))
1334            .unwrap();
1335        manager2.set_source_epoch(1);
1336        manager2.refresh_all();
1337
1338        let received0 = seeds_epoch0.lock().unwrap();
1339        let received1 = seeds_epoch1.lock().unwrap();
1340
1341        assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1342        assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1343
1344        let seed_at_epoch0 = received0[0];
1345        let seed_at_epoch1 = received1[0];
1346
1347        // The seeds must differ — epoch XOR has a real effect.
1348        assert_ne!(
1349            seed_at_epoch0, seed_at_epoch1,
1350            "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1351             derive_epoch_seed is not reaching the source"
1352        );
1353
1354        // They must match the expected derive_epoch_seed values.
1355        // The step_counter starts at 0 and is incremented to 1 before
1356        // refresh_all passes the seed.  So each seed is epoch-seed XOR 1.
1357        assert_eq!(
1358            seed_at_epoch0,
1359            derive_epoch_seed(base_seed, 0) ^ 1,
1360            "epoch-0 seed mismatch (step_counter=1)"
1361        );
1362        assert_eq!(
1363            seed_at_epoch1,
1364            derive_epoch_seed(base_seed, 1) ^ 1,
1365            "epoch-1 seed mismatch (step_counter=1)"
1366        );
1367    }
1368
1369    #[test]
1370    fn step_counter_resets_on_epoch_change() {
1371        // Proves that set_source_epoch resets step_counter to 0 so each
1372        // epoch produces the same step sequence starting from step 1.
1373        let config = SamplerConfig::default();
1374        let seeds = Arc::new(Mutex::new(Vec::new()));
1375
1376        let mut manager = IngestionManager::new(4, config.clone());
1377        manager
1378            .register_source(Box::new(SeedCapturingSource::new(
1379                "src",
1380                Arc::clone(&seeds),
1381            )))
1382            .unwrap();
1383
1384        // Epoch 0, first refresh: step_counter goes 0→1, seed = derive(A, 0) ^ 1
1385        manager.refresh_all();
1386        let step1_seed = seeds.lock().unwrap()[0];
1387        assert_eq!(
1388            step1_seed,
1389            derive_epoch_seed(config.seed, 0) ^ 1,
1390            "epoch 0 step 1 seed"
1391        );
1392
1393        // Advance epoch — step_counter must reset to 0.
1394        manager.set_source_epoch(1);
1395        manager.reset_step_counter();
1396        seeds.lock().unwrap().clear();
1397
1398        // Epoch 1, first refresh: step_counter goes 0→1 again.
1399        // WITHOUT reset it would be 2.
1400        manager.refresh_all();
1401        let step1_epoch1 = seeds.lock().unwrap()[0];
1402        assert_eq!(
1403            step1_epoch1,
1404            derive_epoch_seed(config.seed, 1) ^ 1,
1405            "epoch 1 step 1 seed (must be ^1, not ^2)"
1406        );
1407    }
1408
1409    #[test]
1410    fn step_counter_survives_sampler_save_and_load_state() {
1411        // Proves the step counter survives through the REAL API path:
1412        //   TripletSampler::save_sampler_state → persist_source_state
1413        //   → self.ingestion.snapshot_cursors() internally
1414        //   → DeterministicSplitStore.save_sampler_state
1415        //   → load_sampler_state → load_cursors
1416        let store = Arc::new(DeterministicSplitStore::new(SplitRatios::default(), 1).unwrap());
1417
1418        // Sampler with ScriptedSource (provides empty recipes but that's fine)
1419        let sampler = TripletSampler::new(
1420            SamplerConfig {
1421                seed: 42,
1422                ..SamplerConfig::default()
1423            },
1424            Arc::clone(&store),
1425        );
1426        let refreshes = Arc::new(AtomicUsize::new(0));
1427        sampler
1428            .register_source(Box::new(ScriptedSource::new(
1429                "src",
1430                refreshes,
1431                vec![Ok(SourceSnapshot {
1432                    records: vec![make_record("r1", "src")],
1433                    cursor: SourceCursor {
1434                        last_seen: Utc::now(),
1435                        revision: 1,
1436                    },
1437                })],
1438            )))
1439            .unwrap();
1440
1441        // Call a batch method to trigger ensure_source_state() so
1442        // save_sampler_state actually writes.  The batch itself may
1443        // exhaust (no recipes) but the state is loaded regardless.
1444        let _ = sampler.next_text_batch(SplitLabel::Train);
1445
1446        // Save through the REAL sampler API
1447        sampler.save_sampler_state(None).unwrap();
1448
1449        // Load back what the sampler saved
1450        let loaded = store.load_sampler_state().unwrap().unwrap();
1451        let step_saved = loaded
1452            .source_stream_cursors
1453            .iter()
1454            .find(|(k, _)| k == STEP_CURSOR_KEY)
1455            .map(|(_, v)| *v)
1456            .expect("STEP_CURSOR_KEY must be in persisted source_stream_cursors");
1457        assert!(
1458            step_saved > 0,
1459            "step_counter must be >0 after batch call through TripletSampler"
1460        );
1461
1462        // Feed the REAL loaded cursors to a new manager
1463        let mut manager2 = IngestionManager::new(4, SamplerConfig::default());
1464        manager2
1465            .register_source(Box::new(ScriptedSource::new(
1466                "src",
1467                Arc::new(AtomicUsize::new(0)),
1468                vec![Ok(SourceSnapshot {
1469                    records: vec![make_record("r2", "src")],
1470                    cursor: SourceCursor {
1471                        last_seen: Utc::now(),
1472                        revision: 2,
1473                    },
1474                })],
1475            )))
1476            .unwrap();
1477        manager2.load_cursors(&loaded.source_stream_cursors);
1478
1479        // step_counter restored to saved value, refresh increments
1480        let step_before = manager2
1481            .snapshot_cursors()
1482            .iter()
1483            .find(|(k, _)| k == STEP_CURSOR_KEY)
1484            .map(|(_, v)| *v)
1485            .unwrap();
1486        assert_eq!(
1487            step_before, step_saved,
1488            "load_cursors must restore __step__ to saved value"
1489        );
1490
1491        manager2.refresh_all();
1492        let step_after = manager2
1493            .snapshot_cursors()
1494            .iter()
1495            .find(|(k, _)| k == STEP_CURSOR_KEY)
1496            .map(|(_, v)| *v)
1497            .unwrap();
1498        assert_eq!(
1499            step_after,
1500            step_saved + 1,
1501            "step must continue: loaded {step_saved}, refresh incremented to {step_saved}+1"
1502        );
1503    }
1504
1505    #[test]
1506    fn scripted_and_panic_sources_cover_default_trait_paths() {
1507        let refreshes = Arc::new(AtomicUsize::new(0));
1508        let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1509
1510        // Empty script falls back to an empty snapshot.
1511        let snapshot = scripted
1512            .refresh(&SamplerConfig::default(), None, None)
1513            .expect("fallback snapshot");
1514        assert!(snapshot.records.is_empty());
1515        assert_eq!(snapshot.cursor.revision, 0);
1516
1517        assert_eq!(
1518            scripted
1519                .reported_record_count(&SamplerConfig::default())
1520                .expect("record count"),
1521            0
1522        );
1523        assert!(scripted.default_triplet_recipes().is_empty());
1524
1525        let panic_source = PanicSource {
1526            id: "panic_count".to_string(),
1527        };
1528        assert_eq!(
1529            panic_source
1530                .reported_record_count(&SamplerConfig::default())
1531                .expect("record count"),
1532            0
1533        );
1534    }
1535
1536    #[test]
1537    fn seed_capturing_source_trait_defaults_are_exercised() {
1538        let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1539        assert_eq!(
1540            source
1541                .reported_record_count(&SamplerConfig::default())
1542                .expect("record count"),
1543            0
1544        );
1545        assert!(source.default_triplet_recipes().is_empty());
1546    }
1547
1548    #[test]
1549    fn refresh_paths_handle_zero_capacity_and_no_sources() {
1550        let mut manager = IngestionManager::new(0, SamplerConfig::default());
1551        manager
1552            .register_source(Box::new(ScriptedSource::new(
1553                "zero_capacity",
1554                Arc::new(AtomicUsize::new(0)),
1555                vec![Ok(SourceSnapshot {
1556                    records: vec![make_record("r1", "zero_capacity")],
1557                    cursor: SourceCursor {
1558                        last_seen: Utc::now(),
1559                        revision: 1,
1560                    },
1561                })],
1562            )))
1563            .unwrap();
1564        manager.refresh_all();
1565        assert!(manager.all_caches_empty());
1566
1567        // Weighted refresh with no sources should be a no-op.
1568        let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1569        let empty_weights = HashMap::new();
1570        empty_manager
1571            .refresh_all_with_weights(&empty_weights)
1572            .expect("no sources should not error");
1573        assert!(empty_manager.all_caches_empty());
1574    }
1575
1576    #[test]
1577    fn drain_start_rotates_fairly_across_sources() {
1578        // Create 3 sources, each with 10 records in their buffer after refresh.
1579        // The fair round-robin should ensure all 3 drain at the same rate
1580        // over multiple advance cycles.
1581        struct FairSource {
1582            id: String,
1583            refresh_count: Arc<AtomicUsize>,
1584        }
1585
1586        impl DataSource for FairSource {
1587            fn id(&self) -> &str {
1588                &self.id
1589            }
1590            fn refresh(
1591                &self,
1592                _config: &SamplerConfig,
1593                _cursor: Option<&SourceCursor>,
1594                _limit: Option<usize>,
1595            ) -> Result<SourceSnapshot, SamplerError> {
1596                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1597                Ok(SourceSnapshot {
1598                    records: (0..10)
1599                        .map(|i| make_record(&format!("r{i}"), &self.id))
1600                        .collect(),
1601                    cursor: SourceCursor {
1602                        last_seen: Utc::now(),
1603                        revision: 1,
1604                    },
1605                })
1606            }
1607            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1608                Ok(10)
1609            }
1610        }
1611
1612        let counts = (
1613            Arc::new(AtomicUsize::new(0)),
1614            Arc::new(AtomicUsize::new(0)),
1615            Arc::new(AtomicUsize::new(0)),
1616        );
1617
1618        let mut manager = IngestionManager::new(30, SamplerConfig::default());
1619        manager
1620            .register_source(Box::new(FairSource {
1621                id: "src_0".to_string(),
1622                refresh_count: Arc::clone(&counts.0),
1623            }))
1624            .unwrap();
1625        manager
1626            .register_source(Box::new(FairSource {
1627                id: "src_1".to_string(),
1628                refresh_count: Arc::clone(&counts.1),
1629            }))
1630            .unwrap();
1631        manager
1632            .register_source(Box::new(FairSource {
1633                id: "src_2".to_string(),
1634                refresh_count: Arc::clone(&counts.2),
1635            }))
1636            .unwrap();
1637
1638        // First refresh_all fills all buffers.
1639        manager.refresh_all();
1640        // All 3 refreshed once.
1641        assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1642        assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1643        assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1644
1645        // Each advance(1) drains 1 record from 1 source, rotating via
1646        // drain_start.  With 3 sources, after 3 advances each source
1647        // loses 1 record.  After 30 advances each source loses 10 records
1648        // and triggers a refresh.  Run 33 advances and check all 3 refreshed
1649        // roughly the same number of times.
1650        for _ in 0..33 {
1651            manager.advance(1);
1652        }
1653
1654        let r0 = counts.0.load(Ordering::SeqCst);
1655        let r1 = counts.1.load(Ordering::SeqCst);
1656        let r2 = counts.2.load(Ordering::SeqCst);
1657
1658        // Each had 10 records after initial refresh_all.
1659        // 10 records / (1 drained per 3 cycles) = 30 cycles to drain each buffer.
1660        // After 33 more advances each buffer emptied ~1 time and re-filled, so
1661        // each source should have refreshed ~2 times total (initial + 1 drain).
1662        // The exact count can vary by 1 due to timing, but all 3 must be within
1663        // 1 of each other — no source can be starved.
1664        let min = r0.min(r1).min(r2);
1665        let max = r0.max(r1).max(r2);
1666        assert!(
1667            max <= min + 1,
1668            "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1669        );
1670    }
1671
1672    #[test]
1673    fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1674        // Direct IngestionManager test with batch_size=2, 5 sources.
1675        // Isolates the drain_start rotation from the sampler pipeline.
1676        struct SimpleSource {
1677            id: String,
1678            refresh_count: Arc<AtomicUsize>,
1679        }
1680
1681        impl DataSource for SimpleSource {
1682            fn id(&self) -> &str {
1683                &self.id
1684            }
1685            fn refresh(
1686                &self,
1687                _: &SamplerConfig,
1688                _: Option<&SourceCursor>,
1689                _: Option<usize>,
1690            ) -> Result<SourceSnapshot, SamplerError> {
1691                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1692                let now = Utc::now();
1693                let records: Vec<DataRecord> = (0..8)
1694                    .map(|i| DataRecord {
1695                        id: format!("{}_r{i}", self.id),
1696                        source: self.id.clone(),
1697                        created_at: now,
1698                        updated_at: now,
1699                        quality: QualityScore { trust: 1.0 },
1700                        taxonomy: Vec::new(),
1701                        sections: Vec::new(),
1702                        meta_prefix: None,
1703                    })
1704                    .collect();
1705                Ok(SourceSnapshot {
1706                    records,
1707                    cursor: SourceCursor {
1708                        last_seen: now,
1709                        revision: 1,
1710                    },
1711                })
1712            }
1713            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1714                Ok(8)
1715            }
1716        }
1717
1718        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1719        let mut manager = IngestionManager::new(40, SamplerConfig::default());
1720        for (i, count) in counts.iter().enumerate() {
1721            manager
1722                .register_source(Box::new(SimpleSource {
1723                    id: format!("src_{i}"),
1724                    refresh_count: Arc::clone(count),
1725                }))
1726                .unwrap();
1727        }
1728
1729        manager.refresh_all();
1730        for c in &counts {
1731            assert_eq!(c.load(Ordering::SeqCst), 1);
1732        }
1733
1734        for _ in 0..80 {
1735            manager.advance(2);
1736        }
1737
1738        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1739        // Round-robin drain with 5 sources, 8 records each, batch_size=2, 80
1740        // advance calls.  Most sources refresh 5×; sources 2-4 get one extra
1741        // cycle (6) due to the rotating drain_start tie-breaker positioning.
1742        assert_eq!(
1743            totals,
1744            vec![5, 5, 6, 6, 6],
1745            "direct manager: unexpected refresh distribution"
1746        );
1747    }
1748
1749    /// Helper: create 5 sources with 8 records each and a sampler configured for
1750    /// text batches with batch_size=2.  Returns the sampler and refresh counters.
1751    fn make_five_source_sampler(
1752        counts: &[Arc<AtomicUsize>],
1753    ) -> TripletSampler<DeterministicSplitStore> {
1754        struct Tracked {
1755            id: String,
1756            refresh_count: Arc<AtomicUsize>,
1757        }
1758        impl DataSource for Tracked {
1759            fn id(&self) -> &str {
1760                &self.id
1761            }
1762            fn refresh(
1763                &self,
1764                _: &SamplerConfig,
1765                _: Option<&SourceCursor>,
1766                _: Option<usize>,
1767            ) -> Result<SourceSnapshot, SamplerError> {
1768                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1769                let now = Utc::now();
1770                let records: Vec<DataRecord> = (0..8)
1771                    .map(|i| DataRecord {
1772                        id: format!("{}_r{i}", self.id),
1773                        source: self.id.clone(),
1774                        created_at: now,
1775                        updated_at: now,
1776                        quality: QualityScore { trust: 1.0 },
1777                        taxonomy: Vec::new(),
1778                        sections: vec![RecordSection {
1779                            role: SectionRole::Anchor,
1780                            heading: None,
1781                            text: format!("x{i}"),
1782                            sentences: vec![format!("x{i}")],
1783                        }],
1784                        meta_prefix: None,
1785                    })
1786                    .collect();
1787                Ok(SourceSnapshot {
1788                    records,
1789                    cursor: SourceCursor {
1790                        last_seen: now,
1791                        revision: 1,
1792                    },
1793                })
1794            }
1795            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1796                Ok(8)
1797            }
1798        }
1799
1800        let config = SamplerConfig {
1801            batch_size: 2,
1802            text_recipes: vec![TextRecipe {
1803                name: "anchor".into(),
1804                selector: Selector::Role(SectionRole::Anchor),
1805                weight: 1.0,
1806                instruction: None,
1807            }],
1808            split: SplitRatios {
1809                train: 1.0,
1810                validation: 0.0,
1811                test: 0.0,
1812            },
1813            allowed_splits: vec![SplitLabel::Train],
1814            // Small enough that the cache slides on every advance, changing
1815            // the record pool each batch.  Without this, all 40 records fit in
1816            // the cache permanently and the cross-batch text dedup would never
1817            // clear `emitted_texts`, causing early Exhausted after only a few
1818            // batches (only 8 unique texts across 40 records).
1819            ingestion_max_records: 4,
1820            ..SamplerConfig::default()
1821        };
1822        let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1823        let sampler = TripletSampler::new(config, store);
1824
1825        for (i, count) in counts.iter().enumerate() {
1826            sampler
1827                .register_source(Box::new(Tracked {
1828                    id: format!("src_{i}"),
1829                    refresh_count: Arc::clone(count),
1830                }))
1831                .unwrap();
1832        }
1833        sampler
1834    }
1835
1836    #[test]
1837    fn sampler_unweighted_drain_distributes_evenly() {
1838        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1839        let sampler = make_five_source_sampler(&counts);
1840
1841        sampler.next_text_batch(SplitLabel::Train).unwrap();
1842        for _ in 0..80 {
1843            sampler.next_text_batch(SplitLabel::Train).unwrap();
1844        }
1845
1846        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1847        let min = *totals.iter().min().unwrap();
1848        let max = *totals.iter().max().unwrap();
1849        assert!(
1850            max <= min + 1,
1851            "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1852        );
1853        assert!(
1854            min >= 4,
1855            "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1856        );
1857    }
1858
1859    #[test]
1860    fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1861        // Same as unweighted but goes through the public weighted API with equal
1862        // weights for all 5 sources — verifies the weighted drain also rotates
1863        // fairly via drain_start bias.
1864        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1865        let sampler = make_five_source_sampler(&counts);
1866
1867        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1868
1869        sampler
1870            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1871            .unwrap();
1872        for _ in 0..80 {
1873            sampler
1874                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1875                .unwrap();
1876        }
1877
1878        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1879        let min = *totals.iter().min().unwrap();
1880        let max = *totals.iter().max().unwrap();
1881        assert!(
1882            max <= min + 1,
1883            "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1884        );
1885        assert!(
1886            min >= 4,
1887            "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1888        );
1889    }
1890
1891    #[test]
1892    fn sampler_unweighted_and_weighted_match_distribution() {
1893        // Verify both paths produce similar refresh distributions.
1894        let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1895        let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1896        let usampler = make_five_source_sampler(&uc);
1897        let wsampler = make_five_source_sampler(&wc);
1898
1899        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1900
1901        usampler.next_text_batch(SplitLabel::Train).unwrap();
1902        wsampler
1903            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1904            .unwrap();
1905        for _ in 0..80 {
1906            usampler.next_text_batch(SplitLabel::Train).unwrap();
1907            wsampler
1908                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1909                .unwrap();
1910        }
1911
1912        let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1913        let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1914        let umax = *ut.iter().max().unwrap();
1915        let umin = *ut.iter().min().unwrap();
1916        let wmax = *wt.iter().max().unwrap();
1917        let wmin = *wt.iter().min().unwrap();
1918        assert!(umax <= umin + 1, "unweighted: {ut:?}");
1919        assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1920    }
1921
1922    #[test]
1923    fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1924        // Source 3 gets weight 2.0, others get 1.0.  The weighted
1925        // proportional-fair drain must give source 3 proportionally more
1926        // refresh cycles than any weight-1.0 source.
1927        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1928        let sampler = make_five_source_sampler(&counts);
1929
1930        let mut weights = HashMap::new();
1931        for i in 0..5 {
1932            weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1933        }
1934
1935        sampler
1936            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1937            .unwrap();
1938        for _ in 0..200 {
1939            sampler
1940                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1941                .unwrap();
1942        }
1943
1944        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1945        // The highest-weight source (src_3, w=2.0) must have strictly more
1946        // refreshes than EVERY weight-1.0 source.
1947        // Note: src_4 (index 4, w=1.0) gets 12 refreshes while the other
1948        // weight-1.0 sources get 6-7.  This happens because the proportional-
1949        // fair scheduler uses drain_start as a rotating tie-breaker — being
1950        // positioned right after the high-weight source (src_3) gives src_4
1951        // extra wins when src_3's buffer empties and refills more frequently.
1952        assert!(
1953            totals
1954                .iter()
1955                .enumerate()
1956                .all(|(i, &t)| i == 3 || t < totals[3]),
1957            "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1958        );
1959        // Lock in the exact deterministic distribution.
1960        assert_eq!(
1961            totals,
1962            vec![6, 6, 6, 26, 11],
1963            "unequal-weights: unexpected refresh distribution"
1964        );
1965    }
1966
1967    #[test]
1968    fn register_source_rejects_reserved_id_pattern() {
1969        use crate::constants::splits::is_reserved_source_id;
1970
1971        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1972
1973        // Verify the utility function catches common patterns
1974        assert!(is_reserved_source_id("__step__"));
1975        assert!(is_reserved_source_id("__meta__"));
1976        assert!(is_reserved_source_id("__anything__"));
1977        assert!(is_reserved_source_id("__x__"));
1978        assert!(!is_reserved_source_id(""));
1979        assert!(!is_reserved_source_id("__"));
1980        assert!(!is_reserved_source_id("___"));
1981        assert!(!is_reserved_source_id("normal_source"));
1982        assert!(!is_reserved_source_id("_prefix_suffix_"));
1983        assert!(!is_reserved_source_id("__unclosed"));
1984        assert!(!is_reserved_source_id("unopened__"));
1985
1986        // Registering with a `__*__` id should fail
1987        let result = manager.register_source(Box::new(ScriptedSource::new(
1988            "__reserved__",
1989            Arc::new(AtomicUsize::new(0)),
1990            vec![],
1991        )));
1992        assert!(
1993            result.is_err(),
1994            "register_source should reject reserved source id"
1995        );
1996        let err = result.unwrap_err();
1997        assert!(
1998            matches!(&err, SamplerError::ReservedSourceId(id) if id == "__reserved__"),
1999            "expected ReservedSourceId error, got: {err}"
2000        );
2001
2002        // Verify source was NOT registered (still zero sources)
2003        assert!(!manager.has_sources());
2004
2005        // Normal source IDs still work
2006        manager
2007            .register_source(Box::new(ScriptedSource::new(
2008                "valid_source",
2009                Arc::new(AtomicUsize::new(0)),
2010                vec![],
2011            )))
2012            .unwrap();
2013        assert!(manager.has_sources());
2014    }
2015}