1use crate::{Constant, Constraints, Measurement, Private, Public};
17
18use core::fmt::Debug;
19use std::{
20 cmp::Ordering,
21 collections::{BTreeSet, HashMap},
22 env,
23 fmt::Display,
24 fs,
25 ops::Range,
26 path::{Path, PathBuf},
27 sync::{LazyLock, Mutex, OnceLock},
28};
29
30static FILES: LazyLock<Mutex<HashMap<&'static str, FileUpdates>>> = LazyLock::new(Default::default);
31static WORKSPACE_ROOT: OnceLock<PathBuf> = OnceLock::new();
32
33#[macro_export]
37macro_rules! count_is {
38 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
39 $crate::UpdatableCount {
40 constant: $crate::Measurement::Exact($num_constants),
41 public: $crate::Measurement::Exact($num_public),
42 private: $crate::Measurement::Exact($num_private),
43 constraints: $crate::Measurement::Exact($num_constraints),
44 file: file!(),
45 line: line!(),
46 column: column!(),
47 }
48 };
49}
50
51#[macro_export]
55macro_rules! count_less_than {
56 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
57 $crate::UpdatableCount {
58 constant: $crate::Measurement::UpperBound($num_constants),
59 public: $crate::Measurement::UpperBound($num_public),
60 private: $crate::Measurement::UpperBound($num_private),
61 constraints: $crate::Measurement::UpperBound($num_constraints),
62 file: file!(),
63 line: line!(),
64 column: column!(),
65 }
66 };
67}
68
69#[derive(Copy, Clone, Debug)]
72pub struct UpdatableCount {
73 pub constant: Constant,
74 pub public: Public,
75 pub private: Private,
76 pub constraints: Constraints,
77 #[doc(hidden)]
78 pub file: &'static str,
79 #[doc(hidden)]
80 pub line: u32,
81 #[doc(hidden)]
82 pub column: u32,
83}
84
85impl Display for UpdatableCount {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(
88 f,
89 "Constants: {}, Public: {}, Private: {}, Constraints: {}",
90 self.constant, self.public, self.private, self.constraints
91 )
92 }
93}
94
95impl UpdatableCount {
96 pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
102 self.constant.matches(num_constants)
103 && self.public.matches(num_public)
104 && self.private.matches(num_private)
105 && self.constraints.matches(num_constraints)
106 }
107
108 pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
113 if !self.matches(num_constants, num_public, num_private, num_constraints) {
114 let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
115 match env::var("UPDATE_COUNT") {
116 Ok(query_string) if self.file.contains(&query_string) => {
119 files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
120 self,
121 num_constants,
122 num_public,
123 num_private,
124 num_constraints,
125 );
126 }
127 _ => {
129 println!(
130 "\n
131\x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
132 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
133\x1b[1mExpected\x1b[0m:
134----
135{}
136----
137\x1b[1mActual\x1b[0m:
138----
139Constants: {}, Public: {}, Private: {}, Constraints: {}
140----
141",
142 self.file,
143 self.line,
144 self.column,
145 self,
146 num_constants,
147 num_public,
148 num_private,
149 num_constraints,
150 );
151 std::panic::resume_unwind(Box::new(()));
153 }
154 }
155 }
156 }
157
158 fn locate(&self, file: &str) -> Range<usize> {
169 let mut line_start = 0;
171 let mut starting_index = None;
172 let mut ending_index = None;
173 for (i, line) in LinesWithEnds::from(file).enumerate() {
174 if i == self.line as usize - 1 {
175 let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
177 .skip_while(|&(_, c)| c != '!') .skip(1) .skip_while(|(_, c)| c.is_whitespace()); starting_index = Some(
183 line_start
184 + argument_character_indices
185 .next()
186 .expect("Could not find the beginning of the macro invocation.")
187 .0,
188 );
189 }
190
191 if starting_index.is_some() {
192 match line.char_indices().find(|&(_, c)| c == ')') {
194 None => (), Some((offset, _)) => {
196 ending_index = Some(line_start + offset + 1);
198 break;
199 }
200 }
201 }
202 line_start += line.len();
203 }
204
205 Range {
206 start: starting_index.expect("Could not find the beginning of the macro invocation."),
207 end: ending_index.expect("Could not find the ending of the macro invocation."),
208 }
209 }
210
211 pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
213 let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
214 (Measurement::Exact(self_value), Measurement::Exact(other_value))
215 | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
216 (self_value as i64) - (other_value as i64)
218 }
219 _ => panic!(
220 "Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
221 ),
222 };
223 (
224 difference(self.constant, other.constant),
225 difference(self.public, other.public),
226 difference(self.private, other.private),
227 difference(self.constraints, other.constraints),
228 )
229 }
230
231 fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
234 Self {
235 constant,
236 public,
237 private,
238 constraints,
239 file: Default::default(),
240 line: Default::default(),
241 column: Default::default(),
242 }
243 }
244
245 fn as_argument_string(&self) -> String {
247 let generate_arg = |measurement| match measurement {
248 Measurement::Exact(value) => value,
249 Measurement::UpperBound(bound) => bound,
250 Measurement::Range(..) => panic!(
251 "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
252 ),
253 };
254 format!(
255 "({}, {}, {}, {})",
256 generate_arg(self.constant),
257 generate_arg(self.public),
258 generate_arg(self.private),
259 generate_arg(self.constraints)
260 )
261 }
262}
263
264struct FileUpdates {
268 absolute_path: PathBuf,
269 original_text: String,
270 modified_text: String,
271 updates: BTreeSet<Update>,
276}
277
278impl FileUpdates {
279 fn new(count: &UpdatableCount) -> Self {
283 let path = Path::new(count.file);
284 let absolute_path = match path.is_absolute() {
285 true => path.to_owned(),
286 false => {
287 WORKSPACE_ROOT
289 .get_or_init(|| {
290 Path::new(&env!("CARGO_MANIFEST_DIR"))
292 .ancestors()
293 .filter(|it| it.join("Cargo.toml").exists())
294 .last()
295 .unwrap()
296 .to_path_buf()
297 })
298 .join(path)
299 }
300 };
301 let original_text = fs::read_to_string(&absolute_path).unwrap();
302 let modified_text = original_text.clone();
303 let updates = Default::default();
304 Self { absolute_path, original_text, modified_text, updates }
305 }
306
307 fn update_count(
312 &mut self,
313 count: &UpdatableCount,
314 num_constants: u64,
315 num_public: u64,
316 num_private: u64,
317 num_constraints: u64,
318 ) {
319 let range = count.locate(&self.original_text);
321
322 let mut new_range = range.clone();
323 let mut update_with_same_start = None;
324
325 for previous_update in &self.updates {
328 let amount_deleted = previous_update.end - previous_update.start;
329 let amount_inserted = previous_update.argument_string.len();
330
331 match previous_update.start.cmp(&range.start) {
332 Ordering::Less => {
334 new_range.start = new_range.start - amount_deleted + amount_inserted;
335 new_range.end = new_range.end - amount_deleted + amount_inserted;
336 }
337 Ordering::Equal => {
339 new_range.end = new_range.end - amount_deleted + amount_inserted;
340 update_with_same_start = Some(previous_update);
341 }
342 Ordering::Greater => {
344 break;
345 }
346 }
347 }
348
349 if let Some(update) = update_with_same_start {
352 if update.count.matches(num_constants, num_public, num_private, num_constraints) {
353 return;
354 }
355 }
356
357 let new_update = match update_with_same_start {
359 None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
360 Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
361 };
362
363 self.modified_text.replace_range(new_range, &new_update.argument_string);
365
366 let difference = new_update.count.difference_between(count);
368 println!(
369 "\n
370\x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
371 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
372\x1b[1mOriginal count\x1b[0m:
373----
374{}
375----
376\x1b[1mUpdated count\x1b[0m:
377----
378{}
379----
380\x1b[1mDifference between updated and original\x1b[0m:
381----
382Constants: {}, Public: {}, Private: {}, Constraints: {}
383----
384",
385 count.file,
386 count.line,
387 count.column,
388 count,
389 new_update.count,
390 difference.0,
391 difference.1,
392 difference.2,
393 difference.3
394 );
395
396 self.updates.replace(new_update);
398
399 fs::write(&self.absolute_path, &self.modified_text).unwrap()
401 }
402}
403
404#[derive(Debug)]
406struct Update {
407 start: usize,
409 end: usize,
411 count: UpdatableCount,
413 argument_string: String,
415}
416
417impl Update {
418 fn new(
419 range: &Range<usize>,
420 old_count: &UpdatableCount,
421 num_constants: u64,
422 num_public: u64,
423 num_private: u64,
424 num_constraints: u64,
425 ) -> Self {
426 let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
428 Measurement::Exact(..) => Measurement::Exact(expected),
429 Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
430 Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
431 };
432 let count = UpdatableCount::dummy(
433 generate_new_measurement(old_count.constant, num_constants),
434 generate_new_measurement(old_count.public, num_public),
435 generate_new_measurement(old_count.private, num_private),
436 generate_new_measurement(old_count.constraints, num_constraints),
437 );
438 Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
439 }
440}
441
442impl PartialEq for Update {
443 fn eq(&self, other: &Self) -> bool {
444 self.start == other.start
445 }
446}
447impl Eq for Update {}
448impl PartialOrd for Update {
449 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
450 Some(self.cmp(other))
451 }
452}
453impl Ord for Update {
454 fn cmp(&self, other: &Self) -> Ordering {
455 self.start.cmp(&other.start)
456 }
457}
458
459struct LinesWithEnds<'a> {
462 text: &'a str,
463}
464
465impl<'a> Iterator for LinesWithEnds<'a> {
466 type Item = &'a str;
467
468 fn next(&mut self) -> Option<&'a str> {
469 match self.text.is_empty() {
470 true => None,
471 false => {
472 let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
473 let (res, next) = self.text.split_at(idx);
474 self.text = next;
475 Some(res)
476 }
477 }
478 }
479}
480
481impl<'a> From<&'a str> for LinesWithEnds<'a> {
482 fn from(text: &'a str) -> Self {
483 LinesWithEnds { text }
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use serial_test::serial;
490 use std::env;
491
492 #[test]
493 fn check_position() {
494 let count = count_is!(0, 0, 0, 0);
495 assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
496 assert_eq!(count.line, 494);
497 assert_eq!(count.column, 21);
498 }
499
500 #[test]
504 #[serial]
505 fn check_count_passes() {
506 let count = count_is!(1, 2, 3, 4);
507 let num_constants = 1;
508 let num_public = 2;
509 let num_private = 3;
510 let num_inputs = 4;
511 count.assert_matches(num_constants, num_public, num_private, num_inputs);
512 }
513
514 #[test]
515 #[serial]
516 #[should_panic]
517 fn check_count_fails() {
518 let count = count_is!(1, 2, 3, 4);
519 let num_constants = 5;
520 let num_public = 6;
521 let num_private = 7;
522 let num_inputs = 8;
523
524 count.assert_matches(num_constants, num_public, num_private, num_inputs);
525 }
526
527 #[test]
528 #[serial]
529 #[should_panic]
530 fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
531 let count = count_is!(1, 2, 3, 4);
532 let num_constants = 5;
533 let num_public = 6;
534 let num_private = 7;
535 let num_inputs = 8;
536
537 env::set_var("UPDATE_COUNT", "1");
539
540 count.assert_matches(num_constants, num_public, num_private, num_inputs);
541
542 env::remove_var("UPDATE_COUNT");
543 }
544
545 #[test]
546 #[serial]
547 fn check_count_updates_correctly() {
548 let count = count_is!(11, 12, 13, 14);
550 let num_constants = 11;
551 let num_public = 12;
552 let num_private = 13;
553 let num_inputs = 14;
554
555 env::set_var("UPDATE_COUNT", "updatable_count.rs");
557
558 count.assert_matches(num_constants, num_public, num_private, num_inputs);
559
560 env::remove_var("UPDATE_COUNT");
561 }
562
563 #[test]
564 #[serial]
565 fn check_count_updates_correctly_multiple_times() {
566 let count = count_is!(17, 18, 19, 20);
568
569 env::set_var("UPDATE_COUNT", "updatable_count.rs");
570
571 let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
572 count.assert_matches(num_constants, num_public, num_private, num_inputs);
573
574 let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
575 count.assert_matches(num_constants, num_public, num_private, num_inputs);
576
577 let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
578 count.assert_matches(num_constants, num_public, num_private, num_inputs);
579
580 let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
581 count.assert_matches(num_constants, num_public, num_private, num_inputs);
582
583 env::remove_var("UPDATE_COUNT");
584 }
585
586 #[test]
587 #[serial]
588 fn check_count_less_than_selects_maximum() {
589 let count = count_less_than!(17, 18, 19, 20);
593
594 env::set_var("UPDATE_COUNT", "updatable_count.rs");
596
597 let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
598 count.assert_matches(num_constants, num_public, num_private, num_inputs);
599
600 let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
601 count.assert_matches(num_constants, num_public, num_private, num_inputs);
602
603 let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
604 count.assert_matches(num_constants, num_public, num_private, num_inputs);
605
606 let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
607 count.assert_matches(num_constants, num_public, num_private, num_inputs);
608
609 env::remove_var("UPDATE_COUNT");
610 }
611}