twine_core/constraint/
non_negative.rs1use std::{cmp::Ordering, marker::PhantomData, ops::Add};
2
3use num_traits::Zero;
4
5use super::{Constrained, Constraint, ConstraintError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct NonNegative;
35
36impl NonNegative {
37 pub fn new<T: PartialOrd + Zero>(
43 value: T,
44 ) -> Result<Constrained<T, NonNegative>, ConstraintError> {
45 Constrained::<T, NonNegative>::new(value)
46 }
47
48 #[must_use]
52 pub fn zero<T: PartialOrd + Zero>() -> Constrained<T, NonNegative> {
53 Constrained::<T, NonNegative>::zero()
54 }
55}
56
57impl<T: PartialOrd + Zero> Constraint<T> for NonNegative {
58 fn check(value: &T) -> Result<(), ConstraintError> {
59 match value.partial_cmp(&T::zero()) {
60 Some(Ordering::Greater | Ordering::Equal) => Ok(()),
61 Some(Ordering::Less) => Err(ConstraintError::Negative),
62 None => Err(ConstraintError::NotANumber),
63 }
64 }
65}
66
67impl<T> Add for Constrained<T, NonNegative>
78where
79 T: Add<Output = T> + PartialOrd + Zero,
80{
81 type Output = Self;
82
83 fn add(self, rhs: Self) -> Self {
84 let value = self.value + rhs.value;
85 debug_assert!(
86 value >= T::zero(),
87 "Addition produced a negative value, violating NonNegative bound invariant"
88 );
89 Self {
90 value,
91 _marker: PhantomData,
92 }
93 }
94}
95
96impl<T> Zero for Constrained<T, NonNegative>
97where
98 T: PartialOrd + Zero,
99{
100 fn zero() -> Self {
101 Self {
102 value: T::zero(),
103 _marker: PhantomData,
104 }
105 }
106
107 fn is_zero(&self) -> bool {
108 self.value == T::zero()
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 use uom::si::{f64::MassRate, mass_rate::kilogram_per_second};
117
118 #[test]
119 fn integers() {
120 let one = Constrained::<i32, NonNegative>::new(1).unwrap();
121 assert_eq!(one.into_inner(), 1);
122
123 let two = NonNegative::new(2).unwrap();
124 assert_eq!(two.as_ref(), &2);
125
126 let zero = NonNegative::zero();
127 assert_eq!(zero.into_inner(), 0);
128
129 let sum = one + two + zero;
130 assert_eq!(sum.into_inner(), 3);
131
132 assert!(NonNegative::new(-1).is_err());
133 }
134
135 #[test]
136 fn floats() {
137 assert!(Constrained::<f64, NonNegative>::new(2.0).is_ok());
138 assert!(NonNegative::new(0.0).is_ok());
139 assert!(NonNegative::new(-2.0).is_err());
140 assert!(NonNegative::new(f64::NAN).is_err());
141 }
142
143 #[test]
144 fn mass_rates() {
145 let mass_rate = MassRate::new::<kilogram_per_second>(5.0);
146 assert!(NonNegative::new(mass_rate).is_ok());
147
148 let mass_rate = MassRate::new::<kilogram_per_second>(0.0);
149 assert!(NonNegative::new(mass_rate).is_ok());
150
151 let mass_rate = MassRate::new::<kilogram_per_second>(-2.0);
152 assert!(NonNegative::new(mass_rate).is_err());
153 }
154}