snarkvm_circuit_environment/helpers/
updatable_count.rs

1// Copyright 2024-2025 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Constant, Constraints, Measurement, Private, Public};
17
18use core::fmt::Debug;
19use once_cell::sync::{Lazy, OnceCell};
20use std::{
21    cmp::Ordering,
22    collections::{BTreeSet, HashMap},
23    env,
24    fmt::Display,
25    fs,
26    ops::Range,
27    path::{Path, PathBuf},
28    sync::Mutex,
29};
30
31static FILES: Lazy<Mutex<HashMap<&'static str, FileUpdates>>> = Lazy::new(Default::default);
32static WORKSPACE_ROOT: OnceCell<PathBuf> = OnceCell::new();
33
34/// To update the arguments to `count_is!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
35/// e.g. `UPDATE_COUNT=boolean cargo test
36/// See https://github.com/ProvableHQ/snarkVM/pull/1688 for more details.
37#[macro_export]
38macro_rules! count_is {
39    ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
40        $crate::UpdatableCount {
41            constant: $crate::Measurement::Exact($num_constants),
42            public: $crate::Measurement::Exact($num_public),
43            private: $crate::Measurement::Exact($num_private),
44            constraints: $crate::Measurement::Exact($num_constraints),
45            file: file!(),
46            line: line!(),
47            column: column!(),
48        }
49    };
50}
51
52/// To update the arguments to `count_less_than!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
53/// e.g. `UPDATE_COUNT=boolean cargo test
54/// See https://github.com/ProvableHQ/snarkVM/pull/1688 for more details.
55#[macro_export]
56macro_rules! count_less_than {
57    ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
58        $crate::UpdatableCount {
59            constant: $crate::Measurement::UpperBound($num_constants),
60            public: $crate::Measurement::UpperBound($num_public),
61            private: $crate::Measurement::UpperBound($num_private),
62            constraints: $crate::Measurement::UpperBound($num_constraints),
63            file: file!(),
64            line: line!(),
65            column: column!(),
66        }
67    };
68}
69
70/// A helper struct for tracking the number of constants, public inputs, private inputs, and constraints.
71/// Warning: Do not construct this struct directly. Instead, use the `count_is!` and `count_less_than!` macros.
72#[derive(Copy, Clone, Debug)]
73pub struct UpdatableCount {
74    pub constant: Constant,
75    pub public: Public,
76    pub private: Private,
77    pub constraints: Constraints,
78    #[doc(hidden)]
79    pub file: &'static str,
80    #[doc(hidden)]
81    pub line: u32,
82    #[doc(hidden)]
83    pub column: u32,
84}
85
86impl Display for UpdatableCount {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(
89            f,
90            "Constants: {}, Public: {}, Private: {}, Constraints: {}",
91            self.constant, self.public, self.private, self.constraints
92        )
93    }
94}
95
96impl UpdatableCount {
97    /// Returns `true` if the values matches the `Measurement`s in `UpdatableCount`.
98    ///
99    /// For an `Exact` metric, `value` must be equal to the exact value defined by the metric.
100    /// For a `Range` metric, `value` must be satisfy lower bound and the upper bound.
101    /// For an `UpperBound` metric, `value` must be satisfy the upper bound.
102    pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
103        self.constant.matches(num_constants)
104            && self.public.matches(num_public)
105            && self.private.matches(num_private)
106            && self.constraints.matches(num_constraints)
107    }
108
109    /// If all values match, do nothing.
110    /// If all values metrics do not match:
111    ///    - If the update condition is satisfied, then update the macro invocation that constructed this `UpdatableCount`.
112    ///    - Otherwise, panic.
113    pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
114        if !self.matches(num_constants, num_public, num_private, num_constraints) {
115            let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
116            match env::var("UPDATE_COUNT") {
117                // If `UPDATE_COUNT` is set and the `query_string` matches the file containing the macro invocation
118                // that constructed this `UpdatableCount`, then update the macro invocation.
119                Ok(query_string) if self.file.contains(&query_string) => {
120                    files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
121                        self,
122                        num_constants,
123                        num_public,
124                        num_private,
125                        num_constraints,
126                    );
127                }
128                // Otherwise, error.
129                _ => {
130                    println!(
131                        "\n
132\x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
133   \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
134\x1b[1mExpected\x1b[0m:
135----
136{}
137----
138\x1b[1mActual\x1b[0m:
139----
140Constants: {}, Public: {}, Private: {}, Constraints: {}
141----
142",
143                        self.file,
144                        self.line,
145                        self.column,
146                        self,
147                        num_constants,
148                        num_public,
149                        num_private,
150                        num_constraints,
151                    );
152                    // Use resume_unwind instead of panic!() to prevent a backtrace, which is unnecessary noise.
153                    snarkvm_utilities::panic::resume_unwind(Box::new(()));
154                }
155            }
156        }
157    }
158
159    /// Given a string containing the contents of a file, `locate` returns a range delimiting the arguments
160    /// to the macro invocation that constructed this `UpdatableCount`.
161    /// The beginning of the range corresponds to the opening parenthesis of the macro invocation.
162    /// The end of the range corresponds to the closing parenthesis of the macro invocation.
163    /// ```ignore
164    ///              count_is!(0, 1, 2, 3)
165    /// ```                   ^          ^
166    ///           starting_index     ending_index
167    ///
168    /// Note: This function must always invoked with the file contents of the same file as the macro invocation.
169    fn locate(&self, file: &str) -> Range<usize> {
170        // `line_start` is the absolute byte offset from the beginning of the file to the beginning of the current line.
171        let mut line_start = 0;
172        let mut starting_index = None;
173        let mut ending_index = None;
174        for (i, line) in LinesWithEnds::from(file).enumerate() {
175            if i == self.line as usize - 1 {
176                // Seek past the exclamation point, then skip any whitespace and the macro delimiter to get to the opening parentheses.
177                let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
178                    .skip_while(|&(_, c)| c != '!') // Skip up to the exclamation point.
179                    .skip(1) // Skip `!`.
180                    .skip_while(|(_, c)| c.is_whitespace()); // Skip any whitespace.
181
182                // Set `starting_index` to the absolute position of the opening parenthesis in `file`.
183                starting_index = Some(
184                    line_start
185                        + argument_character_indices
186                            .next()
187                            .expect("Could not find the beginning of the macro invocation.")
188                            .0,
189                );
190            }
191
192            if starting_index.is_some() {
193                // At this point, we have found the opening parentheses, so we continue to skip all characters until the closing parentheses.
194                match line.char_indices().find(|&(_, c)| c == ')') {
195                    None => (), // Do nothing. This means that the closing parentheses was not found on the same line as the opening parentheses.
196                    Some((offset, _)) => {
197                        // Note that the `+ 1` is to account for the fact that `std::ops::Range` is exclusive on the upper bound.
198                        ending_index = Some(line_start + offset + 1);
199                        break;
200                    }
201                }
202            }
203            line_start += line.len();
204        }
205
206        Range {
207            start: starting_index.expect("Could not find the beginning of the macro invocation."),
208            end: ending_index.expect("Could not find the ending of the macro invocation."),
209        }
210    }
211
212    /// Computes the difference between the number of constants, public, private, and constraints of `self` and those of `other`.
213    pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
214        let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
215            (Measurement::Exact(self_value), Measurement::Exact(other_value))
216            | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
217                // Note: This assumes that the number of constants, public, private, and constraints do not exceed `i64::MAX`.
218                (self_value as i64) - (other_value as i64)
219            }
220            _ => panic!(
221                "Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
222            ),
223        };
224        (
225            difference(self.constant, other.constant),
226            difference(self.public, other.public),
227            difference(self.private, other.private),
228            difference(self.constraints, other.constraints),
229        )
230    }
231
232    /// Initializes an `UpdatableCount` without a specified location.
233    /// This is only used to store intermediate counts as the source file is updated.
234    fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
235        Self {
236            constant,
237            public,
238            private,
239            constraints,
240            file: Default::default(),
241            line: Default::default(),
242            column: Default::default(),
243        }
244    }
245
246    /// Returns a string that is intended to replace the arguments to `count_is` or `count_less_than` in the source file.
247    fn as_argument_string(&self) -> String {
248        let generate_arg = |measurement| match measurement {
249            Measurement::Exact(value) => value,
250            Measurement::UpperBound(bound) => bound,
251            Measurement::Range(..) => panic!(
252                "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
253            ),
254        };
255        format!(
256            "({}, {}, {}, {})",
257            generate_arg(self.constant),
258            generate_arg(self.public),
259            generate_arg(self.private),
260            generate_arg(self.constraints)
261        )
262    }
263}
264
265/// This struct is used to track updates to the `UpdatableCount`s in a file.
266/// It is used to ensure that the updates are written to the appropriate location in the file as the file is modified.
267/// This design avoids having to re-read the source file in the event that it has been modified.
268struct FileUpdates {
269    absolute_path: PathBuf,
270    original_text: String,
271    modified_text: String,
272    /// An ordered set of `Update`s.
273    /// `Update`s are ordered by their starting location.
274    /// We assume that all `Updates` are made to disjoint ranges in the original file.
275    /// This assumption is valid since invocations of `count_is` and `count_less_than` cannot be nested.
276    updates: BTreeSet<Update>,
277}
278
279impl FileUpdates {
280    /// Initializes a new instance of `FileUpdates`.
281    /// This function will read the contents of the file at the specified path and store it in the `original_text` field.
282    /// This function will also initialize the `updates` field to an empty vector.
283    fn new(count: &UpdatableCount) -> Self {
284        let path = Path::new(count.file);
285        let absolute_path = match path.is_absolute() {
286            true => path.to_owned(),
287            false => {
288                // Append `path` to the workspace root.
289                WORKSPACE_ROOT
290                    .get_or_try_init(|| {
291                        // Heuristic, see https://github.com/rust-lang/cargo/issues/3946
292                        let workspace_root = Path::new(&env!("CARGO_MANIFEST_DIR"))
293                            .ancestors()
294                            .filter(|it| it.join("Cargo.toml").exists())
295                            .last()
296                            .unwrap()
297                            .to_path_buf();
298
299                        Ok(workspace_root)
300                    })
301                    .unwrap_or_else(|_: env::VarError| {
302                        panic!("No CARGO_MANIFEST_DIR env var and the path is relative: {}", path.display())
303                    })
304                    .join(path)
305            }
306        };
307        let original_text = fs::read_to_string(&absolute_path).unwrap();
308        let modified_text = original_text.clone();
309        let updates = Default::default();
310        Self { absolute_path, original_text, modified_text, updates }
311    }
312
313    /// This function will update the `modified_text` field with the new text that is being inserted.
314    /// The resulting `modified_text` is written to the file at the specified path.
315    /// This implementation allows us to avoid re-reading the source file in the case where multiple updates
316    /// are being made to the same location in the source code.
317    fn update_count(
318        &mut self,
319        count: &UpdatableCount,
320        num_constants: u64,
321        num_public: u64,
322        num_private: u64,
323        num_constraints: u64,
324    ) {
325        // Get the location of arguments in the macro invocation.
326        let range = count.locate(&self.original_text);
327
328        let mut new_range = range.clone();
329        let mut update_with_same_start = None;
330
331        // Shift the range to account for changes made to the original file.
332        // Note that the `Update`s in self.updates are ordered by their starting location.
333        for previous_update in &self.updates {
334            let amount_deleted = previous_update.end - previous_update.start;
335            let amount_inserted = previous_update.argument_string.len();
336
337            match previous_update.start.cmp(&range.start) {
338                // If an update was made in a location preceding the range in the original file, we need to shift the range by the length of the text that was changed.
339                Ordering::Less => {
340                    new_range.start = new_range.start - amount_deleted + amount_inserted;
341                    new_range.end = new_range.end - amount_deleted + amount_inserted;
342                }
343                // If an update was made at the same location as the range in the original file, we need to shift the end of the range by the amount of text that was changed.
344                Ordering::Equal => {
345                    new_range.end = new_range.end - amount_deleted + amount_inserted;
346                    update_with_same_start = Some(previous_update);
347                }
348                // We do not need to shift the range if an update was made in a location following the range in the original file.
349                Ordering::Greater => {
350                    break;
351                }
352            }
353        }
354
355        // If the original `UpdatableCount` has been modified, then check if the counts satisfy the most recent `UpdatableCount`.
356        // If so, then there is no need to write to update the file and we can return early.
357        if let Some(update) = update_with_same_start {
358            if update.count.matches(num_constants, num_public, num_private, num_constraints) {
359                return;
360            }
361        }
362
363        // Construct the new update.
364        let new_update = match update_with_same_start {
365            None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
366            Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
367        };
368
369        // Apply the update at the adjusted location.
370        self.modified_text.replace_range(new_range, &new_update.argument_string);
371
372        // Print the difference between the original and updated counts.
373        let difference = new_update.count.difference_between(count);
374        println!(
375            "\n
376\x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
377   \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
378\x1b[1mOriginal count\x1b[0m:
379----
380{}
381----
382\x1b[1mUpdated count\x1b[0m:
383----
384{}
385----
386\x1b[1mDifference between updated and original\x1b[0m:
387----
388Constants: {}, Public: {}, Private: {}, Constraints: {}
389----
390",
391            count.file,
392            count.line,
393            count.column,
394            count,
395            new_update.count,
396            difference.0,
397            difference.1,
398            difference.2,
399            difference.3
400        );
401
402        // Add the new update to the set of updates.
403        self.updates.replace(new_update);
404
405        // Update the original file with the modified text.
406        fs::write(&self.absolute_path, &self.modified_text).unwrap()
407    }
408}
409
410/// Helper struct to keep track of updates to the original file.
411#[derive(Debug)]
412struct Update {
413    /// Starting location in the original file.
414    start: usize,
415    /// Ending location in the original file.
416    end: usize,
417    /// A dummy count with the new `Measurement`s.
418    count: UpdatableCount,
419    /// A string representation of `count`.
420    argument_string: String,
421}
422
423impl Update {
424    fn new(
425        range: &Range<usize>,
426        old_count: &UpdatableCount,
427        num_constants: u64,
428        num_public: u64,
429        num_private: u64,
430        num_constraints: u64,
431    ) -> Self {
432        // Helper function to determine the new `Measurement` based on the expected value.
433        let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
434            Measurement::Exact(..) => Measurement::Exact(expected),
435            Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
436            Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
437        };
438        let count = UpdatableCount::dummy(
439            generate_new_measurement(old_count.constant, num_constants),
440            generate_new_measurement(old_count.public, num_public),
441            generate_new_measurement(old_count.private, num_private),
442            generate_new_measurement(old_count.constraints, num_constraints),
443        );
444        Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
445    }
446}
447
448impl PartialEq for Update {
449    fn eq(&self, other: &Self) -> bool {
450        self.start == other.start
451    }
452}
453impl Eq for Update {}
454impl PartialOrd for Update {
455    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
456        Some(self.cmp(other))
457    }
458}
459impl Ord for Update {
460    fn cmp(&self, other: &Self) -> Ordering {
461        self.start.cmp(&other.start)
462    }
463}
464
465/// A struct that provides an iterator over the lines in a string, while preserving the original line endings.
466/// This is necessary as `str::lines` does not preserve the original line endings.
467struct LinesWithEnds<'a> {
468    text: &'a str,
469}
470
471impl<'a> Iterator for LinesWithEnds<'a> {
472    type Item = &'a str;
473
474    fn next(&mut self) -> Option<&'a str> {
475        match self.text.is_empty() {
476            true => None,
477            false => {
478                let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
479                let (res, next) = self.text.split_at(idx);
480                self.text = next;
481                Some(res)
482            }
483        }
484    }
485}
486
487impl<'a> From<&'a str> for LinesWithEnds<'a> {
488    fn from(text: &'a str) -> Self {
489        LinesWithEnds { text }
490    }
491}
492
493#[cfg(test)]
494mod test {
495    use serial_test::serial;
496    use std::env;
497
498    #[test]
499    fn check_position() {
500        let count = count_is!(0, 0, 0, 0);
501        assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
502        assert_eq!(count.line, 500);
503        assert_eq!(count.column, 21);
504    }
505
506    // Note: The below tests must be run serially since the behavior `assert_matches` depends on whether or not
507    // the environment variable `UPDATE_COUNT` is set.
508
509    #[test]
510    #[serial]
511    fn check_count_passes() {
512        let count = count_is!(1, 2, 3, 4);
513        let num_constants = 1;
514        let num_public = 2;
515        let num_private = 3;
516        let num_inputs = 4;
517        count.assert_matches(num_constants, num_public, num_private, num_inputs);
518    }
519
520    #[test]
521    #[serial]
522    #[should_panic]
523    fn check_count_fails() {
524        let count = count_is!(1, 2, 3, 4);
525        let num_constants = 5;
526        let num_public = 6;
527        let num_private = 7;
528        let num_inputs = 8;
529
530        count.assert_matches(num_constants, num_public, num_private, num_inputs);
531    }
532
533    #[test]
534    #[serial]
535    #[should_panic]
536    fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
537        let count = count_is!(1, 2, 3, 4);
538        let num_constants = 5;
539        let num_public = 6;
540        let num_private = 7;
541        let num_inputs = 8;
542
543        // Set the environment variable to update the file.
544        env::set_var("UPDATE_COUNT", "1");
545
546        count.assert_matches(num_constants, num_public, num_private, num_inputs);
547
548        env::remove_var("UPDATE_COUNT");
549    }
550
551    #[test]
552    #[serial]
553    fn check_count_updates_correctly() {
554        // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
555        let count = count_is!(11, 12, 13, 14);
556        let num_constants = 11;
557        let num_public = 12;
558        let num_private = 13;
559        let num_inputs = 14;
560
561        // Set the environment variable to update the file.
562        env::set_var("UPDATE_COUNT", "updatable_count.rs");
563
564        count.assert_matches(num_constants, num_public, num_private, num_inputs);
565
566        env::remove_var("UPDATE_COUNT");
567    }
568
569    #[test]
570    #[serial]
571    fn check_count_updates_correctly_multiple_times() {
572        // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
573        let count = count_is!(17, 18, 19, 20);
574
575        env::set_var("UPDATE_COUNT", "updatable_count.rs");
576
577        let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
578        count.assert_matches(num_constants, num_public, num_private, num_inputs);
579
580        let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
581        count.assert_matches(num_constants, num_public, num_private, num_inputs);
582
583        let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
584        count.assert_matches(num_constants, num_public, num_private, num_inputs);
585
586        let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
587        count.assert_matches(num_constants, num_public, num_private, num_inputs);
588
589        env::remove_var("UPDATE_COUNT");
590    }
591
592    #[test]
593    #[serial]
594    fn check_count_less_than_selects_maximum() {
595        // `count` is initially `count_less_than!(1, 2, 3, 4)`.
596        // After counts are updated, `original_count` is `count_less_than!(17, 18, 19, 20)`.
597        // In other words, count is updated to be the maximum of the original and updated counts.
598        let count = count_less_than!(17, 18, 19, 20);
599
600        // Set the environment variable to update the file.
601        env::set_var("UPDATE_COUNT", "updatable_count.rs");
602
603        let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
604        count.assert_matches(num_constants, num_public, num_private, num_inputs);
605
606        let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
607        count.assert_matches(num_constants, num_public, num_private, num_inputs);
608
609        let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
610        count.assert_matches(num_constants, num_public, num_private, num_inputs);
611
612        let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
613        count.assert_matches(num_constants, num_public, num_private, num_inputs);
614
615        env::remove_var("UPDATE_COUNT");
616    }
617}