snarkvm_circuit_environment/helpers/
count.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use core::{
17    fmt::{Debug, Display, Formatter},
18    ops::{Add, Mul, Sub},
19};
20
21pub type Constant = Measurement<u64>;
22pub type Public = Measurement<u64>;
23pub type Private = Measurement<u64>;
24pub type Constraints = Measurement<u64>;
25
26/// A helper struct for tracking the number of constants, public inputs, private inputs, and constraints.
27#[derive(Copy, Clone, Debug)]
28pub struct Count(pub Constant, pub Public, pub Private, pub Constraints);
29
30impl Count {
31    /// Returns a new `Count` whose constituent metrics are all `Exact`.
32    pub const fn zero() -> Self {
33        Count(Measurement::Exact(0), Measurement::Exact(0), Measurement::Exact(0), Measurement::Exact(0))
34    }
35
36    /// Returns a new `Count` whose constituent metrics are all `Exact`.
37    pub const fn is(num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> Self {
38        Count(
39            Measurement::Exact(num_constants),
40            Measurement::Exact(num_public),
41            Measurement::Exact(num_private),
42            Measurement::Exact(num_constraints),
43        )
44    }
45
46    /// Returns a new `Count` whose constituent metrics are all inclusive `UpperBound`.
47    pub const fn less_than(num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> Self {
48        Count(
49            Measurement::UpperBound(num_constants),
50            Measurement::UpperBound(num_public),
51            Measurement::UpperBound(num_private),
52            Measurement::UpperBound(num_constraints),
53        )
54    }
55
56    /// Returns `true` if all constituent metrics match.
57    pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
58        self.0.matches(num_constants)
59            && self.1.matches(num_public)
60            && self.2.matches(num_private)
61            && self.3.matches(num_constraints)
62    }
63}
64
65impl Add for Count {
66    type Output = Count;
67
68    /// Adds the `Count` to another `Count` by summing its constituent metrics.
69    fn add(self, other: Count) -> Self::Output {
70        Count(self.0 + other.0, self.1 + other.1, self.2 + other.2, self.3 + other.3)
71    }
72}
73
74impl Mul<u64> for Count {
75    type Output = Count;
76
77    /// Scales the `Count` by a `u64`.
78    fn mul(self, other: u64) -> Self::Output {
79        Count(self.0 * other, self.1 * other, self.2 * other, self.3 * other)
80    }
81}
82
83impl Mul<Count> for u64 {
84    type Output = Count;
85
86    /// Scales the `Count` by a `u64`.
87    fn mul(self, other: Count) -> Self::Output {
88        other * self
89    }
90}
91
92impl Display for Count {
93    fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
94        write!(f, "Constants: {}, Public: {}, Private: {}, Constraints: {}", self.0, self.1, self.2, self.3)
95    }
96}
97
98/// A `Measurement` is a quantity that can be measured.
99/// The variants of the `Measurement` defines a condition associated with the measurable quantity.
100#[derive(Copy, Clone, Debug, PartialEq, Eq)]
101pub enum Measurement<V: Copy + Debug + Display + Ord + Add<Output = V> + Sub<Output = V> + Mul<Output = V>> {
102    Exact(V),
103    Range(V, V),
104    UpperBound(V),
105}
106
107impl<V: Copy + Debug + Display + Ord + Add<Output = V> + Sub<Output = V> + Mul<Output = V>> Measurement<V> {
108    /// Returns `true` if the value matches the metric.
109    ///
110    /// For an `Exact` metric, `value` must be equal to the exact value defined by the metric.
111    /// For a `Range` metric, `value` must be satisfy lower bound and the upper bound.
112    /// For an `UpperBound` metric, `value` must be satisfy the upper bound.
113    pub fn matches(&self, candidate: V) -> bool {
114        let outcome = match self {
115            Measurement::Exact(expected) => *expected == candidate,
116            Measurement::Range(lower, upper) => candidate >= *lower && candidate <= *upper,
117            Measurement::UpperBound(bound) => candidate <= *bound,
118        };
119
120        if !outcome {
121            eprintln!("Metrics claims the count should be {self:?}, found {candidate:?} during synthesis");
122        }
123
124        outcome
125    }
126}
127
128impl<V: Copy + Debug + Display + Ord + Add<Output = V> + Sub<Output = V> + Mul<Output = V>> Add for Measurement<V> {
129    type Output = Measurement<V>;
130
131    /// Adds two variants of `Measurement` together, returning the newly-summed `Measurement`.
132    fn add(self, other: Measurement<V>) -> Self::Output {
133        match (self, other) {
134            // `Exact` + `Exact` => `Exact`
135            (Measurement::Exact(exact_a), Measurement::Exact(exact_b)) => Measurement::Exact(exact_a + exact_b),
136            // `Range` + `Range` => `Range`
137            (Measurement::Range(lower_a, upper_a), Measurement::Range(lower_b, upper_b)) => {
138                Measurement::Range(lower_a + lower_b, upper_a + upper_b)
139            }
140            // `UpperBound` + `UpperBound` => `UpperBound`
141            (Measurement::UpperBound(upper_a), Measurement::UpperBound(upper_b)) => {
142                Measurement::UpperBound(upper_a + upper_b)
143            }
144            // `Exact` + `Range` => `Range`
145            // `Range` + `Exact` => `Range`
146            (Measurement::Exact(exact), Measurement::Range(lower, upper))
147            | (Measurement::Range(lower, upper), Measurement::Exact(exact)) => {
148                Measurement::Range(exact + lower, exact + upper)
149            }
150            // `Exact` + `UpperBound` => `UpperBound`
151            // `UpperBound` + `Exact` => `UpperBound`
152            (Measurement::Exact(exact), Measurement::UpperBound(upper))
153            | (Measurement::UpperBound(upper), Measurement::Exact(exact)) => Measurement::UpperBound(exact + upper),
154            // `Range` + `UpperBound` => `Range`
155            // `UpperBound` + `Range` => `Range`
156            (Measurement::Range(lower, upper_a), Measurement::UpperBound(upper_b))
157            | (Measurement::UpperBound(upper_a), Measurement::Range(lower, upper_b)) => {
158                Measurement::Range(lower, upper_a + upper_b)
159            }
160        }
161    }
162}
163
164impl<V: Copy + Debug + Display + Ord + Add<Output = V> + Sub<Output = V> + Mul<Output = V>> Mul<V> for Measurement<V> {
165    type Output = Measurement<V>;
166
167    /// Scales the `Measurement` by a value.
168    fn mul(self, other: V) -> Self::Output {
169        match self {
170            Measurement::Exact(value) => Measurement::Exact(value * other),
171            Measurement::Range(lower, upper) => Measurement::Range(lower * other, upper * other),
172            Measurement::UpperBound(bound) => Measurement::UpperBound(bound * other),
173        }
174    }
175}
176
177impl<V: Copy + Debug + Display + Ord + Add<Output = V> + Sub<Output = V> + Mul<Output = V>> Display for Measurement<V> {
178    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
179        match self {
180            Measurement::Exact(value) => write!(f, "{value}"),
181            Measurement::Range(lower, upper) => write!(f, "[{lower}, {upper}]"),
182            Measurement::UpperBound(bound) => write!(f, "<={bound}"),
183        }
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use super::*;
190    use snarkvm_utilities::{TestRng, Uniform};
191
192    const ITERATIONS: u64 = 1024;
193
194    #[test]
195    fn test_exact_matches() {
196        let mut rng = TestRng::default();
197
198        for _ in 0..ITERATIONS {
199            // Generate a random `Measurement` and candidate value.
200            let value = u32::rand(&mut rng) as u64;
201            let candidate = u32::rand(&mut rng) as u64;
202            let metric = Measurement::Exact(value);
203
204            // Check that the metric is only satisfied if the candidate is equal to the value.
205            assert!(metric.matches(value));
206            if candidate == value {
207                assert!(metric.matches(candidate));
208            } else {
209                assert!(!metric.matches(candidate));
210            }
211        }
212    }
213
214    #[test]
215    fn test_upper_matches() {
216        let mut rng = TestRng::default();
217
218        for _ in 0..ITERATIONS {
219            // Generate a random `Measurement::UpperBound` and candidate value.
220            let upper = u32::rand(&mut rng) as u64;
221            let candidate = u32::rand(&mut rng) as u64;
222            let metric = Measurement::UpperBound(upper);
223
224            // Check that the metric is only satisfied if the candidate is less than upper.
225            assert!(metric.matches(upper));
226            if candidate <= upper {
227                assert!(metric.matches(candidate));
228            } else {
229                assert!(!metric.matches(candidate));
230            }
231        }
232    }
233
234    #[test]
235    fn test_range_matches() {
236        let mut rng = TestRng::default();
237
238        for _ in 0..ITERATIONS {
239            // Generate a random `Measurement::UpperBound` and candidate value.
240            let first_bound = u32::rand(&mut rng) as u64;
241            let second_bound = u32::rand(&mut rng) as u64;
242            let candidate = u32::rand(&mut rng) as u64;
243            let (metric, lower, upper) = if first_bound <= second_bound {
244                (Measurement::Range(first_bound, second_bound), first_bound, second_bound)
245            } else {
246                (Measurement::Range(second_bound, first_bound), second_bound, first_bound)
247            };
248
249            // Check that the metric is only satisfied if the candidate is less than upper.
250            assert!(metric.matches(lower));
251            assert!(metric.matches(upper));
252            if lower <= candidate && candidate <= upper {
253                assert!(metric.matches(candidate));
254            } else {
255                assert!(!metric.matches(candidate));
256            }
257        }
258    }
259
260    // Test addition.
261
262    #[test]
263    fn test_exact_plus_exact() {
264        let mut rng = TestRng::default();
265
266        for _ in 0..ITERATIONS {
267            let first = u32::rand(&mut rng) as u64;
268            let second = u32::rand(&mut rng) as u64;
269            let candidate = u32::rand(&mut rng) as u64;
270
271            let a = Measurement::Exact(first);
272            let b = Measurement::Exact(second);
273            let c = a + b;
274
275            assert!(c.matches(first + second));
276            if candidate == first + second {
277                assert!(c.matches(candidate));
278            } else {
279                assert!(!c.matches(candidate));
280            }
281        }
282    }
283
284    #[test]
285    fn test_exact_plus_upper() {
286        let mut rng = TestRng::default();
287
288        for _ in 0..ITERATIONS {
289            let first = u32::rand(&mut rng) as u64;
290            let second = u32::rand(&mut rng) as u64;
291            let candidate = u32::rand(&mut rng) as u64;
292
293            let a = Measurement::Exact(first);
294            let b = Measurement::UpperBound(second);
295            let c = a + b;
296
297            assert!(c.matches(first + second));
298            if candidate <= first + second {
299                assert!(c.matches(candidate));
300            } else {
301                assert!(!c.matches(candidate));
302            }
303        }
304    }
305
306    #[test]
307    fn test_exact_plus_range() {
308        let mut rng = TestRng::default();
309
310        let value = u32::rand(&mut rng) as u64;
311        let first_bound = u32::rand(&mut rng) as u64;
312        let second_bound = u32::rand(&mut rng) as u64;
313        let candidate = u32::rand(&mut rng) as u64;
314
315        let a = Measurement::Exact(value);
316        let (b, lower, upper) = if first_bound <= second_bound {
317            (Measurement::Range(first_bound, second_bound), first_bound, second_bound)
318        } else {
319            (Measurement::Range(second_bound, first_bound), second_bound, first_bound)
320        };
321        let c = a + b;
322
323        assert!(c.matches(value + lower));
324        assert!(c.matches(value + upper));
325        if value + lower <= candidate && candidate <= value + upper {
326            assert!(c.matches(candidate));
327        } else {
328            assert!(!c.matches(candidate));
329        }
330    }
331
332    #[test]
333    fn test_range_plus_exact() {
334        let mut rng = TestRng::default();
335
336        let value = u32::rand(&mut rng) as u64;
337        let first_bound = u32::rand(&mut rng) as u64;
338        let second_bound = u32::rand(&mut rng) as u64;
339        let candidate = u32::rand(&mut rng) as u64;
340
341        let (a, lower, upper) = if first_bound <= second_bound {
342            (Measurement::Range(first_bound, second_bound), first_bound, second_bound)
343        } else {
344            (Measurement::Range(second_bound, first_bound), second_bound, first_bound)
345        };
346        let b = Measurement::Exact(value);
347        let c = a + b;
348
349        assert!(c.matches(value + lower));
350        assert!(c.matches(value + upper));
351        if value + lower <= candidate && candidate <= value + upper {
352            assert!(c.matches(candidate));
353        } else {
354            assert!(!c.matches(candidate));
355        }
356    }
357
358    #[test]
359    fn test_range_plus_range() {
360        let mut rng = TestRng::default();
361
362        for _ in 0..ITERATIONS {
363            let first = u32::rand(&mut rng) as u64;
364            let second = u32::rand(&mut rng) as u64;
365            let third = u32::rand(&mut rng) as u64;
366            let fourth = u32::rand(&mut rng) as u64;
367            let candidate = u32::rand(&mut rng) as u64;
368
369            let (a, first_lower, first_upper) = if first <= second {
370                (Measurement::Range(first, second), first, second)
371            } else {
372                (Measurement::Range(second, first), second, first)
373            };
374            let (b, second_lower, second_upper) = if third <= fourth {
375                (Measurement::Range(third, fourth), third, fourth)
376            } else {
377                (Measurement::Range(fourth, third), fourth, third)
378            };
379            let c = a + b;
380
381            assert!(c.matches(first_lower + second_lower));
382            assert!(c.matches(first_upper + second_upper));
383            if first_lower + second_lower <= candidate && candidate <= first_upper + second_upper {
384                assert!(c.matches(candidate));
385            } else {
386                assert!(!c.matches(candidate));
387            }
388        }
389    }
390
391    #[test]
392    fn test_range_plus_upper() {
393        let mut rng = TestRng::default();
394
395        for _ in 0..ITERATIONS {
396            let first = u32::rand(&mut rng) as u64;
397            let second = u32::rand(&mut rng) as u64;
398            let third = u32::rand(&mut rng) as u64;
399            let candidate = u32::rand(&mut rng) as u64;
400
401            let (a, lower, upper) = if second <= third {
402                (Measurement::Range(second, third), second, third)
403            } else {
404                (Measurement::Range(third, second), third, second)
405            };
406            let b = Measurement::UpperBound(first);
407            let c = a + b;
408
409            assert!(c.matches(lower));
410            assert!(c.matches(first + upper));
411            if lower <= candidate && candidate <= first + upper {
412                assert!(c.matches(candidate));
413            } else {
414                assert!(!c.matches(candidate));
415            }
416        }
417    }
418
419    #[test]
420    fn test_upper_plus_exact() {
421        let mut rng = TestRng::default();
422
423        for _ in 0..ITERATIONS {
424            let first = u32::rand(&mut rng) as u64;
425            let second = u32::rand(&mut rng) as u64;
426            let candidate = u32::rand(&mut rng) as u64;
427
428            let a = Measurement::UpperBound(second);
429            let b = Measurement::Exact(first);
430            let c = a + b;
431
432            assert!(c.matches(first + second));
433            if candidate <= first + second {
434                assert!(c.matches(candidate));
435            } else {
436                assert!(!c.matches(candidate));
437            }
438        }
439    }
440
441    #[test]
442    fn test_upper_plus_range() {
443        let mut rng = TestRng::default();
444
445        for _ in 0..ITERATIONS {
446            let first = u32::rand(&mut rng) as u64;
447            let second = u32::rand(&mut rng) as u64;
448            let third = u32::rand(&mut rng) as u64;
449            let candidate = u32::rand(&mut rng) as u64;
450
451            let a = Measurement::UpperBound(first);
452            let (b, lower, upper) = if second <= third {
453                (Measurement::Range(second, third), second, third)
454            } else {
455                (Measurement::Range(third, second), third, second)
456            };
457            let c = a + b;
458
459            assert!(c.matches(lower));
460            assert!(c.matches(first + upper));
461            if lower <= candidate && candidate <= first + upper {
462                assert!(c.matches(candidate));
463            } else {
464                assert!(!c.matches(candidate));
465            }
466        }
467    }
468
469    #[test]
470    fn test_upper_plus_upper() {
471        let mut rng = TestRng::default();
472
473        for _ in 0..ITERATIONS {
474            let first = u32::rand(&mut rng) as u64;
475            let second = u32::rand(&mut rng) as u64;
476            let candidate = u32::rand(&mut rng) as u64;
477
478            let a = Measurement::UpperBound(second);
479            let b = Measurement::UpperBound(first);
480            let c = a + b;
481
482            assert!(c.matches(first + second));
483            if candidate <= first + second {
484                assert!(c.matches(candidate));
485            } else {
486                assert!(!c.matches(candidate));
487            }
488        }
489    }
490
491    // Test multiplication.
492
493    #[test]
494    fn test_exact_mul() {
495        let mut rng = TestRng::default();
496
497        for _ in 0..ITERATIONS {
498            let start = u32::rand(&mut rng) as u64;
499            let scalar = u32::rand(&mut rng) as u64;
500
501            let expected = Measurement::Exact(start * scalar);
502            let candidate = Measurement::Exact(start) * scalar;
503            assert_eq!(candidate, expected);
504        }
505    }
506
507    #[test]
508    fn test_upper_bound_mul() {
509        let mut rng = TestRng::default();
510
511        for _ in 0..ITERATIONS {
512            let start = u32::rand(&mut rng) as u64;
513            let scalar = u32::rand(&mut rng) as u64;
514
515            let expected = Measurement::UpperBound(start * scalar);
516            let candidate = Measurement::UpperBound(start) * scalar;
517            assert_eq!(candidate, expected);
518        }
519    }
520
521    #[test]
522    fn test_range_mul() {
523        let mut rng = TestRng::default();
524
525        for _ in 0..ITERATIONS {
526            let start = u32::rand(&mut rng) as u64;
527            let end = u32::rand(&mut rng) as u64;
528            let scalar = u32::rand(&mut rng) as u64;
529
530            let expected = Measurement::Range(start * scalar, end * scalar);
531            let candidate = Measurement::Range(start, end) * scalar;
532            assert_eq!(candidate, expected);
533        }
534    }
535}