1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::collections::HashMap;
5use std::error::Error;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::Once;
9#[cfg(test)]
10use std::sync::OnceLock;
11use std::time::Instant;
12#[cfg(test)]
13use tempfile::TempDir;
14
15use cache_manager::CacheRoot;
16use clap::{Parser, ValueEnum, error::ErrorKind};
17
18use triplets_core::config::{ChunkingStrategy, SamplerConfig, TripletRecipe};
19use triplets_core::constants::cache::{MULTI_SOURCE_DEMO_GROUP, MULTI_SOURCE_DEMO_STORE_FILENAME};
20use triplets_core::data::ChunkView;
21use triplets_core::heuristics::{
22 CapacityTotals, EFFECTIVE_NEGATIVES_PER_ANCHOR, EFFECTIVE_POSITIVES_PER_ANCHOR,
23 estimate_source_split_capacity_from_counts, format_replay_factor, format_u128_with_commas,
24 resolve_text_recipes_for_source, split_counts_for_total,
25};
26use triplets_core::metrics::{chunk_proximity_score, source_skew, window_chunk_distance};
27use triplets_core::sampler::chunk_weight;
28use triplets_core::source::DataSource;
29use triplets_core::splits::{FileSplitStore, SplitLabel, SplitRatios, SplitStore};
30use triplets_core::{
31 RecordChunk, SampleBatch, Sampler, SamplerError, SourceId, TextBatch, TextRecipe, TripletBatch,
32 TripletSampler,
33};
34
35type DynSource = Box<dyn DataSource + 'static>;
36
37#[cfg(feature = "extended-metrics")]
38type MetricEntry = (f32, f32, f32, f32, f32);
39
40#[cfg(feature = "extended-metrics")]
41type SourceMetricsMap = HashMap<String, Vec<MetricEntry>>;
42
43fn managed_demo_split_store_path() -> Result<PathBuf, String> {
44 #[cfg(test)]
45 {
46 static TEST_CACHE_ROOT: OnceLock<TempDir> = OnceLock::new();
47 let root = TEST_CACHE_ROOT.get_or_init(|| {
48 TempDir::new().expect("failed to create test demo split-store cache root")
49 });
50 let cache_root = CacheRoot::from_root(root.path());
51 let group = PathBuf::from(MULTI_SOURCE_DEMO_GROUP);
52 let dir = cache_root.ensure_group(&group).map_err(|err| {
53 format!(
54 "failed creating managed demo cache group '{}': {err}",
55 group.display()
56 )
57 })?;
58 Ok(dir.join(MULTI_SOURCE_DEMO_STORE_FILENAME))
59 }
60
61 #[cfg(not(test))]
62 {
63 let cache_root = CacheRoot::from_discovery()
64 .map_err(|err| format!("failed discovering managed cache root: {err}"))?;
65 let group = PathBuf::from(MULTI_SOURCE_DEMO_GROUP);
66 let dir = cache_root.ensure_group(&group).map_err(|err| {
67 format!(
68 "failed creating managed demo cache group '{}': {err}",
69 group.display()
70 )
71 })?;
72 Ok(dir.join(MULTI_SOURCE_DEMO_STORE_FILENAME))
73 }
74}
75
76fn init_example_tracing() {
77 static INIT: Once = Once::new();
78 INIT.call_once(|| {
79 let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
80 .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("triplets=info"));
81 let _ = tracing_subscriber::fmt()
82 .with_env_filter(env_filter)
83 .try_init();
84 });
85}
86
87#[derive(Debug, Clone, Copy, ValueEnum)]
88enum SplitArg {
90 Train,
91 Validation,
92 Test,
93}
94
95impl From<SplitArg> for SplitLabel {
96 fn from(value: SplitArg) -> Self {
97 match value {
98 SplitArg::Train => SplitLabel::Train,
99 SplitArg::Validation => SplitLabel::Validation,
100 SplitArg::Test => SplitLabel::Test,
101 }
102 }
103}
104
105#[derive(Debug, Parser)]
106#[command(
107 name = "estimate_capacity",
108 disable_help_subcommand = true,
109 about = "Metadata-only capacity estimation",
110 long_about = "Estimate record, pair, triplet, and text-sample capacity using source-reported counts only (no data refresh).",
111 after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
112)]
113struct EstimateCapacityCli {
115 #[arg(
116 long,
117 default_value_t = 99,
118 help = "Deterministic seed used for split allocation"
119 )]
120 seed: u64,
121 #[arg(
122 long = "split-ratios",
123 value_name = "TRAIN,VALIDATION,TEST",
124 value_parser = parse_split_ratios_arg,
125 default_value = "0.8,0.1,0.1",
126 help = "Comma-separated split ratios that must sum to 1.0"
127 )]
128 split: SplitRatios,
129 #[arg(
130 long = "source-root",
131 value_name = "PATH",
132 help = "Optional source root override, repeat as needed in source order"
133 )]
134 source_roots: Vec<String>,
135}
136
137#[derive(Debug, Parser)]
138#[command(
139 name = "multi_source_demo",
140 disable_help_subcommand = true,
141 about = "Run sampled batches from multiple sources",
142 long_about = "Sample triplet, pair, or text batches from multiple sources and persist split/epoch state.",
143 after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
144)]
145struct MultiSourceDemoCli {
152 #[arg(
153 long = "text-recipes",
154 help = "Emit a text batch instead of a triplet batch"
155 )]
156 show_text_samples: bool,
157 #[arg(
158 long = "pair-batch",
159 help = "Emit a pair batch instead of a triplet batch"
160 )]
161 show_pair_samples: bool,
162 #[arg(
163 long = "list-text-recipes",
164 help = "Print registered text recipes and exit"
165 )]
166 list_text_recipes: bool,
167 #[arg(
168 long = "batch-size",
169 default_value_t = 4,
170 value_parser = parse_batch_size,
171 help = "Batch size used for sampling"
172 )]
173 batch_size: usize,
174 #[arg(
175 long = "ingestion-max-records",
176 default_value_t = default_ingestion_max_records(),
177 value_parser = parse_ingestion_max_records,
178 help = "Per-source ingestion buffer target used while refreshing records"
179 )]
180 ingestion_max_records: usize,
181 #[arg(long, help = "Optional deterministic seed override")]
182 seed: Option<u64>,
183 #[arg(long, value_enum, help = "Target split to sample from")]
184 split: Option<SplitArg>,
185 #[arg(
186 long = "source-root",
187 value_name = "PATH",
188 help = "Optional source root override, repeat as needed in source order"
189 )]
190 source_roots: Vec<String>,
191 #[arg(
192 long = "split-store-path",
193 value_name = "SPLIT_STORE_PATH",
194 help = "Optional explicit path for persisted split/epoch state file"
195 )]
196 split_store_path: Option<PathBuf>,
197 #[arg(
198 long = "reset",
199 help = "Delete the persisted split/epoch state before sampling, restarting from epoch 0"
200 )]
201 reset: bool,
202 #[arg(
203 long = "batches",
204 value_name = "N",
205 value_parser = parse_batch_count,
206 help = "Run N triplet batches in succession, printing a timing line per batch and (with --features extended-metrics) a per-source similarity summary at the end"
207 )]
208 batches: Option<usize>,
209}
210
211#[derive(Debug, Clone)]
212struct SourceInventory {
214 source_id: String,
215 reported_records: u128,
216 triplet_recipes: Vec<TripletRecipe>,
217}
218
219pub fn run_estimate_capacity<R, Resolve, Build, I>(
224 args_iter: I,
225 resolve_roots: Resolve,
226 build_sources: Build,
227) -> Result<(), Box<dyn Error>>
228where
229 Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
230 Build: FnOnce(&R) -> Vec<DynSource>,
231 I: Iterator<Item = String>,
232{
233 init_example_tracing();
234
235 let Some(cli) = parse_cli::<EstimateCapacityCli, _>(
236 std::iter::once("estimate_capacity".to_string()).chain(args_iter),
237 )?
238 else {
239 return Ok(());
240 };
241
242 let roots = resolve_roots(cli.source_roots)?;
243
244 let config = SamplerConfig {
245 seed: cli.seed,
246 split: cli.split,
247 ..SamplerConfig::default()
248 };
249
250 let sources = build_sources(&roots);
251
252 let mut inventories = Vec::new();
253 for source in &sources {
254 let recipes = if config.recipes.is_empty() {
255 source.default_triplet_recipes()
256 } else {
257 config.recipes.clone()
258 };
259 let reported_records = source.reported_record_count(&config).map_err(|err| {
260 format!(
261 "source '{}' failed to report exact record count: {err}",
262 source.id()
263 )
264 })?;
265 inventories.push(SourceInventory {
266 source_id: source.id().to_string(),
267 reported_records,
268 triplet_recipes: recipes,
269 });
270 }
271
272 let mut per_source_split_counts: HashMap<(String, SplitLabel), u128> = HashMap::new();
273 let mut split_record_counts: HashMap<SplitLabel, u128> = HashMap::new();
274
275 for source in &inventories {
276 let counts = split_counts_for_total(source.reported_records, cli.split);
277 for (label, count) in counts {
278 per_source_split_counts.insert((source.source_id.clone(), label), count);
279 *split_record_counts.entry(label).or_insert(0) += count;
280 }
281 }
282
283 let mut totals_by_split: HashMap<SplitLabel, CapacityTotals> = HashMap::new();
284 let mut totals_by_source_and_split: HashMap<(String, SplitLabel), CapacityTotals> =
285 HashMap::new();
286
287 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
288 let mut totals = CapacityTotals::default();
289
290 for source in &inventories {
291 let source_split_records = per_source_split_counts
292 .get(&(source.source_id.clone(), split_label))
293 .copied()
294 .unwrap_or(0);
295
296 let triplet_recipes = &source.triplet_recipes;
297 let text_recipes = resolve_text_recipes_for_source(&config, triplet_recipes);
298
299 let capacity = estimate_source_split_capacity_from_counts(
300 source_split_records,
301 triplet_recipes,
302 &text_recipes,
303 );
304
305 totals_by_source_and_split.insert((source.source_id.clone(), split_label), capacity);
306
307 totals.triplets += capacity.triplets;
308 totals.effective_triplets += capacity.effective_triplets;
309 totals.pairs += capacity.pairs;
310 totals.text_samples += capacity.text_samples;
311 }
312
313 totals_by_split.insert(split_label, totals);
314 }
315
316 let min_nonzero_records_by_split: HashMap<SplitLabel, u128> =
317 [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test]
318 .into_iter()
319 .map(|split_label| {
320 let min_nonzero = inventories
321 .iter()
322 .filter_map(|source| {
323 per_source_split_counts
324 .get(&(source.source_id.clone(), split_label))
325 .copied()
326 })
327 .filter(|&records| records > 0)
328 .min()
329 .unwrap_or(0);
330 (split_label, min_nonzero)
331 })
332 .collect();
333
334 let min_nonzero_records_all_splits = inventories
335 .iter()
336 .map(|source| source.reported_records)
337 .filter(|&records| records > 0)
338 .min()
339 .unwrap_or(0);
340
341 println!("=== capacity estimate (length-only) ===");
342 println!("mode: metadata-only (no source.refresh calls)");
343 println!("classification: heuristic approximation (not exact)");
344 println!("split seed: {}", cli.seed);
345 println!(
346 "split ratios: train={:.4}, validation={:.4}, test={:.4}",
347 cli.split.train, cli.split.validation, cli.split.test
348 );
349 println!();
350
351 println!("[SOURCES]");
352 for source in &inventories {
353 println!(
354 " {} => reported records: {}",
355 source.source_id,
356 format_u128_with_commas(source.reported_records)
357 );
358 }
359 println!();
360
361 println!("[PER SOURCE BREAKDOWN]");
362 for source in &inventories {
363 println!(" {}", source.source_id);
364 let mut source_grand = CapacityTotals::default();
365 let mut source_total_records = 0u128;
366 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
367 let split_records = per_source_split_counts
368 .get(&(source.source_id.clone(), split_label))
369 .copied()
370 .unwrap_or(0);
371 source_total_records = source_total_records.saturating_add(split_records);
372 let split_longest_records = inventories
373 .iter()
374 .map(|candidate| {
375 per_source_split_counts
376 .get(&(candidate.source_id.clone(), split_label))
377 .copied()
378 .unwrap_or(0)
379 })
380 .max()
381 .unwrap_or(0);
382 let totals = totals_by_source_and_split
383 .get(&(source.source_id.clone(), split_label))
384 .copied()
385 .unwrap_or_default();
386 source_grand.triplets += totals.triplets;
387 source_grand.effective_triplets += totals.effective_triplets;
388 source_grand.pairs += totals.pairs;
389 source_grand.text_samples += totals.text_samples;
390 println!(" [{:?}]", split_label);
391 println!(" records: {}", format_u128_with_commas(split_records));
392 println!(
393 " triplet combinations: {}",
394 format_u128_with_commas(totals.triplets)
395 );
396 println!(
397 " effective sampled triplets (p={}, k={}): {}",
398 EFFECTIVE_POSITIVES_PER_ANCHOR,
399 EFFECTIVE_NEGATIVES_PER_ANCHOR,
400 format_u128_with_commas(totals.effective_triplets)
401 );
402 println!(
403 " pair combinations: {}",
404 format_u128_with_commas(totals.pairs)
405 );
406 println!(
407 " text samples: {}",
408 format_u128_with_commas(totals.text_samples)
409 );
410 println!(
411 " replay factor vs longest source: {}",
412 format_replay_factor(split_longest_records, split_records)
413 );
414 println!(
415 " suggested proportional-size batch weight (0-1): {:.4}",
416 suggested_balancing_weight(split_longest_records, split_records)
417 );
418 let split_smallest_nonzero = min_nonzero_records_by_split
419 .get(&split_label)
420 .copied()
421 .unwrap_or(0);
422 println!(
423 " suggested small-source-boost batch weight (0-1): {:.4}",
424 suggested_oversampling_weight(split_smallest_nonzero, split_records)
425 );
426 println!();
427 }
428 let longest_source_total = inventories
429 .iter()
430 .map(|candidate| candidate.reported_records)
431 .max()
432 .unwrap_or(0);
433 println!(" [ALL SPLITS FOR SOURCE]");
434 println!(
435 " triplet combinations: {}",
436 format_u128_with_commas(source_grand.triplets)
437 );
438 println!(
439 " effective sampled triplets (p={}, k={}): {}",
440 EFFECTIVE_POSITIVES_PER_ANCHOR,
441 EFFECTIVE_NEGATIVES_PER_ANCHOR,
442 format_u128_with_commas(source_grand.effective_triplets)
443 );
444 println!(
445 " pair combinations: {}",
446 format_u128_with_commas(source_grand.pairs)
447 );
448 println!(
449 " text samples: {}",
450 format_u128_with_commas(source_grand.text_samples)
451 );
452 println!(
453 " replay factor vs longest source: {}",
454 format_replay_factor(longest_source_total, source_total_records)
455 );
456 println!(
457 " suggested proportional-size batch weight (0-1): {:.4}",
458 suggested_balancing_weight(longest_source_total, source_total_records)
459 );
460 println!(
461 " suggested small-source-boost batch weight (0-1): {:.4}",
462 suggested_oversampling_weight(min_nonzero_records_all_splits, source_total_records)
463 );
464 println!();
465 }
466
467 let mut grand = CapacityTotals::default();
468 for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
469 let record_count = split_record_counts.get(&split_label).copied().unwrap_or(0);
470 let totals = totals_by_split
471 .get(&split_label)
472 .copied()
473 .unwrap_or_default();
474
475 grand.triplets += totals.triplets;
476 grand.effective_triplets += totals.effective_triplets;
477 grand.pairs += totals.pairs;
478 grand.text_samples += totals.text_samples;
479
480 println!("[{:?}]", split_label);
481 println!(" records: {}", format_u128_with_commas(record_count));
482 println!(
483 " triplet combinations: {}",
484 format_u128_with_commas(totals.triplets)
485 );
486 println!(
487 " effective sampled triplets (p={}, k={}): {}",
488 EFFECTIVE_POSITIVES_PER_ANCHOR,
489 EFFECTIVE_NEGATIVES_PER_ANCHOR,
490 format_u128_with_commas(totals.effective_triplets)
491 );
492 println!(
493 " pair combinations: {}",
494 format_u128_with_commas(totals.pairs)
495 );
496 println!(
497 " text samples: {}",
498 format_u128_with_commas(totals.text_samples)
499 );
500 println!();
501 }
502
503 println!("[ALL SPLITS TOTAL]");
504 println!(
505 " triplet combinations: {}",
506 format_u128_with_commas(grand.triplets)
507 );
508 println!(
509 " effective sampled triplets (p={}, k={}): {}",
510 EFFECTIVE_POSITIVES_PER_ANCHOR,
511 EFFECTIVE_NEGATIVES_PER_ANCHOR,
512 format_u128_with_commas(grand.effective_triplets)
513 );
514 println!(
515 " pair combinations: {}",
516 format_u128_with_commas(grand.pairs)
517 );
518 println!(
519 " text samples: {}",
520 format_u128_with_commas(grand.text_samples)
521 );
522 println!();
523 println!(
524 "Note: counts are heuristic, length-based estimates from source-reported totals and recipe structure. They are approximate, not exact, and assume anchor-positive pairs=records (one positive per anchor by default), negatives=source_records_in_split-1 (anchor excluded as its own negative), and at most one chunk/window realization per sample. In real-world chunked sampling, practical combinations are often higher, so treat this as a floor-like baseline."
525 );
526 println!();
527 println!(
528 "Effective sampled triplets apply a bounded training assumption: effective_triplets = records * p * k per triplet recipe, with defaults p={} positives per anchor and k={} negatives per anchor.",
529 EFFECTIVE_POSITIVES_PER_ANCHOR, EFFECTIVE_NEGATIVES_PER_ANCHOR
530 );
531 println!();
532 println!(
533 "Oversample loops are not inferred from this static report. To measure true oversampling (how many times sampling loops through the combination space), use observed sampled draw counts from an actual run."
534 );
535 println!();
536 println!(
537 "Suggested proportional-size batch weight (0-1) is source/max_source by record count: 1.0 for the largest source in scope, smaller values for smaller sources."
538 );
539 println!();
540 println!(
541 "Suggested small-source-boost batch weight (0-1) is min_nonzero_source/source by record count: 1.0 for the smallest non-zero source in scope, smaller values for larger sources."
542 );
543 println!();
544 println!(
545 "When passed to next_*_batch_with_weights, higher weight means that source is sampled more often relative to lower-weight sources."
546 );
547
548 Ok(())
549}
550
551pub fn run_multi_source_demo<R, Resolve, Build, I>(
556 args_iter: I,
557 resolve_roots: Resolve,
558 build_sources: Build,
559) -> Result<(), Box<dyn Error>>
560where
561 Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
562 Build: FnOnce(&R) -> Vec<DynSource>,
563 I: Iterator<Item = String>,
564{
565 init_example_tracing();
566
567 let Some(cli) = parse_cli::<MultiSourceDemoCli, _>(
568 std::iter::once("multi_source_demo".to_string()).chain(args_iter),
569 )?
570 else {
571 return Ok(());
572 };
573
574 let roots = resolve_roots(cli.source_roots)?;
575
576 let mut config = SamplerConfig::default();
577 config.seed = cli.seed.unwrap_or(config.seed);
578 config.batch_size = cli.batch_size;
579 config.ingestion_max_records = cli.ingestion_max_records;
580 config.chunking = Default::default();
581 let selected_split = cli.split.map(Into::into).unwrap_or(SplitLabel::Train);
582 config.split = SplitRatios::default();
583 config.allowed_splits = vec![selected_split];
584 let chunking = config.chunking.clone();
585 let config_snapshot = MultiSourceDemoConfigSnapshot {
586 seed: config.seed,
587 batch_size: config.batch_size,
588 ingestion_max_records: config.ingestion_max_records,
589 split: selected_split,
590 split_ratios: config.split,
591 max_window_tokens: config.chunking.max_window_tokens,
592 overlap_tokens: config.chunking.overlap_tokens.clone(),
593 summary_fallback_tokens: config.chunking.summary_fallback_tokens,
594 };
595
596 let split_store_path = if let Some(path) = cli.split_store_path {
597 path
598 } else {
599 managed_demo_split_store_path().map_err(|err| {
600 Box::<dyn Error>::from(format!("failed to resolve demo split-store path: {err}"))
601 })?
602 };
603
604 if cli.reset && split_store_path.exists() {
605 std::fs::remove_file(&split_store_path).map_err(|err| {
606 Box::<dyn Error>::from(format!(
607 "failed to remove split store '{}': {err}",
608 split_store_path.display()
609 ))
610 })?;
611 println!("Reset: removed {}", split_store_path.display());
612 }
613 println!(
614 "Persisting split assignments and epoch state to {}",
615 split_store_path.display()
616 );
617 let sources = build_sources(&roots);
618 let split_store = Arc::new(FileSplitStore::open(&split_store_path, config.split, 99)?);
619 let sampler = TripletSampler::new(config, split_store.clone());
620 for source in sources {
621 sampler.register_source(source);
622 }
623
624 if cli.show_pair_samples {
625 match sampler.next_pair_batch(selected_split) {
626 Ok(pair_batch) => {
627 if pair_batch.pairs.is_empty() {
628 println!("Pair sampling produced no results.");
629 } else {
630 print_pair_batch(&chunking, &pair_batch, split_store.as_ref());
631 }
632 sampler.save_sampler_state(None)?;
633 }
634 Err(SamplerError::Exhausted(name)) => {
635 eprintln!(
636 "Pair sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
637 name
638 );
639 }
640 Err(err) => return Err(err.into()),
641 }
642 } else if cli.show_text_samples {
643 match sampler.next_text_batch(selected_split) {
644 Ok(text_batch) => {
645 if text_batch.samples.is_empty() {
646 println!(
647 "Text sampling produced no results. Ensure each source has eligible sections."
648 );
649 } else {
650 print_text_batch(&chunking, &text_batch, split_store.as_ref());
651 }
652 sampler.save_sampler_state(None)?;
653 }
654 Err(SamplerError::Exhausted(name)) => {
655 eprintln!(
656 "Text sampler exhausted selector '{}'. Ensure matching sections exist.",
657 name
658 );
659 }
660 Err(err) => return Err(err.into()),
661 }
662 } else if cli.list_text_recipes {
663 let recipes = sampler.text_recipes();
664 if recipes.is_empty() {
665 println!(
666 "No text recipes registered. Ensure your sources expose triplet selectors or configure text_recipes explicitly."
667 );
668 } else {
669 print_text_recipes(&recipes);
670 }
671 } else if let Some(batch_count) = cli.batches {
672 print_demo_config(&config_snapshot);
673 println!("=== benchmark: {} triplet batches ===", batch_count);
674
675 #[cfg(feature = "extended-metrics")]
677 let mut source_metrics: SourceMetricsMap = HashMap::new();
678
679 for i in 0..batch_count {
680 let t0 = Instant::now();
681 match sampler.next_triplet_batch(selected_split) {
682 Ok(batch) => {
683 let elapsed = t0.elapsed();
684 let n = batch.triplets.len();
685 println!(
686 "batch {:>4} triplets={:<4} elapsed={:>8.2}ms per_triplet={:.2}ms",
687 i + 1,
688 n,
689 elapsed.as_secs_f64() * 1000.0,
690 if n > 0 {
691 elapsed.as_secs_f64() * 1000.0 / n as f64
692 } else {
693 0.0
694 },
695 );
696 #[cfg(feature = "extended-metrics")]
697 {
698 use triplets_core::metrics::lexical_similarity_scores;
699 for triplet in &batch.triplets {
700 let (pj, pc) = lexical_similarity_scores(
701 &triplet.anchor.text,
702 &triplet.positive.text,
703 );
704 let (nj, nc) = lexical_similarity_scores(
705 &triplet.anchor.text,
706 &triplet.negative.text,
707 );
708 let proximity =
709 chunk_proximity_score(&triplet.anchor, &triplet.positive);
710 let source = extract_source(&triplet.anchor.record_id);
711 source_metrics
712 .entry(source)
713 .or_default()
714 .push((pj, pc, nj, nc, proximity));
715 }
716 }
717 }
718 Err(SamplerError::Exhausted(name)) => {
719 println!(
720 "batch {:>4} exhausted recipe '{}' — stopping early",
721 i + 1,
722 name
723 );
724 break;
725 }
726 Err(err) => return Err(err.into()),
727 }
728 }
729
730 sampler.save_sampler_state(None)?;
731
732 #[cfg(feature = "extended-metrics")]
733 if !source_metrics.is_empty() {
734 println!();
735 print_metric_summary(&source_metrics);
736 }
737
738 #[cfg(all(feature = "extended-metrics", feature = "bm25-mining"))]
739 {
740 let (fallback, total) = sampler.bm25_fallback_stats();
741 if total > 0 {
742 let pct = fallback as f64 / total as f64 * 100.0;
743 println!("bm25 fallback rate : {}/{} ({:.1}%)", fallback, total, pct);
744 }
745 }
746 } else {
747 match sampler.next_triplet_batch(selected_split) {
748 Ok(triplet_batch) => {
749 if triplet_batch.triplets.is_empty() {
750 println!(
751 "Triplet sampling produced no results. Ensure multiple records per source exist."
752 );
753 } else {
754 print_triplet_batch(&chunking, &triplet_batch, split_store.as_ref());
755 }
756 sampler.save_sampler_state(None)?;
757 #[cfg(all(feature = "extended-metrics", feature = "bm25-mining"))]
758 {
759 let (fallback, total) = sampler.bm25_fallback_stats();
760 if total > 0 {
761 let pct = fallback as f64 / total as f64 * 100.0;
762 println!("bm25 fallback rate : {}/{} ({:.1}%)", fallback, total, pct);
763 }
764 }
765 }
766 Err(SamplerError::Exhausted(name)) => {
767 eprintln!(
768 "Triplet sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
769 name
770 );
771 }
772 Err(err) => return Err(err.into()),
773 }
774 }
775
776 Ok(())
777}
778
779struct MultiSourceDemoConfigSnapshot {
780 seed: u64,
781 batch_size: usize,
782 ingestion_max_records: usize,
783 split: SplitLabel,
784 split_ratios: SplitRatios,
785 max_window_tokens: usize,
786 overlap_tokens: Vec<usize>,
787 summary_fallback_tokens: usize,
788}
789
790fn print_demo_config(cfg: &MultiSourceDemoConfigSnapshot) {
791 let overlaps: Vec<String> = cfg.overlap_tokens.iter().map(|t| t.to_string()).collect();
792 println!("=== sampler config ===");
793 println!("seed : {}", cfg.seed);
794 println!("batch_size : {}", cfg.batch_size);
795 println!("ingestion_max_records: {}", cfg.ingestion_max_records);
796 println!("split : {:?}", cfg.split);
797 println!(
798 "split_ratios : train={:.2} val={:.2} test={:.2}",
799 cfg.split_ratios.train, cfg.split_ratios.validation, cfg.split_ratios.test
800 );
801 println!("max_window_tokens : {}", cfg.max_window_tokens);
802 println!("overlap_tokens : [{}]", overlaps.join(", "));
803 println!(
804 "summary_fallback : {} tokens (0 = disabled)",
805 cfg.summary_fallback_tokens
806 );
807 println!();
808}
809
810fn default_ingestion_max_records() -> usize {
811 SamplerConfig::default().ingestion_max_records
812}
813
814fn parse_positive_usize_flag(raw: &str, flag: &str) -> Result<usize, String> {
815 let parsed = raw.parse::<usize>().map_err(|_| {
816 format!(
817 "Could not parse {} value '{}' as a positive integer",
818 flag, raw
819 )
820 })?;
821 if parsed == 0 {
822 return Err(format!("{} must be greater than zero", flag));
823 }
824 Ok(parsed)
825}
826
827fn parse_batch_size(raw: &str) -> Result<usize, String> {
828 parse_positive_usize_flag(raw, "--batch-size")
829}
830
831fn parse_ingestion_max_records(raw: &str) -> Result<usize, String> {
832 parse_positive_usize_flag(raw, "--ingestion-max-records")
833}
834
835fn parse_batch_count(raw: &str) -> Result<usize, String> {
836 parse_positive_usize_flag(raw, "--batches")
837}
838
839fn suggested_balancing_weight(max_baseline: u128, source_baseline: u128) -> f32 {
840 if max_baseline == 0 || source_baseline == 0 {
841 return 0.0;
842 }
843 (source_baseline as f64 / max_baseline as f64).clamp(0.0, 1.0) as f32
844}
845
846fn suggested_oversampling_weight(min_nonzero_baseline: u128, source_baseline: u128) -> f32 {
847 if min_nonzero_baseline == 0 || source_baseline == 0 {
848 return 0.0;
849 }
850 (min_nonzero_baseline as f64 / source_baseline as f64).clamp(0.0, 1.0) as f32
851}
852
853fn parse_cli<T, I>(args: I) -> Result<Option<T>, Box<dyn Error>>
854where
855 T: Parser,
856 I: IntoIterator,
857 I::Item: Into<std::ffi::OsString> + Clone,
858{
859 match T::try_parse_from(args) {
860 Ok(cli) => Ok(Some(cli)),
861 Err(err) => match err.kind() {
862 ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
863 err.print()?;
864 Ok(None)
865 }
866 _ => Err(err.into()),
867 },
868 }
869}
870
871fn parse_split_ratios_arg(raw: &str) -> Result<SplitRatios, String> {
872 let parts: Vec<&str> = raw.split(',').collect();
873 if parts.len() != 3 {
874 return Err("--split-ratios expects exactly 3 comma-separated values".to_string());
875 }
876 let train = parts[0]
877 .trim()
878 .parse::<f32>()
879 .map_err(|_| format!("invalid train ratio '{}': must be a float", parts[0].trim()))?;
880 let validation = parts[1].trim().parse::<f32>().map_err(|_| {
881 format!(
882 "invalid validation ratio '{}': must be a float",
883 parts[1].trim()
884 )
885 })?;
886 let test = parts[2]
887 .trim()
888 .parse::<f32>()
889 .map_err(|_| format!("invalid test ratio '{}': must be a float", parts[2].trim()))?;
890 let ratios = SplitRatios {
891 train,
892 validation,
893 test,
894 };
895 let sum = ratios.train + ratios.validation + ratios.test;
896 if (sum - 1.0).abs() > 1e-5 {
897 return Err(format!(
898 "split ratios must sum to 1.0, got {:.6} (train={}, validation={}, test={})",
899 sum, ratios.train, ratios.validation, ratios.test
900 ));
901 }
902 if ratios.train < 0.0 || ratios.validation < 0.0 || ratios.test < 0.0 {
903 return Err("split ratios must be non-negative".to_string());
904 }
905 Ok(ratios)
906}
907
908fn print_triplet_batch(
909 strategy: &ChunkingStrategy,
910 batch: &TripletBatch,
911 split_store: &impl SplitStore,
912) {
913 println!("=== triplet batch ===");
914 for (idx, triplet) in batch.triplets.iter().enumerate() {
915 println!("--- triplet #{} ---", idx);
916 println!("recipe : {}", triplet.recipe);
917 println!("sample_weight: {:.4}", triplet.weight);
918 if let Some(instr) = &triplet.instruction {
919 println!("instruction shown to model:");
920 println!("{instr}");
921 println!();
922 }
923 let pos_proximity = chunk_proximity_score(&triplet.anchor, &triplet.positive);
924 let pos_distance = window_chunk_distance(&triplet.anchor, &triplet.positive);
925 #[cfg(feature = "extended-metrics")]
926 let (pos_sim, neg_sim) = {
927 use triplets_core::metrics::lexical_similarity_scores;
928 (
929 Some(lexical_similarity_scores(
930 &triplet.anchor.text,
931 &triplet.positive.text,
932 )),
933 Some(lexical_similarity_scores(
934 &triplet.anchor.text,
935 &triplet.negative.text,
936 )),
937 )
938 };
939 #[cfg(not(feature = "extended-metrics"))]
940 let (pos_sim, neg_sim): (Option<(f32, f32)>, Option<(f32, f32)>) = (None, None);
941 print_chunk_block(
942 "ANCHOR",
943 &triplet.anchor,
944 strategy,
945 split_store,
946 None,
947 None,
948 None,
949 );
950 print_chunk_block(
951 "POSITIVE",
952 &triplet.positive,
953 strategy,
954 split_store,
955 pos_sim,
956 Some(pos_proximity),
957 pos_distance,
958 );
959 print_chunk_block(
960 "NEGATIVE",
961 &triplet.negative,
962 strategy,
963 split_store,
964 neg_sim,
965 None,
966 None,
967 );
968 }
969 print_source_summary(
970 "triplet anchors",
971 batch
972 .triplets
973 .iter()
974 .map(|triplet| triplet.anchor.record_id.as_str()),
975 );
976 print_recipe_context_by_source(
977 "triplet recipes by source",
978 batch
979 .triplets
980 .iter()
981 .map(|triplet| (triplet.anchor.record_id.as_str(), triplet.recipe.as_str())),
982 );
983}
984
985fn print_text_batch(strategy: &ChunkingStrategy, batch: &TextBatch, split_store: &impl SplitStore) {
986 println!("=== text batch ===");
987 for (idx, sample) in batch.samples.iter().enumerate() {
988 println!("--- sample #{} ---", idx);
989 println!("recipe : {}", sample.recipe);
990 println!("sample_weight: {:.4}", sample.weight);
991 if let Some(instr) = &sample.instruction {
992 println!("instruction shown to model:");
993 println!("{instr}");
994 println!();
995 }
996 print_chunk_block(
997 "TEXT",
998 &sample.chunk,
999 strategy,
1000 split_store,
1001 None,
1002 None,
1003 None,
1004 );
1005 }
1006 print_source_summary(
1007 "text samples",
1008 batch
1009 .samples
1010 .iter()
1011 .map(|sample| sample.chunk.record_id.as_str()),
1012 );
1013 print_recipe_context_by_source(
1014 "text recipes by source",
1015 batch
1016 .samples
1017 .iter()
1018 .map(|sample| (sample.chunk.record_id.as_str(), sample.recipe.as_str())),
1019 );
1020}
1021
1022fn print_pair_batch(
1023 strategy: &ChunkingStrategy,
1024 batch: &SampleBatch,
1025 split_store: &impl SplitStore,
1026) {
1027 println!("=== pair batch ===");
1028 for (idx, pair) in batch.pairs.iter().enumerate() {
1029 println!("--- pair #{} ---", idx);
1030 println!("recipe : {}", pair.recipe);
1031 println!("label : {:?}", pair.label);
1032 if let Some(reason) = &pair.reason {
1033 println!("reason : {}", reason);
1034 }
1035 print_chunk_block(
1036 "ANCHOR",
1037 &pair.anchor,
1038 strategy,
1039 split_store,
1040 None,
1041 None,
1042 None,
1043 );
1044 print_chunk_block(
1045 "OTHER",
1046 &pair.positive,
1047 strategy,
1048 split_store,
1049 None,
1050 None,
1051 None,
1052 );
1053 }
1054 print_source_summary(
1055 "pair anchors",
1056 batch
1057 .pairs
1058 .iter()
1059 .map(|pair| pair.anchor.record_id.as_str()),
1060 );
1061 print_recipe_context_by_source(
1062 "pair recipes by source",
1063 batch
1064 .pairs
1065 .iter()
1066 .map(|pair| (pair.anchor.record_id.as_str(), pair.recipe.as_str())),
1067 );
1068}
1069
1070fn print_text_recipes(recipes: &[TextRecipe]) {
1071 println!("=== available text recipes ===");
1072 for recipe in recipes {
1073 println!(
1074 "- {} (weight: {:.3}) selector={:?}",
1075 recipe.name, recipe.weight, recipe.selector
1076 );
1077 if let Some(instr) = &recipe.instruction {
1078 println!(" instruction: {}", instr);
1079 }
1080 }
1081}
1082
1083#[cfg(feature = "extended-metrics")]
1084fn metric_mean_median(vals: &mut [f32]) -> (f32, f32) {
1085 let mean = vals.iter().sum::<f32>() / vals.len() as f32;
1086 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1087 let median = if vals.len() % 2 == 1 {
1088 vals[vals.len() / 2]
1089 } else {
1090 (vals[vals.len() / 2 - 1] + vals[vals.len() / 2]) / 2.0
1091 };
1092 (mean, median)
1093}
1094
1095#[cfg(feature = "extended-metrics")]
1096fn print_metric_summary(source_data: &SourceMetricsMap) {
1097 let total: usize = source_data.values().map(|v| v.len()).sum();
1098 let n_sources = source_data.len();
1099 println!(
1100 "=== extended metrics summary ({} triplets, {} {}) ===",
1101 total,
1102 n_sources,
1103 if n_sources == 1 { "source" } else { "sources" }
1104 );
1105
1106 fn metric_pair(entries: &[MetricEntry], pos_idx: usize, neg_idx: usize) -> [(f32, f32); 2] {
1108 let extract = |idx: usize| -> Vec<f32> {
1109 entries
1110 .iter()
1111 .map(|e| match idx {
1112 0 => e.0,
1113 1 => e.1,
1114 2 => e.2,
1115 3 => e.3,
1116 _ => e.4,
1117 })
1118 .collect()
1119 };
1120 let mut pos_vals = extract(pos_idx);
1121 let mut neg_vals = extract(neg_idx);
1122 [
1123 metric_mean_median(&mut pos_vals),
1124 metric_mean_median(&mut neg_vals),
1125 ]
1126 }
1127
1128 fn print_metric_section(
1129 label: &str,
1130 sources: &[&String],
1131 source_data: &SourceMetricsMap,
1132 pos_idx: usize,
1133 neg_idx: usize,
1134 total: usize,
1135 n_sources: usize,
1136 ) {
1137 const SEP: usize = 83;
1138 println!();
1139 println!("[{}]", label);
1140 println!(
1141 "{:<24} {:>5} {:<16} {:<16} {:<16}",
1142 "source", "n", "positive", "negative", "gap (pos\u{2212}neg)"
1143 );
1144 println!(
1145 "{:<24} {:>5} {:<16} {:<16} {:<16}",
1146 "", "", "mean / median", "mean / median", "mean / median"
1147 );
1148 println!("{}", "-".repeat(SEP));
1149 for source in sources {
1150 let entries = &source_data[*source];
1151 let [pos, neg] = metric_pair(entries, pos_idx, neg_idx);
1152 let gap_mean = pos.0 - neg.0;
1153 let gap_med = pos.1 - neg.1;
1154 println!(
1155 "{:<24} {:>5} {:.3} / {:.3} {:.3} / {:.3} {:+.3} / {:+.3}",
1156 source,
1157 entries.len(),
1158 pos.0,
1159 pos.1,
1160 neg.0,
1161 neg.1,
1162 gap_mean,
1163 gap_med,
1164 );
1165 }
1166 if n_sources > 1 {
1167 let all: Vec<MetricEntry> = source_data.values().flatten().copied().collect();
1168 let [pos, neg] = metric_pair(&all, pos_idx, neg_idx);
1169 let gap_mean = pos.0 - neg.0;
1170 let gap_med = pos.1 - neg.1;
1171 println!("{}", "-".repeat(SEP));
1172 println!(
1173 "{:<24} {:>5} {:.3} / {:.3} {:.3} / {:.3} {:+.3} / {:+.3}",
1174 "ALL", total, pos.0, pos.1, neg.0, neg.1, gap_mean, gap_med,
1175 );
1176 }
1177 }
1178
1179 fn print_single_metric_section(
1180 label: &str,
1181 sources: &[&String],
1182 source_data: &SourceMetricsMap,
1183 idx: usize,
1184 total: usize,
1185 n_sources: usize,
1186 ) {
1187 const SEP: usize = 58;
1188 println!();
1189 println!("[{}]", label);
1190 println!("{:<24} {:>5} {:<16}", "source", "n", "mean / median");
1191 println!("{}", "-".repeat(SEP));
1192 for source in sources {
1193 let entries = &source_data[*source];
1194 let mut vals: Vec<f32> = entries
1195 .iter()
1196 .map(|e| match idx {
1197 0 => e.0,
1198 1 => e.1,
1199 2 => e.2,
1200 3 => e.3,
1201 _ => e.4,
1202 })
1203 .collect();
1204 let (mean, median) = metric_mean_median(&mut vals);
1205 println!(
1206 "{:<24} {:>5} {:.3} / {:.3}",
1207 source,
1208 entries.len(),
1209 mean,
1210 median,
1211 );
1212 }
1213 if n_sources > 1 {
1214 let mut all: Vec<f32> = source_data
1215 .values()
1216 .flatten()
1217 .map(|e| match idx {
1218 0 => e.0,
1219 1 => e.1,
1220 2 => e.2,
1221 3 => e.3,
1222 _ => e.4,
1223 })
1224 .collect();
1225 let (mean, median) = metric_mean_median(&mut all);
1226 println!("{}", "-".repeat(SEP));
1227 println!("{:<24} {:>5} {:.3} / {:.3}", "ALL", total, mean, median);
1228 }
1229 }
1230
1231 let mut sources: Vec<&String> = source_data.keys().collect();
1232 sources.sort();
1233
1234 print_metric_section(
1235 "jaccard \u{2194} anchor",
1236 &sources,
1237 source_data,
1238 0,
1239 2,
1240 total,
1241 n_sources,
1242 );
1243 print_metric_section(
1244 "byte-cos \u{2194} anchor",
1245 &sources,
1246 source_data,
1247 1,
1248 3,
1249 total,
1250 n_sources,
1251 );
1252 print_single_metric_section(
1253 "anchor-positive proximity",
1254 &sources,
1255 source_data,
1256 4,
1257 total,
1258 n_sources,
1259 );
1260 println!();
1261}
1262
1263trait ChunkDebug {
1264 fn view_name(&self) -> String;
1265}
1266
1267impl ChunkDebug for RecordChunk {
1268 fn view_name(&self) -> String {
1269 match &self.view {
1270 ChunkView::Window {
1271 index,
1272 span,
1273 overlap,
1274 } => format!(
1275 "window#index={} span={} overlap={} tokens={}",
1276 index, span, overlap, self.tokens_estimate
1277 ),
1278 ChunkView::SummaryFallback { strategy, .. } => {
1279 format!("summary:{} tokens={}", strategy, self.tokens_estimate)
1280 }
1281 }
1282 }
1283}
1284
1285fn print_chunk_block(
1286 title: &str,
1287 chunk: &RecordChunk,
1288 strategy: &ChunkingStrategy,
1289 split_store: &impl SplitStore,
1290 anchor_sim: Option<(f32, f32)>,
1291 ap_proximity: Option<f32>,
1292 ap_distance: Option<f32>,
1293) {
1294 let chunk_weight = chunk_weight(strategy, chunk);
1295 let split = split_store
1296 .label_for(&chunk.record_id)
1297 .map(|label| format!("{:?}", label))
1298 .unwrap_or_else(|| "Unknown".to_string());
1299 println!("--- {} ---", title);
1300 println!("split : {}", split);
1301 println!("view : {}", chunk.view_name());
1302 println!("chunk_weight : {:.4}", chunk_weight);
1303 println!("record_id : {}", chunk.record_id);
1304 println!("section_idx : {}", chunk.section_idx);
1305 println!("token_est : {}", chunk.tokens_estimate);
1306 if let Some(proximity) = ap_proximity {
1307 println!("a<->p proximity : {:.4}", proximity);
1308 }
1309 if let Some(distance) = ap_distance {
1310 println!("a<->p distance : {:.4}", distance);
1311 }
1312 if let Some((j, c)) = anchor_sim {
1313 println!("jaccard(↔a) : {:.4} byte-cos(↔a): {:.4}", j, c);
1314 }
1315 if !chunk.kvp_meta.is_empty() {
1316 let mut kvp_keys: Vec<&String> = chunk.kvp_meta.keys().collect();
1317 kvp_keys.sort();
1318 println!("kvp_meta :");
1319 for k in kvp_keys {
1320 let vals = &chunk.kvp_meta[k];
1321 let display = if vals.len() > 1 {
1322 format!("{} ({} variations)", vals[0], vals.len())
1323 } else {
1324 vals[0].clone()
1325 };
1326 println!("\t{}: {}", k, display);
1327 }
1328 }
1329 println!("model_input (exact text sent to the model):");
1330 println!("<<< BEGIN MODEL TEXT >>>");
1331 println!("{}", chunk.text);
1332 println!("<<< END MODEL TEXT >>>");
1333 println!();
1334}
1335
1336fn print_source_summary<'a, I>(label: &str, ids: I)
1337where
1338 I: Iterator<Item = &'a str>,
1339{
1340 let mut counts: HashMap<SourceId, usize> = HashMap::new();
1341 for id in ids {
1342 let source = extract_source(id);
1343 *counts.entry(source).or_insert(0) += 1;
1344 }
1345 if counts.is_empty() {
1346 return;
1347 }
1348 let skew = source_skew(&counts);
1349 let mut entries: Vec<(String, usize)> = counts.into_iter().collect();
1350 entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
1351 println!("--- {} by source ---", label);
1352 if let Some(skew) = skew {
1353 for entry in &skew.per_source {
1354 println!(
1355 "{}: count={} share={:.2}",
1356 entry.source, entry.count, entry.share
1357 );
1358 }
1359 println!(
1360 "skew: sources={} total={} min={} max={} mean={:.2} ratio={:.2}",
1361 skew.sources, skew.total, skew.min, skew.max, skew.mean, skew.ratio
1362 );
1363 }
1364}
1365
1366fn print_recipe_context_by_source<'a, I>(label: &str, entries: I)
1367where
1368 I: Iterator<Item = (&'a str, &'a str)>,
1369{
1370 let mut counts: HashMap<SourceId, HashMap<String, usize>> = HashMap::new();
1371 for (record_id, recipe) in entries {
1372 let source = extract_source(record_id);
1373 let entry = counts
1374 .entry(source)
1375 .or_default()
1376 .entry(recipe.to_string())
1377 .or_insert(0);
1378 *entry += 1;
1379 }
1380 if counts.is_empty() {
1381 return;
1382 }
1383 let mut sources: Vec<(SourceId, HashMap<String, usize>)> = counts.into_iter().collect();
1384 sources.sort_by(|a, b| a.0.cmp(&b.0));
1385 println!("--- {} ---", label);
1386 for (source, recipes) in sources {
1387 println!("{source}");
1388 let mut entries: Vec<(String, usize)> = recipes.into_iter().collect();
1389 entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
1390 for (recipe, count) in entries {
1391 println!(" - {recipe}={count}");
1392 }
1393 }
1394}
1395
1396fn extract_source(record_id: &str) -> SourceId {
1397 record_id
1398 .split_once("::")
1399 .map(|(source, _)| source.to_string())
1400 .unwrap_or_else(|| "unknown".to_string())
1401}
1402
1403#[cfg(test)]
1404mod tests {
1405 use super::*;
1406 use chrono::{TimeZone, Utc};
1407 use tempfile::tempdir;
1408 use triplets_core::DataRecord;
1409 use triplets_core::DeterministicSplitStore;
1410 use triplets_core::data::{QualityScore, RecordSection, SectionRole};
1411 use triplets_core::source::{SourceCursor, SourceSnapshot};
1412 use triplets_core::utils::make_section;
1413
1414 fn empty_dyn_sources(_: &()) -> Vec<DynSource> {
1415 Vec::new()
1416 }
1417
1418 fn ok_unit_roots(_: Vec<String>) -> Result<(), Box<dyn Error>> {
1419 Ok(())
1420 }
1421
1422 fn error_unit_roots(_: Vec<String>) -> Result<(), Box<dyn Error>> {
1423 Err("root-resolution-error".into())
1424 }
1425
1426 struct ErrorRefreshSource {
1427 id: String,
1428 }
1429
1430 impl DataSource for ErrorRefreshSource {
1431 fn id(&self) -> &str {
1432 &self.id
1433 }
1434
1435 fn refresh(
1436 &self,
1437 _config: &SamplerConfig,
1438 _cursor: Option<&SourceCursor>,
1439 _limit: Option<usize>,
1440 ) -> Result<SourceSnapshot, SamplerError> {
1441 Err(SamplerError::SourceUnavailable {
1442 source_id: self.id.clone(),
1443 reason: "simulated refresh failure".to_string(),
1444 })
1445 }
1446
1447 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1448 Ok(1)
1449 }
1450
1451 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1452 vec![default_recipe("error_refresh_recipe")]
1453 }
1454 }
1455
1456 struct TestSource {
1458 id: String,
1459 count: Option<u128>,
1460 recipes: Vec<TripletRecipe>,
1461 }
1462
1463 impl DataSource for TestSource {
1464 fn id(&self) -> &str {
1465 &self.id
1466 }
1467
1468 fn refresh(
1469 &self,
1470 _config: &SamplerConfig,
1471 _cursor: Option<&SourceCursor>,
1472 _limit: Option<usize>,
1473 ) -> Result<SourceSnapshot, SamplerError> {
1474 Ok(SourceSnapshot {
1475 records: Vec::new(),
1476 cursor: SourceCursor {
1477 last_seen: Utc::now(),
1478 revision: 0,
1479 },
1480 })
1481 }
1482
1483 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1484 self.count.ok_or_else(|| SamplerError::SourceInconsistent {
1485 source_id: self.id.clone(),
1486 details: "test source has no configured exact count".to_string(),
1487 })
1488 }
1489
1490 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1491 self.recipes.clone()
1492 }
1493 }
1494
1495 struct ConfigRequiredSource {
1496 id: String,
1497 expected_seed: u64,
1498 }
1499
1500 impl DataSource for ConfigRequiredSource {
1501 fn id(&self) -> &str {
1502 &self.id
1503 }
1504
1505 fn refresh(
1506 &self,
1507 _config: &SamplerConfig,
1508 _cursor: Option<&SourceCursor>,
1509 _limit: Option<usize>,
1510 ) -> Result<SourceSnapshot, SamplerError> {
1511 Ok(SourceSnapshot {
1512 records: Vec::new(),
1513 cursor: SourceCursor {
1514 last_seen: Utc::now(),
1515 revision: 0,
1516 },
1517 })
1518 }
1519
1520 fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError> {
1521 if config.seed == self.expected_seed {
1522 Ok(1)
1523 } else {
1524 Err(SamplerError::SourceInconsistent {
1525 source_id: self.id.clone(),
1526 details: format!(
1527 "expected sampler seed {} but got {}",
1528 self.expected_seed, config.seed
1529 ),
1530 })
1531 }
1532 }
1533
1534 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1535 Vec::new()
1536 }
1537 }
1538
1539 struct FixtureSource {
1540 id: String,
1541 records: Vec<DataRecord>,
1542 recipes: Vec<TripletRecipe>,
1543 }
1544
1545 impl DataSource for FixtureSource {
1546 fn id(&self) -> &str {
1547 &self.id
1548 }
1549
1550 fn refresh(
1551 &self,
1552 _config: &SamplerConfig,
1553 _cursor: Option<&SourceCursor>,
1554 _limit: Option<usize>,
1555 ) -> Result<SourceSnapshot, SamplerError> {
1556 Ok(SourceSnapshot {
1557 records: self.records.clone(),
1558 cursor: SourceCursor {
1559 last_seen: Utc::now(),
1560 revision: 0,
1561 },
1562 })
1563 }
1564
1565 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1566 Ok(self.records.len() as u128)
1567 }
1568
1569 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1570 self.recipes.clone()
1571 }
1572 }
1573
1574 struct IngestionConfigSource {
1575 expected_ingestion_max_records: usize,
1576 records: Vec<DataRecord>,
1577 }
1578
1579 impl DataSource for IngestionConfigSource {
1580 fn id(&self) -> &str {
1581 "ingestion_config_source"
1582 }
1583
1584 fn refresh(
1585 &self,
1586 config: &SamplerConfig,
1587 _cursor: Option<&SourceCursor>,
1588 _limit: Option<usize>,
1589 ) -> Result<SourceSnapshot, SamplerError> {
1590 if config.ingestion_max_records != self.expected_ingestion_max_records {
1591 return Err(SamplerError::SourceInconsistent {
1592 source_id: self.id().to_string(),
1593 details: format!(
1594 "expected ingestion_max_records {} but got {}",
1595 self.expected_ingestion_max_records, config.ingestion_max_records
1596 ),
1597 });
1598 }
1599 Ok(SourceSnapshot {
1600 records: self.records.clone(),
1601 cursor: SourceCursor {
1602 last_seen: Utc::now(),
1603 revision: 0,
1604 },
1605 })
1606 }
1607
1608 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1609 Ok(self.records.len() as u128)
1610 }
1611
1612 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1613 vec![default_recipe("ingestion_config_recipe")]
1614 }
1615 }
1616
1617 fn fixture_record(
1618 source: &str,
1619 id_suffix: &str,
1620 day: u32,
1621 title: &str,
1622 body: &str,
1623 ) -> DataRecord {
1624 let now = Utc.with_ymd_and_hms(2025, 1, day, 12, 0, 0).unwrap();
1625 DataRecord {
1626 id: format!("{source}::{id_suffix}"),
1627 source: source.to_string(),
1628 created_at: now,
1629 updated_at: now,
1630 quality: QualityScore { trust: 1.0 },
1631 taxonomy: Vec::new(),
1632 sections: vec![
1633 make_section(SectionRole::Anchor, Some("title"), title),
1634 make_section(SectionRole::Context, Some("body"), body),
1635 ],
1636 meta_prefix: None,
1637 }
1638 }
1639
1640 fn default_recipe(name: &str) -> TripletRecipe {
1641 TripletRecipe {
1642 name: name.to_string().into(),
1643 anchor: triplets_core::config::Selector::Role(SectionRole::Anchor),
1644 positive_selector: triplets_core::config::Selector::Role(SectionRole::Context),
1645 negative_selector: triplets_core::config::Selector::Role(SectionRole::Context),
1646 negative_strategy: triplets_core::config::NegativeStrategy::WrongArticle,
1647 weight: 1.0,
1648 instruction: None,
1649 allow_same_anchor_positive: false,
1650 }
1651 }
1652
1653 #[test]
1654 fn parse_helpers_validate_inputs() {
1655 assert_eq!(parse_batch_size("2").unwrap(), 2);
1656 assert!(parse_batch_size("0").is_err());
1657 assert!(parse_batch_size("abc").is_err());
1658 assert_eq!(parse_ingestion_max_records("16").unwrap(), 16);
1659 assert!(parse_ingestion_max_records("0").is_err());
1660 assert!(parse_batch_count("0").is_err());
1661
1662 let split = parse_split_ratios_arg("0.8,0.1,0.1").unwrap();
1663 assert!((split.train - 0.8).abs() < 1e-6);
1664 assert!(parse_split_ratios_arg("0.8,0.1").is_err());
1665 assert!(parse_split_ratios_arg("1.0,0.0,0.1").is_err());
1666 assert!(parse_split_ratios_arg("-0.1,0.6,0.5").is_err());
1667 }
1668
1669 #[test]
1670 fn fixture_and_ingestion_sources_trait_methods_cover_paths() {
1671 let records = vec![fixture_record("fixture_source", "r1", 1, "Title", "Body")];
1672 let recipes = vec![default_recipe("fixture_recipe")];
1673 let fixture = FixtureSource {
1674 id: "fixture_source".into(),
1675 records: records.clone(),
1676 recipes: recipes.clone(),
1677 };
1678
1679 let snapshot = fixture
1680 .refresh(&SamplerConfig::default(), None, None)
1681 .expect("fixture refresh should succeed");
1682 assert_eq!(snapshot.records.len(), 1);
1683 assert_eq!(
1684 fixture
1685 .reported_record_count(&SamplerConfig::default())
1686 .unwrap(),
1687 1
1688 );
1689 assert_eq!(fixture.default_triplet_recipes().len(), 1);
1690
1691 let source = IngestionConfigSource {
1692 expected_ingestion_max_records: 7,
1693 records,
1694 };
1695 let ok_cfg = SamplerConfig {
1696 ingestion_max_records: 7,
1697 ..SamplerConfig::default()
1698 };
1699 assert!(source.refresh(&ok_cfg, None, None).is_ok());
1700 assert_eq!(source.reported_record_count(&ok_cfg).unwrap(), 1);
1701 assert_eq!(source.default_triplet_recipes().len(), 1);
1702
1703 let bad_cfg = SamplerConfig {
1704 ingestion_max_records: 8,
1705 ..SamplerConfig::default()
1706 };
1707 let err = source.refresh(&bad_cfg, None, None).unwrap_err();
1708 assert!(matches!(err, SamplerError::SourceInconsistent { .. }));
1709 }
1710
1711 #[test]
1712 fn suggested_balancing_weight_is_longest_normalized_and_bounded() {
1713 assert!((suggested_balancing_weight(100, 100) - 1.0).abs() < 1e-6);
1714 assert!((suggested_balancing_weight(400, 100) - 0.25).abs() < 1e-6);
1715 assert!((suggested_balancing_weight(400, 400) - 1.0).abs() < 1e-6);
1716 assert_eq!(suggested_balancing_weight(0, 100), 0.0);
1717 assert_eq!(suggested_balancing_weight(100, 0), 0.0);
1718 }
1719
1720 #[test]
1721 fn suggested_oversampling_weight_is_inverse_in_unit_interval() {
1722 assert!((suggested_oversampling_weight(100, 100) - 1.0).abs() < 1e-6);
1723 assert!((suggested_oversampling_weight(100, 400) - 0.25).abs() < 1e-6);
1724 assert!((suggested_oversampling_weight(100, 1000) - 0.1).abs() < 1e-6);
1725 assert_eq!(suggested_oversampling_weight(0, 100), 0.0);
1726 assert_eq!(suggested_oversampling_weight(100, 0), 0.0);
1727 }
1728
1729 #[test]
1730 fn parse_cli_handles_help_and_invalid_args() {
1731 let help = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--help"]).unwrap();
1732 assert!(help.is_none());
1733
1734 let err = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--unknown"]);
1735 assert!(err.is_err());
1736 }
1737
1738 #[test]
1739 fn run_estimate_capacity_succeeds_with_reported_counts() {
1740 let result = run_estimate_capacity(
1741 std::iter::empty::<String>(),
1742 |roots| {
1743 assert!(roots.is_empty());
1744 Ok(())
1745 },
1746 |_| {
1747 vec![Box::new(TestSource {
1748 id: "source_a".into(),
1749 count: Some(12),
1750 recipes: vec![default_recipe("r1")],
1751 }) as DynSource]
1752 },
1753 );
1754
1755 assert!(result.is_ok());
1756 }
1757
1758 #[test]
1759 fn run_estimate_capacity_errors_when_source_count_missing() {
1760 let result = run_estimate_capacity(
1761 std::iter::empty::<String>(),
1762 |_| Ok(()),
1763 |_| {
1764 vec![Box::new(TestSource {
1765 id: "source_missing".into(),
1766 count: None,
1767 recipes: vec![default_recipe("r1")],
1768 }) as DynSource]
1769 },
1770 );
1771
1772 let err = result.unwrap_err().to_string();
1773 assert!(err.contains("failed to report exact record count"));
1774 }
1775
1776 #[test]
1777 fn run_estimate_capacity_propagates_root_resolution_error() {
1778 let result = run_estimate_capacity(
1779 std::iter::empty::<String>(),
1780 |_| Err("root resolution failed".into()),
1781 empty_dyn_sources,
1782 );
1783
1784 let err = result.unwrap_err().to_string();
1785 assert!(err.contains("root resolution failed"));
1786 }
1787
1788 #[test]
1789 fn run_estimate_capacity_allows_empty_source_list() {
1790 let result =
1791 run_estimate_capacity(std::iter::empty::<String>(), |_| Ok(()), empty_dyn_sources);
1792
1793 assert!(result.is_ok());
1794 }
1795
1796 #[test]
1797 fn run_estimate_capacity_configures_sources_centrally_before_counting() {
1798 let result = run_estimate_capacity(
1799 std::iter::empty::<String>(),
1800 |_| Ok(()),
1801 |_| {
1802 vec![Box::new(ConfigRequiredSource {
1803 id: "requires_config".into(),
1804 expected_seed: 99,
1805 }) as DynSource]
1806 },
1807 );
1808
1809 assert!(result.is_ok());
1810 }
1811
1812 #[test]
1813 fn config_required_source_refresh_and_seed_mismatch_are_exercised() {
1814 let source = ConfigRequiredSource {
1815 id: "cfg-source".to_string(),
1816 expected_seed: 42,
1817 };
1818
1819 let refreshed = source
1820 .refresh(&SamplerConfig::default(), None, None)
1821 .unwrap();
1822 assert!(refreshed.records.is_empty());
1823
1824 let mismatched = source.reported_record_count(&SamplerConfig {
1825 seed: 7,
1826 ..SamplerConfig::default()
1827 });
1828 assert!(matches!(
1829 mismatched,
1830 Err(SamplerError::SourceInconsistent { .. })
1831 ));
1832
1833 assert!(source.default_triplet_recipes().is_empty());
1834 }
1835
1836 #[test]
1837 fn run_multi_source_demo_exhausted_paths_return_ok() {
1838 struct OneRecordSource;
1839
1840 impl DataSource for OneRecordSource {
1841 fn id(&self) -> &str {
1842 "one_record"
1843 }
1844
1845 fn refresh(
1846 &self,
1847 _config: &SamplerConfig,
1848 _cursor: Option<&SourceCursor>,
1849 _limit: Option<usize>,
1850 ) -> Result<SourceSnapshot, SamplerError> {
1851 let now = Utc::now();
1852 Ok(SourceSnapshot {
1853 records: vec![DataRecord {
1854 id: "one_record::r1".to_string(),
1855 source: "one_record".to_string(),
1856 created_at: now,
1857 updated_at: now,
1858 quality: QualityScore { trust: 1.0 },
1859 taxonomy: Vec::new(),
1860 sections: vec![
1861 RecordSection {
1862 role: SectionRole::Anchor,
1863 heading: Some("title".to_string()),
1864 text: "anchor".to_string(),
1865 sentences: vec!["anchor".to_string()],
1866 },
1867 RecordSection {
1868 role: SectionRole::Context,
1869 heading: Some("body".to_string()),
1870 text: "context".to_string(),
1871 sentences: vec!["context".to_string()],
1872 },
1873 ],
1874 meta_prefix: None,
1875 }],
1876 cursor: SourceCursor {
1877 last_seen: now,
1878 revision: 0,
1879 },
1880 })
1881 }
1882
1883 fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
1884 Ok(1)
1885 }
1886
1887 fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
1888 vec![default_recipe("single_record_recipe")]
1889 }
1890 }
1891
1892 let one = OneRecordSource;
1893 assert_eq!(
1894 one.reported_record_count(&SamplerConfig::default())
1895 .unwrap(),
1896 1
1897 );
1898 assert_eq!(one.default_triplet_recipes().len(), 1);
1899
1900 for mode in ["--pair-batch", "--text-recipes", ""] {
1901 let dir = tempdir().unwrap();
1902 let split_store_path = dir.path().join("split_store.bin");
1903 let mut args = vec![
1904 "--split-store-path".to_string(),
1905 split_store_path.to_string_lossy().to_string(),
1906 ];
1907 if !mode.is_empty() {
1908 args.push(mode.to_string());
1909 }
1910
1911 let result = run_multi_source_demo(
1912 args.into_iter(),
1913 |_| Ok(()),
1914 |_| vec![Box::new(OneRecordSource) as DynSource],
1915 );
1916 assert!(result.is_ok());
1917 }
1918 }
1919
1920 #[test]
1921 fn parse_multi_source_cli_handles_help_and_batch_size_validation() {
1922 let help = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--help"]).unwrap();
1923 assert!(help.is_none());
1924
1925 let err = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--batch-size", "0"]);
1926 assert!(err.is_err());
1927
1928 let err = parse_cli::<MultiSourceDemoCli, _>([
1929 "multi_source_demo",
1930 "--ingestion-max-records",
1931 "0",
1932 ]);
1933 assert!(err.is_err());
1934
1935 let parsed = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo"]);
1936 assert!(parsed.is_ok());
1937 }
1938
1939 #[test]
1940 fn run_debug_invalid_cli_args_return_errors() {
1941 let estimate = run_estimate_capacity(
1942 ["--unknown".to_string()].into_iter(),
1943 ok_unit_roots,
1944 empty_dyn_sources,
1945 );
1946 assert!(estimate.is_err());
1947
1948 let demo = run_multi_source_demo(
1949 ["--unknown".to_string()].into_iter(),
1950 ok_unit_roots,
1951 empty_dyn_sources,
1952 );
1953 assert!(demo.is_err());
1954 }
1955
1956 #[test]
1957 fn helper_and_error_refresh_source_methods_are_exercised() {
1958 assert!(ok_unit_roots(Vec::new()).is_ok());
1959 assert!(error_unit_roots(Vec::new()).is_err());
1960
1961 let source = ErrorRefreshSource {
1962 id: "error_refresh_source".to_string(),
1963 };
1964 assert_eq!(
1965 source
1966 .reported_record_count(&SamplerConfig::default())
1967 .unwrap(),
1968 1
1969 );
1970 assert_eq!(source.default_triplet_recipes().len(), 1);
1971 }
1972
1973 #[test]
1974 fn print_source_summary_handles_non_empty_ids() {
1975 let ids = [
1976 "source_a::r1",
1977 "source_a::r2",
1978 "source_b::r1",
1979 "source_without_delimiter",
1980 ];
1981 print_source_summary("non-empty summary", ids.into_iter());
1982 }
1983
1984 #[test]
1985 fn run_multi_source_demo_refresh_failures_degrade_to_exhausted_paths() {
1986 for mode in [
1987 vec!["--pair-batch".to_string()],
1988 vec!["--text-recipes".to_string()],
1989 vec!["--batches".to_string(), "1".to_string()],
1990 Vec::new(),
1991 ] {
1992 let dir = tempdir().unwrap();
1993 let split_store_path = dir.path().join("error_modes_split_store.bin");
1994 let mut args = mode;
1995 args.push("--split-store-path".to_string());
1996 args.push(split_store_path.to_string_lossy().to_string());
1997
1998 let result = run_multi_source_demo(
1999 args.into_iter(),
2000 |_| Ok(()),
2001 |_| {
2002 vec![Box::new(ErrorRefreshSource {
2003 id: "error_refresh_source".to_string(),
2004 }) as DynSource]
2005 },
2006 );
2007
2008 assert!(result.is_ok());
2009 }
2010 }
2011
2012 #[test]
2013 fn run_multi_source_demo_batches_exhausted_path_returns_ok() {
2014 let dir = tempdir().unwrap();
2015 let split_store_path = dir.path().join("batches_exhausted_split_store.bin");
2016 let args = vec![
2017 "--batches".to_string(),
2018 "3".to_string(),
2019 "--split-store-path".to_string(),
2020 split_store_path.to_string_lossy().to_string(),
2021 ];
2022
2023 let result = run_multi_source_demo(
2024 args.into_iter(),
2025 |_| Ok(()),
2026 |_| {
2027 vec![Box::new(FixtureSource {
2028 id: "batches_exhausted_source".into(),
2029 records: vec![fixture_record(
2030 "batches_exhausted_source",
2031 "r1",
2032 1,
2033 "Only one record",
2034 "Single record body",
2035 )],
2036 recipes: vec![default_recipe("batches_exhausted_recipe")],
2037 }) as DynSource]
2038 },
2039 );
2040
2041 assert!(result.is_ok());
2042 }
2043
2044 #[test]
2045 fn run_multi_source_demo_default_triplet_success_path_returns_ok() {
2046 let dir = tempdir().unwrap();
2047 let split_store_path = dir.path().join("default_triplet_success_split_store.bin");
2048 let args = vec![
2049 "--split-store-path".to_string(),
2050 split_store_path.to_string_lossy().to_string(),
2051 ];
2052
2053 let result = run_multi_source_demo(
2054 args.into_iter(),
2055 |_| Ok(()),
2056 |_| {
2057 vec![Box::new(FixtureSource {
2058 id: "default_triplet_success_source".into(),
2059 records: vec![
2060 fixture_record(
2061 "default_triplet_success_source",
2062 "r1",
2063 1,
2064 "Title one",
2065 "Body one",
2066 ),
2067 fixture_record(
2068 "default_triplet_success_source",
2069 "r2",
2070 2,
2071 "Title two",
2072 "Body two",
2073 ),
2074 fixture_record(
2075 "default_triplet_success_source",
2076 "r3",
2077 3,
2078 "Title three",
2079 "Body three",
2080 ),
2081 ],
2082 recipes: vec![default_recipe("default_triplet_success_recipe")],
2083 }) as DynSource]
2084 },
2085 );
2086
2087 assert!(result.is_ok());
2088 }
2089
2090 #[test]
2091 fn run_multi_source_demo_passes_ingestion_max_records_to_sources() {
2092 let dir = tempdir().unwrap();
2093 let split_store_path = dir.path().join("ingestion_config_split_store.bin");
2094 let expected = 7;
2095
2096 let result = run_multi_source_demo(
2097 [
2098 "--pair-batch".to_string(),
2099 "--ingestion-max-records".to_string(),
2100 expected.to_string(),
2101 "--split-store-path".to_string(),
2102 split_store_path.to_string_lossy().to_string(),
2103 ]
2104 .into_iter(),
2105 |_| Ok(()),
2106 |_| {
2107 vec![Box::new(IngestionConfigSource {
2108 expected_ingestion_max_records: expected,
2109 records: (1..=8)
2110 .map(|day| {
2111 fixture_record(
2112 "ingestion_config_source",
2113 &format!("r{day}"),
2114 day,
2115 &format!("Config headline {day}"),
2116 &format!("Config body {day}"),
2117 )
2118 })
2119 .collect(),
2120 }) as DynSource]
2121 },
2122 );
2123
2124 assert!(result.is_ok());
2125 }
2126
2127 #[test]
2128 fn parse_cli_handles_display_version_path() {
2129 #[derive(Debug, Parser)]
2130 #[command(name = "version_test", version = "1.0.0")]
2131 struct VersionCli {}
2132
2133 let parsed = parse_cli::<VersionCli, _>(["version_test", "--version"]).unwrap();
2134 assert!(parsed.is_none());
2135 }
2136
2137 #[test]
2138 fn run_multi_source_demo_list_text_recipes_path_succeeds() {
2139 let dir = tempdir().unwrap();
2140 let split_store_path = dir.path().join("recipes_split_store.bin");
2141 let mut args = vec![
2142 "--list-text-recipes".to_string(),
2143 "--split-store-path".to_string(),
2144 split_store_path.to_string_lossy().to_string(),
2145 ];
2146 let result = run_multi_source_demo(
2147 args.drain(..),
2148 |_| Ok(()),
2149 |_| {
2150 vec![Box::new(TestSource {
2151 id: "source_for_recipes".into(),
2152 count: Some(10),
2153 recipes: vec![default_recipe("recipe_a")],
2154 }) as DynSource]
2155 },
2156 );
2157
2158 assert!(result.is_ok());
2159 }
2160
2161 #[test]
2162 fn run_multi_source_demo_list_text_recipes_uses_explicit_split_store_path() {
2163 let dir = tempdir().unwrap();
2164 let split_store_path = dir.path().join("custom_split_store.bin");
2165 let args = vec![
2166 "--list-text-recipes".to_string(),
2167 "--split-store-path".to_string(),
2168 split_store_path.to_string_lossy().to_string(),
2169 ];
2170
2171 let result = run_multi_source_demo(
2172 args.into_iter(),
2173 |_| Ok(()),
2174 |_| {
2175 vec![Box::new(TestSource {
2176 id: "source_without_text_recipes".into(),
2177 count: Some(1),
2178 recipes: Vec::new(),
2179 }) as DynSource]
2180 },
2181 );
2182
2183 assert!(result.is_ok());
2184 }
2185
2186 #[test]
2187 fn run_multi_source_demo_sampling_modes_handle_empty_sources() {
2188 for mode in [
2189 vec!["--pair-batch".to_string()],
2190 vec!["--text-recipes".to_string()],
2191 vec![],
2192 ] {
2193 let dir = tempdir().unwrap();
2194 let split_store_path = dir.path().join("empty_sources_split_store.bin");
2195 let mut args = mode;
2196 args.push("--split-store-path".to_string());
2197 args.push(split_store_path.to_string_lossy().to_string());
2198 args.push("--split".to_string());
2199 args.push("validation".to_string());
2200
2201 let result = run_multi_source_demo(
2202 args.into_iter(),
2203 |_| Ok(()),
2204 |_| {
2205 vec![Box::new(TestSource {
2206 id: "source_empty".into(),
2207 count: Some(0),
2208 recipes: vec![default_recipe("recipe_empty")],
2209 }) as DynSource]
2210 },
2211 );
2212
2213 assert!(result.is_ok());
2214 }
2215 }
2216
2217 #[test]
2218 fn run_multi_source_demo_propagates_root_resolution_error() {
2219 let dir = tempdir().unwrap();
2220 let split_store_path = dir.path().join("root_resolution_error_store.bin");
2221 let result = run_multi_source_demo(
2222 [
2223 "--split-store-path".to_string(),
2224 split_store_path.to_string_lossy().to_string(),
2225 ]
2226 .into_iter(),
2227 |_| Err("demo root resolution failed".into()),
2228 empty_dyn_sources,
2229 );
2230
2231 let err = result.unwrap_err().to_string();
2232 assert!(err.contains("demo root resolution failed"));
2233 }
2234
2235 #[test]
2236 fn run_multi_source_demo_list_text_recipes_allows_empty_sources() {
2237 let dir = tempdir().unwrap();
2238 let split_store_path = dir.path().join("empty_source_list_recipes.bin");
2239 let result = run_multi_source_demo(
2240 [
2241 "--list-text-recipes".to_string(),
2242 "--split-store-path".to_string(),
2243 split_store_path.to_string_lossy().to_string(),
2244 ]
2245 .into_iter(),
2246 |_| Ok(()),
2247 empty_dyn_sources,
2248 );
2249
2250 assert!(result.is_ok());
2251 }
2252
2253 #[test]
2254 fn print_helpers_and_extract_source_cover_paths() {
2255 let split = SplitRatios::default();
2256 let store = DeterministicSplitStore::new(split, 42).unwrap();
2257 let strategy = ChunkingStrategy::default();
2258
2259 let anchor = RecordChunk {
2260 record_id: "source_a::rec1".to_string(),
2261 section_idx: 0,
2262 view: ChunkView::Window {
2263 index: 1,
2264 overlap: 2,
2265 span: 12,
2266 },
2267 text: "anchor text".to_string(),
2268 tokens_estimate: 8,
2269 quality: triplets_core::data::QualityScore { trust: 0.9 },
2270 kvp_meta: [(
2271 "date".to_string(),
2272 vec!["2025-01-01".to_string(), "Jan 1, 2025".to_string()],
2273 )]
2274 .into_iter()
2275 .collect(),
2276 };
2277 let positive = RecordChunk {
2278 record_id: "source_a::rec2".to_string(),
2279 section_idx: 1,
2280 view: ChunkView::SummaryFallback {
2281 strategy: "summary".to_string(),
2282 weight: 0.7,
2283 },
2284 text: "positive text".to_string(),
2285 tokens_estimate: 6,
2286 quality: triplets_core::data::QualityScore { trust: 0.8 },
2287 kvp_meta: Default::default(),
2288 };
2289 let negative = RecordChunk {
2290 record_id: "source_b::rec3".to_string(),
2291 section_idx: 2,
2292 view: ChunkView::Window {
2293 index: 0,
2294 overlap: 0,
2295 span: 16,
2296 },
2297 text: "negative text".to_string(),
2298 tokens_estimate: 7,
2299 quality: triplets_core::data::QualityScore { trust: 0.5 },
2300 kvp_meta: Default::default(),
2301 };
2302
2303 let triplet_batch = TripletBatch {
2304 triplets: vec![triplets_core::SampleTriplet {
2305 recipe: "triplet_recipe".to_string(),
2306 anchor: anchor.clone(),
2307 positive: positive.clone(),
2308 negative: negative.clone(),
2309 weight: 1.0,
2310 instruction: Some("triplet instruction".to_string()),
2311 }],
2312 };
2313 print_triplet_batch(&strategy, &triplet_batch, &store);
2314
2315 let pair_batch = SampleBatch {
2316 pairs: vec![triplets_core::SamplePair {
2317 recipe: "pair_recipe".to_string(),
2318 anchor: anchor.clone(),
2319 positive: positive.clone(),
2320 weight: 1.0,
2321 instruction: None,
2322 label: triplets_core::PairLabel::Positive,
2323 reason: Some("same topic".to_string()),
2324 }],
2325 };
2326 print_pair_batch(&strategy, &pair_batch, &store);
2327
2328 let text_batch = TextBatch {
2329 samples: vec![triplets_core::TextSample {
2330 recipe: "text_recipe".to_string(),
2331 chunk: negative,
2332 weight: 0.8,
2333 instruction: Some("text instruction".to_string()),
2334 }],
2335 };
2336 print_text_batch(&strategy, &text_batch, &store);
2337
2338 let recipes = vec![TextRecipe {
2339 name: "recipe_name".into(),
2340 selector: triplets_core::config::Selector::Role(SectionRole::Context),
2341 instruction: Some("instruction".into()),
2342 weight: 1.0,
2343 }];
2344 print_text_recipes(&recipes);
2345
2346 assert_eq!(extract_source("source_a::record"), "source_a");
2347 assert_eq!(extract_source("record-without-delimiter"), "unknown");
2348 }
2349
2350 #[test]
2351 fn split_arg_conversion_and_version_parse_paths_are_covered() {
2352 assert!(matches!(
2353 SplitLabel::from(SplitArg::Train),
2354 SplitLabel::Train
2355 ));
2356 assert!(matches!(
2357 SplitLabel::from(SplitArg::Validation),
2358 SplitLabel::Validation
2359 ));
2360 assert!(matches!(SplitLabel::from(SplitArg::Test), SplitLabel::Test));
2361 }
2362
2363 #[test]
2364 fn parse_split_ratios_reports_per_field_parse_errors() {
2365 assert!(
2366 parse_split_ratios_arg("x,0.1,0.9")
2367 .unwrap_err()
2368 .contains("invalid train ratio")
2369 );
2370 assert!(
2371 parse_split_ratios_arg("0.1,y,0.8")
2372 .unwrap_err()
2373 .contains("invalid validation ratio")
2374 );
2375 assert!(
2376 parse_split_ratios_arg("0.1,0.2,z")
2377 .unwrap_err()
2378 .contains("invalid test ratio")
2379 );
2380 }
2381
2382 #[test]
2383 fn run_multi_source_demo_exhausted_paths_are_handled() {
2384 for mode in [
2385 vec!["--pair-batch".to_string()],
2386 vec!["--text-recipes".to_string()],
2387 Vec::new(),
2388 ] {
2389 let dir = tempdir().unwrap();
2390 let split_store_path = dir.path().join("exhausted_split_store.bin");
2391 let mut args = mode;
2392 args.push("--split-store-path".to_string());
2393 args.push(split_store_path.to_string_lossy().to_string());
2394
2395 let result = run_multi_source_demo(
2396 args.into_iter(),
2397 |_| Ok(()),
2398 |_| {
2399 vec![Box::new(TestSource {
2400 id: "source_without_recipes".into(),
2401 count: Some(1),
2402 recipes: Vec::new(),
2403 }) as DynSource]
2404 },
2405 );
2406
2407 assert!(result.is_ok());
2408 }
2409 }
2410
2411 #[test]
2412 fn run_multi_source_demo_reset_recreates_split_store_and_samples() {
2413 let dir = tempdir().unwrap();
2414 let split_store_path = dir.path().join("reset_split_store.bin");
2415 std::fs::write(&split_store_path, b"stale-data").unwrap();
2416
2417 let args = vec![
2418 "--reset".to_string(),
2419 "--pair-batch".to_string(),
2420 "--split-store-path".to_string(),
2421 split_store_path.to_string_lossy().to_string(),
2422 ];
2423
2424 let result = run_multi_source_demo(
2425 args.into_iter(),
2426 |_| Ok(()),
2427 |_| {
2428 let recipes = vec![default_recipe("fixture_recipe")];
2429 let records: Vec<DataRecord> = (1..=8)
2430 .map(|day| {
2431 fixture_record(
2432 "fixture_source",
2433 &format!("r{day}"),
2434 day,
2435 &format!("Fixture headline {day}"),
2436 &format!("Fixture body content for day {day}."),
2437 )
2438 })
2439 .collect();
2440 vec![Box::new(FixtureSource {
2441 id: "fixture_source".into(),
2442 records,
2443 recipes,
2444 }) as DynSource]
2445 },
2446 );
2447
2448 assert!(result.is_ok());
2449 assert!(split_store_path.exists());
2450 let metadata = std::fs::metadata(&split_store_path).unwrap();
2451 assert!(metadata.len() > 0);
2452 }
2453
2454 #[test]
2455 fn run_multi_source_demo_batches_mode_executes_multiple_batches() {
2456 let dir = tempdir().unwrap();
2457 let split_store_path = dir.path().join("batches_split_store.bin");
2458 let args = vec![
2459 "--batches".to_string(),
2460 "2".to_string(),
2461 "--split-store-path".to_string(),
2462 split_store_path.to_string_lossy().to_string(),
2463 ];
2464
2465 let result = run_multi_source_demo(
2466 args.into_iter(),
2467 |_| Ok(()),
2468 |_| {
2469 let recipes = vec![default_recipe("batch_recipe")];
2470 vec![Box::new(FixtureSource {
2471 id: "batch_source".into(),
2472 records: vec![
2473 fixture_record(
2474 "batch_source",
2475 "r1",
2476 3,
2477 "Inflation cools in latest report",
2478 "Core inflation moderated compared with prior quarter.",
2479 ),
2480 fixture_record(
2481 "batch_source",
2482 "r2",
2483 4,
2484 "Labor market remains resilient",
2485 "Job openings remain elevated despite slower growth.",
2486 ),
2487 fixture_record(
2488 "batch_source",
2489 "r3",
2490 5,
2491 "Manufacturing sentiment stabilizes",
2492 "Survey data suggests output expectations are improving.",
2493 ),
2494 ],
2495 recipes,
2496 }) as DynSource]
2497 },
2498 );
2499
2500 assert!(result.is_ok());
2501 assert!(split_store_path.exists());
2502 }
2503
2504 #[test]
2505 fn managed_demo_split_store_path_resolves_under_cache_group() {
2506 let path = managed_demo_split_store_path().unwrap();
2507 assert!(path.ends_with(MULTI_SOURCE_DEMO_STORE_FILENAME));
2508 let parent = path
2509 .parent()
2510 .expect("managed split-store path should have a parent");
2511 assert!(parent.ends_with(PathBuf::from(MULTI_SOURCE_DEMO_GROUP)));
2512 }
2513
2514 #[test]
2515 fn run_multi_source_demo_help_returns_ok_without_work() {
2516 let no_help = run_multi_source_demo(
2517 std::iter::empty::<String>(),
2518 error_unit_roots,
2519 empty_dyn_sources,
2520 );
2521 assert!(
2522 no_help
2523 .expect_err("non-help path should attempt to resolve roots")
2524 .to_string()
2525 .contains("root-resolution-error")
2526 );
2527
2528 let result = run_multi_source_demo(
2529 ["--help".to_string()].into_iter(),
2530 ok_unit_roots,
2531 empty_dyn_sources,
2532 );
2533
2534 assert!(result.is_ok());
2535 }
2536
2537 #[test]
2538 fn run_estimate_capacity_help_returns_ok_without_work() {
2539 let result = run_estimate_capacity(
2540 ["--help".to_string()].into_iter(),
2541 ok_unit_roots,
2542 empty_dyn_sources,
2543 );
2544
2545 assert!(result.is_ok());
2546 }
2547
2548 #[test]
2549 fn run_multi_source_demo_pair_exhausted_branch_returns_ok() {
2550 let dir = tempdir().unwrap();
2551 let split_store_path = dir.path().join("pair_exhausted_split_store.bin");
2552 let args = vec![
2553 "--pair-batch".to_string(),
2554 "--split-store-path".to_string(),
2555 split_store_path.to_string_lossy().to_string(),
2556 ];
2557
2558 let result = run_multi_source_demo(
2559 args.into_iter(),
2560 |_| Ok(()),
2561 |_| {
2562 vec![Box::new(FixtureSource {
2563 id: "pair_exhausted_source".into(),
2564 records: vec![fixture_record(
2565 "pair_exhausted_source",
2566 "r1",
2567 1,
2568 "Single record title",
2569 "Single record body",
2570 )],
2571 recipes: vec![default_recipe("pair_exhausted_recipe")],
2572 }) as DynSource]
2573 },
2574 );
2575
2576 assert!(result.is_ok());
2577 }
2578
2579 #[test]
2580 fn run_multi_source_demo_uses_managed_split_store_path_when_not_provided() {
2581 let result = run_multi_source_demo(
2582 ["--list-text-recipes".to_string()].into_iter(),
2583 |_| Ok(()),
2584 |_| {
2585 vec![Box::new(TestSource {
2586 id: "managed_path_source".into(),
2587 count: Some(2),
2588 recipes: vec![default_recipe("managed_recipe")],
2589 }) as DynSource]
2590 },
2591 );
2592
2593 assert!(result.is_ok());
2594 }
2595
2596 #[test]
2597 fn run_multi_source_demo_reset_errors_when_target_is_directory() {
2598 let dir = tempdir().unwrap();
2599 let split_store_path = dir.path().join("split_store_dir");
2600 std::fs::create_dir(&split_store_path).unwrap();
2601
2602 let result = run_multi_source_demo(
2603 [
2604 "--reset".to_string(),
2605 "--split-store-path".to_string(),
2606 split_store_path.to_string_lossy().to_string(),
2607 ]
2608 .into_iter(),
2609 |_| Ok(()),
2610 empty_dyn_sources,
2611 );
2612
2613 let err = result.unwrap_err().to_string();
2614 assert!(err.contains("failed to remove split store"));
2615 }
2616
2617 #[test]
2618 fn print_summary_helpers_accept_empty_iterators() {
2619 print_source_summary("empty summary", std::iter::empty::<&str>());
2620 print_recipe_context_by_source("empty recipe context", std::iter::empty::<(&str, &str)>());
2621 }
2622
2623 #[cfg(feature = "extended-metrics")]
2624 #[test]
2625 fn metric_mean_median_handles_even_length_inputs() {
2626 let mut vals = [1.0, 4.0, 2.0, 3.0];
2627 let (mean, median) = metric_mean_median(&mut vals);
2628 assert!((mean - 2.5).abs() < 1e-6);
2629 assert!((median - 2.5).abs() < 1e-6);
2630 }
2631
2632 #[cfg(feature = "extended-metrics")]
2633 #[test]
2634 fn metric_mean_median_handles_odd_length_inputs() {
2635 let mut vals = [3.0, 1.0, 2.0];
2636 let (mean, median) = metric_mean_median(&mut vals);
2637 assert!((mean - 2.0).abs() < 1e-6);
2638 assert!((median - 2.0).abs() < 1e-6);
2639 }
2640
2641 #[cfg(feature = "extended-metrics")]
2642 #[test]
2643 fn print_metric_summary_includes_multi_source_aggregate() {
2644 let source_data = HashMap::from([
2645 (
2646 "source_a".to_string(),
2647 vec![(0.9, 0.8, 0.2, 0.1, 0.7), (0.8, 0.7, 0.3, 0.2, 0.8)],
2648 ),
2649 (
2650 "source_b".to_string(),
2651 vec![(0.7, 0.6, 0.4, 0.3, 0.5), (0.6, 0.5, 0.5, 0.4, 0.6)],
2652 ),
2653 ]);
2654
2655 print_metric_summary(&source_data);
2656 }
2657}