1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::collections::HashMap;
5use std::error::Error;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::Once;
9#[cfg(test)]
10use std::sync::OnceLock;
11use std::time::Instant;
12#[cfg(test)]
13use tempfile::TempDir;
14
15use cache_manager::CacheRoot;
16use clap::{Parser, ValueEnum, error::ErrorKind};
17
18use triplets_core::config::{ChunkingStrategy, SamplerConfig, TripletRecipe};
19use triplets_core::constants::cache::{MULTI_SOURCE_DEMO_GROUP, MULTI_SOURCE_DEMO_STORE_FILENAME};
20use triplets_core::data::ChunkView;
21use triplets_core::heuristics::{
22 CapacityTotals, EFFECTIVE_NEGATIVES_PER_ANCHOR, EFFECTIVE_POSITIVES_PER_ANCHOR,
23 estimate_source_split_capacity_from_counts, format_replay_factor, format_u128_with_commas,
24 resolve_text_recipes_for_source, split_counts_for_total,
25};
26use triplets_core::metrics::{chunk_proximity_score, source_skew, window_chunk_distance};
27use triplets_core::sampler::chunk_weight;
28use triplets_core::source::DataSource;
29use triplets_core::splits::{FileSplitStore, SplitLabel, SplitRatios, SplitStore};
30use triplets_core::{
31 RecordChunk, SampleBatch, Sampler, SamplerError, SourceId, TextBatch, TextRecipe, TripletBatch,
32 TripletSampler,
33};
34
35type DynSource = Box<dyn DataSource + 'static>;
36
37#[cfg(feature = "extended-metrics")]
38type MetricEntry = (f32, f32, f32, f32, f32);
39
40#[cfg(feature = "extended-metrics")]
41type SourceMetricsMap = HashMap<String, Vec<MetricEntry>>;
42
43fn managed_demo_split_store_path() -> Result<PathBuf, String> {
44 #[cfg(test)]
45 {
46 static TEST_CACHE_ROOT: OnceLock<TempDir> = OnceLock::new();
47 let root = TEST_CACHE_ROOT.get_or_init(|| {
48 TempDir::new().expect("failed to create test demo split-store cache root")
49 });
50 let cache_root = CacheRoot::from_root(root.path());
51 let group = PathBuf::from(MULTI_SOURCE_DEMO_GROUP);
52 let dir = cache_root.ensure_group(&group).map_err(|err| {
53 format!(
54 "failed creating managed demo cache group '{}': {err}",
55 group.display()
56 )
57 })?;
58 Ok(dir.join(MULTI_SOURCE_DEMO_STORE_FILENAME))
59 }
60
61 #[cfg(not(test))]
62 {
63 let cache_root = CacheRoot::from_discovery()
64 .map_err(|err| format!("failed discovering managed cache root: {err}"))?;
65 let group = PathBuf::from(MULTI_SOURCE_DEMO_GROUP);
66 let dir = cache_root.ensure_group(&group).map_err(|err| {
67 format!(
68 "failed creating managed demo cache group '{}': {err}",
69 group.display()
70 )
71 })?;
72 Ok(dir.join(MULTI_SOURCE_DEMO_STORE_FILENAME))
73 }
74}
75
76fn init_example_tracing() {
77 static INIT: Once = Once::new();
78 INIT.call_once(|| {
79 let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
80 .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("triplets=info"));
81 let _ = tracing_subscriber::fmt()
82 .with_env_filter(env_filter)
83 .try_init();
84 });
85}
86
87#[derive(Debug, Clone, Copy, ValueEnum)]
88enum SplitArg {
90 Train,
91 Validation,
92 Test,
93}
94
95impl From<SplitArg> for SplitLabel {
96 fn from(value: SplitArg) -> Self {
97 match value {
98 SplitArg::Train => SplitLabel::Train,
99 SplitArg::Validation => SplitLabel::Validation,
100 SplitArg::Test => SplitLabel::Test,
101 }
102 }
103}
104
105#[derive(Debug, Parser)]
106#[command(
107 name = "estimate_capacity",
108 disable_help_subcommand = true,
109 about = "Metadata-only capacity estimation",
110 long_about = "Estimate record, pair, triplet, and text-sample capacity using source-reported counts only (no data refresh).",
111 after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
112)]
113struct EstimateCapacityCli {
115 #[arg(
116 long,
117 default_value_t = 99,
118 help = "Deterministic seed used for split allocation"
119 )]
120 seed: u64,
121 #[arg(
122 long = "split-ratios",
123 value_name = "TRAIN,VALIDATION,TEST",
124 value_parser = parse_split_ratios_arg,
125 default_value = "0.8,0.1,0.1",
126 help = "Comma-separated split ratios that must sum to 1.0"
127 )]
128 split: SplitRatios,
129 #[arg(
130 long = "source-root",
131 value_name = "PATH",
132 help = "Optional source root override, repeat as needed in source order"
133 )]
134 source_roots: Vec<String>,
135}
136
137#[derive(Debug, Parser)]
138#[command(
139 name = "multi_source_demo",
140 disable_help_subcommand = true,
141 about = "Run sampled batches from multiple sources",
142 long_about = "Sample triplet, pair, or text batches from multiple sources and persist split/epoch state.",
143 after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
144)]
145struct MultiSourceDemoCli {
152 #[arg(
153 long = "text-recipes",
154 help = "Emit a text batch instead of a triplet batch"
155 )]
156 show_text_samples: bool,
157 #[arg(
158 long = "pair-batch",
159 help = "Emit a pair batch instead of a triplet batch"
160 )]
161 show_pair_samples: bool,
162 #[arg(
163 long = "list-text-recipes",
164 help = "Print registered text recipes and exit"
165 )]
166 list_text_recipes: bool,
167 #[arg(
168 long = "batch-size",
169 default_value_t = 4,
170 value_parser = parse_batch_size,
171 help = "Batch size used for sampling"
172 )]
173 batch_size: usize,
174 #[arg(
175 long = "ingestion-max-records",
176 default_value_t = default_ingestion_max_records(),
177 value_parser = parse_ingestion_max_records,
178 help = "Per-source ingestion buffer target used while refreshing records"
179 )]
180 ingestion_max_records: usize,
181 #[arg(long, help = "Optional deterministic seed override")]
182 seed: Option<u64>,
183 #[arg(long, value_enum, help = "Target split to sample from")]
184 split: Option<SplitArg>,
185 #[arg(
186 long = "source-root",
187 value_name = "PATH",
188 help = "Optional source root override, repeat as needed in source order"
189 )]
190 source_roots: Vec<String>,
191 #[arg(
192 long = "split-store-path",
193 value_name = "SPLIT_STORE_PATH",
194 help = "Optional explicit path for persisted split/epoch state file"
195 )]
196 split_store_path: Option<PathBuf>,
197 #[arg(
198 long = "reset",
199 help = "Delete the persisted split/epoch state before sampling, restarting from epoch 0"
200 )]
201 reset: bool,
202 #[arg(
203 long = "batches",
204 value_name = "N",
205 value_parser = parse_batch_count,
206 help = "Run N triplet batches in succession, printing a timing line per batch and (with --features extended-metrics) a per-source similarity summary at the end"
207 )]
208 batches: Option<usize>,
209}
210
211#[derive(Debug, Clone)]
212struct SourceInventory {
214 source_id: String,
215 reported_records: u128,
216 triplet_recipes: Vec<TripletRecipe>,
217}
218
219pub fn run_estimate_capacity<R, Resolve, Build, I>(
224 args_iter: I,
225 resolve_roots: Resolve,
226 build_sources: Build,
227) -> Result<(), Box<dyn Error>>
228where
229 Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
230 Build: FnOnce(&R) -> Vec<DynSource>,
231 I: Iterator<Item = String>,
232{
233 init_example_tracing();
234
235 let Some(cli) = parse_cli::<EstimateCapacityCli, _>(
236 std::iter::once("estimate_capacity".to_string()).chain(args_iter),
237 )?
238 else {
239 return Ok(());
240 };
241
242 let roots = resolve_roots(cli.source_roots)?;
243
244 let config = SamplerConfig {
245 seed: cli.seed,
246 split: cli.split,
247 ..SamplerConfig::default()
248 };
249
250 let sources = build_sources(&roots);
251
252 let mut inventories = Vec::new();
253 for source in &sources {
254 let recipes = if config.recipes.is_empty() {
255 source.default_triplet_recipes()
256 } else {
257 config.recipes.clone()
258 };
259 let reported_records = source.reported_record_count(&config).map_err(|err| {
260 format!(
261 "source '{}' failed to report exact record count: {err}",
262 source.id()
263 )
264 })?;
265 inventories.push(SourceInventory {
266 source_id: source.id().to_string(),
267 reported_records,
268 triplet_recipes: recipes,
269 });
270 }
271
272 let mut per_source_split_counts: HashMap<(String, SplitLabel), u128> = HashMap::new();
273 let mut split_record_counts: HashMap<SplitLabel, u128> = HashMap::new();
274
275 for source in &inventories {
276 let counts = split_counts_for_total(source.reported_records, cli.split);
277 for (label, count) in counts {
278 per_source_split_counts.insert((source.source_id.clone(), label), count);
279 *split_record_counts.entry(label).or_insert(0) += count;
280 }
281 }
282
283 let mut totals_by_split: HashMap<SplitLabel, CapacityTotals> = HashMap::new();
284 let mut totals_by_source_and_split: HashMap<(String, SplitLabel), CapacityTotals> =
285 HashMap::new();
286
287 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
288 let mut totals = CapacityTotals::default();
289
290 for source in &inventories {
291 let source_split_records = per_source_split_counts
292 .get(&(source.source_id.clone(), split_label))
293 .copied()
294 .unwrap_or(0);
295
296 let triplet_recipes = &source.triplet_recipes;
297 let text_recipes = resolve_text_recipes_for_source(&config, triplet_recipes);
298
299 let capacity = estimate_source_split_capacity_from_counts(
300 source_split_records,
301 triplet_recipes,
302 &text_recipes,
303 );
304
305 totals_by_source_and_split.insert((source.source_id.clone(), split_label), capacity);
306
307 totals.triplets += capacity.triplets;
308 totals.effective_triplets += capacity.effective_triplets;
309 totals.pairs += capacity.pairs;
310 totals.text_samples += capacity.text_samples;
311 }
312
313 totals_by_split.insert(split_label, totals);
314 }
315
316 let min_nonzero_records_by_split: HashMap<SplitLabel, u128> =
317 [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test]
318 .into_iter()
319 .map(|split_label| {
320 let min_nonzero = inventories
321 .iter()
322 .filter_map(|source| {
323 per_source_split_counts
324 .get(&(source.source_id.clone(), split_label))
325 .copied()
326 })
327 .filter(|&records| records > 0)
328 .min()
329 .unwrap_or(0);
330 (split_label, min_nonzero)
331 })
332 .collect();
333
334 let min_nonzero_records_all_splits = inventories
335 .iter()
336 .map(|source| source.reported_records)
337 .filter(|&records| records > 0)
338 .min()
339 .unwrap_or(0);
340
341 println!("=== capacity estimate (length-only) ===");
342 println!("mode: metadata-only (no source.refresh calls)");
343 println!("classification: heuristic approximation (not exact)");
344 println!("split seed: {}", cli.seed);
345 println!(
346 "split ratios: train={:.4}, validation={:.4}, test={:.4}",
347 cli.split.train, cli.split.validation, cli.split.test
348 );
349 println!();
350
351 println!("[SOURCES]");
352 for source in &inventories {
353 println!(
354 " {} => reported records: {}",
355 source.source_id,
356 format_u128_with_commas(source.reported_records)
357 );
358 }
359 println!();
360
361 println!("[PER SOURCE BREAKDOWN]");
362 for source in &inventories {
363 println!(" {}", source.source_id);
364 let mut source_grand = CapacityTotals::default();
365 let mut source_total_records = 0u128;
366 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
367 let split_records = per_source_split_counts
368 .get(&(source.source_id.clone(), split_label))
369 .copied()
370 .unwrap_or(0);
371 source_total_records = source_total_records.saturating_add(split_records);
372 let split_longest_records = inventories
373 .iter()
374 .map(|candidate| {
375 per_source_split_counts
376 .get(&(candidate.source_id.clone(), split_label))
377 .copied()
378 .unwrap_or(0)
379 })
380 .max()
381 .unwrap_or(0);
382 let totals = totals_by_source_and_split
383 .get(&(source.source_id.clone(), split_label))
384 .copied()
385 .unwrap_or_default();
386 source_grand.triplets += totals.triplets;
387 source_grand.effective_triplets += totals.effective_triplets;
388 source_grand.pairs += totals.pairs;
389 source_grand.text_samples += totals.text_samples;
390 println!(" [{:?}]", split_label);
391 println!(" records: {}", format_u128_with_commas(split_records));
392 println!(
393 " triplet combinations: {}",
394 format_u128_with_commas(totals.triplets)
395 );
396 println!(
397 " effective sampled triplets (p={}, k={}): {}",
398 EFFECTIVE_POSITIVES_PER_ANCHOR,
399 EFFECTIVE_NEGATIVES_PER_ANCHOR,
400 format_u128_with_commas(totals.effective_triplets)
401 );
402 println!(
403 " pair combinations: {}",
404 format_u128_with_commas(totals.pairs)
405 );
406 println!(
407 " text samples: {}",
408 format_u128_with_commas(totals.text_samples)
409 );
410 println!(
411 " replay factor vs longest source: {}",
412 format_replay_factor(split_longest_records, split_records)
413 );
414 println!(
415 " suggested proportional-size batch weight (0-1): {:.4}",
416 suggested_balancing_weight(split_longest_records, split_records)
417 );
418 let split_smallest_nonzero = min_nonzero_records_by_split
419 .get(&split_label)
420 .copied()
421 .unwrap_or(0);
422 println!(
423 " suggested small-source-boost batch weight (0-1): {:.4}",
424 suggested_oversampling_weight(split_smallest_nonzero, split_records)
425 );
426 println!();
427 }
428 let longest_source_total = inventories
429 .iter()
430 .map(|candidate| candidate.reported_records)
431 .max()
432 .unwrap_or(0);
433 println!(" [ALL SPLITS FOR SOURCE]");
434 println!(
435 " triplet combinations: {}",
436 format_u128_with_commas(source_grand.triplets)
437 );
438 println!(
439 " effective sampled triplets (p={}, k={}): {}",
440 EFFECTIVE_POSITIVES_PER_ANCHOR,
441 EFFECTIVE_NEGATIVES_PER_ANCHOR,
442 format_u128_with_commas(source_grand.effective_triplets)
443 );
444 println!(
445 " pair combinations: {}",
446 format_u128_with_commas(source_grand.pairs)
447 );
448 println!(
449 " text samples: {}",
450 format_u128_with_commas(source_grand.text_samples)
451 );
452 println!(
453 " replay factor vs longest source: {}",
454 format_replay_factor(longest_source_total, source_total_records)
455 );
456 println!(
457 " suggested proportional-size batch weight (0-1): {:.4}",
458 suggested_balancing_weight(longest_source_total, source_total_records)
459 );
460 println!(
461 " suggested small-source-boost batch weight (0-1): {:.4}",
462 suggested_oversampling_weight(min_nonzero_records_all_splits, source_total_records)
463 );
464 println!();
465 }
466
467 let mut grand = CapacityTotals::default();
468 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
469 let record_count = split_record_counts.get(&split_label).copied().unwrap_or(0);
470 let totals = totals_by_split
471 .get(&split_label)
472 .copied()
473 .unwrap_or_default();
474
475 grand.triplets += totals.triplets;
476 grand.effective_triplets += totals.effective_triplets;
477 grand.pairs += totals.pairs;
478 grand.text_samples += totals.text_samples;
479
480 println!("[{:?}]", split_label);
481 println!(" records: {}", format_u128_with_commas(record_count));
482 println!(
483 " triplet combinations: {}",
484 format_u128_with_commas(totals.triplets)
485 );
486 println!(
487 " effective sampled triplets (p={}, k={}): {}",
488 EFFECTIVE_POSITIVES_PER_ANCHOR,
489 EFFECTIVE_NEGATIVES_PER_ANCHOR,
490 format_u128_with_commas(totals.effective_triplets)
491 );
492 println!(
493 " pair combinations: {}",
494 format_u128_with_commas(totals.pairs)
495 );
496 println!(
497 " text samples: {}",
498 format_u128_with_commas(totals.text_samples)
499 );
500 println!();
501 }
502
503 println!("[ALL SPLITS TOTAL]");
504 println!(
505 " triplet combinations: {}",
506 format_u128_with_commas(grand.triplets)
507 );
508 println!(
509 " effective sampled triplets (p={}, k={}): {}",
510 EFFECTIVE_POSITIVES_PER_ANCHOR,
511 EFFECTIVE_NEGATIVES_PER_ANCHOR,
512 format_u128_with_commas(grand.effective_triplets)
513 );
514 println!(
515 " pair combinations: {}",
516 format_u128_with_commas(grand.pairs)
517 );
518 println!(
519 " text samples: {}",
520 format_u128_with_commas(grand.text_samples)
521 );
522 println!();
523 println!(
524 "Note: counts are heuristic, length-based estimates from source-reported totals and recipe structure. They are approximate, not exact, and assume anchor-positive pairs=records (one positive per anchor by default), negatives=source_records_in_split-1 (anchor excluded as its own negative), and at most one chunk/window realization per sample. In real-world chunked sampling, practical combinations are often higher, so treat this as a floor-like baseline."
525 );
526 println!();
527 println!(
528 "Effective sampled triplets apply a bounded training assumption: effective_triplets = records * p * k per triplet recipe, with defaults p={} positives per anchor and k={} negatives per anchor.",
529 EFFECTIVE_POSITIVES_PER_ANCHOR, EFFECTIVE_NEGATIVES_PER_ANCHOR
530 );
531 println!();
532 println!(
533 "Oversample loops are not inferred from this static report. To measure true oversampling (how many times sampling loops through the combination space), use observed sampled draw counts from an actual run."
534 );
535 println!();
536 println!(
537 "Suggested proportional-size batch weight (0-1) is source/max_source by record count: 1.0 for the largest source in scope, smaller values for smaller sources."
538 );
539 println!();
540 println!(
541 "Suggested small-source-boost batch weight (0-1) is min_nonzero_source/source by record count: 1.0 for the smallest non-zero source in scope, smaller values for larger sources."
542 );
543 println!();
544 println!(
545 "When passed to next_*_batch_with_weights, higher weight means that source is sampled more often relative to lower-weight sources."
546 );
547
548 Ok(())
549}
550
551pub fn run_multi_source_demo<R, Resolve, Build, I>(
556 args_iter: I,
557 resolve_roots: Resolve,
558 build_sources: Build,
559) -> Result<(), Box<dyn Error>>
560where
561 Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
562 Build: FnOnce(&R) -> Vec<DynSource>,
563 I: Iterator<Item = String>,
564{
565 init_example_tracing();
566
567 let Some(cli) = parse_cli::<MultiSourceDemoCli, _>(
568 std::iter::once("multi_source_demo".to_string()).chain(args_iter),
569 )?
570 else {
571 return Ok(());
572 };
573
574 let roots = resolve_roots(cli.source_roots)?;
575
576 let mut config = SamplerConfig::default();
577 config.seed = cli.seed.unwrap_or(config.seed);
578 config.batch_size = cli.batch_size;
579 config.ingestion_max_records = cli.ingestion_max_records;
580 config.chunking = Default::default();
581 let selected_split = cli.split.map(Into::into).unwrap_or(SplitLabel::Train);
582 config.split = SplitRatios::default();
583 config.allowed_splits = vec![selected_split];
584 let chunking = config.chunking.clone();
585 let config_snapshot = MultiSourceDemoConfigSnapshot {
586 seed: config.seed,
587 batch_size: config.batch_size,
588 ingestion_max_records: config.ingestion_max_records,
589 split: selected_split,
590 split_ratios: config.split,
591 max_window_tokens: config.chunking.max_window_tokens,
592 overlap_tokens: config.chunking.overlap_tokens.clone(),
593 summary_fallback_tokens: config.chunking.summary_fallback_tokens,
594 };
595
596 let split_store_path = if let Some(path) = cli.split_store_path {
597 path
598 } else {
599 managed_demo_split_store_path().map_err(|err| {
600 Box::<dyn Error>::from(format!("failed to resolve demo split-store path: {err}"))
601 })?
602 };
603
604 if cli.reset && split_store_path.exists() {
605 std::fs::remove_file(&split_store_path).map_err(|err| {
606 Box::<dyn Error>::from(format!(
607 "failed to remove split store '{}': {err}",
608 split_store_path.display()
609 ))
610 })?;
611 println!("Reset: removed {}", split_store_path.display());
612 }
613 println!(
614 "Persisting split assignments and epoch state to {}",
615 split_store_path.display()
616 );
617 let sources = build_sources(&roots);
618 let split_store = Arc::new(FileSplitStore::open(&split_store_path, config.split, 99)?);
619 let sampler = TripletSampler::new(config, split_store.clone());
620 for source in sources {
621 sampler
622 .register_source(source)
623 .map_err(|e| Box::new(e) as Box<dyn Error>)?;
624 }
625
626 if cli.show_pair_samples {
627 match sampler.next_pair_batch(selected_split) {
628 Ok(pair_batch) => {
629 if pair_batch.pairs.is_empty() {
630 println!("Pair sampling produced no results.");
631 } else {
632 print_pair_batch(&chunking, &pair_batch, split_store.as_ref());
633 }
634 sampler.save_sampler_state(None)?;
635 }
636 Err(SamplerError::Exhausted(name)) => {
637 eprintln!(
638 "Pair sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
639 name
640 );
641 }
642 Err(err) => return Err(err.into()),
643 }
644 } else if cli.show_text_samples {
645 match sampler.next_text_batch(selected_split) {
646 Ok(text_batch) => {
647 if text_batch.samples.is_empty() {
648 println!(
649 "Text sampling produced no results. Ensure each source has eligible sections."
650 );
651 } else {
652 print_text_batch(&chunking, &text_batch, split_store.as_ref());
653 }
654 sampler.save_sampler_state(None)?;
655 }
656 Err(SamplerError::Exhausted(name)) => {
657 eprintln!(
658 "Text sampler exhausted selector '{}'. Ensure matching sections exist.",
659 name
660 );
661 }
662 Err(err) => return Err(err.into()),
663 }
664 } else if cli.list_text_recipes {
665 let recipes = sampler.text_recipes();
666 if recipes.is_empty() {
667 println!(
668 "No text recipes registered. Ensure your sources expose triplet selectors or configure text_recipes explicitly."
669 );
670 } else {
671 print_text_recipes(&recipes);
672 }
673 } else if let Some(batch_count) = cli.batches {
674 print_demo_config(&config_snapshot);
675 println!("=== benchmark: {} triplet batches ===", batch_count);
676
677 #[cfg(feature = "extended-metrics")]
679 let mut source_metrics: SourceMetricsMap = HashMap::new();
680
681 for i in 0..batch_count {
682 let t0 = Instant::now();
683 match sampler.next_triplet_batch(selected_split) {
684 Ok(batch) => {
685 let elapsed = t0.elapsed();
686 let n = batch.triplets.len();
687 println!(
688 "batch {:>4} triplets={:<4} elapsed={:>8.2}ms per_triplet={:.2}ms",
689 i + 1,
690 n,
691 elapsed.as_secs_f64() * 1000.0,
692 if n > 0 {
693 elapsed.as_secs_f64() * 1000.0 / n as f64
694 } else {
695 0.0
696 },
697 );
698 #[cfg(feature = "extended-metrics")]
699 {
700 use triplets_core::metrics::lexical_similarity_scores;
701 for triplet in &batch.triplets {
702 let (pj, pc) = lexical_similarity_scores(
703 &triplet.anchor.text,
704 &triplet.positive.text,
705 );
706 let (nj, nc) = lexical_similarity_scores(
707 &triplet.anchor.text,
708 &triplet.negative.text,
709 );
710 let proximity =
711 chunk_proximity_score(&triplet.anchor, &triplet.positive);
712 let source = extract_source(&triplet.anchor.record_id);
713 source_metrics
714 .entry(source)
715 .or_default()
716 .push((pj, pc, nj, nc, proximity));
717 }
718 }
719 }
720 Err(SamplerError::Exhausted(name)) => {
721 println!(
722 "batch {:>4} exhausted recipe '{}' — stopping early",
723 i + 1,
724 name
725 );
726 break;
727 }
728 Err(err) => return Err(err.into()),
729 }
730 }
731
732 sampler.save_sampler_state(None)?;
733
734 #[cfg(feature = "extended-metrics")]
735 if !source_metrics.is_empty() {
736 println!();
737 print_metric_summary(&source_metrics);
738 }
739
740 #[cfg(all(feature = "extended-metrics", feature = "bm25-mining"))]
741 {
742 let (fallback, total) = sampler.bm25_fallback_stats();
743 if total > 0 {
744 let pct = fallback as f64 / total as f64 * 100.0;
745 println!("bm25 fallback rate : {}/{} ({:.1}%)", fallback, total, pct);
746 }
747 }
748 } else {
749 match sampler.next_triplet_batch(selected_split) {
750 Ok(triplet_batch) => {
751 if triplet_batch.triplets.is_empty() {
752 println!(
753 "Triplet sampling produced no results. Ensure multiple records per source exist."
754 );
755 } else {
756 print_triplet_batch(&chunking, &triplet_batch, split_store.as_ref());
757 }
758 sampler.save_sampler_state(None)?;
759 #[cfg(all(feature = "extended-metrics", feature = "bm25-mining"))]
760 {
761 let (fallback, total) = sampler.bm25_fallback_stats();
762 if total > 0 {
763 let pct = fallback as f64 / total as f64 * 100.0;
764 println!("bm25 fallback rate : {}/{} ({:.1}%)", fallback, total, pct);
765 }
766 }
767 }
768 Err(SamplerError::Exhausted(name)) => {
769 eprintln!(
770 "Triplet sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
771 name
772 );
773 }
774 Err(err) => return Err(err.into()),
775 }
776 }
777
778 Ok(())
779}
780
781struct MultiSourceDemoConfigSnapshot {
782 seed: u64,
783 batch_size: usize,
784 ingestion_max_records: usize,
785 split: SplitLabel,
786 split_ratios: SplitRatios,
787 max_window_tokens: usize,
788 overlap_tokens: Vec<usize>,
789 summary_fallback_tokens: usize,
790}
791
792fn print_demo_config(cfg: &MultiSourceDemoConfigSnapshot) {
793 let overlaps: Vec<String> = cfg.overlap_tokens.iter().map(|t| t.to_string()).collect();
794 println!("=== sampler config ===");
795 println!("seed : {}", cfg.seed);
796 println!("batch_size : {}", cfg.batch_size);
797 println!("ingestion_max_records: {}", cfg.ingestion_max_records);
798 println!("split : {:?}", cfg.split);
799 println!(
800 "split_ratios : train={:.2} val={:.2} test={:.2}",
801 cfg.split_ratios.train, cfg.split_ratios.validation, cfg.split_ratios.test
802 );
803 println!("max_window_tokens : {}", cfg.max_window_tokens);
804 println!("overlap_tokens : [{}]", overlaps.join(", "));
805 println!(
806 "summary_fallback : {} tokens (0 = disabled)",
807 cfg.summary_fallback_tokens
808 );
809 println!();
810}
811
812fn default_ingestion_max_records() -> usize {
813 SamplerConfig::default().ingestion_max_records
814}
815
816fn parse_positive_usize_flag(raw: &str, flag: &str) -> Result<usize, String> {
817 let parsed = raw.parse::<usize>().map_err(|_| {
818 format!(
819 "Could not parse {} value '{}' as a positive integer",
820 flag, raw
821 )
822 })?;
823 if parsed == 0 {
824 return Err(format!("{} must be greater than zero", flag));
825 }
826 Ok(parsed)
827}
828
829fn parse_batch_size(raw: &str) -> Result<usize, String> {
830 parse_positive_usize_flag(raw, "--batch-size")
831}
832
833fn parse_ingestion_max_records(raw: &str) -> Result<usize, String> {
834 parse_positive_usize_flag(raw, "--ingestion-max-records")
835}
836
837fn parse_batch_count(raw: &str) -> Result<usize, String> {
838 parse_positive_usize_flag(raw, "--batches")
839}
840
841fn suggested_balancing_weight(max_baseline: u128, source_baseline: u128) -> f32 {
842 if max_baseline == 0 || source_baseline == 0 {
843 return 0.0;
844 }
845 (source_baseline as f64 / max_baseline as f64).clamp(0.0, 1.0) as f32
846}
847
848fn suggested_oversampling_weight(min_nonzero_baseline: u128, source_baseline: u128) -> f32 {
849 if min_nonzero_baseline == 0 || source_baseline == 0 {
850 return 0.0;
851 }
852 (min_nonzero_baseline as f64 / source_baseline as f64).clamp(0.0, 1.0) as f32
853}
854
855fn parse_cli<T, I>(args: I) -> Result<Option<T>, Box<dyn Error>>
856where
857 T: Parser,
858 I: IntoIterator,
859 I::Item: Into<std::ffi::OsString> + Clone,
860{
861 match T::try_parse_from(args) {
862 Ok(cli) => Ok(Some(cli)),
863 Err(err) => match err.kind() {
864 ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
865 err.print()?;
866 Ok(None)
867 }
868 _ => Err(err.into()),
869 },
870 }
871}
872
873fn parse_split_ratios_arg(raw: &str) -> Result<SplitRatios, String> {
874 let parts: Vec<&str> = raw.split(',').collect();
875 if parts.len() != 3 {
876 return Err("--split-ratios expects exactly 3 comma-separated values".to_string());
877 }
878 let train = parts[0]
879 .trim()
880 .parse::<f32>()
881 .map_err(|_| format!("invalid train ratio '{}': must be a float", parts[0].trim()))?;
882 let validation = parts[1].trim().parse::<f32>().map_err(|_| {
883 format!(
884 "invalid validation ratio '{}': must be a float",
885 parts[1].trim()
886 )
887 })?;
888 let test = parts[2]
889 .trim()
890 .parse::<f32>()
891 .map_err(|_| format!("invalid test ratio '{}': must be a float", parts[2].trim()))?;
892 let ratios = SplitRatios {
893 train,
894 validation,
895 test,
896 };
897 let sum = ratios.train + ratios.validation + ratios.test;
898 if (sum - 1.0).abs() > 1e-5 {
899 return Err(format!(
900 "split ratios must sum to 1.0, got {:.6} (train={}, validation={}, test={})",
901 sum, ratios.train, ratios.validation, ratios.test
902 ));
903 }
904 if ratios.train < 0.0 || ratios.validation < 0.0 || ratios.test < 0.0 {
905 return Err("split ratios must be non-negative".to_string());
906 }
907 Ok(ratios)
908}
909
910fn print_triplet_batch(
911 strategy: &ChunkingStrategy,
912 batch: &TripletBatch,
913 split_store: &impl SplitStore,
914) {
915 println!("=== triplet batch ===");
916 for (idx, triplet) in batch.triplets.iter().enumerate() {
917 println!("--- triplet #{} ---", idx);
918 println!("recipe : {}", triplet.recipe);
919 println!("sample_weight: {:.4}", triplet.weight);
920 if let Some(instr) = &triplet.instruction {
921 println!("instruction shown to model:");
922 println!("{instr}");
923 println!();
924 }
925 let pos_proximity = chunk_proximity_score(&triplet.anchor, &triplet.positive);
926 let pos_distance = window_chunk_distance(&triplet.anchor, &triplet.positive);
927 #[cfg(feature = "extended-metrics")]
928 let (pos_sim, neg_sim) = {
929 use triplets_core::metrics::lexical_similarity_scores;
930 (
931 Some(lexical_similarity_scores(
932 &triplet.anchor.text,
933 &triplet.positive.text,
934 )),
935 Some(lexical_similarity_scores(
936 &triplet.anchor.text,
937 &triplet.negative.text,
938 )),
939 )
940 };
941 #[cfg(not(feature = "extended-metrics"))]
942 let (pos_sim, neg_sim): (Option<(f32, f32)>, Option<(f32, f32)>) = (None, None);
943 print_chunk_block(
944 "ANCHOR",
945 &triplet.anchor,
946 strategy,
947 split_store,
948 None,
949 None,
950 None,
951 );
952 print_chunk_block(
953 "POSITIVE",
954 &triplet.positive,
955 strategy,
956 split_store,
957 pos_sim,
958 Some(pos_proximity),
959 pos_distance,
960 );
961 print_chunk_block(
962 "NEGATIVE",
963 &triplet.negative,
964 strategy,
965 split_store,
966 neg_sim,
967 None,
968 None,
969 );
970 }
971 print_source_summary(
972 "triplet anchors",
973 batch
974 .triplets
975 .iter()
976 .map(|triplet| triplet.anchor.record_id.as_str()),
977 );
978 print_recipe_context_by_source(
979 "triplet recipes by source",
980 batch
981 .triplets
982 .iter()
983 .map(|triplet| (triplet.anchor.record_id.as_str(), triplet.recipe.as_str())),
984 );
985}
986
987fn print_text_batch(strategy: &ChunkingStrategy, batch: &TextBatch, split_store: &impl SplitStore) {
988 println!("=== text batch ===");
989 for (idx, sample) in batch.samples.iter().enumerate() {
990 println!("--- sample #{} ---", idx);
991 println!("recipe : {}", sample.recipe);
992 println!("sample_weight: {:.4}", sample.weight);
993 if let Some(instr) = &sample.instruction {
994 println!("instruction shown to model:");
995 println!("{instr}");
996 println!();
997 }
998 print_chunk_block(
999 "TEXT",
1000 &sample.chunk,
1001 strategy,
1002 split_store,
1003 None,
1004 None,
1005 None,
1006 );
1007 }
1008 print_source_summary(
1009 "text samples",
1010 batch
1011 .samples
1012 .iter()
1013 .map(|sample| sample.chunk.record_id.as_str()),
1014 );
1015 print_recipe_context_by_source(
1016 "text recipes by source",
1017 batch
1018 .samples
1019 .iter()
1020 .map(|sample| (sample.chunk.record_id.as_str(), sample.recipe.as_str())),
1021 );
1022}
1023
1024fn print_pair_batch(
1025 strategy: &ChunkingStrategy,
1026 batch: &SampleBatch,
1027 split_store: &impl SplitStore,
1028) {
1029 println!("=== pair batch ===");
1030 for (idx, pair) in batch.pairs.iter().enumerate() {
1031 println!("--- pair #{} ---", idx);
1032 println!("recipe : {}", pair.recipe);
1033 println!("label : {:?}", pair.label);
1034 if let Some(reason) = &pair.reason {
1035 println!("reason : {}", reason);
1036 }
1037 print_chunk_block(
1038 "ANCHOR",
1039 &pair.anchor,
1040 strategy,
1041 split_store,
1042 None,
1043 None,
1044 None,
1045 );
1046 print_chunk_block(
1047 "OTHER",
1048 &pair.positive,
1049 strategy,
1050 split_store,
1051 None,
1052 None,
1053 None,
1054 );
1055 }
1056 print_source_summary(
1057 "pair anchors",
1058 batch
1059 .pairs
1060 .iter()
1061 .map(|pair| pair.anchor.record_id.as_str()),
1062 );
1063 print_recipe_context_by_source(
1064 "pair recipes by source",
1065 batch
1066 .pairs
1067 .iter()
1068 .map(|pair| (pair.anchor.record_id.as_str(), pair.recipe.as_str())),
1069 );
1070}
1071
1072fn print_text_recipes(recipes: &[TextRecipe]) {
1073 println!("=== available text recipes ===");
1074 for recipe in recipes {
1075 println!(
1076 "- {} (weight: {:.3}) selector={:?}",
1077 recipe.name, recipe.weight, recipe.selector
1078 );
1079 if let Some(instr) = &recipe.instruction {
1080 println!(" instruction: {}", instr);
1081 }
1082 }
1083}
1084
1085#[cfg(feature = "extended-metrics")]
1086fn metric_mean_median(vals: &mut [f32]) -> (f32, f32) {
1087 let mean = vals.iter().sum::<f32>() / vals.len() as f32;
1088 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1089 let median = if vals.len() % 2 == 1 {
1090 vals[vals.len() / 2]
1091 } else {
1092 (vals[vals.len() / 2 - 1] + vals[vals.len() / 2]) / 2.0
1093 };
1094 (mean, median)
1095}
1096
1097#[cfg(feature = "extended-metrics")]
1098fn print_metric_summary(source_data: &SourceMetricsMap) {
1099 let total: usize = source_data.values().map(|v| v.len()).sum();
1100 let n_sources = source_data.len();
1101 println!(
1102 "=== extended metrics summary ({} triplets, {} {}) ===",
1103 total,
1104 n_sources,
1105 if n_sources == 1 { "source" } else { "sources" }
1106 );
1107
1108 fn metric_pair(entries: &[MetricEntry], pos_idx: usize, neg_idx: usize) -> [(f32, f32); 2] {
1110 let extract = |idx: usize| -> Vec<f32> {
1111 entries
1112 .iter()
1113 .map(|e| match idx {
1114 0 => e.0,
1115 1 => e.1,
1116 2 => e.2,
1117 3 => e.3,
1118 _ => e.4,
1119 })
1120 .collect()
1121 };
1122 let mut pos_vals = extract(pos_idx);
1123 let mut neg_vals = extract(neg_idx);
1124 [
1125 metric_mean_median(&mut pos_vals),
1126 metric_mean_median(&mut neg_vals),
1127 ]
1128 }
1129
1130 fn print_metric_section(
1131 label: &str,
1132 sources: &[&String],
1133 source_data: &SourceMetricsMap,
1134 pos_idx: usize,
1135 neg_idx: usize,
1136 total: usize,
1137 n_sources: usize,
1138 ) {
1139 const SEP: usize = 83;
1140 println!();
1141 println!("[{}]", label);
1142 println!(
1143 "{:<24} {:>5} {:<16} {:<16} {:<16}",
1144 "source", "n", "positive", "negative", "gap (pos\u{2212}neg)"
1145 );
1146 println!(
1147 "{:<24} {:>5} {:<16} {:<16} {:<16}",
1148 "", "", "mean / median", "mean / median", "mean / median"
1149 );
1150 println!("{}", "-".repeat(SEP));
1151 for source in sources {
1152 let entries = &source_data[*source];
1153 let [pos, neg] = metric_pair(entries, pos_idx, neg_idx);
1154 let gap_mean = pos.0 - neg.0;
1155 let gap_med = pos.1 - neg.1;
1156 println!(
1157 "{:<24} {:>5} {:.3} / {:.3} {:.3} / {:.3} {:+.3} / {:+.3}",
1158 source,
1159 entries.len(),
1160 pos.0,
1161 pos.1,
1162 neg.0,
1163 neg.1,
1164 gap_mean,
1165 gap_med,
1166 );
1167 }
1168 if n_sources > 1 {
1169 let all: Vec<MetricEntry> = source_data.values().flatten().copied().collect();
1170 let [pos, neg] = metric_pair(&all, pos_idx, neg_idx);
1171 let gap_mean = pos.0 - neg.0;
1172 let gap_med = pos.1 - neg.1;
1173 println!("{}", "-".repeat(SEP));
1174 println!(
1175 "{:<24} {:>5} {:.3} / {:.3} {:.3} / {:.3} {:+.3} / {:+.3}",
1176 "ALL", total, pos.0, pos.1, neg.0, neg.1, gap_mean, gap_med,
1177 );
1178 }
1179 }
1180
1181 fn print_single_metric_section(
1182 label: &str,
1183 sources: &[&String],
1184 source_data: &SourceMetricsMap,
1185 idx: usize,
1186 total: usize,
1187 n_sources: usize,
1188 ) {
1189 const SEP: usize = 58;
1190 println!();
1191 println!("[{}]", label);
1192 println!("{:<24} {:>5} {:<16}", "source", "n", "mean / median");
1193 println!("{}", "-".repeat(SEP));
1194 for source in sources {
1195 let entries = &source_data[*source];
1196 let mut vals: Vec<f32> = entries
1197 .iter()
1198 .map(|e| match idx {
1199 0 => e.0,
1200 1 => e.1,
1201 2 => e.2,
1202 3 => e.3,
1203 _ => e.4,
1204 })
1205 .collect();
1206 let (mean, median) = metric_mean_median(&mut vals);
1207 println!(
1208 "{:<24} {:>5} {:.3} / {:.3}",
1209 source,
1210 entries.len(),
1211 mean,
1212 median,
1213 );
1214 }
1215 if n_sources > 1 {
1216 let mut all: Vec<f32> = source_data
1217 .values()
1218 .flatten()
1219 .map(|e| match idx {
1220 0 => e.0,
1221 1 => e.1,
1222 2 => e.2,
1223 3 => e.3,
1224 _ => e.4,
1225 })
1226 .collect();
1227 let (mean, median) = metric_mean_median(&mut all);
1228 println!("{}", "-".repeat(SEP));
1229 println!("{:<24} {:>5} {:.3} / {:.3}", "ALL", total, mean, median);
1230 }
1231 }
1232
1233 let mut sources: Vec<&String> = source_data.keys().collect();
1234 sources.sort();
1235
1236 print_metric_section(
1237 "jaccard \u{2194} anchor",
1238 &sources,
1239 source_data,
1240 0,
1241 2,
1242 total,
1243 n_sources,
1244 );
1245 print_metric_section(
1246 "byte-cos \u{2194} anchor",
1247 &sources,
1248 source_data,
1249 1,
1250 3,
1251 total,
1252 n_sources,
1253 );
1254 print_single_metric_section(
1255 "anchor-positive proximity",
1256 &sources,
1257 source_data,
1258 4,
1259 total,
1260 n_sources,
1261 );
1262 println!();
1263}
1264
1265trait ChunkDebug {
1266 fn view_name(&self) -> String;
1267}
1268
1269impl ChunkDebug for RecordChunk {
1270 fn view_name(&self) -> String {
1271 match &self.view {
1272 ChunkView::Window {
1273 index,
1274 span,
1275 overlap,
1276 } => format!(
1277 "window#index={} span={} overlap={} tokens={}",
1278 index, span, overlap, self.tokens_estimate
1279 ),
1280 ChunkView::SummaryFallback { strategy, .. } => {
1281 format!("summary:{} tokens={}", strategy, self.tokens_estimate)
1282 }
1283 }
1284 }
1285}
1286
1287fn print_chunk_block(
1288 title: &str,
1289 chunk: &RecordChunk,
1290 strategy: &ChunkingStrategy,
1291 split_store: &impl SplitStore,
1292 anchor_sim: Option<(f32, f32)>,
1293 ap_proximity: Option<f32>,
1294 ap_distance: Option<f32>,
1295) {
1296 let chunk_weight = chunk_weight(strategy, chunk);
1297 let split = split_store
1298 .label_for(&chunk.record_id)
1299 .map(|label| format!("{:?}", label))
1300 .unwrap_or_else(|| "Unknown".to_string());
1301 println!("--- {} ---", title);
1302 println!("split : {}", split);
1303 println!("view : {}", chunk.view_name());
1304 println!("chunk_weight : {:.4}", chunk_weight);
1305 println!("record_id : {}", chunk.record_id);
1306 println!("section_idx : {}", chunk.section_idx);
1307 println!("token_est : {}", chunk.tokens_estimate);
1308 if let Some(proximity) = ap_proximity {
1309 println!("a<->p proximity : {:.4}", proximity);
1310 }
1311 if let Some(distance) = ap_distance {
1312 println!("a<->p distance : {:.4}", distance);
1313 }
1314 if let Some((j, c)) = anchor_sim {
1315 println!("jaccard(↔a) : {:.4} byte-cos(↔a): {:.4}", j, c);
1316 }
1317 if !chunk.kvp_meta.is_empty() {
1318 let mut kvp_keys: Vec<&String> = chunk.kvp_meta.keys().collect();
1319 kvp_keys.sort();
1320 println!("kvp_meta :");
1321 for k in kvp_keys {
1322 let vals = &chunk.kvp_meta[k];
1323 let display = if vals.len() > 1 {
1324 format!("{} ({} variations)", vals[0], vals.len())
1325 } else {
1326 vals[0].clone()
1327 };
1328 println!("\t{}: {}", k, display);
1329 }
1330 }
1331 println!("model_input (exact text sent to the model):");
1332 println!("<<< BEGIN MODEL TEXT >>>");
1333 println!("{}", chunk.text);
1334 println!("<<< END MODEL TEXT >>>");
1335 println!();
1336}
1337
1338fn print_source_summary<'a, I>(label: &str, ids: I)
1339where
1340 I: Iterator<Item = &'a str>,
1341{
1342 let mut counts: HashMap<SourceId, usize> = HashMap::new();
1343 for id in ids {
1344 let source = extract_source(id);
1345 *counts.entry(source).or_insert(0) += 1;
1346 }
1347 if counts.is_empty() {
1348 return;
1349 }
1350 let skew = source_skew(&counts);
1351 let mut entries: Vec<(String, usize)> = counts.into_iter().collect();
1352 entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
1353 println!("--- {} by source ---", label);
1354 if let Some(skew) = skew {
1355 for entry in &skew.per_source {
1356 println!(
1357 "{}: count={} share={:.2}",
1358 entry.source, entry.count, entry.share
1359 );
1360 }
1361 println!(
1362 "skew: sources={} total={} min={} max={} mean={:.2} ratio={:.2}",
1363 skew.sources, skew.total, skew.min, skew.max, skew.mean, skew.ratio
1364 );
1365 }
1366}
1367
1368fn print_recipe_context_by_source<'a, I>(label: &str, entries: I)
1369where
1370 I: Iterator<Item = (&'a str, &'a str)>,
1371{
1372 let mut counts: HashMap<SourceId, HashMap<String, usize>> = HashMap::new();
1373 for (record_id, recipe) in entries {
1374 let source = extract_source(record_id);
1375 let entry = counts
1376 .entry(source)
1377 .or_default()
1378 .entry(recipe.to_string())
1379 .or_insert(0);
1380 *entry += 1;
1381 }
1382 if counts.is_empty() {
1383 return;
1384 }
1385 let mut sources: Vec<(SourceId, HashMap<String, usize>)> = counts.into_iter().collect();
1386 sources.sort_by(|a, b| a.0.cmp(&b.0));
1387 println!("--- {} ---", label);
1388 for (source, recipes) in sources {
1389 println!("{source}");
1390 let mut entries: Vec<(String, usize)> = recipes.into_iter().collect();
1391 entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
1392 for (recipe, count) in entries {
1393 println!(" - {recipe}={count}");
1394 }
1395 }
1396}
1397
1398fn extract_source(record_id: &str) -> SourceId {
1399 record_id
1400 .split_once("::")
1401 .map(|(source, _)| source.to_string())
1402 .unwrap_or_else(|| "unknown".to_string())
1403}
1404
1405#[cfg(test)]
1406mod tests {
1407 use super::*;
1408 use chrono::{TimeZone, Utc};
1409 use tempfile::tempdir;
1410 use triplets_core::DataRecord;
1411 use triplets_core::DeterministicSplitStore;
1412 use triplets_core::data::{QualityScore, RecordSection, SectionRole};
1413 use triplets_core::source::{SourceCursor, SourceSnapshot};
1414 use triplets_core::utils::make_section;
1415
1416 fn empty_dyn_sources(_: &()) -> Vec<DynSource> {
1417 Vec::new()
1418 }
1419
1420 fn ok_unit_roots(_: Vec<String>) -> Result<(), Box<dyn Error>> {
1421 Ok(())
1422 }
1423
1424 fn error_unit_roots(_: Vec<String>) -> Result<(), Box<dyn Error>> {
1425 Err("root-resolution-error".into())
1426 }
1427
1428 struct ErrorRefreshSource {
1429 id: String,
1430 }
1431
1432 impl DataSource for ErrorRefreshSource {
1433 fn id(&self) -> &str {
1434 &self.id
1435 }
1436
1437 fn refresh(
1438 &self,
1439 _config: &SamplerConfig,
1440 _cursor: Option<&SourceCursor>,
1441 _limit: Option<usize>,
1442 ) -> Result<SourceSnapshot, SamplerError> {
1443 Err(SamplerError::SourceUnavailable {
1444 source_id: self.id.clone(),
1445 reason: "simulated refresh failure".to_string(),
1446 })
1447 }
1448
1449 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1450 Ok(1)
1451 }
1452
1453 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1454 vec![default_recipe("error_refresh_recipe")]
1455 }
1456 }
1457
1458 struct TestSource {
1460 id: String,
1461 count: Option<u128>,
1462 recipes: Vec<TripletRecipe>,
1463 }
1464
1465 impl DataSource for TestSource {
1466 fn id(&self) -> &str {
1467 &self.id
1468 }
1469
1470 fn refresh(
1471 &self,
1472 _config: &SamplerConfig,
1473 _cursor: Option<&SourceCursor>,
1474 _limit: Option<usize>,
1475 ) -> Result<SourceSnapshot, SamplerError> {
1476 Ok(SourceSnapshot {
1477 records: Vec::new(),
1478 cursor: SourceCursor {
1479 last_seen: Utc::now(),
1480 revision: 0,
1481 },
1482 })
1483 }
1484
1485 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1486 self.count.ok_or_else(|| SamplerError::SourceInconsistent {
1487 source_id: self.id.clone(),
1488 details: "test source has no configured exact count".to_string(),
1489 })
1490 }
1491
1492 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1493 self.recipes.clone()
1494 }
1495 }
1496
1497 struct ConfigRequiredSource {
1498 id: String,
1499 expected_seed: u64,
1500 }
1501
1502 impl DataSource for ConfigRequiredSource {
1503 fn id(&self) -> &str {
1504 &self.id
1505 }
1506
1507 fn refresh(
1508 &self,
1509 _config: &SamplerConfig,
1510 _cursor: Option<&SourceCursor>,
1511 _limit: Option<usize>,
1512 ) -> Result<SourceSnapshot, SamplerError> {
1513 Ok(SourceSnapshot {
1514 records: Vec::new(),
1515 cursor: SourceCursor {
1516 last_seen: Utc::now(),
1517 revision: 0,
1518 },
1519 })
1520 }
1521
1522 fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError> {
1523 if config.seed == self.expected_seed {
1524 Ok(1)
1525 } else {
1526 Err(SamplerError::SourceInconsistent {
1527 source_id: self.id.clone(),
1528 details: format!(
1529 "expected sampler seed {} but got {}",
1530 self.expected_seed, config.seed
1531 ),
1532 })
1533 }
1534 }
1535
1536 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1537 Vec::new()
1538 }
1539 }
1540
1541 struct FixtureSource {
1542 id: String,
1543 records: Vec<DataRecord>,
1544 recipes: Vec<TripletRecipe>,
1545 }
1546
1547 impl DataSource for FixtureSource {
1548 fn id(&self) -> &str {
1549 &self.id
1550 }
1551
1552 fn refresh(
1553 &self,
1554 _config: &SamplerConfig,
1555 _cursor: Option<&SourceCursor>,
1556 _limit: Option<usize>,
1557 ) -> Result<SourceSnapshot, SamplerError> {
1558 Ok(SourceSnapshot {
1559 records: self.records.clone(),
1560 cursor: SourceCursor {
1561 last_seen: Utc::now(),
1562 revision: 0,
1563 },
1564 })
1565 }
1566
1567 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1568 Ok(self.records.len() as u128)
1569 }
1570
1571 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1572 self.recipes.clone()
1573 }
1574 }
1575
1576 struct IngestionConfigSource {
1577 expected_ingestion_max_records: usize,
1578 records: Vec<DataRecord>,
1579 }
1580
1581 impl DataSource for IngestionConfigSource {
1582 fn id(&self) -> &str {
1583 "ingestion_config_source"
1584 }
1585
1586 fn refresh(
1587 &self,
1588 config: &SamplerConfig,
1589 _cursor: Option<&SourceCursor>,
1590 _limit: Option<usize>,
1591 ) -> Result<SourceSnapshot, SamplerError> {
1592 if config.ingestion_max_records != self.expected_ingestion_max_records {
1593 return Err(SamplerError::SourceInconsistent {
1594 source_id: self.id().to_string(),
1595 details: format!(
1596 "expected ingestion_max_records {} but got {}",
1597 self.expected_ingestion_max_records, config.ingestion_max_records
1598 ),
1599 });
1600 }
1601 Ok(SourceSnapshot {
1602 records: self.records.clone(),
1603 cursor: SourceCursor {
1604 last_seen: Utc::now(),
1605 revision: 0,
1606 },
1607 })
1608 }
1609
1610 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1611 Ok(self.records.len() as u128)
1612 }
1613
1614 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1615 vec![default_recipe("ingestion_config_recipe")]
1616 }
1617 }
1618
1619 fn fixture_record(
1620 source: &str,
1621 id_suffix: &str,
1622 day: u32,
1623 title: &str,
1624 body: &str,
1625 ) -> DataRecord {
1626 let now = Utc.with_ymd_and_hms(2025, 1, day, 12, 0, 0).unwrap();
1627 DataRecord {
1628 id: format!("{source}::{id_suffix}"),
1629 source: source.to_string(),
1630 created_at: now,
1631 updated_at: now,
1632 quality: QualityScore { trust: 1.0 },
1633 taxonomy: Vec::new(),
1634 sections: vec![
1635 make_section(SectionRole::Anchor, Some("title"), title),
1636 make_section(SectionRole::Context, Some("body"), body),
1637 ],
1638 meta_prefix: None,
1639 }
1640 }
1641
1642 fn default_recipe(name: &str) -> TripletRecipe {
1643 TripletRecipe {
1644 name: name.to_string().into(),
1645 anchor: triplets_core::config::Selector::Role(SectionRole::Anchor),
1646 positive_selector: triplets_core::config::Selector::Role(SectionRole::Context),
1647 negative_selector: triplets_core::config::Selector::Role(SectionRole::Context),
1648 negative_strategy: triplets_core::config::NegativeStrategy::WrongArticle,
1649 weight: 1.0,
1650 instruction: None,
1651 allow_same_anchor_positive: false,
1652 }
1653 }
1654
1655 #[test]
1656 fn parse_helpers_validate_inputs() {
1657 assert_eq!(parse_batch_size("2").unwrap(), 2);
1658 assert!(parse_batch_size("0").is_err());
1659 assert!(parse_batch_size("abc").is_err());
1660 assert_eq!(parse_ingestion_max_records("16").unwrap(), 16);
1661 assert!(parse_ingestion_max_records("0").is_err());
1662 assert!(parse_batch_count("0").is_err());
1663
1664 let split = parse_split_ratios_arg("0.8,0.1,0.1").unwrap();
1665 assert!((split.train - 0.8).abs() < 1e-6);
1666 assert!(parse_split_ratios_arg("0.8,0.1").is_err());
1667 assert!(parse_split_ratios_arg("1.0,0.0,0.1").is_err());
1668 assert!(parse_split_ratios_arg("-0.1,0.6,0.5").is_err());
1669 }
1670
1671 #[test]
1672 fn fixture_and_ingestion_sources_trait_methods_cover_paths() {
1673 let records = vec![fixture_record("fixture_source", "r1", 1, "Title", "Body")];
1674 let recipes = vec![default_recipe("fixture_recipe")];
1675 let fixture = FixtureSource {
1676 id: "fixture_source".into(),
1677 records: records.clone(),
1678 recipes: recipes.clone(),
1679 };
1680
1681 let snapshot = fixture
1682 .refresh(&SamplerConfig::default(), None, None)
1683 .expect("fixture refresh should succeed");
1684 assert_eq!(snapshot.records.len(), 1);
1685 assert_eq!(
1686 fixture
1687 .reported_record_count(&SamplerConfig::default())
1688 .unwrap(),
1689 1
1690 );
1691 assert_eq!(fixture.default_triplet_recipes().len(), 1);
1692
1693 let source = IngestionConfigSource {
1694 expected_ingestion_max_records: 7,
1695 records,
1696 };
1697 let ok_cfg = SamplerConfig {
1698 ingestion_max_records: 7,
1699 ..SamplerConfig::default()
1700 };
1701 assert!(source.refresh(&ok_cfg, None, None).is_ok());
1702 assert_eq!(source.reported_record_count(&ok_cfg).unwrap(), 1);
1703 assert_eq!(source.default_triplet_recipes().len(), 1);
1704
1705 let bad_cfg = SamplerConfig {
1706 ingestion_max_records: 8,
1707 ..SamplerConfig::default()
1708 };
1709 let err = source.refresh(&bad_cfg, None, None).unwrap_err();
1710 assert!(matches!(err, SamplerError::SourceInconsistent { .. }));
1711 }
1712
1713 #[test]
1714 fn suggested_balancing_weight_is_longest_normalized_and_bounded() {
1715 assert!((suggested_balancing_weight(100, 100) - 1.0).abs() < 1e-6);
1716 assert!((suggested_balancing_weight(400, 100) - 0.25).abs() < 1e-6);
1717 assert!((suggested_balancing_weight(400, 400) - 1.0).abs() < 1e-6);
1718 assert_eq!(suggested_balancing_weight(0, 100), 0.0);
1719 assert_eq!(suggested_balancing_weight(100, 0), 0.0);
1720 }
1721
1722 #[test]
1723 fn suggested_oversampling_weight_is_inverse_in_unit_interval() {
1724 assert!((suggested_oversampling_weight(100, 100) - 1.0).abs() < 1e-6);
1725 assert!((suggested_oversampling_weight(100, 400) - 0.25).abs() < 1e-6);
1726 assert!((suggested_oversampling_weight(100, 1000) - 0.1).abs() < 1e-6);
1727 assert_eq!(suggested_oversampling_weight(0, 100), 0.0);
1728 assert_eq!(suggested_oversampling_weight(100, 0), 0.0);
1729 }
1730
1731 #[test]
1732 fn parse_cli_handles_help_and_invalid_args() {
1733 let help = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--help"]).unwrap();
1734 assert!(help.is_none());
1735
1736 let err = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--unknown"]);
1737 assert!(err.is_err());
1738 }
1739
1740 #[test]
1741 fn run_estimate_capacity_succeeds_with_reported_counts() {
1742 let result = run_estimate_capacity(
1743 std::iter::empty::<String>(),
1744 |roots| {
1745 assert!(roots.is_empty());
1746 Ok(())
1747 },
1748 |_| {
1749 vec![Box::new(TestSource {
1750 id: "source_a".into(),
1751 count: Some(12),
1752 recipes: vec![default_recipe("r1")],
1753 }) as DynSource]
1754 },
1755 );
1756
1757 assert!(result.is_ok());
1758 }
1759
1760 #[test]
1761 fn run_estimate_capacity_errors_when_source_count_missing() {
1762 let result = run_estimate_capacity(
1763 std::iter::empty::<String>(),
1764 |_| Ok(()),
1765 |_| {
1766 vec![Box::new(TestSource {
1767 id: "source_missing".into(),
1768 count: None,
1769 recipes: vec![default_recipe("r1")],
1770 }) as DynSource]
1771 },
1772 );
1773
1774 let err = result.unwrap_err().to_string();
1775 assert!(err.contains("failed to report exact record count"));
1776 }
1777
1778 #[test]
1779 fn run_estimate_capacity_propagates_root_resolution_error() {
1780 let result = run_estimate_capacity(
1781 std::iter::empty::<String>(),
1782 |_| Err("root resolution failed".into()),
1783 empty_dyn_sources,
1784 );
1785
1786 let err = result.unwrap_err().to_string();
1787 assert!(err.contains("root resolution failed"));
1788 }
1789
1790 #[test]
1791 fn run_estimate_capacity_allows_empty_source_list() {
1792 let result =
1793 run_estimate_capacity(std::iter::empty::<String>(), |_| Ok(()), empty_dyn_sources);
1794
1795 assert!(result.is_ok());
1796 }
1797
1798 #[test]
1799 fn run_estimate_capacity_configures_sources_centrally_before_counting() {
1800 let result = run_estimate_capacity(
1801 std::iter::empty::<String>(),
1802 |_| Ok(()),
1803 |_| {
1804 vec![Box::new(ConfigRequiredSource {
1805 id: "requires_config".into(),
1806 expected_seed: 99,
1807 }) as DynSource]
1808 },
1809 );
1810
1811 assert!(result.is_ok());
1812 }
1813
1814 #[test]
1815 fn config_required_source_refresh_and_seed_mismatch_are_exercised() {
1816 let source = ConfigRequiredSource {
1817 id: "cfg-source".to_string(),
1818 expected_seed: 42,
1819 };
1820
1821 let refreshed = source
1822 .refresh(&SamplerConfig::default(), None, None)
1823 .unwrap();
1824 assert!(refreshed.records.is_empty());
1825
1826 let mismatched = source.reported_record_count(&SamplerConfig {
1827 seed: 7,
1828 ..SamplerConfig::default()
1829 });
1830 assert!(matches!(
1831 mismatched,
1832 Err(SamplerError::SourceInconsistent { .. })
1833 ));
1834
1835 assert!(source.default_triplet_recipes().is_empty());
1836 }
1837
1838 #[test]
1839 fn run_multi_source_demo_exhausted_paths_return_ok() {
1840 struct OneRecordSource;
1841
1842 impl DataSource for OneRecordSource {
1843 fn id(&self) -> &str {
1844 "one_record"
1845 }
1846
1847 fn refresh(
1848 &self,
1849 _config: &SamplerConfig,
1850 _cursor: Option<&SourceCursor>,
1851 _limit: Option<usize>,
1852 ) -> Result<SourceSnapshot, SamplerError> {
1853 let now = Utc::now();
1854 Ok(SourceSnapshot {
1855 records: vec![DataRecord {
1856 id: "one_record::r1".to_string(),
1857 source: "one_record".to_string(),
1858 created_at: now,
1859 updated_at: now,
1860 quality: QualityScore { trust: 1.0 },
1861 taxonomy: Vec::new(),
1862 sections: vec![
1863 RecordSection {
1864 role: SectionRole::Anchor,
1865 heading: Some("title".to_string()),
1866 text: "anchor".to_string(),
1867 sentences: vec!["anchor".to_string()],
1868 },
1869 RecordSection {
1870 role: SectionRole::Context,
1871 heading: Some("body".to_string()),
1872 text: "context".to_string(),
1873 sentences: vec!["context".to_string()],
1874 },
1875 ],
1876 meta_prefix: None,
1877 }],
1878 cursor: SourceCursor {
1879 last_seen: now,
1880 revision: 0,
1881 },
1882 })
1883 }
1884
1885 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1886 Ok(1)
1887 }
1888
1889 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1890 vec![default_recipe("single_record_recipe")]
1891 }
1892 }
1893
1894 let one = OneRecordSource;
1895 assert_eq!(
1896 one.reported_record_count(&SamplerConfig::default())
1897 .unwrap(),
1898 1
1899 );
1900 assert_eq!(one.default_triplet_recipes().len(), 1);
1901
1902 for mode in ["--pair-batch", "--text-recipes", ""] {
1903 let dir = tempdir().unwrap();
1904 let split_store_path = dir.path().join("split_store.bin");
1905 let mut args = vec![
1906 "--split-store-path".to_string(),
1907 split_store_path.to_string_lossy().to_string(),
1908 ];
1909 if !mode.is_empty() {
1910 args.push(mode.to_string());
1911 }
1912
1913 let result = run_multi_source_demo(
1914 args.into_iter(),
1915 |_| Ok(()),
1916 |_| vec![Box::new(OneRecordSource) as DynSource],
1917 );
1918 assert!(result.is_ok());
1919 }
1920 }
1921
1922 #[test]
1923 fn parse_multi_source_cli_handles_help_and_batch_size_validation() {
1924 let help = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--help"]).unwrap();
1925 assert!(help.is_none());
1926
1927 let err = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--batch-size", "0"]);
1928 assert!(err.is_err());
1929
1930 let err = parse_cli::<MultiSourceDemoCli, _>([
1931 "multi_source_demo",
1932 "--ingestion-max-records",
1933 "0",
1934 ]);
1935 assert!(err.is_err());
1936
1937 let parsed = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo"]);
1938 assert!(parsed.is_ok());
1939 }
1940
1941 #[test]
1942 fn run_debug_invalid_cli_args_return_errors() {
1943 let estimate = run_estimate_capacity(
1944 ["--unknown".to_string()].into_iter(),
1945 ok_unit_roots,
1946 empty_dyn_sources,
1947 );
1948 assert!(estimate.is_err());
1949
1950 let demo = run_multi_source_demo(
1951 ["--unknown".to_string()].into_iter(),
1952 ok_unit_roots,
1953 empty_dyn_sources,
1954 );
1955 assert!(demo.is_err());
1956 }
1957
1958 #[test]
1959 fn helper_and_error_refresh_source_methods_are_exercised() {
1960 assert!(ok_unit_roots(Vec::new()).is_ok());
1961 assert!(error_unit_roots(Vec::new()).is_err());
1962
1963 let source = ErrorRefreshSource {
1964 id: "error_refresh_source".to_string(),
1965 };
1966 assert_eq!(
1967 source
1968 .reported_record_count(&SamplerConfig::default())
1969 .unwrap(),
1970 1
1971 );
1972 assert_eq!(source.default_triplet_recipes().len(), 1);
1973 }
1974
1975 #[test]
1976 fn print_source_summary_handles_non_empty_ids() {
1977 let ids = [
1978 "source_a::r1",
1979 "source_a::r2",
1980 "source_b::r1",
1981 "source_without_delimiter",
1982 ];
1983 print_source_summary("non-empty summary", ids.into_iter());
1984 }
1985
1986 #[test]
1987 fn run_multi_source_demo_refresh_failures_degrade_to_exhausted_paths() {
1988 for mode in [
1989 vec!["--pair-batch".to_string()],
1990 vec!["--text-recipes".to_string()],
1991 vec!["--batches".to_string(), "1".to_string()],
1992 Vec::new(),
1993 ] {
1994 let dir = tempdir().unwrap();
1995 let split_store_path = dir.path().join("error_modes_split_store.bin");
1996 let mut args = mode;
1997 args.push("--split-store-path".to_string());
1998 args.push(split_store_path.to_string_lossy().to_string());
1999
2000 let result = run_multi_source_demo(
2001 args.into_iter(),
2002 |_| Ok(()),
2003 |_| {
2004 vec![Box::new(ErrorRefreshSource {
2005 id: "error_refresh_source".to_string(),
2006 }) as DynSource]
2007 },
2008 );
2009
2010 assert!(result.is_ok());
2011 }
2012 }
2013
2014 #[test]
2015 fn run_multi_source_demo_batches_exhausted_path_returns_ok() {
2016 let dir = tempdir().unwrap();
2017 let split_store_path = dir.path().join("batches_exhausted_split_store.bin");
2018 let args = vec![
2019 "--batches".to_string(),
2020 "3".to_string(),
2021 "--split-store-path".to_string(),
2022 split_store_path.to_string_lossy().to_string(),
2023 ];
2024
2025 let result = run_multi_source_demo(
2026 args.into_iter(),
2027 |_| Ok(()),
2028 |_| {
2029 vec![Box::new(FixtureSource {
2030 id: "batches_exhausted_source".into(),
2031 records: vec![fixture_record(
2032 "batches_exhausted_source",
2033 "r1",
2034 1,
2035 "Only one record",
2036 "Single record body",
2037 )],
2038 recipes: vec![default_recipe("batches_exhausted_recipe")],
2039 }) as DynSource]
2040 },
2041 );
2042
2043 assert!(result.is_ok());
2044 }
2045
2046 #[test]
2047 fn run_multi_source_demo_default_triplet_success_path_returns_ok() {
2048 let dir = tempdir().unwrap();
2049 let split_store_path = dir.path().join("default_triplet_success_split_store.bin");
2050 let args = vec![
2051 "--split-store-path".to_string(),
2052 split_store_path.to_string_lossy().to_string(),
2053 ];
2054
2055 let result = run_multi_source_demo(
2056 args.into_iter(),
2057 |_| Ok(()),
2058 |_| {
2059 vec![Box::new(FixtureSource {
2060 id: "default_triplet_success_source".into(),
2061 records: vec![
2062 fixture_record(
2063 "default_triplet_success_source",
2064 "r1",
2065 1,
2066 "Title one",
2067 "Body one",
2068 ),
2069 fixture_record(
2070 "default_triplet_success_source",
2071 "r2",
2072 2,
2073 "Title two",
2074 "Body two",
2075 ),
2076 fixture_record(
2077 "default_triplet_success_source",
2078 "r3",
2079 3,
2080 "Title three",
2081 "Body three",
2082 ),
2083 ],
2084 recipes: vec![default_recipe("default_triplet_success_recipe")],
2085 }) as DynSource]
2086 },
2087 );
2088
2089 assert!(result.is_ok());
2090 }
2091
2092 #[test]
2093 fn run_multi_source_demo_passes_ingestion_max_records_to_sources() {
2094 let dir = tempdir().unwrap();
2095 let split_store_path = dir.path().join("ingestion_config_split_store.bin");
2096 let expected = 7;
2097
2098 let result = run_multi_source_demo(
2099 [
2100 "--pair-batch".to_string(),
2101 "--ingestion-max-records".to_string(),
2102 expected.to_string(),
2103 "--split-store-path".to_string(),
2104 split_store_path.to_string_lossy().to_string(),
2105 ]
2106 .into_iter(),
2107 |_| Ok(()),
2108 |_| {
2109 vec![Box::new(IngestionConfigSource {
2110 expected_ingestion_max_records: expected,
2111 records: (1..=8)
2112 .map(|day| {
2113 fixture_record(
2114 "ingestion_config_source",
2115 &format!("r{day}"),
2116 day,
2117 &format!("Config headline {day}"),
2118 &format!("Config body {day}"),
2119 )
2120 })
2121 .collect(),
2122 }) as DynSource]
2123 },
2124 );
2125
2126 assert!(result.is_ok());
2127 }
2128
2129 #[test]
2130 fn parse_cli_handles_display_version_path() {
2131 #[derive(Debug, Parser)]
2132 #[command(name = "version_test", version = "1.0.0")]
2133 struct VersionCli {}
2134
2135 let parsed = parse_cli::<VersionCli, _>(["version_test", "--version"]).unwrap();
2136 assert!(parsed.is_none());
2137 }
2138
2139 #[test]
2140 fn run_multi_source_demo_list_text_recipes_path_succeeds() {
2141 let dir = tempdir().unwrap();
2142 let split_store_path = dir.path().join("recipes_split_store.bin");
2143 let mut args = vec![
2144 "--list-text-recipes".to_string(),
2145 "--split-store-path".to_string(),
2146 split_store_path.to_string_lossy().to_string(),
2147 ];
2148 let result = run_multi_source_demo(
2149 args.drain(..),
2150 |_| Ok(()),
2151 |_| {
2152 vec![Box::new(TestSource {
2153 id: "source_for_recipes".into(),
2154 count: Some(10),
2155 recipes: vec![default_recipe("recipe_a")],
2156 }) as DynSource]
2157 },
2158 );
2159
2160 assert!(result.is_ok());
2161 }
2162
2163 #[test]
2164 fn run_multi_source_demo_list_text_recipes_uses_explicit_split_store_path() {
2165 let dir = tempdir().unwrap();
2166 let split_store_path = dir.path().join("custom_split_store.bin");
2167 let args = vec![
2168 "--list-text-recipes".to_string(),
2169 "--split-store-path".to_string(),
2170 split_store_path.to_string_lossy().to_string(),
2171 ];
2172
2173 let result = run_multi_source_demo(
2174 args.into_iter(),
2175 |_| Ok(()),
2176 |_| {
2177 vec![Box::new(TestSource {
2178 id: "source_without_text_recipes".into(),
2179 count: Some(1),
2180 recipes: Vec::new(),
2181 }) as DynSource]
2182 },
2183 );
2184
2185 assert!(result.is_ok());
2186 }
2187
2188 #[test]
2189 fn run_multi_source_demo_sampling_modes_handle_empty_sources() {
2190 for mode in [
2191 vec!["--pair-batch".to_string()],
2192 vec!["--text-recipes".to_string()],
2193 vec![],
2194 ] {
2195 let dir = tempdir().unwrap();
2196 let split_store_path = dir.path().join("empty_sources_split_store.bin");
2197 let mut args = mode;
2198 args.push("--split-store-path".to_string());
2199 args.push(split_store_path.to_string_lossy().to_string());
2200 args.push("--split".to_string());
2201 args.push("validation".to_string());
2202
2203 let result = run_multi_source_demo(
2204 args.into_iter(),
2205 |_| Ok(()),
2206 |_| {
2207 vec![Box::new(TestSource {
2208 id: "source_empty".into(),
2209 count: Some(0),
2210 recipes: vec![default_recipe("recipe_empty")],
2211 }) as DynSource]
2212 },
2213 );
2214
2215 assert!(result.is_ok());
2216 }
2217 }
2218
2219 #[test]
2220 fn run_multi_source_demo_propagates_root_resolution_error() {
2221 let dir = tempdir().unwrap();
2222 let split_store_path = dir.path().join("root_resolution_error_store.bin");
2223 let result = run_multi_source_demo(
2224 [
2225 "--split-store-path".to_string(),
2226 split_store_path.to_string_lossy().to_string(),
2227 ]
2228 .into_iter(),
2229 |_| Err("demo root resolution failed".into()),
2230 empty_dyn_sources,
2231 );
2232
2233 let err = result.unwrap_err().to_string();
2234 assert!(err.contains("demo root resolution failed"));
2235 }
2236
2237 #[test]
2238 fn run_multi_source_demo_list_text_recipes_allows_empty_sources() {
2239 let dir = tempdir().unwrap();
2240 let split_store_path = dir.path().join("empty_source_list_recipes.bin");
2241 let result = run_multi_source_demo(
2242 [
2243 "--list-text-recipes".to_string(),
2244 "--split-store-path".to_string(),
2245 split_store_path.to_string_lossy().to_string(),
2246 ]
2247 .into_iter(),
2248 |_| Ok(()),
2249 empty_dyn_sources,
2250 );
2251
2252 assert!(result.is_ok());
2253 }
2254
2255 #[test]
2256 fn print_helpers_and_extract_source_cover_paths() {
2257 let split = SplitRatios::default();
2258 let store = DeterministicSplitStore::new(split, 42).unwrap();
2259 let strategy = ChunkingStrategy::default();
2260
2261 let anchor = RecordChunk {
2262 record_id: "source_a::rec1".to_string(),
2263 section_idx: 0,
2264 view: ChunkView::Window {
2265 index: 1,
2266 overlap: 2,
2267 span: 12,
2268 },
2269 text: "anchor text".to_string(),
2270 tokens_estimate: 8,
2271 quality: triplets_core::data::QualityScore { trust: 0.9 },
2272 kvp_meta: [(
2273 "date".to_string(),
2274 vec!["2025-01-01".to_string(), "Jan 1, 2025".to_string()],
2275 )]
2276 .into_iter()
2277 .collect(),
2278 };
2279 let positive = RecordChunk {
2280 record_id: "source_a::rec2".to_string(),
2281 section_idx: 1,
2282 view: ChunkView::SummaryFallback {
2283 strategy: "summary".to_string(),
2284 weight: 0.7,
2285 },
2286 text: "positive text".to_string(),
2287 tokens_estimate: 6,
2288 quality: triplets_core::data::QualityScore { trust: 0.8 },
2289 kvp_meta: Default::default(),
2290 };
2291 let negative = RecordChunk {
2292 record_id: "source_b::rec3".to_string(),
2293 section_idx: 2,
2294 view: ChunkView::Window {
2295 index: 0,
2296 overlap: 0,
2297 span: 16,
2298 },
2299 text: "negative text".to_string(),
2300 tokens_estimate: 7,
2301 quality: triplets_core::data::QualityScore { trust: 0.5 },
2302 kvp_meta: Default::default(),
2303 };
2304
2305 let triplet_batch = TripletBatch {
2306 triplets: vec![triplets_core::SampleTriplet {
2307 recipe: "triplet_recipe".to_string(),
2308 anchor: anchor.clone(),
2309 positive: positive.clone(),
2310 negative: negative.clone(),
2311 weight: 1.0,
2312 instruction: Some("triplet instruction".to_string()),
2313 }],
2314 };
2315 print_triplet_batch(&strategy, &triplet_batch, &store);
2316
2317 let pair_batch = SampleBatch {
2318 pairs: vec![triplets_core::SamplePair {
2319 recipe: "pair_recipe".to_string(),
2320 anchor: anchor.clone(),
2321 positive: positive.clone(),
2322 weight: 1.0,
2323 instruction: None,
2324 label: triplets_core::PairLabel::Positive,
2325 reason: Some("same topic".to_string()),
2326 }],
2327 };
2328 print_pair_batch(&strategy, &pair_batch, &store);
2329
2330 let text_batch = TextBatch {
2331 samples: vec![triplets_core::TextSample {
2332 recipe: "text_recipe".to_string(),
2333 chunk: negative,
2334 weight: 0.8,
2335 instruction: Some("text instruction".to_string()),
2336 }],
2337 };
2338 print_text_batch(&strategy, &text_batch, &store);
2339
2340 let recipes = vec![TextRecipe {
2341 name: "recipe_name".into(),
2342 selector: triplets_core::config::Selector::Role(SectionRole::Context),
2343 instruction: Some("instruction".into()),
2344 weight: 1.0,
2345 }];
2346 print_text_recipes(&recipes);
2347
2348 assert_eq!(extract_source("source_a::record"), "source_a");
2349 assert_eq!(extract_source("record-without-delimiter"), "unknown");
2350 }
2351
2352 #[test]
2353 fn split_arg_conversion_and_version_parse_paths_are_covered() {
2354 assert!(matches!(
2355 SplitLabel::from(SplitArg::Train),
2356 SplitLabel::Train
2357 ));
2358 assert!(matches!(
2359 SplitLabel::from(SplitArg::Validation),
2360 SplitLabel::Validation
2361 ));
2362 assert!(matches!(SplitLabel::from(SplitArg::Test), SplitLabel::Test));
2363 }
2364
2365 #[test]
2366 fn parse_split_ratios_reports_per_field_parse_errors() {
2367 assert!(
2368 parse_split_ratios_arg("x,0.1,0.9")
2369 .unwrap_err()
2370 .contains("invalid train ratio")
2371 );
2372 assert!(
2373 parse_split_ratios_arg("0.1,y,0.8")
2374 .unwrap_err()
2375 .contains("invalid validation ratio")
2376 );
2377 assert!(
2378 parse_split_ratios_arg("0.1,0.2,z")
2379 .unwrap_err()
2380 .contains("invalid test ratio")
2381 );
2382 }
2383
2384 #[test]
2385 fn run_multi_source_demo_exhausted_paths_are_handled() {
2386 for mode in [
2387 vec!["--pair-batch".to_string()],
2388 vec!["--text-recipes".to_string()],
2389 Vec::new(),
2390 ] {
2391 let dir = tempdir().unwrap();
2392 let split_store_path = dir.path().join("exhausted_split_store.bin");
2393 let mut args = mode;
2394 args.push("--split-store-path".to_string());
2395 args.push(split_store_path.to_string_lossy().to_string());
2396
2397 let result = run_multi_source_demo(
2398 args.into_iter(),
2399 |_| Ok(()),
2400 |_| {
2401 vec![Box::new(TestSource {
2402 id: "source_without_recipes".into(),
2403 count: Some(1),
2404 recipes: Vec::new(),
2405 }) as DynSource]
2406 },
2407 );
2408
2409 assert!(result.is_ok());
2410 }
2411 }
2412
2413 #[test]
2414 fn run_multi_source_demo_reset_recreates_split_store_and_samples() {
2415 let dir = tempdir().unwrap();
2416 let split_store_path = dir.path().join("reset_split_store.bin");
2417 std::fs::write(&split_store_path, b"stale-data").unwrap();
2418
2419 let args = vec![
2420 "--reset".to_string(),
2421 "--pair-batch".to_string(),
2422 "--split-store-path".to_string(),
2423 split_store_path.to_string_lossy().to_string(),
2424 ];
2425
2426 let result = run_multi_source_demo(
2427 args.into_iter(),
2428 |_| Ok(()),
2429 |_| {
2430 let recipes = vec![default_recipe("fixture_recipe")];
2431 let records: Vec<DataRecord> = (1..=8)
2432 .map(|day| {
2433 fixture_record(
2434 "fixture_source",
2435 &format!("r{day}"),
2436 day,
2437 &format!("Fixture headline {day}"),
2438 &format!("Fixture body content for day {day}."),
2439 )
2440 })
2441 .collect();
2442 vec![Box::new(FixtureSource {
2443 id: "fixture_source".into(),
2444 records,
2445 recipes,
2446 }) as DynSource]
2447 },
2448 );
2449
2450 assert!(result.is_ok());
2451 assert!(split_store_path.exists());
2452 let metadata = std::fs::metadata(&split_store_path).unwrap();
2453 assert!(metadata.len() > 0);
2454 }
2455
2456 #[test]
2457 fn run_multi_source_demo_batches_mode_executes_multiple_batches() {
2458 let dir = tempdir().unwrap();
2459 let split_store_path = dir.path().join("batches_split_store.bin");
2460 let args = vec![
2461 "--batches".to_string(),
2462 "2".to_string(),
2463 "--split-store-path".to_string(),
2464 split_store_path.to_string_lossy().to_string(),
2465 ];
2466
2467 let result = run_multi_source_demo(
2468 args.into_iter(),
2469 |_| Ok(()),
2470 |_| {
2471 let recipes = vec![default_recipe("batch_recipe")];
2472 vec![Box::new(FixtureSource {
2473 id: "batch_source".into(),
2474 records: vec![
2475 fixture_record(
2476 "batch_source",
2477 "r1",
2478 3,
2479 "Inflation cools in latest report",
2480 "Core inflation moderated compared with prior quarter.",
2481 ),
2482 fixture_record(
2483 "batch_source",
2484 "r2",
2485 4,
2486 "Labor market remains resilient",
2487 "Job openings remain elevated despite slower growth.",
2488 ),
2489 fixture_record(
2490 "batch_source",
2491 "r3",
2492 5,
2493 "Manufacturing sentiment stabilizes",
2494 "Survey data suggests output expectations are improving.",
2495 ),
2496 ],
2497 recipes,
2498 }) as DynSource]
2499 },
2500 );
2501
2502 assert!(result.is_ok());
2503 assert!(split_store_path.exists());
2504 }
2505
2506 #[test]
2507 fn managed_demo_split_store_path_resolves_under_cache_group() {
2508 let path = managed_demo_split_store_path().unwrap();
2509 assert!(path.ends_with(MULTI_SOURCE_DEMO_STORE_FILENAME));
2510 let parent = path
2511 .parent()
2512 .expect("managed split-store path should have a parent");
2513 assert!(parent.ends_with(PathBuf::from(MULTI_SOURCE_DEMO_GROUP)));
2514 }
2515
2516 #[test]
2517 fn run_multi_source_demo_help_returns_ok_without_work() {
2518 let no_help = run_multi_source_demo(
2519 std::iter::empty::<String>(),
2520 error_unit_roots,
2521 empty_dyn_sources,
2522 );
2523 assert!(
2524 no_help
2525 .expect_err("non-help path should attempt to resolve roots")
2526 .to_string()
2527 .contains("root-resolution-error")
2528 );
2529
2530 let result = run_multi_source_demo(
2531 ["--help".to_string()].into_iter(),
2532 ok_unit_roots,
2533 empty_dyn_sources,
2534 );
2535
2536 assert!(result.is_ok());
2537 }
2538
2539 #[test]
2540 fn run_estimate_capacity_help_returns_ok_without_work() {
2541 let result = run_estimate_capacity(
2542 ["--help".to_string()].into_iter(),
2543 ok_unit_roots,
2544 empty_dyn_sources,
2545 );
2546
2547 assert!(result.is_ok());
2548 }
2549
2550 #[test]
2551 fn run_multi_source_demo_pair_exhausted_branch_returns_ok() {
2552 let dir = tempdir().unwrap();
2553 let split_store_path = dir.path().join("pair_exhausted_split_store.bin");
2554 let args = vec![
2555 "--pair-batch".to_string(),
2556 "--split-store-path".to_string(),
2557 split_store_path.to_string_lossy().to_string(),
2558 ];
2559
2560 let result = run_multi_source_demo(
2561 args.into_iter(),
2562 |_| Ok(()),
2563 |_| {
2564 vec![Box::new(FixtureSource {
2565 id: "pair_exhausted_source".into(),
2566 records: vec![fixture_record(
2567 "pair_exhausted_source",
2568 "r1",
2569 1,
2570 "Single record title",
2571 "Single record body",
2572 )],
2573 recipes: vec![default_recipe("pair_exhausted_recipe")],
2574 }) as DynSource]
2575 },
2576 );
2577
2578 assert!(result.is_ok());
2579 }
2580
2581 #[test]
2582 fn run_multi_source_demo_uses_managed_split_store_path_when_not_provided() {
2583 let result = run_multi_source_demo(
2584 ["--list-text-recipes".to_string()].into_iter(),
2585 |_| Ok(()),
2586 |_| {
2587 vec![Box::new(TestSource {
2588 id: "managed_path_source".into(),
2589 count: Some(2),
2590 recipes: vec![default_recipe("managed_recipe")],
2591 }) as DynSource]
2592 },
2593 );
2594
2595 assert!(result.is_ok());
2596 }
2597
2598 #[test]
2599 fn run_multi_source_demo_reset_errors_when_target_is_directory() {
2600 let dir = tempdir().unwrap();
2601 let split_store_path = dir.path().join("split_store_dir");
2602 std::fs::create_dir(&split_store_path).unwrap();
2603
2604 let result = run_multi_source_demo(
2605 [
2606 "--reset".to_string(),
2607 "--split-store-path".to_string(),
2608 split_store_path.to_string_lossy().to_string(),
2609 ]
2610 .into_iter(),
2611 |_| Ok(()),
2612 empty_dyn_sources,
2613 );
2614
2615 let err = result.unwrap_err().to_string();
2616 assert!(err.contains("failed to remove split store"));
2617 }
2618
2619 #[test]
2620 fn print_summary_helpers_accept_empty_iterators() {
2621 print_source_summary("empty summary", std::iter::empty::<&str>());
2622 print_recipe_context_by_source("empty recipe context", std::iter::empty::<(&str, &str)>());
2623 }
2624
2625 #[cfg(feature = "extended-metrics")]
2626 #[test]
2627 fn metric_mean_median_handles_even_length_inputs() {
2628 let mut vals = [1.0, 4.0, 2.0, 3.0];
2629 let (mean, median) = metric_mean_median(&mut vals);
2630 assert!((mean - 2.5).abs() < 1e-6);
2631 assert!((median - 2.5).abs() < 1e-6);
2632 }
2633
2634 #[cfg(feature = "extended-metrics")]
2635 #[test]
2636 fn metric_mean_median_handles_odd_length_inputs() {
2637 let mut vals = [3.0, 1.0, 2.0];
2638 let (mean, median) = metric_mean_median(&mut vals);
2639 assert!((mean - 2.0).abs() < 1e-6);
2640 assert!((median - 2.0).abs() < 1e-6);
2641 }
2642
2643 #[cfg(feature = "extended-metrics")]
2644 #[test]
2645 fn print_metric_summary_includes_multi_source_aggregate() {
2646 let source_data = HashMap::from([
2647 (
2648 "source_a".to_string(),
2649 vec![(0.9, 0.8, 0.2, 0.1, 0.7), (0.8, 0.7, 0.3, 0.2, 0.8)],
2650 ),
2651 (
2652 "source_b".to_string(),
2653 vec![(0.7, 0.6, 0.4, 0.3, 0.5), (0.6, 0.5, 0.5, 0.4, 0.6)],
2654 ),
2655 ]);
2656
2657 print_metric_summary(&source_data);
2658 }
2659}