Skip to main content

triplets_core/
ingestion.rs

1use crate::config::SamplerConfig;
2use crate::data::DataRecord;
3use crate::errors::SamplerError;
4use crate::hash::derive_epoch_seed;
5use crate::source::{DataSource, SourceCursor, SourceSnapshot};
6use crate::types::{RecordId, SourceId};
7use chrono::Utc;
8use indexmap::IndexMap;
9use std::collections::HashMap;
10use std::collections::VecDeque;
11use std::sync::{Arc, Condvar, Mutex, RwLock};
12use std::thread;
13use std::time::Duration;
14use tracing::debug;
15
16/// Thread-safe in-memory cache of ingested records keyed by record id.
17#[derive(Clone)]
18pub struct RecordCache {
19    inner: Arc<RwLock<RecordCacheInner>>,
20    notifier: Arc<(Mutex<CacheStats>, Condvar)>,
21}
22
23/// Internal mutable cache storage behind `RecordCache` locks.
24struct RecordCacheInner {
25    records: IndexMap<RecordId, CachedRecord>,
26    order: VecDeque<RecordId>,
27    max_records: usize,
28    next_version: u64,
29}
30
31/// Internal cache entry plus monotonic version marker.
32struct CachedRecord {
33    record: DataRecord,
34    version: u64,
35}
36
37/// Internal ingest notification counters.
38#[derive(Default)]
39struct CacheStats {
40    ingests: u64,
41}
42
43impl RecordCache {
44    /// Create a cache capped to at most `max_records` live records.
45    pub fn new(max_records: usize) -> Self {
46        Self {
47            inner: Arc::new(RwLock::new(RecordCacheInner {
48                records: IndexMap::new(),
49                order: VecDeque::new(),
50                max_records,
51                next_version: 0,
52            })),
53            notifier: Arc::new((Mutex::new(CacheStats::default()), Condvar::new())),
54        }
55    }
56
57    /// Ingest a batch of records, replacing existing entries by id.
58    pub fn ingest<I>(&self, records: I)
59    where
60        I: IntoIterator<Item = DataRecord>,
61    {
62        let mut batch: Vec<DataRecord> = records.into_iter().collect();
63        if batch.is_empty() {
64            return;
65        }
66        let mut inner = self.inner.write().expect("record cache poisoned");
67        inner.ingest_batch(&mut batch);
68        drop(inner);
69        let (lock, cvar) = &*self.notifier;
70        let mut stats = lock.lock().expect("record cache stats poisoned");
71        stats.ingests = stats.ingests.saturating_add(1);
72        cvar.notify_all();
73    }
74
75    /// Remove all cached records.
76    pub fn clear(&self) {
77        let mut inner = self.inner.write().expect("record cache poisoned");
78        inner.records.clear();
79        inner.order.clear();
80    }
81
82    /// Return a cloned snapshot of current cached records.
83    pub fn snapshot(&self) -> Vec<DataRecord> {
84        let inner = self.inner.read().expect("record cache poisoned");
85        inner
86            .records
87            .values()
88            .map(|entry| entry.record.clone())
89            .collect()
90    }
91
92    /// Return the number of completed ingest operations.
93    pub fn ingest_count(&self) -> u64 {
94        let (lock, _) = &*self.notifier;
95        lock.lock().expect("record cache stats poisoned").ingests
96    }
97
98    /// Wait until ingest count exceeds `last_seen`, or until timeout elapses.
99    pub fn wait_for_ingest(&self, last_seen: u64, timeout: Duration) -> u64 {
100        let (lock, cvar) = &*self.notifier;
101        let mut stats = lock.lock().expect("record cache stats poisoned");
102        while stats.ingests <= last_seen {
103            let result = cvar
104                .wait_timeout(stats, timeout)
105                .expect("record cache stats poisoned");
106            stats = result.0;
107            if result.1.timed_out() {
108                break;
109            }
110        }
111        stats.ingests
112    }
113
114    /// Wait indefinitely until ingest count exceeds `last_seen`.
115    pub fn wait_for_ingest_blocking(&self, last_seen: u64) -> u64 {
116        let (lock, cvar) = &*self.notifier;
117        let mut stats = lock.lock().expect("record cache stats poisoned");
118        while stats.ingests <= last_seen {
119            stats = cvar.wait(stats).expect("record cache stats poisoned");
120        }
121        stats.ingests
122    }
123
124    /// Returns `true` when the cache has no records.
125    pub fn is_empty(&self) -> bool {
126        let inner = self.inner.read().expect("record cache poisoned");
127        inner.records.is_empty()
128    }
129
130    /// Return the number of records currently cached.
131    pub fn len(&self) -> usize {
132        let inner = self.inner.read().expect("record cache poisoned");
133        inner.records.len()
134    }
135}
136
137impl RecordCacheInner {
138    fn ingest_batch(&mut self, records: &mut Vec<DataRecord>) {
139        for record in records.drain(..) {
140            self.next_version = self.next_version.saturating_add(1);
141            let record_id = record.id.clone();
142            if self.records.contains_key(&record_id) {
143                if let Some(entry) = self.records.get_mut(&record_id) {
144                    entry.record = record;
145                    entry.version = self.next_version;
146                }
147                Self::refresh_order(&mut self.order, &record_id);
148                self.order.push_back(record_id);
149            } else {
150                self.order.push_back(record_id.clone());
151                self.records.insert(
152                    record_id,
153                    CachedRecord {
154                        record,
155                        version: self.next_version,
156                    },
157                );
158            }
159            self.enforce_limit();
160        }
161    }
162
163    fn enforce_limit(&mut self) {
164        if self.max_records == 0 {
165            self.records.clear();
166            self.order.clear();
167            return;
168        }
169        while self.records.len() > self.max_records {
170            if let Some(oldest) = self.order.pop_front() {
171                self.records.swap_remove(&oldest);
172            } else {
173                break;
174            }
175        }
176    }
177
178    fn refresh_order(order: &mut VecDeque<RecordId>, id: &RecordId) {
179        if order.is_empty() {
180            return;
181        }
182        if let Some(pos) = order.iter().position(|existing| existing == id) {
183            order.remove(pos);
184        }
185    }
186}
187
188/// Coordinates on-demand source refresh and per-source-cache population.
189pub struct IngestionManager {
190    sources: Vec<SourceState>,
191    max_records: usize,
192    sampler_config: SamplerConfig,
193    /// Current source epoch used to vary per-source permutation seeds across epochs.
194    source_epoch: u64,
195    /// Monotonic generation incremented whenever at least one source is refreshed.
196    source_refresh_generation: u64,
197    /// Source ids refreshed during the most recent `refresh_all_internal` call.
198    ///
199    /// This is updated even when cache ingest does not change, and is cleared when
200    /// no source refresh occurs in that cycle.
201    last_refreshed_sources: Vec<SourceId>,
202    /// Rotating start index for the round-robin buffer drain.  Instead of always
203    /// draining from source 0 (which starves high-index sources of refresh
204    /// opportunities), each cycle begins at this position and advances by one.
205    /// Over N cycles every source is drained first exactly once.
206    drain_start: usize,
207}
208
209#[derive(Clone, Debug, Default)]
210/// Last-refresh telemetry captured per source.
211pub struct SourceRefreshStats {
212    /// Duration of the most recent refresh in milliseconds.
213    pub last_refresh_ms: u128,
214    /// Number of records returned by the most recent refresh.
215    pub last_record_count: usize,
216    /// Throughput estimate from the most recent refresh.
217    pub last_records_per_sec: f64,
218    /// Last refresh error message, if any.
219    pub last_error: Option<String>,
220    /// Total refresh failures seen for this source.
221    pub error_count: u64,
222}
223
224impl IngestionManager {
225    /// Create a new ingestion manager that ingests on demand.
226    pub fn new(max_records: usize, sampler_config: SamplerConfig) -> Self {
227        Self {
228            sources: Vec::new(),
229            max_records,
230            sampler_config,
231            source_epoch: 0,
232            source_refresh_generation: 0,
233            last_refreshed_sources: Vec::new(),
234            drain_start: 0,
235        }
236    }
237
238    /// Return a monotonic generation for source refresh cycles.
239    pub fn source_refresh_generation(&self) -> u64 {
240        self.source_refresh_generation
241    }
242
243    /// Return source ids refreshed by the most recent refresh cycle.
244    pub fn last_refreshed_sources(&self) -> &[SourceId] {
245        &self.last_refreshed_sources
246    }
247
248    /// Update the current source epoch.
249    ///
250    /// Only changes the epoch value so subsequent `refresh` calls pass
251    /// `seed ^ epoch` to sources, producing a different permutation.
252    /// Stream cursors are intentionally NOT reset here — the cursor is a raw
253    /// I/O offset into the source's stream and must continue advancing so
254    /// every record is eventually fetched (resetting it would repeat the
255    /// leading slice of the source on every epoch boundary).
256    pub(crate) fn set_source_epoch(&mut self, epoch: u64) {
257        self.source_epoch = epoch;
258    }
259
260    /// Return the current source epoch.
261    #[cfg(test)]
262    pub fn source_epoch(&self) -> u64 {
263        self.source_epoch
264    }
265
266    /// Reset all raw source stream cursors and drain per-source buffers.
267    ///
268    /// Use this only when starting a deterministic replay from a specific
269    /// epoch (e.g. explicit `set_epoch` calls). A clean-start reset ensures
270    /// the new permutation begins at position 0 of the permuted index space.
271    pub(crate) fn reset_stream_cursors(&mut self) {
272        for state in &mut self.sources {
273            state.cursor = None;
274            state.buffer.clear();
275            state.cache.clear();
276        }
277    }
278
279    /// Register a source for on-demand ingestion.
280    pub fn register_source(&mut self, source: Box<dyn DataSource + 'static>) {
281        let cache = RecordCache::new(self.max_records);
282        self.sources.push(SourceState {
283            source,
284            cursor: None,
285            buffer: VecDeque::new(),
286            cache,
287            stats: SourceRefreshStats::default(),
288        });
289    }
290
291    /// Load persisted per-source stream cursors.
292    pub fn load_cursors(&mut self, cursors: &[(SourceId, u64)]) {
293        if cursors.is_empty() {
294            return;
295        }
296        let mut map = std::collections::HashMap::with_capacity(cursors.len());
297        for (id, revision) in cursors {
298            map.insert(id.as_str(), *revision);
299        }
300        for state in &mut self.sources {
301            if let Some(revision) = map.get(state.source.id()) {
302                state.cursor = Some(SourceCursor {
303                    last_seen: Utc::now(),
304                    revision: *revision,
305                });
306            }
307        }
308    }
309
310    /// Snapshot current per-source stream cursors.
311    pub fn snapshot_cursors(&self) -> Vec<(SourceId, u64)> {
312        let mut out = Vec::new();
313        for state in &self.sources {
314            if let Some(cursor) = state.cursor.as_ref() {
315                out.push((state.source.id().to_string(), cursor.revision));
316            }
317        }
318        out
319    }
320
321    /// Return latest refresh telemetry for each registered source.
322    pub fn source_refresh_stats(&self) -> Vec<(SourceId, SourceRefreshStats)> {
323        self.sources
324            .iter()
325            .map(|state| (state.source.id().to_string(), state.stats.clone()))
326            .collect()
327    }
328
329    /// Return a flat snapshot of every record currently in all per-source caches.
330    ///
331    /// Records are cloned in source order; the `source` field is guaranteed
332    /// to be set (it is normalised in `refresh_all_internal`).
333    pub fn all_records_snapshot(&self) -> Vec<DataRecord> {
334        self.sources
335            .iter()
336            .flat_map(|s| s.cache.snapshot())
337            .collect()
338    }
339
340    /// Returns `true` when ALL per-source caches are empty.
341    pub fn all_caches_empty(&self) -> bool {
342        self.sources.iter().all(|s| s.cache.is_empty())
343    }
344
345    /// Returns the total number of records across all per-source caches.
346    pub fn all_records_len(&self) -> usize {
347        self.sources.iter().map(|s| s.cache.len()).sum()
348    }
349
350    /// Returns the sum of ingest counts across all per-source caches.
351    ///
352    /// Used as a monotonic proxy to detect whether any cache has been updated
353    /// since the last sync.
354    pub fn total_ingest_count(&self) -> u64 {
355        self.sources.iter().map(|s| s.cache.ingest_count()).sum()
356    }
357
358    /// Refresh all registered sources once.
359    pub fn refresh_all(&mut self) {
360        self.refresh_all_internal(false, None, None);
361    }
362
363    /// Advance the ingestion window by ingesting `step` new records.
364    pub fn advance(&mut self, step: usize) {
365        self.refresh_all_internal(false, Some(step), None);
366    }
367
368    /// Advance the ingestion window by ingesting `step` new records with weights.
369    ///
370    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
371    /// source ID or a negative value.
372    pub fn advance_with_weights(
373        &mut self,
374        step: usize,
375        weights: &HashMap<SourceId, f32>,
376    ) -> Result<(), SamplerError> {
377        self.validate_weights(weights)?;
378        self.refresh_all_internal(false, Some(step), Some(weights));
379        Ok(())
380    }
381
382    /// Force refresh all registered sources, discarding buffered records.
383    pub fn force_refresh_all(&mut self) {
384        self.refresh_all_internal(true, None, None);
385    }
386
387    /// Refresh all registered sources once with per-call source weights.
388    ///
389    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
390    /// source ID or a negative value.
391    pub fn refresh_all_with_weights(
392        &mut self,
393        weights: &HashMap<SourceId, f32>,
394    ) -> Result<(), SamplerError> {
395        self.validate_weights(weights)?;
396        self.refresh_all_internal(false, None, Some(weights));
397        Ok(())
398    }
399
400    /// Force refresh all registered sources with per-call source weights.
401    ///
402    /// Returns `Err(SamplerError::InvalidWeight)` if `weights` contains an unregistered
403    /// source ID or a negative value.
404    pub fn force_refresh_all_with_weights(
405        &mut self,
406        weights: &HashMap<SourceId, f32>,
407    ) -> Result<(), SamplerError> {
408        self.validate_weights(weights)?;
409        self.refresh_all_internal(true, None, Some(weights));
410        Ok(())
411    }
412
413    fn validate_weights(&self, weights: &HashMap<SourceId, f32>) -> Result<(), SamplerError> {
414        let known_ids: std::collections::HashSet<&str> =
415            self.sources.iter().map(|s| s.source.id()).collect();
416        for (id, &w) in weights {
417            if !known_ids.contains(id.as_str()) {
418                return Err(SamplerError::InvalidWeight {
419                    source_id: id.clone(),
420                    reason: "source is not registered".to_string(),
421                });
422            }
423            if w < 0.0 {
424                return Err(SamplerError::InvalidWeight {
425                    source_id: id.clone(),
426                    reason: format!("weight {w} is negative"),
427                });
428            }
429        }
430        Ok(())
431    }
432
433    /// Rebuild the shared cache by round-robin draining per-source buffers.
434    ///
435    /// When `force_refresh` is false, each source only refreshes when its buffer
436    /// is empty; when true, all buffers are cleared and all sources refresh.
437    /// If `step` is provided, performs a rolling update of `step` records (no clear).
438    /// If `step` is None, clears the cache and fills up to max capacity.
439    fn refresh_all_internal(
440        &mut self,
441        force_refresh: bool,
442        step: Option<usize>,
443        weights: Option<&HashMap<SourceId, f32>>,
444    ) {
445        self.last_refreshed_sources.clear();
446        let mut refresh_plan = Vec::new();
447        for (idx, state) in self.sources.iter_mut().enumerate() {
448            if force_refresh {
449                state.buffer.clear();
450            }
451            if force_refresh || state.buffer.is_empty() {
452                refresh_plan.push((idx, state.cursor.clone()));
453            }
454        }
455
456        if !refresh_plan.is_empty() {
457            self.source_refresh_generation = self.source_refresh_generation.saturating_add(1);
458            self.last_refreshed_sources = refresh_plan
459                .iter()
460                .map(|(idx, _)| self.sources[*idx].source.id().to_string())
461                .collect();
462            let mut results: Vec<
463                Option<(Result<SourceSnapshot, SamplerError>, std::time::Duration)>,
464            > = Vec::with_capacity(self.sources.len());
465            results.resize_with(self.sources.len(), || None);
466            let fetch_limit = self.max_records;
467            let sampler_config = self.sampler_config.clone();
468            thread::scope(|scope| {
469                let mut handles = Vec::with_capacity(refresh_plan.len());
470                for (idx, cursor) in &refresh_plan {
471                    let source = &self.sources[*idx].source;
472                    let cursor = cursor.clone();
473                    let sampler_config = sampler_config.clone();
474                    let source_epoch = self.source_epoch;
475                    handles.push((
476                        *idx,
477                        scope.spawn(move || {
478                            let start = std::time::Instant::now();
479                            // XOR the source epoch into the seed so each epoch
480                            // produces a distinct permutation within the source.
481                            let epoch_config = SamplerConfig {
482                                seed: derive_epoch_seed(sampler_config.seed, source_epoch),
483                                ..sampler_config
484                            };
485                            let result =
486                                source.refresh(&epoch_config, cursor.as_ref(), Some(fetch_limit));
487                            let elapsed = start.elapsed();
488                            (result, elapsed)
489                        }),
490                    ));
491                }
492                for (idx, handle) in handles {
493                    let result = match handle.join() {
494                        Ok((result, elapsed)) => {
495                            debug!(
496                                source_id = %self.sources[idx].source.id(),
497                                refresh_ms = elapsed.as_millis(),
498                                "source refresh completed"
499                            );
500                            (result, elapsed)
501                        }
502                        Err(_) => (
503                            Err(SamplerError::SourceUnavailable {
504                                source_id: self.sources[idx].source.id().to_string(),
505                                reason: "source refresh thread panicked".into(),
506                            }),
507                            std::time::Duration::from_secs(0),
508                        ),
509                    };
510                    results[idx] = Some(result);
511                }
512            });
513
514            for (idx, _) in refresh_plan {
515                let Some((result, elapsed)) = results[idx].take() else {
516                    continue;
517                };
518                match result {
519                    Ok(snapshot) => {
520                        let SourceSnapshot {
521                            records,
522                            cursor: next_cursor,
523                        } = snapshot;
524                        let record_count = records.len();
525                        let seconds = elapsed.as_secs_f64();
526                        let per_sec = if seconds > 0.0 {
527                            (record_count as f64) / seconds
528                        } else {
529                            0.0
530                        };
531                        let stats = &mut self.sources[idx].stats;
532                        stats.last_refresh_ms = elapsed.as_millis();
533                        stats.last_record_count = record_count;
534                        stats.last_records_per_sec = per_sec;
535                        stats.last_error = None;
536                        debug!(
537                            source_id = %self.sources[idx].source.id(),
538                            record_count,
539                            refresh_ms = elapsed.as_millis(),
540                            records_per_sec = per_sec,
541                            "source refresh ingested records"
542                        );
543                        let source_id = self.sources[idx].source.id().to_string();
544                        let normalized = records
545                            .into_iter()
546                            .map(|mut record| {
547                                record.source = source_id.clone();
548                                record
549                            })
550                            .collect::<Vec<_>>();
551                        self.sources[idx].buffer.extend(normalized);
552                        self.sources[idx].cursor = Some(next_cursor);
553                    }
554                    Err(err) => {
555                        let stats = &mut self.sources[idx].stats;
556                        stats.last_refresh_ms = elapsed.as_millis();
557                        stats.last_record_count = 0;
558                        stats.last_records_per_sec = 0.0;
559                        stats.last_error = Some(err.to_string());
560                        stats.error_count = stats.error_count.saturating_add(1);
561                        eprintln!(
562                            "[data_sampler] source '{}' refresh failed: {}",
563                            self.sources[idx].source.id(),
564                            err
565                        );
566                    }
567                }
568            }
569        }
570
571        // On a full refresh (step=None) clear every per-source cache so that the
572        // snapshot always reflects the newest window, matching the previous
573        // shared-cache clear semantics.
574        if step.is_none() {
575            for state in self.sources.iter_mut() {
576                state.cache.clear();
577            }
578        }
579        if self.max_records == 0 {
580            return;
581        }
582        let target_limit = step.unwrap_or(self.max_records);
583        if let Some(weights) = weights {
584            self.weighted_drain_into_caches(target_limit, weights);
585        } else {
586            // Fair round-robin drain: start from `drain_start` instead of 0 so
587            // that the drain cursor rotates across cycles.  This prevents head
588            // sources (low indices) from always draining faster than tail sources,
589            // which was starving tail sources of refresh opportunities.
590            let n = self.sources.len();
591            if n > 0 {
592                let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); n];
593                let mut total_drained = 0;
594                let mut any_remaining = true;
595                while total_drained < target_limit && any_remaining {
596                    any_remaining = false;
597                    for offset in 0..n {
598                        if total_drained >= target_limit {
599                            break;
600                        }
601                        let idx = (self.drain_start + offset) % n;
602                        if let Some(record) = self.sources[idx].buffer.pop_front() {
603                            per_source[idx].push(record);
604                            total_drained += 1;
605                            any_remaining = true;
606                        }
607                    }
608                }
609                // Advance the drain cursor so the next cycle starts from a different
610                // position.  Only advance when at least one record was drained, so a
611                // burst of drain-noop cycles on an empty source list doesn't rotate.
612                if total_drained > 0 {
613                    self.drain_start = (self.drain_start + 1) % n;
614                }
615                for (idx, batch) in per_source.into_iter().enumerate() {
616                    if !batch.is_empty() {
617                        self.sources[idx].cache.ingest(batch);
618                    }
619                }
620            }
621        }
622    }
623
624    fn weighted_drain_into_caches(&mut self, limit: usize, weights: &HashMap<SourceId, f32>) {
625        let len = self.sources.len();
626        if len == 0 {
627            return;
628        }
629        let mut weight_values = Vec::with_capacity(len);
630        let mut any_positive = false;
631        for state in &self.sources {
632            let weight = weights.get(state.source.id()).copied().unwrap_or(1.0);
633            if weight > 0.0 {
634                any_positive = true;
635            }
636            weight_values.push(weight);
637        }
638        if !any_positive {
639            weight_values.fill(1.0);
640        }
641
642        let mut current = vec![0.0f32; len];
643        let mut per_source: Vec<Vec<DataRecord>> = vec![Vec::new(); len];
644        let mut total = 0;
645        while total < limit {
646            let mut total_weight = 0.0f32;
647            for (idx, weight) in weight_values.iter().copied().enumerate().take(len) {
648                if weight <= 0.0 {
649                    continue;
650                }
651                if self.sources[idx].buffer.is_empty() {
652                    continue;
653                }
654                total_weight += weight;
655            }
656            if total_weight == 0.0 {
657                break;
658            }
659
660            let mut best_idx = None;
661            let mut best_score = f32::MIN;
662            // Rotating tie-breaker: when scores are equal, prefer the source
663            // just PAST drain_start in rotation order (i.e. the source whose
664            // turn is coming up in the round-robin).
665            let closer_to_start = |a: usize, b: usize| -> bool {
666                let da = (a + len - self.drain_start) % len;
667                let db = (b + len - self.drain_start) % len;
668                da < db
669            };
670            for idx in 0..len {
671                if weight_values[idx] <= 0.0 {
672                    continue;
673                }
674                if self.sources[idx].buffer.is_empty() {
675                    continue;
676                }
677                current[idx] += weight_values[idx];
678                let is_better = if current[idx] > best_score {
679                    true
680                } else if current[idx] == best_score {
681                    closer_to_start(idx, best_idx.unwrap_or(0))
682                } else {
683                    false
684                };
685                if is_better {
686                    best_score = current[idx];
687                    best_idx = Some(idx);
688                }
689            }
690
691            let idx = match best_idx {
692                Some(idx) => idx,
693                None => break,
694            };
695            current[idx] -= total_weight;
696            if let Some(record) = self.sources[idx].buffer.pop_front() {
697                per_source[idx].push(record);
698                total += 1;
699            }
700        }
701
702        if total > 0 && len > 0 {
703            self.drain_start = (self.drain_start + 1) % len;
704        }
705
706        for (idx, batch) in per_source.into_iter().enumerate() {
707            if !batch.is_empty() {
708                self.sources[idx].cache.ingest(batch);
709            }
710        }
711    }
712
713    /// Returns `true` when at least one source is registered.
714    pub fn has_sources(&self) -> bool {
715        !self.sources.is_empty()
716    }
717}
718
719/// Per-source ingestion runtime state.
720struct SourceState {
721    source: Box<dyn DataSource + 'static>,
722    cursor: Option<SourceCursor>,
723    buffer: VecDeque<DataRecord>,
724    /// Per-source LRU record cache capped at `max_records`.
725    cache: RecordCache,
726    stats: SourceRefreshStats,
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::TripletSampler;
733    use crate::config::{Selector, TextRecipe, TripletRecipe};
734    use crate::data::{QualityScore, RecordSection, SectionRole};
735    use crate::sampler::Sampler;
736    use crate::splits::{DeterministicSplitStore, SplitLabel, SplitRatios};
737    use chrono::Utc;
738    use std::collections::HashMap;
739    use std::collections::VecDeque;
740    use std::sync::atomic::{AtomicUsize, Ordering};
741    use std::sync::{Arc, Mutex};
742
743    fn make_record(id: &str, source: &str) -> DataRecord {
744        let now = Utc::now();
745        DataRecord {
746            id: id.to_string(),
747            source: source.to_string(),
748            created_at: now,
749            updated_at: now,
750            quality: QualityScore { trust: 1.0 },
751            taxonomy: Vec::new(),
752            sections: vec![RecordSection {
753                role: SectionRole::Anchor,
754                heading: None,
755                text: id.to_string(),
756                sentences: vec![id.to_string()],
757            }],
758            meta_prefix: None,
759        }
760    }
761
762    struct ScriptedSource {
763        id: String,
764        refreshes: Arc<AtomicUsize>,
765        script: Arc<Mutex<VecDeque<Result<SourceSnapshot, SamplerError>>>>,
766    }
767
768    impl ScriptedSource {
769        fn new(
770            id: &str,
771            refreshes: Arc<AtomicUsize>,
772            script: Vec<Result<SourceSnapshot, SamplerError>>,
773        ) -> Self {
774            Self {
775                id: id.to_string(),
776                refreshes,
777                script: Arc::new(Mutex::new(script.into_iter().collect())),
778            }
779        }
780    }
781
782    impl DataSource for ScriptedSource {
783        fn id(&self) -> &str {
784            &self.id
785        }
786
787        fn refresh(
788            &self,
789            _config: &SamplerConfig,
790            _cursor: Option<&SourceCursor>,
791            _limit: Option<usize>,
792        ) -> Result<SourceSnapshot, SamplerError> {
793            self.refreshes.fetch_add(1, Ordering::SeqCst);
794            let mut guard = self.script.lock().expect("script lock poisoned");
795            guard.pop_front().unwrap_or_else(|| {
796                Ok(SourceSnapshot {
797                    records: Vec::new(),
798                    cursor: SourceCursor {
799                        last_seen: Utc::now(),
800                        revision: 0,
801                    },
802                })
803            })
804        }
805
806        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
807            Ok(0)
808        }
809
810        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
811            Vec::new()
812        }
813    }
814
815    struct PanicSource {
816        id: String,
817    }
818
819    impl DataSource for PanicSource {
820        fn id(&self) -> &str {
821            &self.id
822        }
823
824        fn refresh(
825            &self,
826            _config: &SamplerConfig,
827            _cursor: Option<&SourceCursor>,
828            _limit: Option<usize>,
829        ) -> Result<SourceSnapshot, SamplerError> {
830            panic!("panic source refresh")
831        }
832
833        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
834            Ok(0)
835        }
836    }
837
838    #[test]
839    fn record_cache_waits_len_and_clear_paths_are_covered() {
840        let cache = RecordCache::new(2);
841        assert!(cache.is_empty());
842        assert_eq!(cache.len(), 0);
843        assert_eq!(cache.ingest_count(), 0);
844
845        cache.ingest(Vec::<DataRecord>::new());
846        assert_eq!(cache.wait_for_ingest(0, Duration::from_millis(1)), 0);
847
848        let cache_for_waiter = cache.clone();
849        let handle = std::thread::spawn(move || cache_for_waiter.wait_for_ingest_blocking(0));
850        std::thread::sleep(Duration::from_millis(5));
851        cache.ingest(vec![make_record("r1", "s")]);
852        assert_eq!(handle.join().unwrap(), 1);
853        assert_eq!(cache.ingest_count(), 1);
854
855        cache.ingest(vec![make_record("r2", "s"), make_record("r3", "s")]);
856        assert_eq!(cache.len(), 2);
857        let ids: Vec<String> = cache
858            .snapshot()
859            .into_iter()
860            .map(|record| record.id)
861            .collect();
862        assert!(ids.contains(&"r2".to_string()));
863        assert!(ids.contains(&"r3".to_string()));
864
865        cache.clear();
866        assert!(cache.is_empty());
867    }
868
869    #[test]
870    fn record_cache_zero_limit_discards_everything() {
871        let cache = RecordCache::new(0);
872        cache.ingest(vec![make_record("r1", "s")]);
873        assert!(cache.is_empty());
874        assert_eq!(cache.len(), 0);
875    }
876
877    #[test]
878    fn manager_loads_and_snapshots_cursors_and_reports_has_sources() {
879        let mut manager = IngestionManager::new(4, SamplerConfig::default());
880        assert!(!manager.has_sources());
881        manager.load_cursors(&[]);
882
883        let refreshes = Arc::new(AtomicUsize::new(0));
884        manager.register_source(Box::new(ScriptedSource::new(
885            "cursor_source",
886            refreshes,
887            vec![Ok(SourceSnapshot {
888                records: vec![make_record("id_1", "original_source")],
889                cursor: SourceCursor {
890                    last_seen: Utc::now(),
891                    revision: 33,
892                },
893            })],
894        )));
895        assert!(manager.has_sources());
896
897        manager.load_cursors(&[("cursor_source".to_string(), 7)]);
898        let cursors = manager.snapshot_cursors();
899        assert_eq!(cursors, vec![("cursor_source".to_string(), 7)]);
900
901        manager.refresh_all();
902        let updated = manager.snapshot_cursors();
903        assert_eq!(updated, vec![("cursor_source".to_string(), 33)]);
904        let records = manager.all_records_snapshot();
905        assert_eq!(records.len(), 1);
906        assert_eq!(records[0].source, "cursor_source");
907    }
908
909    #[test]
910    fn advance_uses_buffer_before_refreshing_again() {
911        let refreshes = Arc::new(AtomicUsize::new(0));
912        let mut manager = IngestionManager::new(5, SamplerConfig::default());
913        manager.register_source(Box::new(ScriptedSource::new(
914            "buffered",
915            refreshes.clone(),
916            vec![Ok(SourceSnapshot {
917                records: vec![
918                    make_record("a", "x"),
919                    make_record("b", "x"),
920                    make_record("c", "x"),
921                ],
922                cursor: SourceCursor {
923                    last_seen: Utc::now(),
924                    revision: 1,
925                },
926            })],
927        )));
928
929        manager.advance(1);
930        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
931        assert_eq!(manager.all_records_len(), 1);
932
933        manager.advance(1);
934        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
935        assert_eq!(manager.all_records_len(), 2);
936    }
937
938    #[test]
939    fn force_refresh_clears_buffer_and_fetches_again() {
940        let refreshes = Arc::new(AtomicUsize::new(0));
941        let mut manager = IngestionManager::new(4, SamplerConfig::default());
942        manager.register_source(Box::new(ScriptedSource::new(
943            "force",
944            refreshes.clone(),
945            vec![
946                Ok(SourceSnapshot {
947                    records: vec![
948                        make_record("r1", "x"),
949                        make_record("r2", "x"),
950                        make_record("r3", "x"),
951                    ],
952                    cursor: SourceCursor {
953                        last_seen: Utc::now(),
954                        revision: 10,
955                    },
956                }),
957                Ok(SourceSnapshot {
958                    records: vec![make_record("r4", "x")],
959                    cursor: SourceCursor {
960                        last_seen: Utc::now(),
961                        revision: 11,
962                    },
963                }),
964            ],
965        )));
966
967        manager.advance(1);
968        assert_eq!(manager.all_records_len(), 1);
969        assert_eq!(refreshes.load(Ordering::SeqCst), 1);
970
971        manager.force_refresh_all();
972        assert_eq!(refreshes.load(Ordering::SeqCst), 2);
973        let records = manager.all_records_snapshot();
974        assert_eq!(records.len(), 1);
975        assert_eq!(records[0].id, "r4");
976    }
977
978    #[test]
979    fn weighted_drain_respects_zero_and_fallback_weights() {
980        let mut manager = IngestionManager::new(6, SamplerConfig::default());
981        manager.register_source(Box::new(ScriptedSource::new(
982            "a",
983            Arc::new(AtomicUsize::new(0)),
984            vec![Ok(SourceSnapshot {
985                records: vec![make_record("a1", "a"), make_record("a2", "a")],
986                cursor: SourceCursor {
987                    last_seen: Utc::now(),
988                    revision: 1,
989                },
990            })],
991        )));
992        manager.register_source(Box::new(ScriptedSource::new(
993            "b",
994            Arc::new(AtomicUsize::new(0)),
995            vec![Ok(SourceSnapshot {
996                records: vec![make_record("b1", "b"), make_record("b2", "b")],
997                cursor: SourceCursor {
998                    last_seen: Utc::now(),
999                    revision: 1,
1000                },
1001            })],
1002        )));
1003
1004        let mut only_b = HashMap::new();
1005        only_b.insert("a".to_string(), 0.0);
1006        only_b.insert("b".to_string(), 1.0);
1007        manager.refresh_all_with_weights(&only_b).unwrap();
1008        let ids: Vec<String> = manager
1009            .all_records_snapshot()
1010            .into_iter()
1011            .map(|record| record.id)
1012            .collect();
1013        assert!(ids.iter().all(|id| id.starts_with('b')));
1014
1015        let mut manager_fallback = IngestionManager::new(6, SamplerConfig::default());
1016        manager_fallback.register_source(Box::new(ScriptedSource::new(
1017            "a",
1018            Arc::new(AtomicUsize::new(0)),
1019            vec![Ok(SourceSnapshot {
1020                records: vec![make_record("a1", "a")],
1021                cursor: SourceCursor {
1022                    last_seen: Utc::now(),
1023                    revision: 2,
1024                },
1025            })],
1026        )));
1027        manager_fallback.register_source(Box::new(ScriptedSource::new(
1028            "b",
1029            Arc::new(AtomicUsize::new(0)),
1030            vec![Ok(SourceSnapshot {
1031                records: vec![make_record("b1", "b")],
1032                cursor: SourceCursor {
1033                    last_seen: Utc::now(),
1034                    revision: 2,
1035                },
1036            })],
1037        )));
1038
1039        let mut all_zero = HashMap::new();
1040        all_zero.insert("a".to_string(), 0.0);
1041        all_zero.insert("b".to_string(), 0.0);
1042        manager_fallback
1043            .refresh_all_with_weights(&all_zero)
1044            .unwrap();
1045        let ids: Vec<String> = manager_fallback
1046            .all_records_snapshot()
1047            .into_iter()
1048            .map(|record| record.id)
1049            .collect();
1050        assert!(ids.contains(&"a1".to_string()));
1051        assert!(ids.contains(&"b1".to_string()));
1052    }
1053
1054    #[test]
1055    fn refresh_errors_and_panics_update_source_stats() {
1056        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1057        manager.register_source(Box::new(ScriptedSource::new(
1058            "err_source",
1059            Arc::new(AtomicUsize::new(0)),
1060            vec![Err(SamplerError::SourceUnavailable {
1061                source_id: "err_source".to_string(),
1062                reason: "boom".to_string(),
1063            })],
1064        )));
1065        manager.register_source(Box::new(PanicSource {
1066            id: "panic_source".to_string(),
1067        }));
1068
1069        manager.refresh_all();
1070        let stats = manager.source_refresh_stats();
1071        let err_stats = stats
1072            .iter()
1073            .find(|(source, _)| source == "err_source")
1074            .map(|(_, stats)| stats)
1075            .unwrap();
1076        assert_eq!(err_stats.error_count, 1);
1077        assert!(
1078            err_stats
1079                .last_error
1080                .as_ref()
1081                .is_some_and(|msg| msg.contains("boom"))
1082        );
1083
1084        let panic_stats = stats
1085            .iter()
1086            .find(|(source, _)| source == "panic_source")
1087            .map(|(_, stats)| stats)
1088            .unwrap();
1089        assert_eq!(panic_stats.error_count, 1);
1090        assert!(
1091            panic_stats
1092                .last_error
1093                .as_ref()
1094                .is_some_and(|msg| msg.contains("panicked"))
1095        );
1096    }
1097
1098    #[test]
1099    fn force_refresh_with_weights_path_is_exercised() {
1100        let mut manager = IngestionManager::new(3, SamplerConfig::default());
1101        manager.register_source(Box::new(ScriptedSource::new(
1102            "w",
1103            Arc::new(AtomicUsize::new(0)),
1104            vec![Ok(SourceSnapshot {
1105                records: vec![make_record("w1", "w")],
1106                cursor: SourceCursor {
1107                    last_seen: Utc::now(),
1108                    revision: 3,
1109                },
1110            })],
1111        )));
1112
1113        let mut weights = HashMap::new();
1114        weights.insert("w".to_string(), 1.0);
1115        manager.force_refresh_all_with_weights(&weights).unwrap();
1116        assert_eq!(manager.all_records_len(), 1);
1117    }
1118
1119    #[test]
1120    fn advance_with_weights_rejects_unknown_source() {
1121        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1122        manager.register_source(Box::new(ScriptedSource::new(
1123            "known",
1124            Arc::new(AtomicUsize::new(0)),
1125            vec![],
1126        )));
1127
1128        let mut weights = HashMap::new();
1129        weights.insert("known".to_string(), 1.0);
1130        weights.insert("unknown".to_string(), 0.5);
1131
1132        let err = manager.advance_with_weights(1, &weights).unwrap_err();
1133        assert!(
1134            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "unknown"),
1135            "expected InvalidWeight for unknown source, got {err:?}"
1136        );
1137    }
1138
1139    #[test]
1140    fn refresh_all_with_weights_rejects_negative_weight() {
1141        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1142        manager.register_source(Box::new(ScriptedSource::new(
1143            "src",
1144            Arc::new(AtomicUsize::new(0)),
1145            vec![],
1146        )));
1147
1148        let mut weights = HashMap::new();
1149        weights.insert("src".to_string(), -1.0);
1150
1151        let err = manager.refresh_all_with_weights(&weights).unwrap_err();
1152        assert!(
1153            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "src"),
1154            "expected InvalidWeight for negative weight, got {err:?}"
1155        );
1156    }
1157
1158    #[test]
1159    fn force_refresh_all_with_weights_rejects_unknown_source() {
1160        let mut manager = IngestionManager::new(4, SamplerConfig::default());
1161        manager.register_source(Box::new(ScriptedSource::new(
1162            "real",
1163            Arc::new(AtomicUsize::new(0)),
1164            vec![],
1165        )));
1166
1167        let mut weights = HashMap::new();
1168        weights.insert("ghost".to_string(), 1.0);
1169
1170        let err = manager
1171            .force_refresh_all_with_weights(&weights)
1172            .unwrap_err();
1173        assert!(
1174            matches!(err, SamplerError::InvalidWeight { ref source_id, .. } if source_id == "ghost"),
1175            "expected InvalidWeight for unknown source, got {err:?}"
1176        );
1177    }
1178
1179    /// A source that records the `config.seed` value it receives on each `refresh()` call.
1180    struct SeedCapturingSource {
1181        id: String,
1182        received_seeds: Arc<Mutex<Vec<u64>>>,
1183    }
1184
1185    impl SeedCapturingSource {
1186        fn new(id: &str, received_seeds: Arc<Mutex<Vec<u64>>>) -> Self {
1187            Self {
1188                id: id.to_string(),
1189                received_seeds,
1190            }
1191        }
1192    }
1193
1194    impl DataSource for SeedCapturingSource {
1195        fn id(&self) -> &str {
1196            &self.id
1197        }
1198
1199        fn refresh(
1200            &self,
1201            config: &SamplerConfig,
1202            _cursor: Option<&SourceCursor>,
1203            _limit: Option<usize>,
1204        ) -> Result<SourceSnapshot, SamplerError> {
1205            self.received_seeds
1206                .lock()
1207                .expect("seed lock poisoned")
1208                .push(config.seed);
1209            Ok(SourceSnapshot {
1210                records: Vec::new(),
1211                cursor: SourceCursor {
1212                    last_seen: Utc::now(),
1213                    revision: 0,
1214                },
1215            })
1216        }
1217
1218        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1219            Ok(0)
1220        }
1221
1222        fn default_triplet_recipes(&self) -> Vec<crate::config::TripletRecipe> {
1223            Vec::new()
1224        }
1225    }
1226
1227    #[test]
1228    fn source_epoch_xor_changes_seed_received_by_source() {
1229        // Verify that derive_epoch_seed(base, epoch) is actually threaded through to
1230        // the source's refresh() call, and that different epochs produce different seeds.
1231        let base_seed = 0xDEAD_BEEF_u64;
1232        let config = SamplerConfig {
1233            seed: base_seed,
1234            ..SamplerConfig::default()
1235        };
1236
1237        let seeds_epoch0 = Arc::new(Mutex::new(Vec::<u64>::new()));
1238        let seeds_epoch1 = Arc::new(Mutex::new(Vec::<u64>::new()));
1239
1240        // --- epoch 0 ---
1241        let mut manager = IngestionManager::new(4, config.clone());
1242        manager.register_source(Box::new(SeedCapturingSource::new(
1243            "src",
1244            Arc::clone(&seeds_epoch0),
1245        )));
1246        // source_epoch defaults to 0; refresh_all passes derive_epoch_seed(base, 0)
1247        manager.refresh_all();
1248
1249        // --- epoch 1 ---
1250        let mut manager2 = IngestionManager::new(4, config.clone());
1251        manager2.register_source(Box::new(SeedCapturingSource::new(
1252            "src",
1253            Arc::clone(&seeds_epoch1),
1254        )));
1255        manager2.set_source_epoch(1);
1256        manager2.refresh_all();
1257
1258        let received0 = seeds_epoch0.lock().unwrap();
1259        let received1 = seeds_epoch1.lock().unwrap();
1260
1261        assert!(!received0.is_empty(), "epoch-0 source was never refreshed");
1262        assert!(!received1.is_empty(), "epoch-1 source was never refreshed");
1263
1264        let seed_at_epoch0 = received0[0];
1265        let seed_at_epoch1 = received1[0];
1266
1267        // The seeds must differ — epoch XOR has a real effect.
1268        assert_ne!(
1269            seed_at_epoch0, seed_at_epoch1,
1270            "epoch 0 and epoch 1 both produced seed {seed_at_epoch0:#x}; \
1271             derive_epoch_seed is not reaching the source"
1272        );
1273
1274        // They must match the expected derive_epoch_seed values.
1275        assert_eq!(
1276            seed_at_epoch0,
1277            derive_epoch_seed(base_seed, 0),
1278            "epoch-0 seed mismatch"
1279        );
1280        assert_eq!(
1281            seed_at_epoch1,
1282            derive_epoch_seed(base_seed, 1),
1283            "epoch-1 seed mismatch"
1284        );
1285    }
1286
1287    #[test]
1288    fn scripted_and_panic_sources_cover_default_trait_paths() {
1289        let refreshes = Arc::new(AtomicUsize::new(0));
1290        let scripted = ScriptedSource::new("scripted", refreshes, vec![]);
1291
1292        // Empty script falls back to an empty snapshot.
1293        let snapshot = scripted
1294            .refresh(&SamplerConfig::default(), None, None)
1295            .expect("fallback snapshot");
1296        assert!(snapshot.records.is_empty());
1297        assert_eq!(snapshot.cursor.revision, 0);
1298
1299        assert_eq!(
1300            scripted
1301                .reported_record_count(&SamplerConfig::default())
1302                .expect("record count"),
1303            0
1304        );
1305        assert!(scripted.default_triplet_recipes().is_empty());
1306
1307        let panic_source = PanicSource {
1308            id: "panic_count".to_string(),
1309        };
1310        assert_eq!(
1311            panic_source
1312                .reported_record_count(&SamplerConfig::default())
1313                .expect("record count"),
1314            0
1315        );
1316    }
1317
1318    #[test]
1319    fn seed_capturing_source_trait_defaults_are_exercised() {
1320        let source = SeedCapturingSource::new("seed_defaults", Arc::new(Mutex::new(Vec::new())));
1321        assert_eq!(
1322            source
1323                .reported_record_count(&SamplerConfig::default())
1324                .expect("record count"),
1325            0
1326        );
1327        assert!(source.default_triplet_recipes().is_empty());
1328    }
1329
1330    #[test]
1331    fn refresh_paths_handle_zero_capacity_and_no_sources() {
1332        let mut manager = IngestionManager::new(0, SamplerConfig::default());
1333        manager.register_source(Box::new(ScriptedSource::new(
1334            "zero_capacity",
1335            Arc::new(AtomicUsize::new(0)),
1336            vec![Ok(SourceSnapshot {
1337                records: vec![make_record("r1", "zero_capacity")],
1338                cursor: SourceCursor {
1339                    last_seen: Utc::now(),
1340                    revision: 1,
1341                },
1342            })],
1343        )));
1344        manager.refresh_all();
1345        assert!(manager.all_caches_empty());
1346
1347        // Weighted refresh with no sources should be a no-op.
1348        let mut empty_manager = IngestionManager::new(4, SamplerConfig::default());
1349        let empty_weights = HashMap::new();
1350        empty_manager
1351            .refresh_all_with_weights(&empty_weights)
1352            .expect("no sources should not error");
1353        assert!(empty_manager.all_caches_empty());
1354    }
1355
1356    #[test]
1357    fn drain_start_rotates_fairly_across_sources() {
1358        // Create 3 sources, each with 10 records in their buffer after refresh.
1359        // The fair round-robin should ensure all 3 drain at the same rate
1360        // over multiple advance cycles.
1361        struct FairSource {
1362            id: String,
1363            refresh_count: Arc<AtomicUsize>,
1364        }
1365
1366        impl DataSource for FairSource {
1367            fn id(&self) -> &str {
1368                &self.id
1369            }
1370            fn refresh(
1371                &self,
1372                _config: &SamplerConfig,
1373                _cursor: Option<&SourceCursor>,
1374                _limit: Option<usize>,
1375            ) -> Result<SourceSnapshot, SamplerError> {
1376                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1377                Ok(SourceSnapshot {
1378                    records: (0..10)
1379                        .map(|i| make_record(&format!("r{i}"), &self.id))
1380                        .collect(),
1381                    cursor: SourceCursor {
1382                        last_seen: Utc::now(),
1383                        revision: 1,
1384                    },
1385                })
1386            }
1387            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1388                Ok(10)
1389            }
1390        }
1391
1392        let counts = (
1393            Arc::new(AtomicUsize::new(0)),
1394            Arc::new(AtomicUsize::new(0)),
1395            Arc::new(AtomicUsize::new(0)),
1396        );
1397
1398        let mut manager = IngestionManager::new(30, SamplerConfig::default());
1399        manager.register_source(Box::new(FairSource {
1400            id: "src_0".to_string(),
1401            refresh_count: Arc::clone(&counts.0),
1402        }));
1403        manager.register_source(Box::new(FairSource {
1404            id: "src_1".to_string(),
1405            refresh_count: Arc::clone(&counts.1),
1406        }));
1407        manager.register_source(Box::new(FairSource {
1408            id: "src_2".to_string(),
1409            refresh_count: Arc::clone(&counts.2),
1410        }));
1411
1412        // First refresh_all fills all buffers.
1413        manager.refresh_all();
1414        // All 3 refreshed once.
1415        assert_eq!(counts.0.load(Ordering::SeqCst), 1);
1416        assert_eq!(counts.1.load(Ordering::SeqCst), 1);
1417        assert_eq!(counts.2.load(Ordering::SeqCst), 1);
1418
1419        // Each advance(1) drains 1 record from 1 source, rotating via
1420        // drain_start.  With 3 sources, after 3 advances each source
1421        // loses 1 record.  After 30 advances each source loses 10 records
1422        // and triggers a refresh.  Run 33 advances and check all 3 refreshed
1423        // roughly the same number of times.
1424        for _ in 0..33 {
1425            manager.advance(1);
1426        }
1427
1428        let r0 = counts.0.load(Ordering::SeqCst);
1429        let r1 = counts.1.load(Ordering::SeqCst);
1430        let r2 = counts.2.load(Ordering::SeqCst);
1431
1432        // Each had 10 records after initial refresh_all.
1433        // 10 records / (1 drained per 3 cycles) = 30 cycles to drain each buffer.
1434        // After 33 more advances each buffer emptied ~1 time and re-filled, so
1435        // each source should have refreshed ~2 times total (initial + 1 drain).
1436        // The exact count can vary by 1 due to timing, but all 3 must be within
1437        // 1 of each other — no source can be starved.
1438        let min = r0.min(r1).min(r2);
1439        let max = r0.max(r1).max(r2);
1440        assert!(
1441            max <= min + 1,
1442            "sources should refresh at roughly the same rate: got {r0}/{r1}/{r2}"
1443        );
1444    }
1445
1446    #[test]
1447    fn direct_drain_start_rotates_fairly_with_batch_2_of_5() {
1448        // Direct IngestionManager test with batch_size=2, 5 sources.
1449        // Isolates the drain_start rotation from the sampler pipeline.
1450        struct SimpleSource {
1451            id: String,
1452            refresh_count: Arc<AtomicUsize>,
1453        }
1454
1455        impl DataSource for SimpleSource {
1456            fn id(&self) -> &str {
1457                &self.id
1458            }
1459            fn refresh(
1460                &self,
1461                _: &SamplerConfig,
1462                _: Option<&SourceCursor>,
1463                _: Option<usize>,
1464            ) -> Result<SourceSnapshot, SamplerError> {
1465                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1466                let now = Utc::now();
1467                let records: Vec<DataRecord> = (0..8)
1468                    .map(|i| DataRecord {
1469                        id: format!("{}_r{i}", self.id),
1470                        source: self.id.clone(),
1471                        created_at: now,
1472                        updated_at: now,
1473                        quality: QualityScore { trust: 1.0 },
1474                        taxonomy: Vec::new(),
1475                        sections: Vec::new(),
1476                        meta_prefix: None,
1477                    })
1478                    .collect();
1479                Ok(SourceSnapshot {
1480                    records,
1481                    cursor: SourceCursor {
1482                        last_seen: now,
1483                        revision: 1,
1484                    },
1485                })
1486            }
1487            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1488                Ok(8)
1489            }
1490        }
1491
1492        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1493        let mut manager = IngestionManager::new(40, SamplerConfig::default());
1494        for (i, count) in counts.iter().enumerate() {
1495            manager.register_source(Box::new(SimpleSource {
1496                id: format!("src_{i}"),
1497                refresh_count: Arc::clone(count),
1498            }));
1499        }
1500
1501        manager.refresh_all();
1502        for c in &counts {
1503            assert_eq!(c.load(Ordering::SeqCst), 1);
1504        }
1505
1506        for _ in 0..80 {
1507            manager.advance(2);
1508        }
1509
1510        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1511        // Round-robin drain with 5 sources, 8 records each, batch_size=2, 80
1512        // advance calls.  Most sources refresh 5×; sources 2-4 get one extra
1513        // cycle (6) due to the rotating drain_start tie-breaker positioning.
1514        assert_eq!(
1515            totals,
1516            vec![5, 5, 6, 6, 6],
1517            "direct manager: unexpected refresh distribution"
1518        );
1519    }
1520
1521    /// Helper: create 5 sources with 8 records each and a sampler configured for
1522    /// text batches with batch_size=2.  Returns the sampler and refresh counters.
1523    fn make_five_source_sampler(
1524        counts: &[Arc<AtomicUsize>],
1525    ) -> TripletSampler<DeterministicSplitStore> {
1526        struct Tracked {
1527            id: String,
1528            refresh_count: Arc<AtomicUsize>,
1529        }
1530        impl DataSource for Tracked {
1531            fn id(&self) -> &str {
1532                &self.id
1533            }
1534            fn refresh(
1535                &self,
1536                _: &SamplerConfig,
1537                _: Option<&SourceCursor>,
1538                _: Option<usize>,
1539            ) -> Result<SourceSnapshot, SamplerError> {
1540                self.refresh_count.fetch_add(1, Ordering::SeqCst);
1541                let now = Utc::now();
1542                let records: Vec<DataRecord> = (0..8)
1543                    .map(|i| DataRecord {
1544                        id: format!("{}_r{i}", self.id),
1545                        source: self.id.clone(),
1546                        created_at: now,
1547                        updated_at: now,
1548                        quality: QualityScore { trust: 1.0 },
1549                        taxonomy: Vec::new(),
1550                        sections: vec![RecordSection {
1551                            role: SectionRole::Anchor,
1552                            heading: None,
1553                            text: format!("x{i}"),
1554                            sentences: vec![format!("x{i}")],
1555                        }],
1556                        meta_prefix: None,
1557                    })
1558                    .collect();
1559                Ok(SourceSnapshot {
1560                    records,
1561                    cursor: SourceCursor {
1562                        last_seen: now,
1563                        revision: 1,
1564                    },
1565                })
1566            }
1567            fn reported_record_count(&self, _: &SamplerConfig) -> Result<u128, SamplerError> {
1568                Ok(8)
1569            }
1570        }
1571
1572        let config = SamplerConfig {
1573            batch_size: 2,
1574            text_recipes: vec![TextRecipe {
1575                name: "anchor".into(),
1576                selector: Selector::Role(SectionRole::Anchor),
1577                weight: 1.0,
1578                instruction: None,
1579            }],
1580            split: SplitRatios {
1581                train: 1.0,
1582                validation: 0.0,
1583                test: 0.0,
1584            },
1585            allowed_splits: vec![SplitLabel::Train],
1586            ingestion_max_records: 40,
1587            ..SamplerConfig::default()
1588        };
1589        let store = Arc::new(DeterministicSplitStore::new(config.split, 99).unwrap());
1590        let sampler = TripletSampler::new(config, store);
1591
1592        for (i, count) in counts.iter().enumerate() {
1593            sampler.register_source(Box::new(Tracked {
1594                id: format!("src_{i}"),
1595                refresh_count: Arc::clone(count),
1596            }));
1597        }
1598        sampler
1599    }
1600
1601    #[test]
1602    fn sampler_unweighted_drain_distributes_evenly() {
1603        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1604        let sampler = make_five_source_sampler(&counts);
1605
1606        sampler.next_text_batch(SplitLabel::Train).unwrap();
1607        for _ in 0..80 {
1608            sampler.next_text_batch(SplitLabel::Train).unwrap();
1609        }
1610
1611        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1612        let min = *totals.iter().min().unwrap();
1613        let max = *totals.iter().max().unwrap();
1614        assert!(
1615            max <= min + 1,
1616            "unweighted: all sources must refresh at roughly the same rate: {totals:?}"
1617        );
1618        assert!(
1619            min >= 4,
1620            "unweighted: each source should have refreshed at least 4 times: {totals:?}"
1621        );
1622    }
1623
1624    #[test]
1625    fn sampler_weighted_drain_with_equal_weights_distributes_evenly() {
1626        // Same as unweighted but goes through the public weighted API with equal
1627        // weights for all 5 sources — verifies the weighted drain also rotates
1628        // fairly via drain_start bias.
1629        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1630        let sampler = make_five_source_sampler(&counts);
1631
1632        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1633
1634        sampler
1635            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1636            .unwrap();
1637        for _ in 0..80 {
1638            sampler
1639                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1640                .unwrap();
1641        }
1642
1643        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1644        let min = *totals.iter().min().unwrap();
1645        let max = *totals.iter().max().unwrap();
1646        assert!(
1647            max <= min + 1,
1648            "weighted (equal): all sources must refresh at roughly the same rate: {totals:?}"
1649        );
1650        assert!(
1651            min >= 4,
1652            "weighted (equal): each source should have refreshed at least 4 times: {totals:?}"
1653        );
1654    }
1655
1656    #[test]
1657    fn sampler_unweighted_and_weighted_match_distribution() {
1658        // Verify both paths produce similar refresh distributions.
1659        let uc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1660        let wc: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1661        let usampler = make_five_source_sampler(&uc);
1662        let wsampler = make_five_source_sampler(&wc);
1663
1664        let weights: HashMap<String, f32> = (0..5).map(|i| (format!("src_{i}"), 1.0)).collect();
1665
1666        usampler.next_text_batch(SplitLabel::Train).unwrap();
1667        wsampler
1668            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1669            .unwrap();
1670        for _ in 0..80 {
1671            usampler.next_text_batch(SplitLabel::Train).unwrap();
1672            wsampler
1673                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1674                .unwrap();
1675        }
1676
1677        let ut: Vec<usize> = uc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1678        let wt: Vec<usize> = wc.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1679        let umax = *ut.iter().max().unwrap();
1680        let umin = *ut.iter().min().unwrap();
1681        let wmax = *wt.iter().max().unwrap();
1682        let wmin = *wt.iter().min().unwrap();
1683        assert!(umax <= umin + 1, "unweighted: {ut:?}");
1684        assert!(wmax <= wmin + 1, "weighted equal: {wt:?}");
1685    }
1686
1687    #[test]
1688    fn sampler_weighted_drain_with_unequal_weights_respects_ratios() {
1689        // Source 3 gets weight 2.0, others get 1.0.  The weighted
1690        // proportional-fair drain must give source 3 proportionally more
1691        // refresh cycles than any weight-1.0 source.
1692        let counts: Vec<Arc<AtomicUsize>> = (0..5).map(|_| Arc::new(AtomicUsize::new(0))).collect();
1693        let sampler = make_five_source_sampler(&counts);
1694
1695        let mut weights = HashMap::new();
1696        for i in 0..5 {
1697            weights.insert(format!("src_{i}"), if i == 3 { 2.0f32 } else { 1.0 });
1698        }
1699
1700        sampler
1701            .next_text_batch_with_weights(SplitLabel::Train, &weights)
1702            .unwrap();
1703        for _ in 0..200 {
1704            sampler
1705                .next_text_batch_with_weights(SplitLabel::Train, &weights)
1706                .unwrap();
1707        }
1708
1709        let totals: Vec<usize> = counts.iter().map(|c| c.load(Ordering::SeqCst)).collect();
1710        // The highest-weight source (src_3, w=2.0) must have strictly more
1711        // refreshes than EVERY weight-1.0 source.
1712        // Note: src_4 (index 4, w=1.0) gets 12 refreshes while the other
1713        // weight-1.0 sources get 6-7.  This happens because the proportional-
1714        // fair scheduler uses drain_start as a rotating tie-breaker — being
1715        // positioned right after the high-weight source (src_3) gives src_4
1716        // extra wins when src_3's buffer empties and refills more frequently.
1717        assert!(
1718            totals
1719                .iter()
1720                .enumerate()
1721                .all(|(i, &t)| i == 3 || t < totals[3]),
1722            "src_3 (w=2.0) must outpace all w=1.0 sources (totals: {totals:?})"
1723        );
1724        // Lock in the exact deterministic distribution.
1725        assert_eq!(
1726            totals,
1727            vec![6, 7, 7, 26, 12],
1728            "unequal-weights: unexpected refresh distribution"
1729        );
1730    }
1731}