snarkvm_circuit_environment/helpers/
updatable_count.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Constant, Constraints, Measurement, Private, Public};
17
18use core::fmt::Debug;
19use std::{
20    cmp::Ordering,
21    collections::{BTreeSet, HashMap},
22    env,
23    fmt::Display,
24    fs,
25    ops::Range,
26    path::{Path, PathBuf},
27    sync::{LazyLock, Mutex, OnceLock},
28};
29
30static FILES: LazyLock<Mutex<HashMap<&'static str, FileUpdates>>> = LazyLock::new(Default::default);
31static WORKSPACE_ROOT: OnceLock<PathBuf> = OnceLock::new();
32
33/// To update the arguments to `count_is!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
34/// e.g. `UPDATE_COUNT=boolean cargo test
35/// See <https://github.com/ProvableHQ/snarkVM/pull/1688> for more details.
36#[macro_export]
37macro_rules! count_is {
38    ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
39        $crate::UpdatableCount {
40            constant: $crate::Measurement::Exact($num_constants),
41            public: $crate::Measurement::Exact($num_public),
42            private: $crate::Measurement::Exact($num_private),
43            constraints: $crate::Measurement::Exact($num_constraints),
44            file: file!(),
45            line: line!(),
46            column: column!(),
47        }
48    };
49}
50
51/// To update the arguments to `count_less_than!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
52/// e.g. `UPDATE_COUNT=boolean cargo test
53/// See <https://github.com/ProvableHQ/snarkVM/pull/1688> for more details.
54#[macro_export]
55macro_rules! count_less_than {
56    ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
57        $crate::UpdatableCount {
58            constant: $crate::Measurement::UpperBound($num_constants),
59            public: $crate::Measurement::UpperBound($num_public),
60            private: $crate::Measurement::UpperBound($num_private),
61            constraints: $crate::Measurement::UpperBound($num_constraints),
62            file: file!(),
63            line: line!(),
64            column: column!(),
65        }
66    };
67}
68
69/// A helper struct for tracking the number of constants, public inputs, private inputs, and constraints.
70/// Warning: Do not construct this struct directly. Instead, use the `count_is!` and `count_less_than!` macros.
71#[derive(Copy, Clone, Debug)]
72pub struct UpdatableCount {
73    pub constant: Constant,
74    pub public: Public,
75    pub private: Private,
76    pub constraints: Constraints,
77    #[doc(hidden)]
78    pub file: &'static str,
79    #[doc(hidden)]
80    pub line: u32,
81    #[doc(hidden)]
82    pub column: u32,
83}
84
85impl Display for UpdatableCount {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        write!(
88            f,
89            "Constants: {}, Public: {}, Private: {}, Constraints: {}",
90            self.constant, self.public, self.private, self.constraints
91        )
92    }
93}
94
95impl UpdatableCount {
96    /// Returns `true` if the values matches the `Measurement`s in `UpdatableCount`.
97    ///
98    /// For an `Exact` metric, `value` must be equal to the exact value defined by the metric.
99    /// For a `Range` metric, `value` must be satisfy lower bound and the upper bound.
100    /// For an `UpperBound` metric, `value` must be satisfy the upper bound.
101    pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
102        self.constant.matches(num_constants)
103            && self.public.matches(num_public)
104            && self.private.matches(num_private)
105            && self.constraints.matches(num_constraints)
106    }
107
108    /// If all values match, do nothing.
109    /// If all values metrics do not match:
110    ///    - If the update condition is satisfied, then update the macro invocation that constructed this `UpdatableCount`.
111    ///    - Otherwise, panic.
112    pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
113        if !self.matches(num_constants, num_public, num_private, num_constraints) {
114            let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
115            match env::var("UPDATE_COUNT") {
116                // If `UPDATE_COUNT` is set and the `query_string` matches the file containing the macro invocation
117                // that constructed this `UpdatableCount`, then update the macro invocation.
118                Ok(query_string) if self.file.contains(&query_string) => {
119                    files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
120                        self,
121                        num_constants,
122                        num_public,
123                        num_private,
124                        num_constraints,
125                    );
126                }
127                // Otherwise, error.
128                _ => {
129                    println!(
130                        "\n
131\x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
132   \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
133\x1b[1mExpected\x1b[0m:
134----
135{}
136----
137\x1b[1mActual\x1b[0m:
138----
139Constants: {}, Public: {}, Private: {}, Constraints: {}
140----
141",
142                        self.file,
143                        self.line,
144                        self.column,
145                        self,
146                        num_constants,
147                        num_public,
148                        num_private,
149                        num_constraints,
150                    );
151                    // Use resume_unwind instead of panic!() to prevent a backtrace, which is unnecessary noise.
152                    std::panic::resume_unwind(Box::new(()));
153                }
154            }
155        }
156    }
157
158    /// Given a string containing the contents of a file, `locate` returns a range delimiting the arguments
159    /// to the macro invocation that constructed this `UpdatableCount`.
160    /// The beginning of the range corresponds to the opening parenthesis of the macro invocation.
161    /// The end of the range corresponds to the closing parenthesis of the macro invocation.
162    /// ```ignore
163    ///              count_is!(0, 1, 2, 3)
164    /// ```                   ^          ^
165    ///           starting_index     ending_index
166    ///
167    /// Note: This function must always invoked with the file contents of the same file as the macro invocation.
168    fn locate(&self, file: &str) -> Range<usize> {
169        // `line_start` is the absolute byte offset from the beginning of the file to the beginning of the current line.
170        let mut line_start = 0;
171        let mut starting_index = None;
172        let mut ending_index = None;
173        for (i, line) in LinesWithEnds::from(file).enumerate() {
174            if i == self.line as usize - 1 {
175                // Seek past the exclamation point, then skip any whitespace and the macro delimiter to get to the opening parentheses.
176                let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
177                    .skip_while(|&(_, c)| c != '!') // Skip up to the exclamation point.
178                    .skip(1) // Skip `!`.
179                    .skip_while(|(_, c)| c.is_whitespace()); // Skip any whitespace.
180
181                // Set `starting_index` to the absolute position of the opening parenthesis in `file`.
182                starting_index = Some(
183                    line_start
184                        + argument_character_indices
185                            .next()
186                            .expect("Could not find the beginning of the macro invocation.")
187                            .0,
188                );
189            }
190
191            if starting_index.is_some() {
192                // At this point, we have found the opening parentheses, so we continue to skip all characters until the closing parentheses.
193                match line.char_indices().find(|&(_, c)| c == ')') {
194                    None => (), // Do nothing. This means that the closing parentheses was not found on the same line as the opening parentheses.
195                    Some((offset, _)) => {
196                        // Note that the `+ 1` is to account for the fact that `std::ops::Range` is exclusive on the upper bound.
197                        ending_index = Some(line_start + offset + 1);
198                        break;
199                    }
200                }
201            }
202            line_start += line.len();
203        }
204
205        Range {
206            start: starting_index.expect("Could not find the beginning of the macro invocation."),
207            end: ending_index.expect("Could not find the ending of the macro invocation."),
208        }
209    }
210
211    /// Computes the difference between the number of constants, public, private, and constraints of `self` and those of `other`.
212    pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
213        let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
214            (Measurement::Exact(self_value), Measurement::Exact(other_value))
215            | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
216                // Note: This assumes that the number of constants, public, private, and constraints do not exceed `i64::MAX`.
217                (self_value as i64) - (other_value as i64)
218            }
219            _ => panic!(
220                "Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
221            ),
222        };
223        (
224            difference(self.constant, other.constant),
225            difference(self.public, other.public),
226            difference(self.private, other.private),
227            difference(self.constraints, other.constraints),
228        )
229    }
230
231    /// Initializes an `UpdatableCount` without a specified location.
232    /// This is only used to store intermediate counts as the source file is updated.
233    fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
234        Self {
235            constant,
236            public,
237            private,
238            constraints,
239            file: Default::default(),
240            line: Default::default(),
241            column: Default::default(),
242        }
243    }
244
245    /// Returns a string that is intended to replace the arguments to `count_is` or `count_less_than` in the source file.
246    fn as_argument_string(&self) -> String {
247        let generate_arg = |measurement| match measurement {
248            Measurement::Exact(value) => value,
249            Measurement::UpperBound(bound) => bound,
250            Measurement::Range(..) => panic!(
251                "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
252            ),
253        };
254        format!(
255            "({}, {}, {}, {})",
256            generate_arg(self.constant),
257            generate_arg(self.public),
258            generate_arg(self.private),
259            generate_arg(self.constraints)
260        )
261    }
262}
263
264/// This struct is used to track updates to the `UpdatableCount`s in a file.
265/// It is used to ensure that the updates are written to the appropriate location in the file as the file is modified.
266/// This design avoids having to re-read the source file in the event that it has been modified.
267struct FileUpdates {
268    absolute_path: PathBuf,
269    original_text: String,
270    modified_text: String,
271    /// An ordered set of `Update`s.
272    /// `Update`s are ordered by their starting location.
273    /// We assume that all `Updates` are made to disjoint ranges in the original file.
274    /// This assumption is valid since invocations of `count_is` and `count_less_than` cannot be nested.
275    updates: BTreeSet<Update>,
276}
277
278impl FileUpdates {
279    /// Initializes a new instance of `FileUpdates`.
280    /// This function will read the contents of the file at the specified path and store it in the `original_text` field.
281    /// This function will also initialize the `updates` field to an empty vector.
282    fn new(count: &UpdatableCount) -> Self {
283        let path = Path::new(count.file);
284        let absolute_path = match path.is_absolute() {
285            true => path.to_owned(),
286            false => {
287                // Append `path` to the workspace root.
288                WORKSPACE_ROOT
289                    .get_or_init(|| {
290                        // Heuristic, see https://github.com/rust-lang/cargo/issues/3946
291                        Path::new(&env!("CARGO_MANIFEST_DIR"))
292                            .ancestors()
293                            .filter(|it| it.join("Cargo.toml").exists())
294                            .last()
295                            .unwrap()
296                            .to_path_buf()
297                    })
298                    .join(path)
299            }
300        };
301        let original_text = fs::read_to_string(&absolute_path).unwrap();
302        let modified_text = original_text.clone();
303        let updates = Default::default();
304        Self { absolute_path, original_text, modified_text, updates }
305    }
306
307    /// This function will update the `modified_text` field with the new text that is being inserted.
308    /// The resulting `modified_text` is written to the file at the specified path.
309    /// This implementation allows us to avoid re-reading the source file in the case where multiple updates
310    /// are being made to the same location in the source code.
311    fn update_count(
312        &mut self,
313        count: &UpdatableCount,
314        num_constants: u64,
315        num_public: u64,
316        num_private: u64,
317        num_constraints: u64,
318    ) {
319        // Get the location of arguments in the macro invocation.
320        let range = count.locate(&self.original_text);
321
322        let mut new_range = range.clone();
323        let mut update_with_same_start = None;
324
325        // Shift the range to account for changes made to the original file.
326        // Note that the `Update`s in self.updates are ordered by their starting location.
327        for previous_update in &self.updates {
328            let amount_deleted = previous_update.end - previous_update.start;
329            let amount_inserted = previous_update.argument_string.len();
330
331            match previous_update.start.cmp(&range.start) {
332                // If an update was made in a location preceding the range in the original file, we need to shift the range by the length of the text that was changed.
333                Ordering::Less => {
334                    new_range.start = new_range.start - amount_deleted + amount_inserted;
335                    new_range.end = new_range.end - amount_deleted + amount_inserted;
336                }
337                // If an update was made at the same location as the range in the original file, we need to shift the end of the range by the amount of text that was changed.
338                Ordering::Equal => {
339                    new_range.end = new_range.end - amount_deleted + amount_inserted;
340                    update_with_same_start = Some(previous_update);
341                }
342                // We do not need to shift the range if an update was made in a location following the range in the original file.
343                Ordering::Greater => {
344                    break;
345                }
346            }
347        }
348
349        // If the original `UpdatableCount` has been modified, then check if the counts satisfy the most recent `UpdatableCount`.
350        // If so, then there is no need to write to update the file and we can return early.
351        if let Some(update) = update_with_same_start {
352            if update.count.matches(num_constants, num_public, num_private, num_constraints) {
353                return;
354            }
355        }
356
357        // Construct the new update.
358        let new_update = match update_with_same_start {
359            None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
360            Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
361        };
362
363        // Apply the update at the adjusted location.
364        self.modified_text.replace_range(new_range, &new_update.argument_string);
365
366        // Print the difference between the original and updated counts.
367        let difference = new_update.count.difference_between(count);
368        println!(
369            "\n
370\x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
371   \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
372\x1b[1mOriginal count\x1b[0m:
373----
374{}
375----
376\x1b[1mUpdated count\x1b[0m:
377----
378{}
379----
380\x1b[1mDifference between updated and original\x1b[0m:
381----
382Constants: {}, Public: {}, Private: {}, Constraints: {}
383----
384",
385            count.file,
386            count.line,
387            count.column,
388            count,
389            new_update.count,
390            difference.0,
391            difference.1,
392            difference.2,
393            difference.3
394        );
395
396        // Add the new update to the set of updates.
397        self.updates.replace(new_update);
398
399        // Update the original file with the modified text.
400        fs::write(&self.absolute_path, &self.modified_text).unwrap()
401    }
402}
403
404/// Helper struct to keep track of updates to the original file.
405#[derive(Debug)]
406struct Update {
407    /// Starting location in the original file.
408    start: usize,
409    /// Ending location in the original file.
410    end: usize,
411    /// A dummy count with the new `Measurement`s.
412    count: UpdatableCount,
413    /// A string representation of `count`.
414    argument_string: String,
415}
416
417impl Update {
418    fn new(
419        range: &Range<usize>,
420        old_count: &UpdatableCount,
421        num_constants: u64,
422        num_public: u64,
423        num_private: u64,
424        num_constraints: u64,
425    ) -> Self {
426        // Helper function to determine the new `Measurement` based on the expected value.
427        let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
428            Measurement::Exact(..) => Measurement::Exact(expected),
429            Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
430            Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
431        };
432        let count = UpdatableCount::dummy(
433            generate_new_measurement(old_count.constant, num_constants),
434            generate_new_measurement(old_count.public, num_public),
435            generate_new_measurement(old_count.private, num_private),
436            generate_new_measurement(old_count.constraints, num_constraints),
437        );
438        Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
439    }
440}
441
442impl PartialEq for Update {
443    fn eq(&self, other: &Self) -> bool {
444        self.start == other.start
445    }
446}
447impl Eq for Update {}
448impl PartialOrd for Update {
449    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
450        Some(self.cmp(other))
451    }
452}
453impl Ord for Update {
454    fn cmp(&self, other: &Self) -> Ordering {
455        self.start.cmp(&other.start)
456    }
457}
458
459/// A struct that provides an iterator over the lines in a string, while preserving the original line endings.
460/// This is necessary as `str::lines` does not preserve the original line endings.
461struct LinesWithEnds<'a> {
462    text: &'a str,
463}
464
465impl<'a> Iterator for LinesWithEnds<'a> {
466    type Item = &'a str;
467
468    fn next(&mut self) -> Option<&'a str> {
469        match self.text.is_empty() {
470            true => None,
471            false => {
472                let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
473                let (res, next) = self.text.split_at(idx);
474                self.text = next;
475                Some(res)
476            }
477        }
478    }
479}
480
481impl<'a> From<&'a str> for LinesWithEnds<'a> {
482    fn from(text: &'a str) -> Self {
483        LinesWithEnds { text }
484    }
485}
486
487#[cfg(test)]
488mod test {
489    use serial_test::serial;
490    use std::env;
491
492    #[test]
493    fn check_position() {
494        let count = count_is!(0, 0, 0, 0);
495        assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
496        assert_eq!(count.line, 494);
497        assert_eq!(count.column, 21);
498    }
499
500    // Note: The below tests must be run serially since the behavior `assert_matches` depends on whether or not
501    // the environment variable `UPDATE_COUNT` is set.
502
503    #[test]
504    #[serial]
505    fn check_count_passes() {
506        let count = count_is!(1, 2, 3, 4);
507        let num_constants = 1;
508        let num_public = 2;
509        let num_private = 3;
510        let num_inputs = 4;
511        count.assert_matches(num_constants, num_public, num_private, num_inputs);
512    }
513
514    #[test]
515    #[serial]
516    #[should_panic]
517    fn check_count_fails() {
518        let count = count_is!(1, 2, 3, 4);
519        let num_constants = 5;
520        let num_public = 6;
521        let num_private = 7;
522        let num_inputs = 8;
523
524        count.assert_matches(num_constants, num_public, num_private, num_inputs);
525    }
526
527    #[test]
528    #[serial]
529    #[should_panic]
530    fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
531        let count = count_is!(1, 2, 3, 4);
532        let num_constants = 5;
533        let num_public = 6;
534        let num_private = 7;
535        let num_inputs = 8;
536
537        // Set the environment variable to update the file.
538        env::set_var("UPDATE_COUNT", "1");
539
540        count.assert_matches(num_constants, num_public, num_private, num_inputs);
541
542        env::remove_var("UPDATE_COUNT");
543    }
544
545    #[test]
546    #[serial]
547    fn check_count_updates_correctly() {
548        // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
549        let count = count_is!(11, 12, 13, 14);
550        let num_constants = 11;
551        let num_public = 12;
552        let num_private = 13;
553        let num_inputs = 14;
554
555        // Set the environment variable to update the file.
556        env::set_var("UPDATE_COUNT", "updatable_count.rs");
557
558        count.assert_matches(num_constants, num_public, num_private, num_inputs);
559
560        env::remove_var("UPDATE_COUNT");
561    }
562
563    #[test]
564    #[serial]
565    fn check_count_updates_correctly_multiple_times() {
566        // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
567        let count = count_is!(17, 18, 19, 20);
568
569        env::set_var("UPDATE_COUNT", "updatable_count.rs");
570
571        let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
572        count.assert_matches(num_constants, num_public, num_private, num_inputs);
573
574        let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
575        count.assert_matches(num_constants, num_public, num_private, num_inputs);
576
577        let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
578        count.assert_matches(num_constants, num_public, num_private, num_inputs);
579
580        let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
581        count.assert_matches(num_constants, num_public, num_private, num_inputs);
582
583        env::remove_var("UPDATE_COUNT");
584    }
585
586    #[test]
587    #[serial]
588    fn check_count_less_than_selects_maximum() {
589        // `count` is initially `count_less_than!(1, 2, 3, 4)`.
590        // After counts are updated, `original_count` is `count_less_than!(17, 18, 19, 20)`.
591        // In other words, count is updated to be the maximum of the original and updated counts.
592        let count = count_less_than!(17, 18, 19, 20);
593
594        // Set the environment variable to update the file.
595        env::set_var("UPDATE_COUNT", "updatable_count.rs");
596
597        let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
598        count.assert_matches(num_constants, num_public, num_private, num_inputs);
599
600        let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
601        count.assert_matches(num_constants, num_public, num_private, num_inputs);
602
603        let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
604        count.assert_matches(num_constants, num_public, num_private, num_inputs);
605
606        let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
607        count.assert_matches(num_constants, num_public, num_private, num_inputs);
608
609        env::remove_var("UPDATE_COUNT");
610    }
611}