snarkvm_circuit_environment/helpers/
linear_combination.rs1use crate::{Mode, *};
17use snarkvm_fields::PrimeField;
18
19use core::{
20 fmt,
21 ops::{Add, AddAssign, Mul, Neg, Sub},
22};
23use smallvec::SmallVec;
24
25#[derive(Clone, Hash)]
36pub struct LinearCombination<F: PrimeField> {
37 constant: F,
38 terms: SmallVec<[(Variable<F>, F); 1]>,
40 value: F,
42}
43
44impl<F: PrimeField> LinearCombination<F> {
45 pub(crate) fn zero() -> Self {
47 Self { constant: F::zero(), terms: Default::default(), value: Default::default() }
48 }
49
50 pub(crate) fn one() -> Self {
52 Self { constant: F::one(), terms: Default::default(), value: F::one() }
53 }
54
55 pub fn is_constant(&self) -> bool {
57 self.terms.is_empty()
58 }
59
60 pub fn is_public(&self) -> bool {
63 self.constant.is_zero()
64 && self.terms.len() == 1
65 && match self.terms.first() {
66 Some((Variable::Public(..), coefficient)) => *coefficient == F::one(),
67 _ => false,
68 }
69 }
70
71 pub fn is_private(&self) -> bool {
73 !self.is_constant() && !self.is_public()
74 }
75
76 pub fn mode(&self) -> Mode {
78 if self.is_constant() {
79 Mode::Constant
80 } else if self.is_public() {
81 Mode::Public
82 } else {
83 Mode::Private
84 }
85 }
86
87 pub fn value(&self) -> F {
89 self.value
90 }
91
92 pub fn is_boolean_type(&self) -> bool {
101 if self.terms.is_empty() {
103 self.constant.is_zero() || self.constant.is_one()
104 }
105 else if self.constant.is_zero() {
107 if self.terms.iter().any(|(v, _)| !(v.value().is_zero() || v.value().is_one())) {
110 eprintln!("Property 2 of the `Boolean` type was violated in {self}");
111 return false;
112 }
113
114 if !(self.value.is_zero() || self.value.is_one()) {
116 eprintln!("Property 3 of the `Boolean` type was violated");
117 return false;
118 }
119
120 true
121 } else {
122 eprintln!("Both LC::constant and LC::terms contain elements, which is a violation");
125 false
126 }
127 }
128
129 pub(super) fn to_constant(&self) -> F {
131 self.constant
132 }
133
134 pub(super) fn to_terms(&self) -> &[(Variable<F>, F)] {
136 &self.terms
137 }
138
139 pub(super) fn num_nonzeros(&self) -> u64 {
141 match self.constant.is_zero() {
143 true => self.terms.len() as u64,
144 false => (self.terms.len() as u64).saturating_add(1),
145 }
146 }
147
148 #[cfg(test)]
150 pub(super) fn num_additions(&self) -> u64 {
151 match !self.constant.is_zero() && !self.terms.is_empty() {
153 true => self.terms.len() as u64,
154 false => (self.terms.len() as u64).saturating_sub(1),
155 }
156 }
157}
158
159impl<F: PrimeField> From<Variable<F>> for LinearCombination<F> {
160 fn from(variable: Variable<F>) -> Self {
161 Self::from(&variable)
162 }
163}
164
165impl<F: PrimeField> From<&Variable<F>> for LinearCombination<F> {
166 fn from(variable: &Variable<F>) -> Self {
167 Self::from(&[variable.clone()])
168 }
169}
170
171impl<F: PrimeField, const N: usize> From<[Variable<F>; N]> for LinearCombination<F> {
172 fn from(variables: [Variable<F>; N]) -> Self {
173 Self::from(&variables[..])
174 }
175}
176
177impl<F: PrimeField, const N: usize> From<&[Variable<F>; N]> for LinearCombination<F> {
178 fn from(variables: &[Variable<F>; N]) -> Self {
179 Self::from(&variables[..])
180 }
181}
182
183impl<F: PrimeField> From<Vec<Variable<F>>> for LinearCombination<F> {
184 fn from(variables: Vec<Variable<F>>) -> Self {
185 Self::from(variables.as_slice())
186 }
187}
188
189impl<F: PrimeField> From<&Vec<Variable<F>>> for LinearCombination<F> {
190 fn from(variables: &Vec<Variable<F>>) -> Self {
191 Self::from(variables.as_slice())
192 }
193}
194
195impl<F: PrimeField> From<&[Variable<F>]> for LinearCombination<F> {
196 fn from(variables: &[Variable<F>]) -> Self {
197 let mut output = Self::zero();
198 for variable in variables {
199 match variable.is_constant() {
200 true => output.constant += variable.value(),
201 false => {
202 match output.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
203 Ok(idx) => {
204 output.terms[idx].1 += F::one();
206 if output.terms[idx].1.is_zero() {
208 output.terms.remove(idx);
209 }
210 }
211 Err(idx) => {
212 output.terms.insert(idx, (variable.clone(), F::one()));
214 }
215 }
216 }
217 }
218 output.value += variable.value();
220 }
221 output
222 }
223}
224
225impl<F: PrimeField> Neg for LinearCombination<F> {
226 type Output = Self;
227
228 #[inline]
229 fn neg(self) -> Self::Output {
230 let mut output = self;
231 output.constant = -output.constant;
232 output.terms.iter_mut().for_each(|(_, coefficient)| *coefficient = -(*coefficient));
233 output.value = -output.value;
234 output
235 }
236}
237
238impl<F: PrimeField> Neg for &LinearCombination<F> {
239 type Output = LinearCombination<F>;
240
241 #[inline]
242 fn neg(self) -> Self::Output {
243 -(self.clone())
244 }
245}
246
247impl<F: PrimeField> Add<Variable<F>> for LinearCombination<F> {
248 type Output = Self;
249
250 #[allow(clippy::op_ref)]
251 fn add(self, other: Variable<F>) -> Self::Output {
252 self + &other
253 }
254}
255
256impl<F: PrimeField> Add<&Variable<F>> for LinearCombination<F> {
257 type Output = Self;
258
259 fn add(self, other: &Variable<F>) -> Self::Output {
260 self + Self::from(other)
261 }
262}
263
264impl<F: PrimeField> Add<Variable<F>> for &LinearCombination<F> {
265 type Output = LinearCombination<F>;
266
267 #[allow(clippy::op_ref)]
268 fn add(self, other: Variable<F>) -> Self::Output {
269 self.clone() + &other
270 }
271}
272
273impl<F: PrimeField> Add<LinearCombination<F>> for LinearCombination<F> {
274 type Output = Self;
275
276 fn add(self, other: Self) -> Self::Output {
277 self + &other
278 }
279}
280
281impl<F: PrimeField> Add<&LinearCombination<F>> for LinearCombination<F> {
282 type Output = Self;
283
284 fn add(self, other: &Self) -> Self::Output {
285 &self + other
286 }
287}
288
289impl<F: PrimeField> Add<LinearCombination<F>> for &LinearCombination<F> {
290 type Output = LinearCombination<F>;
291
292 fn add(self, other: LinearCombination<F>) -> Self::Output {
293 self + &other
294 }
295}
296
297impl<F: PrimeField> Add<&LinearCombination<F>> for &LinearCombination<F> {
298 type Output = LinearCombination<F>;
299
300 fn add(self, other: &LinearCombination<F>) -> Self::Output {
301 if self.constant.is_zero() && self.terms.is_empty() {
302 other.clone()
303 } else if other.constant.is_zero() && other.terms.is_empty() {
304 self.clone()
305 } else if self.terms.len() > other.terms.len() {
306 let mut output = self.clone();
307 output += other;
308 output
309 } else {
310 let mut output = other.clone();
311 output += self;
312 output
313 }
314 }
315}
316
317impl<F: PrimeField> AddAssign<LinearCombination<F>> for LinearCombination<F> {
318 fn add_assign(&mut self, other: Self) {
319 *self += &other;
320 }
321}
322
323impl<F: PrimeField> AddAssign<&LinearCombination<F>> for LinearCombination<F> {
324 fn add_assign(&mut self, other: &Self) {
325 if other.constant.is_zero() && other.terms.is_empty() {
327 return;
328 }
329
330 if self.constant.is_zero() && self.terms.is_empty() {
331 *self = other.clone();
332 } else {
333 self.constant += other.constant;
335
336 for (variable, coefficient) in other.terms.iter() {
338 match variable.is_constant() {
339 true => panic!("Malformed linear combination found"),
340 false => {
341 match self.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
342 Ok(idx) => {
343 self.terms[idx].1 += *coefficient;
345 if self.terms[idx].1.is_zero() {
347 self.terms.remove(idx);
348 }
349 }
350 Err(idx) => {
351 self.terms.insert(idx, (variable.clone(), *coefficient));
353 }
354 }
355 }
356 }
357 }
358
359 self.value += other.value;
361 }
362 }
363}
364
365impl<F: PrimeField> Sub<Variable<F>> for LinearCombination<F> {
366 type Output = Self;
367
368 #[allow(clippy::op_ref)]
369 fn sub(self, other: Variable<F>) -> Self::Output {
370 self - &other
371 }
372}
373
374impl<F: PrimeField> Sub<&Variable<F>> for LinearCombination<F> {
375 type Output = Self;
376
377 fn sub(self, other: &Variable<F>) -> Self::Output {
378 self - Self::from(other)
379 }
380}
381
382impl<F: PrimeField> Sub<Variable<F>> for &LinearCombination<F> {
383 type Output = LinearCombination<F>;
384
385 #[allow(clippy::op_ref)]
386 fn sub(self, other: Variable<F>) -> Self::Output {
387 self.clone() - &other
388 }
389}
390
391impl<F: PrimeField> Sub<LinearCombination<F>> for LinearCombination<F> {
392 type Output = Self;
393
394 fn sub(self, other: Self) -> Self::Output {
395 self - &other
396 }
397}
398
399impl<F: PrimeField> Sub<&LinearCombination<F>> for LinearCombination<F> {
400 type Output = Self;
401
402 fn sub(self, other: &Self) -> Self::Output {
403 &self - other
404 }
405}
406
407impl<F: PrimeField> Sub<LinearCombination<F>> for &LinearCombination<F> {
408 type Output = LinearCombination<F>;
409
410 fn sub(self, other: LinearCombination<F>) -> Self::Output {
411 self - &other
412 }
413}
414
415impl<F: PrimeField> Sub<&LinearCombination<F>> for &LinearCombination<F> {
416 type Output = LinearCombination<F>;
417
418 fn sub(self, other: &LinearCombination<F>) -> Self::Output {
419 self + &(-other)
420 }
421}
422
423impl<F: PrimeField> Mul<F> for LinearCombination<F> {
424 type Output = Self;
425
426 #[allow(clippy::op_ref)]
427 fn mul(self, coefficient: F) -> Self::Output {
428 self * &coefficient
429 }
430}
431
432impl<F: PrimeField> Mul<&F> for LinearCombination<F> {
433 type Output = Self;
434
435 fn mul(self, coefficient: &F) -> Self::Output {
436 let mut output = self;
437 output.constant *= coefficient;
438 output.terms = output
439 .terms
440 .into_iter()
441 .filter_map(|(v, current_coefficient)| {
442 let res = current_coefficient * coefficient;
443 (!res.is_zero()).then_some((v, res))
444 })
445 .collect();
446 output.value *= coefficient;
447 output
448 }
449}
450
451impl<F: PrimeField> Mul<F> for &LinearCombination<F> {
452 type Output = LinearCombination<F>;
453
454 #[allow(clippy::op_ref)]
455 fn mul(self, coefficient: F) -> Self::Output {
456 self * &coefficient
457 }
458}
459
460impl<F: PrimeField> Mul<&F> for &LinearCombination<F> {
461 type Output = LinearCombination<F>;
462
463 fn mul(self, coefficient: &F) -> Self::Output {
464 self.clone() * coefficient
465 }
466}
467
468impl<F: PrimeField> fmt::Debug for LinearCombination<F> {
469 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
470 let mut output = format!("Constant({})", self.constant);
471
472 for (variable, coefficient) in &self.terms {
473 output += &match (variable.mode(), coefficient.is_one()) {
474 (Mode::Constant, _) => panic!("Malformed linear combination at: ({coefficient} * {variable:?})"),
475 (_, true) => format!(" + {variable:?}"),
476 _ => format!(" + {coefficient} * {variable:?}"),
477 };
478 }
479 write!(f, "{output}")
480 }
481}
482
483impl<F: PrimeField> fmt::Display for LinearCombination<F> {
484 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
485 write!(f, "{}", self.value)
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use snarkvm_fields::{One as O, Zero as Z};
493
494 use std::sync::Arc;
495
496 #[test]
497 fn test_zero() {
498 let zero = <Circuit as Environment>::BaseField::zero();
499
500 let candidate = LinearCombination::zero();
501 assert_eq!(zero, candidate.constant);
502 assert!(candidate.terms.is_empty());
503 assert_eq!(zero, candidate.value());
504 }
505
506 #[test]
507 fn test_one() {
508 let one = <Circuit as Environment>::BaseField::one();
509
510 let candidate = LinearCombination::one();
511 assert_eq!(one, candidate.constant);
512 assert!(candidate.terms.is_empty());
513 assert_eq!(one, candidate.value());
514 }
515
516 #[test]
517 fn test_two() {
518 let one = <Circuit as Environment>::BaseField::one();
519 let two = one + one;
520
521 let candidate = LinearCombination::one() + LinearCombination::one();
522 assert_eq!(two, candidate.constant);
523 assert!(candidate.terms.is_empty());
524 assert_eq!(two, candidate.value());
525 }
526
527 #[test]
528 fn test_is_constant() {
529 let zero = <Circuit as Environment>::BaseField::zero();
530 let one = <Circuit as Environment>::BaseField::one();
531
532 let candidate = LinearCombination::zero();
533 assert!(candidate.is_constant());
534 assert_eq!(zero, candidate.constant);
535 assert_eq!(zero, candidate.value());
536
537 let candidate = LinearCombination::one();
538 assert!(candidate.is_constant());
539 assert_eq!(one, candidate.constant);
540 assert_eq!(one, candidate.value());
541 }
542
543 #[test]
544 fn test_mul() {
545 let zero = <Circuit as Environment>::BaseField::zero();
546 let one = <Circuit as Environment>::BaseField::one();
547 let two = one + one;
548 let four = two + two;
549
550 let start = LinearCombination::from(Variable::Public(Arc::new((1, one))));
551 assert!(!start.is_constant());
552 assert_eq!(one, start.value());
553
554 let candidate = start * four;
556 assert_eq!(four, candidate.value());
557 assert_eq!(zero, candidate.constant);
558 assert_eq!(1, candidate.terms.len());
559
560 let (candidate_variable, candidate_coefficient) = candidate.terms.first().unwrap();
561 assert!(candidate_variable.is_public());
562 assert_eq!(one, candidate_variable.value());
563 assert_eq!(four, *candidate_coefficient);
564 }
565
566 #[test]
567 fn test_debug() {
568 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
569 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
570 {
571 let expected = "Constant(1) + Public(1, 1) + Private(0, 1)";
572
573 let candidate = LinearCombination::one() + one_public + one_private;
574 assert_eq!(expected, format!("{candidate:?}"));
575
576 let candidate = one_private + one_public + LinearCombination::one();
577 assert_eq!(expected, format!("{candidate:?}"));
578
579 let candidate = one_private + LinearCombination::one() + one_public;
580 assert_eq!(expected, format!("{candidate:?}"));
581
582 let candidate = one_public + LinearCombination::one() + one_private;
583 assert_eq!(expected, format!("{candidate:?}"));
584 }
585 {
586 let expected = "Constant(1) + 2 * Public(1, 1) + Private(0, 1)";
587
588 let candidate = LinearCombination::one() + one_public + one_public + one_private;
589 assert_eq!(expected, format!("{candidate:?}"));
590
591 let candidate = one_private + one_public + LinearCombination::one() + one_public;
592 assert_eq!(expected, format!("{candidate:?}"));
593
594 let candidate = one_public + one_private + LinearCombination::one() + one_public;
595 assert_eq!(expected, format!("{candidate:?}"));
596
597 let candidate = one_public + LinearCombination::one() + one_private + one_public;
598 assert_eq!(expected, format!("{candidate:?}"));
599 }
600 {
601 let expected = "Constant(1) + Public(1, 1) + 2 * Private(0, 1)";
602
603 let candidate = LinearCombination::one() + one_public + one_private + one_private;
604 assert_eq!(expected, format!("{candidate:?}"));
605
606 let candidate = one_private + one_public + LinearCombination::one() + one_private;
607 assert_eq!(expected, format!("{candidate:?}"));
608
609 let candidate = one_private + one_private + LinearCombination::one() + one_public;
610 assert_eq!(expected, format!("{candidate:?}"));
611
612 let candidate = one_public + LinearCombination::one() + one_private + one_private;
613 assert_eq!(expected, format!("{candidate:?}"));
614 }
615 {
616 let expected = "Constant(1) + Public(1, 1)";
617
618 let candidate = LinearCombination::one() + one_public + one_private - one_private;
619 assert_eq!(expected, format!("{candidate:?}"));
620
621 let candidate = one_private + one_public + LinearCombination::one() - one_private;
622 assert_eq!(expected, format!("{candidate:?}"));
623
624 let candidate = one_private - one_private + LinearCombination::one() + one_public;
625 assert_eq!(expected, format!("{candidate:?}"));
626
627 let candidate = one_public + LinearCombination::one() + one_private - one_private;
628 assert_eq!(expected, format!("{candidate:?}"));
629 }
630 }
631
632 #[rustfmt::skip]
633 #[test]
634 fn test_num_additions() {
635 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
636 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
637 let two_private = one_private + one_private;
638
639 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::zero();
640 assert_eq!(0, candidate.num_additions());
641
642 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::one();
643 assert_eq!(0, candidate.num_additions());
644
645 let candidate = LinearCombination::zero() + one_public;
646 assert_eq!(0, candidate.num_additions());
647
648 let candidate = LinearCombination::one() + one_public;
649 assert_eq!(1, candidate.num_additions());
650
651 let candidate = LinearCombination::zero() + one_public + one_public;
652 assert_eq!(0, candidate.num_additions());
653
654 let candidate = LinearCombination::one() + one_public + one_public;
655 assert_eq!(1, candidate.num_additions());
656
657 let candidate = LinearCombination::zero() + one_public + one_private;
658 assert_eq!(1, candidate.num_additions());
659
660 let candidate = LinearCombination::one() + one_public + one_private;
661 assert_eq!(2, candidate.num_additions());
662
663 let candidate = LinearCombination::zero() + one_public + one_private + one_public;
664 assert_eq!(1, candidate.num_additions());
665
666 let candidate = LinearCombination::one() + one_public + one_private + one_public;
667 assert_eq!(2, candidate.num_additions());
668
669 let candidate = LinearCombination::zero() + one_public + one_private + one_public + one_private;
670 assert_eq!(1, candidate.num_additions());
671
672 let candidate = LinearCombination::one() + one_public + one_private + one_public + one_private;
673 assert_eq!(2, candidate.num_additions());
674
675 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
676 assert_eq!(1, candidate.num_additions());
677
678 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
679 assert_eq!(2, candidate.num_additions());
680
681 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private;
682 assert_eq!(2, candidate.num_additions());
683
684 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
685 assert_eq!(1, candidate.num_additions());
686
687 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
688 assert_eq!(2, candidate.num_additions());
689
690 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private + &two_private;
691 assert_eq!(2, candidate.num_additions());
692
693 let candidate = LinearCombination::zero() - one_public;
696 assert_eq!(0, candidate.num_additions());
697
698 let candidate = LinearCombination::one() - one_public;
699 assert_eq!(1, candidate.num_additions());
700
701 let candidate = LinearCombination::zero() + one_public - one_public;
702 assert_eq!(0, candidate.num_additions());
703
704 let candidate = LinearCombination::one() + one_public - one_public;
705 assert_eq!(0, candidate.num_additions());
706
707 let candidate = LinearCombination::zero() + one_public - one_private;
708 assert_eq!(1, candidate.num_additions());
709
710 let candidate = LinearCombination::one() + one_public - one_private;
711 assert_eq!(2, candidate.num_additions());
712
713 let candidate = LinearCombination::zero() + one_public + one_private - one_public;
714 assert_eq!(0, candidate.num_additions());
715
716 let candidate = LinearCombination::one() + one_public + one_private - one_public;
717 assert_eq!(1, candidate.num_additions());
718
719 let candidate = LinearCombination::zero() + one_public + one_private + one_public - one_private;
720 assert_eq!(0, candidate.num_additions());
721
722 let candidate = LinearCombination::one() + one_public + one_private + one_public - one_private;
723 assert_eq!(1, candidate.num_additions());
724
725 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
726 assert_eq!(0, candidate.num_additions());
727
728 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
729 assert_eq!(1, candidate.num_additions());
730
731 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public - one_private;
732 assert_eq!(1, candidate.num_additions());
733
734 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
735 assert_eq!(0, candidate.num_additions());
736
737 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
738 assert_eq!(1, candidate.num_additions());
739
740 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private - &two_private;
741 assert_eq!(1, candidate.num_additions());
742 }
743}