1use serde::{Deserialize, Serialize};
27use sha2::{Digest, Sha256};
28
29use crate::distribution::Distribution;
30
31#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[non_exhaustive]
36#[serde(tag = "kind")]
37pub enum FactorKind {
38 #[default]
40 Continuous,
41 Discrete,
44 Categorical { n: usize },
48 Boolean,
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56#[non_exhaustive]
57pub struct Factor {
58 pub name: String,
59 pub distribution: Distribution,
60 pub kind: FactorKind,
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct Group {
66 pub name: String,
67 pub factor_indices: Vec<usize>,
68}
69
70#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77#[non_exhaustive]
78pub struct Problem {
79 pub factors: Vec<Factor>,
80 pub groups: Option<Vec<Group>>,
82}
83
84impl Problem {
85 #[must_use]
87 pub fn dim(&self) -> usize {
88 self.factors.len()
89 }
90
91 #[must_use]
93 pub fn factors(&self) -> &[Factor] {
94 &self.factors
95 }
96
97 #[must_use]
105 #[allow(clippy::expect_used)]
106 pub fn content_hash(&self) -> [u8; 32] {
107 let bytes = serde_json::to_vec(self)
113 .expect("serializing Problem to JSON cannot fail (all plain data)");
114 let mut hasher = Sha256::new();
115 hasher.update(&bytes);
116 hasher.finalize().into()
117 }
118}
119
120#[derive(Debug, Default, Clone)]
122pub struct ProblemBuilder {
123 factors: Vec<Factor>,
124 groups: Vec<Group>,
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
129#[non_exhaustive]
130pub enum BuildError {
131 #[error("Problem must have at least one factor")]
133 Empty,
134 #[error("duplicate factor name: {name}")]
136 DuplicateName { name: String },
137 #[error("invalid distribution for factor {name}: {reason}")]
139 InvalidDistribution { name: String, reason: String },
140 #[error("Categorical factor {name} must have n >= 1")]
142 EmptyCategorical { name: String },
143 #[error("group {group} has empty factor_indices")]
145 EmptyGroup { group: String },
146 #[error("group {group}: factor index {index} out of range (dim={dim})")]
148 GroupIndexOutOfRange {
149 group: String,
150 index: usize,
151 dim: usize,
152 },
153 #[error("factor {index} appears in multiple groups")]
155 FactorInMultipleGroups { index: usize },
156}
157
158impl ProblemBuilder {
159 #[must_use]
161 pub fn new() -> Self {
162 Self::default()
163 }
164
165 #[must_use]
167 pub fn factor(mut self, name: &str, distribution: Distribution) -> Self {
168 self.factors.push(Factor {
169 name: name.to_string(),
170 distribution,
171 kind: FactorKind::Continuous,
172 });
173 self
174 }
175
176 #[must_use]
178 pub fn group(mut self, name: &str, factor_indices: &[usize]) -> Self {
179 self.groups.push(Group {
180 name: name.to_string(),
181 factor_indices: factor_indices.to_vec(),
182 });
183 self
184 }
185
186 #[must_use]
188 pub fn factor_with_kind(
189 mut self,
190 name: &str,
191 distribution: Distribution,
192 kind: FactorKind,
193 ) -> Self {
194 self.factors.push(Factor {
195 name: name.to_string(),
196 distribution,
197 kind,
198 });
199 self
200 }
201
202 pub fn build(self) -> Result<Problem, BuildError> {
206 if self.factors.is_empty() {
207 return Err(BuildError::Empty);
208 }
209
210 let mut seen: Vec<&str> = Vec::with_capacity(self.factors.len());
212 for f in &self.factors {
213 if seen.contains(&f.name.as_str()) {
214 return Err(BuildError::DuplicateName {
215 name: f.name.clone(),
216 });
217 }
218 seen.push(f.name.as_str());
219 }
220
221 for f in &self.factors {
223 validate_distribution(&f.distribution).map_err(|reason| {
224 BuildError::InvalidDistribution {
225 name: f.name.clone(),
226 reason,
227 }
228 })?;
229 if let FactorKind::Categorical { n: 0 } = f.kind {
230 return Err(BuildError::EmptyCategorical {
231 name: f.name.clone(),
232 });
233 }
234 }
235
236 let dim = self.factors.len();
238 let mut factor_group_owner: Vec<Option<usize>> = vec![None; dim];
239 for (gi, g) in self.groups.iter().enumerate() {
240 if g.factor_indices.is_empty() {
241 return Err(BuildError::EmptyGroup {
242 group: g.name.clone(),
243 });
244 }
245 for &idx in &g.factor_indices {
246 if idx >= dim {
247 return Err(BuildError::GroupIndexOutOfRange {
248 group: g.name.clone(),
249 index: idx,
250 dim,
251 });
252 }
253 if factor_group_owner[idx].is_some() {
254 return Err(BuildError::FactorInMultipleGroups { index: idx });
255 }
256 factor_group_owner[idx] = Some(gi);
257 }
258 }
259
260 let groups = if self.groups.is_empty() {
261 None
262 } else {
263 Some(self.groups)
264 };
265
266 Ok(Problem {
267 factors: self.factors,
268 groups,
269 })
270 }
271}
272
273#[allow(clippy::neg_cmp_op_on_partial_ord, clippy::nonminimal_bool)]
285fn validate_distribution(d: &Distribution) -> Result<(), String> {
286 match *d {
287 Distribution::Uniform { lo, hi } => {
288 if !(lo < hi) {
289 return Err(format!("Uniform: lo ({lo}) must be < hi ({hi})"));
290 }
291 }
292 Distribution::Normal { sigma, .. } => {
293 if !(sigma > 0.0) {
294 return Err(format!("Normal: sigma ({sigma}) must be > 0"));
295 }
296 }
297 Distribution::LogNormal { sigma_log, .. } => {
298 if !(sigma_log > 0.0) {
299 return Err(format!("LogNormal: sigma_log ({sigma_log}) must be > 0"));
300 }
301 }
302 Distribution::Triangular { lo, mode, hi } => {
303 if !(lo < hi) {
304 return Err(format!("Triangular: lo ({lo}) must be < hi ({hi})"));
305 }
306 if !(lo <= mode && mode <= hi) {
307 return Err(format!(
308 "Triangular: mode ({mode}) must be in [lo ({lo}), hi ({hi})]"
309 ));
310 }
311 }
312 Distribution::Beta {
313 alpha,
314 beta,
315 lo,
316 hi,
317 } => {
318 if !(alpha > 0.0) {
319 return Err(format!("Beta: alpha ({alpha}) must be > 0"));
320 }
321 if !(beta > 0.0) {
322 return Err(format!("Beta: beta ({beta}) must be > 0"));
323 }
324 if !(lo < hi) {
325 return Err(format!("Beta: lo ({lo}) must be < hi ({hi})"));
326 }
327 }
328 Distribution::Gamma { shape, scale } => {
329 if !(shape > 0.0) {
330 return Err(format!("Gamma: shape ({shape}) must be > 0"));
331 }
332 if !(scale > 0.0) {
333 return Err(format!("Gamma: scale ({scale}) must be > 0"));
334 }
335 }
336 Distribution::Weibull { shape, scale } => {
337 if !(shape > 0.0) {
338 return Err(format!("Weibull: shape ({shape}) must be > 0"));
339 }
340 if !(scale > 0.0) {
341 return Err(format!("Weibull: scale ({scale}) must be > 0"));
342 }
343 }
344 Distribution::Exponential { lambda } => {
345 if !(lambda > 0.0) {
346 return Err(format!("Exponential: lambda ({lambda}) must be > 0"));
347 }
348 }
349 Distribution::Bernoulli { p } => {
350 if !(0.0..=1.0).contains(&p) {
351 return Err(format!("Bernoulli: p ({p}) must be in [0, 1]"));
352 }
353 }
354 Distribution::DiscreteUniform { lo, hi } => {
355 if !(lo <= hi) {
356 return Err(format!("DiscreteUniform: lo ({lo}) must be <= hi ({hi})"));
357 }
358 }
359 }
360 Ok(())
361}
362
363#[cfg(test)]
364#[allow(clippy::float_cmp)]
365mod tests {
366 use super::*;
367
368 fn uniform(lo: f64, hi: f64) -> Distribution {
369 Distribution::Uniform { lo, hi }
370 }
371
372 #[test]
375 fn build_single_factor() {
376 let p = ProblemBuilder::new()
377 .factor("x", uniform(0.0, 1.0))
378 .build()
379 .expect("builds");
380 assert_eq!(p.dim(), 1);
381 assert_eq!(p.factors()[0].name, "x");
382 assert_eq!(p.factors()[0].kind, FactorKind::Continuous);
383 }
384
385 #[test]
386 fn build_three_factors() {
387 let p = ProblemBuilder::new()
388 .factor("a", uniform(0.0, 1.0))
389 .factor(
390 "b",
391 Distribution::Normal {
392 mu: 0.0,
393 sigma: 1.0,
394 },
395 )
396 .factor("c", Distribution::Exponential { lambda: 1.0 })
397 .build()
398 .expect("builds");
399 assert_eq!(p.dim(), 3);
400 let names: Vec<&str> = p.factors().iter().map(|f| f.name.as_str()).collect();
401 assert_eq!(names, vec!["a", "b", "c"]);
402 }
403
404 #[test]
405 fn build_with_kind_preserves_kind() {
406 let p = ProblemBuilder::new()
407 .factor_with_kind("d", uniform(0.0, 10.0), FactorKind::Discrete)
408 .factor_with_kind(
409 "c",
410 Distribution::DiscreteUniform { lo: 0, hi: 4 },
411 FactorKind::Categorical { n: 5 },
412 )
413 .factor_with_kind("b", Distribution::Bernoulli { p: 0.5 }, FactorKind::Boolean)
414 .build()
415 .expect("builds");
416 assert_eq!(p.factors()[0].kind, FactorKind::Discrete);
417 assert_eq!(p.factors()[1].kind, FactorKind::Categorical { n: 5 });
418 assert_eq!(p.factors()[2].kind, FactorKind::Boolean);
419 }
420
421 #[test]
422 fn factor_kind_default_is_continuous() {
423 assert_eq!(FactorKind::default(), FactorKind::Continuous);
424 }
425
426 #[test]
429 fn empty_builder_fails() {
430 let err = ProblemBuilder::new().build().unwrap_err();
431 assert_eq!(err, BuildError::Empty);
432 }
433
434 #[test]
435 fn duplicate_name_fails() {
436 let err = ProblemBuilder::new()
437 .factor("x", uniform(0.0, 1.0))
438 .factor("x", uniform(2.0, 3.0))
439 .build()
440 .unwrap_err();
441 assert_eq!(
442 err,
443 BuildError::DuplicateName {
444 name: "x".to_string()
445 }
446 );
447 }
448
449 #[test]
450 fn invalid_uniform_lo_geq_hi_fails() {
451 let err = ProblemBuilder::new()
452 .factor("x", Distribution::Uniform { lo: 1.0, hi: 1.0 })
453 .build()
454 .unwrap_err();
455 match err {
456 BuildError::InvalidDistribution { name, .. } => assert_eq!(name, "x"),
457 _ => panic!("wrong error variant: {err:?}"),
458 }
459 }
460
461 #[test]
462 fn invalid_normal_sigma_zero_fails() {
463 let err = ProblemBuilder::new()
464 .factor(
465 "x",
466 Distribution::Normal {
467 mu: 0.0,
468 sigma: 0.0,
469 },
470 )
471 .build()
472 .unwrap_err();
473 match err {
474 BuildError::InvalidDistribution { name, .. } => assert_eq!(name, "x"),
475 _ => panic!("wrong error variant: {err:?}"),
476 }
477 }
478
479 #[test]
480 fn invalid_beta_alpha_zero_fails() {
481 let err = ProblemBuilder::new()
482 .factor(
483 "x",
484 Distribution::Beta {
485 alpha: 0.0,
486 beta: 1.0,
487 lo: 0.0,
488 hi: 1.0,
489 },
490 )
491 .build()
492 .unwrap_err();
493 assert!(matches!(err, BuildError::InvalidDistribution { .. }));
494 }
495
496 #[test]
497 fn invalid_triangular_mode_outside_range_fails() {
498 let err = ProblemBuilder::new()
499 .factor(
500 "x",
501 Distribution::Triangular {
502 lo: 0.0,
503 mode: 2.0,
504 hi: 1.0,
505 },
506 )
507 .build()
508 .unwrap_err();
509 assert!(matches!(err, BuildError::InvalidDistribution { .. }));
510 }
511
512 #[test]
513 fn invalid_bernoulli_p_above_one_fails() {
514 let err = ProblemBuilder::new()
515 .factor("x", Distribution::Bernoulli { p: 1.5 })
516 .build()
517 .unwrap_err();
518 assert!(matches!(err, BuildError::InvalidDistribution { .. }));
519 }
520
521 #[test]
522 fn invalid_exponential_lambda_zero_fails() {
523 let err = ProblemBuilder::new()
524 .factor("x", Distribution::Exponential { lambda: 0.0 })
525 .build()
526 .unwrap_err();
527 assert!(matches!(err, BuildError::InvalidDistribution { .. }));
528 }
529
530 #[test]
531 fn empty_categorical_fails() {
532 let err = ProblemBuilder::new()
533 .factor_with_kind(
534 "x",
535 Distribution::DiscreteUniform { lo: 0, hi: 0 },
536 FactorKind::Categorical { n: 0 },
537 )
538 .build()
539 .unwrap_err();
540 assert_eq!(
541 err,
542 BuildError::EmptyCategorical {
543 name: "x".to_string()
544 }
545 );
546 }
547
548 #[test]
551 fn dim_matches_factor_count() {
552 let p = ProblemBuilder::new()
553 .factor("a", uniform(0.0, 1.0))
554 .factor("b", uniform(0.0, 1.0))
555 .factor("c", uniform(0.0, 1.0))
556 .build()
557 .expect("builds");
558 assert_eq!(p.dim(), 3);
559 }
560
561 #[test]
562 fn factors_returns_in_insertion_order() {
563 let p = ProblemBuilder::new()
564 .factor("alpha", uniform(0.0, 1.0))
565 .factor(
566 "beta",
567 Distribution::Normal {
568 mu: 0.0,
569 sigma: 1.0,
570 },
571 )
572 .factor("gamma", Distribution::Exponential { lambda: 1.0 })
573 .build()
574 .expect("builds");
575 let names: Vec<&str> = p.factors().iter().map(|f| f.name.as_str()).collect();
576 assert_eq!(names, vec!["alpha", "beta", "gamma"]);
577 }
578
579 #[test]
582 fn content_hash_is_stable_across_calls() {
583 let p = ProblemBuilder::new()
584 .factor("x", uniform(0.0, 1.0))
585 .build()
586 .expect("builds");
587 let h1 = p.content_hash();
588 let h2 = p.content_hash();
589 let h3 = p.content_hash();
590 assert_eq!(h1, h2);
591 assert_eq!(h2, h3);
592 }
593
594 #[test]
595 fn content_hash_equal_for_equal_problems() {
596 let make = || {
597 ProblemBuilder::new()
598 .factor("x", uniform(0.0, 1.0))
599 .factor(
600 "y",
601 Distribution::Normal {
602 mu: 0.0,
603 sigma: 2.0,
604 },
605 )
606 .build()
607 .expect("builds")
608 };
609 assert_eq!(make().content_hash(), make().content_hash());
610 }
611
612 #[test]
613 fn content_hash_distinct_for_different_distributions() {
614 let p1 = ProblemBuilder::new()
615 .factor("x", Distribution::Uniform { lo: 0.0, hi: 1.0 })
616 .build()
617 .expect("builds");
618 let p2 = ProblemBuilder::new()
619 .factor("x", Distribution::Uniform { lo: 0.0, hi: 2.0 })
620 .build()
621 .expect("builds");
622 assert_ne!(p1.content_hash(), p2.content_hash());
623 }
624
625 #[test]
626 fn content_hash_distinct_for_different_factor_names() {
627 let p1 = ProblemBuilder::new()
628 .factor("x", uniform(0.0, 1.0))
629 .build()
630 .expect("builds");
631 let p2 = ProblemBuilder::new()
632 .factor("y", uniform(0.0, 1.0))
633 .build()
634 .expect("builds");
635 assert_ne!(p1.content_hash(), p2.content_hash());
636 }
637
638 #[test]
639 fn content_hash_distinct_for_factor_order_swap() {
640 let p1 = ProblemBuilder::new()
641 .factor("a", uniform(0.0, 1.0))
642 .factor("b", uniform(2.0, 3.0))
643 .build()
644 .expect("builds");
645 let p2 = ProblemBuilder::new()
646 .factor("b", uniform(2.0, 3.0))
647 .factor("a", uniform(0.0, 1.0))
648 .build()
649 .expect("builds");
650 assert_ne!(p1.content_hash(), p2.content_hash());
652 }
653
654 #[test]
655 fn content_hash_distinct_for_different_kinds() {
656 let p1 = ProblemBuilder::new()
657 .factor("x", uniform(0.0, 1.0))
658 .build()
659 .expect("builds");
660 let p2 = ProblemBuilder::new()
661 .factor_with_kind("x", uniform(0.0, 1.0), FactorKind::Discrete)
662 .build()
663 .expect("builds");
664 assert_ne!(p1.content_hash(), p2.content_hash());
665 }
666
667 #[test]
668 fn content_hash_returns_thirty_two_bytes() {
669 let p = ProblemBuilder::new()
670 .factor("x", uniform(0.0, 1.0))
671 .build()
672 .expect("builds");
673 let h = p.content_hash();
674 assert_eq!(h.len(), 32);
675 }
676
677 #[test]
680 fn problem_serde_round_trip() {
681 let p = ProblemBuilder::new()
682 .factor("a", uniform(0.0, 1.0))
683 .factor(
684 "b",
685 Distribution::Beta {
686 alpha: 2.0,
687 beta: 5.0,
688 lo: 0.0,
689 hi: 1.0,
690 },
691 )
692 .factor_with_kind("c", Distribution::Bernoulli { p: 0.3 }, FactorKind::Boolean)
693 .build()
694 .expect("builds");
695 let json = serde_json::to_string(&p).expect("serialize");
696 let back: Problem = serde_json::from_str(&json).expect("deserialize");
697 assert_eq!(back, p);
698 assert_eq!(back.content_hash(), p.content_hash());
699 }
700
701 #[test]
702 fn factor_kind_serde_round_trip() {
703 let cases = vec![
704 FactorKind::Continuous,
705 FactorKind::Discrete,
706 FactorKind::Categorical { n: 4 },
707 FactorKind::Boolean,
708 ];
709 for k in cases {
710 let json = serde_json::to_string(&k).expect("serialize");
711 let back: FactorKind = serde_json::from_str(&json).expect("deserialize");
712 assert_eq!(back, k);
713 }
714 }
715
716 #[test]
719 fn build_error_implements_display_and_debug() {
720 let err = BuildError::Empty;
721 let _ = format!("{err}");
722 let _ = format!("{err:?}");
723 let err = BuildError::DuplicateName { name: "x".into() };
724 let _ = format!("{err}");
725 }
726
727 #[test]
730 fn grouped_problem_builds() {
731 let p = ProblemBuilder::new()
732 .factor("x1", uniform(0.0, 1.0))
733 .factor("x2", uniform(0.0, 1.0))
734 .factor("x3", uniform(0.0, 1.0))
735 .group("shape", &[0, 1])
736 .group("scale", &[2])
737 .build()
738 .unwrap();
739 assert_eq!(p.groups.as_ref().unwrap().len(), 2);
740 }
741
742 #[test]
743 fn no_groups_gives_none() {
744 let p = ProblemBuilder::new()
745 .factor("x1", uniform(0.0, 1.0))
746 .build()
747 .unwrap();
748 assert!(p.groups.is_none());
749 }
750
751 #[test]
752 fn group_index_out_of_range_fails() {
753 let result = ProblemBuilder::new()
754 .factor("x1", uniform(0.0, 1.0))
755 .group("bad", &[5])
756 .build();
757 assert!(result.is_err());
758 }
759
760 #[test]
761 fn factor_in_multiple_groups_fails() {
762 let result = ProblemBuilder::new()
763 .factor("x1", uniform(0.0, 1.0))
764 .factor("x2", uniform(0.0, 1.0))
765 .group("a", &[0])
766 .group("b", &[0, 1])
767 .build();
768 assert!(result.is_err());
769 }
770
771 #[test]
772 fn empty_group_fails() {
773 let result = ProblemBuilder::new()
774 .factor("x1", uniform(0.0, 1.0))
775 .group("empty", &[])
776 .build();
777 assert!(result.is_err());
778 }
779}