1use crate::{Constant, Constraints, Measurement, Private, Public};
17
18use core::fmt::Debug;
19use once_cell::sync::{Lazy, OnceCell};
20use std::{
21 cmp::Ordering,
22 collections::{BTreeSet, HashMap},
23 env,
24 fmt::Display,
25 fs,
26 ops::Range,
27 path::{Path, PathBuf},
28 sync::Mutex,
29};
30
31static FILES: Lazy<Mutex<HashMap<&'static str, FileUpdates>>> = Lazy::new(Default::default);
32static WORKSPACE_ROOT: OnceCell<PathBuf> = OnceCell::new();
33
34#[macro_export]
38macro_rules! count_is {
39 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
40 $crate::UpdatableCount {
41 constant: $crate::Measurement::Exact($num_constants),
42 public: $crate::Measurement::Exact($num_public),
43 private: $crate::Measurement::Exact($num_private),
44 constraints: $crate::Measurement::Exact($num_constraints),
45 file: file!(),
46 line: line!(),
47 column: column!(),
48 }
49 };
50}
51
52#[macro_export]
56macro_rules! count_less_than {
57 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
58 $crate::UpdatableCount {
59 constant: $crate::Measurement::UpperBound($num_constants),
60 public: $crate::Measurement::UpperBound($num_public),
61 private: $crate::Measurement::UpperBound($num_private),
62 constraints: $crate::Measurement::UpperBound($num_constraints),
63 file: file!(),
64 line: line!(),
65 column: column!(),
66 }
67 };
68}
69
70#[derive(Copy, Clone, Debug)]
73pub struct UpdatableCount {
74 pub constant: Constant,
75 pub public: Public,
76 pub private: Private,
77 pub constraints: Constraints,
78 #[doc(hidden)]
79 pub file: &'static str,
80 #[doc(hidden)]
81 pub line: u32,
82 #[doc(hidden)]
83 pub column: u32,
84}
85
86impl Display for UpdatableCount {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(
89 f,
90 "Constants: {}, Public: {}, Private: {}, Constraints: {}",
91 self.constant, self.public, self.private, self.constraints
92 )
93 }
94}
95
96impl UpdatableCount {
97 pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
103 self.constant.matches(num_constants)
104 && self.public.matches(num_public)
105 && self.private.matches(num_private)
106 && self.constraints.matches(num_constraints)
107 }
108
109 pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
114 if !self.matches(num_constants, num_public, num_private, num_constraints) {
115 let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
116 match env::var("UPDATE_COUNT") {
117 Ok(query_string) if self.file.contains(&query_string) => {
120 files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
121 self,
122 num_constants,
123 num_public,
124 num_private,
125 num_constraints,
126 );
127 }
128 _ => {
130 println!(
131 "\n
132\x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
133 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
134\x1b[1mExpected\x1b[0m:
135----
136{}
137----
138\x1b[1mActual\x1b[0m:
139----
140Constants: {}, Public: {}, Private: {}, Constraints: {}
141----
142",
143 self.file,
144 self.line,
145 self.column,
146 self,
147 num_constants,
148 num_public,
149 num_private,
150 num_constraints,
151 );
152 snarkvm_utilities::panic::resume_unwind(Box::new(()));
154 }
155 }
156 }
157 }
158
159 fn locate(&self, file: &str) -> Range<usize> {
170 let mut line_start = 0;
172 let mut starting_index = None;
173 let mut ending_index = None;
174 for (i, line) in LinesWithEnds::from(file).enumerate() {
175 if i == self.line as usize - 1 {
176 let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
178 .skip_while(|&(_, c)| c != '!') .skip(1) .skip_while(|(_, c)| c.is_whitespace()); starting_index = Some(
184 line_start
185 + argument_character_indices
186 .next()
187 .expect("Could not find the beginning of the macro invocation.")
188 .0,
189 );
190 }
191
192 if starting_index.is_some() {
193 match line.char_indices().find(|&(_, c)| c == ')') {
195 None => (), Some((offset, _)) => {
197 ending_index = Some(line_start + offset + 1);
199 break;
200 }
201 }
202 }
203 line_start += line.len();
204 }
205
206 Range {
207 start: starting_index.expect("Could not find the beginning of the macro invocation."),
208 end: ending_index.expect("Could not find the ending of the macro invocation."),
209 }
210 }
211
212 pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
214 let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
215 (Measurement::Exact(self_value), Measurement::Exact(other_value))
216 | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
217 (self_value as i64) - (other_value as i64)
219 }
220 _ => panic!(
221 "Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
222 ),
223 };
224 (
225 difference(self.constant, other.constant),
226 difference(self.public, other.public),
227 difference(self.private, other.private),
228 difference(self.constraints, other.constraints),
229 )
230 }
231
232 fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
235 Self {
236 constant,
237 public,
238 private,
239 constraints,
240 file: Default::default(),
241 line: Default::default(),
242 column: Default::default(),
243 }
244 }
245
246 fn as_argument_string(&self) -> String {
248 let generate_arg = |measurement| match measurement {
249 Measurement::Exact(value) => value,
250 Measurement::UpperBound(bound) => bound,
251 Measurement::Range(..) => panic!(
252 "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
253 ),
254 };
255 format!(
256 "({}, {}, {}, {})",
257 generate_arg(self.constant),
258 generate_arg(self.public),
259 generate_arg(self.private),
260 generate_arg(self.constraints)
261 )
262 }
263}
264
265struct FileUpdates {
269 absolute_path: PathBuf,
270 original_text: String,
271 modified_text: String,
272 updates: BTreeSet<Update>,
277}
278
279impl FileUpdates {
280 fn new(count: &UpdatableCount) -> Self {
284 let path = Path::new(count.file);
285 let absolute_path = match path.is_absolute() {
286 true => path.to_owned(),
287 false => {
288 WORKSPACE_ROOT
290 .get_or_try_init(|| {
291 let workspace_root = Path::new(&env!("CARGO_MANIFEST_DIR"))
293 .ancestors()
294 .filter(|it| it.join("Cargo.toml").exists())
295 .last()
296 .unwrap()
297 .to_path_buf();
298
299 Ok(workspace_root)
300 })
301 .unwrap_or_else(|_: env::VarError| {
302 panic!("No CARGO_MANIFEST_DIR env var and the path is relative: {}", path.display())
303 })
304 .join(path)
305 }
306 };
307 let original_text = fs::read_to_string(&absolute_path).unwrap();
308 let modified_text = original_text.clone();
309 let updates = Default::default();
310 Self { absolute_path, original_text, modified_text, updates }
311 }
312
313 fn update_count(
318 &mut self,
319 count: &UpdatableCount,
320 num_constants: u64,
321 num_public: u64,
322 num_private: u64,
323 num_constraints: u64,
324 ) {
325 let range = count.locate(&self.original_text);
327
328 let mut new_range = range.clone();
329 let mut update_with_same_start = None;
330
331 for previous_update in &self.updates {
334 let amount_deleted = previous_update.end - previous_update.start;
335 let amount_inserted = previous_update.argument_string.len();
336
337 match previous_update.start.cmp(&range.start) {
338 Ordering::Less => {
340 new_range.start = new_range.start - amount_deleted + amount_inserted;
341 new_range.end = new_range.end - amount_deleted + amount_inserted;
342 }
343 Ordering::Equal => {
345 new_range.end = new_range.end - amount_deleted + amount_inserted;
346 update_with_same_start = Some(previous_update);
347 }
348 Ordering::Greater => {
350 break;
351 }
352 }
353 }
354
355 if let Some(update) = update_with_same_start {
358 if update.count.matches(num_constants, num_public, num_private, num_constraints) {
359 return;
360 }
361 }
362
363 let new_update = match update_with_same_start {
365 None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
366 Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
367 };
368
369 self.modified_text.replace_range(new_range, &new_update.argument_string);
371
372 let difference = new_update.count.difference_between(count);
374 println!(
375 "\n
376\x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
377 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
378\x1b[1mOriginal count\x1b[0m:
379----
380{}
381----
382\x1b[1mUpdated count\x1b[0m:
383----
384{}
385----
386\x1b[1mDifference between updated and original\x1b[0m:
387----
388Constants: {}, Public: {}, Private: {}, Constraints: {}
389----
390",
391 count.file,
392 count.line,
393 count.column,
394 count,
395 new_update.count,
396 difference.0,
397 difference.1,
398 difference.2,
399 difference.3
400 );
401
402 self.updates.replace(new_update);
404
405 fs::write(&self.absolute_path, &self.modified_text).unwrap()
407 }
408}
409
410#[derive(Debug)]
412struct Update {
413 start: usize,
415 end: usize,
417 count: UpdatableCount,
419 argument_string: String,
421}
422
423impl Update {
424 fn new(
425 range: &Range<usize>,
426 old_count: &UpdatableCount,
427 num_constants: u64,
428 num_public: u64,
429 num_private: u64,
430 num_constraints: u64,
431 ) -> Self {
432 let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
434 Measurement::Exact(..) => Measurement::Exact(expected),
435 Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
436 Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
437 };
438 let count = UpdatableCount::dummy(
439 generate_new_measurement(old_count.constant, num_constants),
440 generate_new_measurement(old_count.public, num_public),
441 generate_new_measurement(old_count.private, num_private),
442 generate_new_measurement(old_count.constraints, num_constraints),
443 );
444 Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
445 }
446}
447
448impl PartialEq for Update {
449 fn eq(&self, other: &Self) -> bool {
450 self.start == other.start
451 }
452}
453impl Eq for Update {}
454impl PartialOrd for Update {
455 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
456 Some(self.cmp(other))
457 }
458}
459impl Ord for Update {
460 fn cmp(&self, other: &Self) -> Ordering {
461 self.start.cmp(&other.start)
462 }
463}
464
465struct LinesWithEnds<'a> {
468 text: &'a str,
469}
470
471impl<'a> Iterator for LinesWithEnds<'a> {
472 type Item = &'a str;
473
474 fn next(&mut self) -> Option<&'a str> {
475 match self.text.is_empty() {
476 true => None,
477 false => {
478 let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
479 let (res, next) = self.text.split_at(idx);
480 self.text = next;
481 Some(res)
482 }
483 }
484 }
485}
486
487impl<'a> From<&'a str> for LinesWithEnds<'a> {
488 fn from(text: &'a str) -> Self {
489 LinesWithEnds { text }
490 }
491}
492
493#[cfg(test)]
494mod test {
495 use serial_test::serial;
496 use std::env;
497
498 #[test]
499 fn check_position() {
500 let count = count_is!(0, 0, 0, 0);
501 assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
502 assert_eq!(count.line, 500);
503 assert_eq!(count.column, 21);
504 }
505
506 #[test]
510 #[serial]
511 fn check_count_passes() {
512 let count = count_is!(1, 2, 3, 4);
513 let num_constants = 1;
514 let num_public = 2;
515 let num_private = 3;
516 let num_inputs = 4;
517 count.assert_matches(num_constants, num_public, num_private, num_inputs);
518 }
519
520 #[test]
521 #[serial]
522 #[should_panic]
523 fn check_count_fails() {
524 let count = count_is!(1, 2, 3, 4);
525 let num_constants = 5;
526 let num_public = 6;
527 let num_private = 7;
528 let num_inputs = 8;
529
530 count.assert_matches(num_constants, num_public, num_private, num_inputs);
531 }
532
533 #[test]
534 #[serial]
535 #[should_panic]
536 fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
537 let count = count_is!(1, 2, 3, 4);
538 let num_constants = 5;
539 let num_public = 6;
540 let num_private = 7;
541 let num_inputs = 8;
542
543 env::set_var("UPDATE_COUNT", "1");
545
546 count.assert_matches(num_constants, num_public, num_private, num_inputs);
547
548 env::remove_var("UPDATE_COUNT");
549 }
550
551 #[test]
552 #[serial]
553 fn check_count_updates_correctly() {
554 let count = count_is!(11, 12, 13, 14);
556 let num_constants = 11;
557 let num_public = 12;
558 let num_private = 13;
559 let num_inputs = 14;
560
561 env::set_var("UPDATE_COUNT", "updatable_count.rs");
563
564 count.assert_matches(num_constants, num_public, num_private, num_inputs);
565
566 env::remove_var("UPDATE_COUNT");
567 }
568
569 #[test]
570 #[serial]
571 fn check_count_updates_correctly_multiple_times() {
572 let count = count_is!(17, 18, 19, 20);
574
575 env::set_var("UPDATE_COUNT", "updatable_count.rs");
576
577 let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
578 count.assert_matches(num_constants, num_public, num_private, num_inputs);
579
580 let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
581 count.assert_matches(num_constants, num_public, num_private, num_inputs);
582
583 let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
584 count.assert_matches(num_constants, num_public, num_private, num_inputs);
585
586 let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
587 count.assert_matches(num_constants, num_public, num_private, num_inputs);
588
589 env::remove_var("UPDATE_COUNT");
590 }
591
592 #[test]
593 #[serial]
594 fn check_count_less_than_selects_maximum() {
595 let count = count_less_than!(17, 18, 19, 20);
599
600 env::set_var("UPDATE_COUNT", "updatable_count.rs");
602
603 let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
604 count.assert_matches(num_constants, num_public, num_private, num_inputs);
605
606 let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
607 count.assert_matches(num_constants, num_public, num_private, num_inputs);
608
609 let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
610 count.assert_matches(num_constants, num_public, num_private, num_inputs);
611
612 let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
613 count.assert_matches(num_constants, num_public, num_private, num_inputs);
614
615 env::remove_var("UPDATE_COUNT");
616 }
617}