snarkvm_circuit_environment/helpers/
linear_combination.rs1use crate::{Mode, *};
17use snarkvm_fields::PrimeField;
18
19use core::{
20 fmt,
21 ops::{Add, AddAssign, Mul, Neg, Sub},
22};
23use smallvec::SmallVec;
24
25#[derive(Clone, Hash)]
36pub struct LinearCombination<F: PrimeField> {
37 constant: F,
38 terms: SmallVec<[(Variable<F>, F); 1]>,
40 value: F,
42}
43
44impl<F: PrimeField> LinearCombination<F> {
45 pub(crate) fn zero() -> Self {
47 Self { constant: F::zero(), terms: Default::default(), value: Default::default() }
48 }
49
50 pub(crate) fn one() -> Self {
52 Self { constant: F::one(), terms: Default::default(), value: F::one() }
53 }
54
55 pub fn is_constant(&self) -> bool {
57 self.terms.is_empty()
58 }
59
60 pub fn is_public(&self) -> bool {
63 self.constant.is_zero()
64 && self.terms.len() == 1
65 && match self.terms.first() {
66 Some((Variable::Public(..), coefficient)) => *coefficient == F::one(),
67 _ => false,
68 }
69 }
70
71 pub fn is_private(&self) -> bool {
73 !self.is_constant() && !self.is_public()
74 }
75
76 pub fn mode(&self) -> Mode {
78 if self.is_constant() {
79 Mode::Constant
80 } else if self.is_public() {
81 Mode::Public
82 } else {
83 Mode::Private
84 }
85 }
86
87 pub fn value(&self) -> F {
89 self.value
90 }
91
92 pub fn is_boolean_type(&self) -> bool {
101 if self.terms.is_empty() {
103 self.constant.is_zero() || self.constant.is_one()
104 }
105 else if self.constant.is_zero() {
107 if self.terms.iter().any(|(v, _)| !(v.value().is_zero() || v.value().is_one())) {
110 eprintln!("Property 2 of the `Boolean` type was violated in {self}");
111 return false;
112 }
113
114 if !(self.value.is_zero() || self.value.is_one()) {
116 eprintln!("Property 3 of the `Boolean` type was violated");
117 return false;
118 }
119
120 true
121 } else {
122 eprintln!("Both LC::constant and LC::terms contain elements, which is a violation");
125 false
126 }
127 }
128
129 pub(super) fn to_constant(&self) -> F {
131 self.constant
132 }
133
134 pub(super) fn to_terms(&self) -> &[(Variable<F>, F)] {
136 &self.terms
137 }
138
139 pub(super) fn num_nonzeros(&self) -> u64 {
141 match self.constant.is_zero() {
143 true => self.terms.len() as u64,
144 false => (self.terms.len() as u64).saturating_add(1),
145 }
146 }
147
148 #[cfg(test)]
150 pub(super) fn num_additions(&self) -> u64 {
151 match !self.constant.is_zero() && !self.terms.is_empty() {
153 true => self.terms.len() as u64,
154 false => (self.terms.len() as u64).saturating_sub(1),
155 }
156 }
157}
158
159impl<F: PrimeField> From<Variable<F>> for LinearCombination<F> {
160 fn from(variable: Variable<F>) -> Self {
161 Self::from(&variable)
162 }
163}
164
165impl<F: PrimeField> From<&Variable<F>> for LinearCombination<F> {
166 fn from(variable: &Variable<F>) -> Self {
167 Self::from(&[variable.clone()])
168 }
169}
170
171impl<F: PrimeField, const N: usize> From<[Variable<F>; N]> for LinearCombination<F> {
172 fn from(variables: [Variable<F>; N]) -> Self {
173 Self::from(&variables[..])
174 }
175}
176
177impl<F: PrimeField, const N: usize> From<&[Variable<F>; N]> for LinearCombination<F> {
178 fn from(variables: &[Variable<F>; N]) -> Self {
179 Self::from(&variables[..])
180 }
181}
182
183impl<F: PrimeField> From<Vec<Variable<F>>> for LinearCombination<F> {
184 fn from(variables: Vec<Variable<F>>) -> Self {
185 Self::from(variables.as_slice())
186 }
187}
188
189impl<F: PrimeField> From<&Vec<Variable<F>>> for LinearCombination<F> {
190 fn from(variables: &Vec<Variable<F>>) -> Self {
191 Self::from(variables.as_slice())
192 }
193}
194
195impl<F: PrimeField> From<&[Variable<F>]> for LinearCombination<F> {
196 fn from(variables: &[Variable<F>]) -> Self {
197 let mut output = Self::zero();
198 for variable in variables {
199 match variable.is_constant() {
200 true => output.constant += variable.value(),
201 false => {
202 match output.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
203 Ok(idx) => {
204 output.terms[idx].1 += F::one();
206 if output.terms[idx].1.is_zero() {
208 output.terms.remove(idx);
209 }
210 }
211 Err(idx) => {
212 output.terms.insert(idx, (variable.clone(), F::one()));
214 }
215 }
216 }
217 }
218 output.value += variable.value();
220 }
221 output
222 }
223}
224
225impl<F: PrimeField> Neg for LinearCombination<F> {
226 type Output = Self;
227
228 #[inline]
229 fn neg(self) -> Self::Output {
230 let mut output = self;
231 output.constant = -output.constant;
232 output.terms.iter_mut().for_each(|(_, coefficient)| *coefficient = -(*coefficient));
233 output.value = -output.value;
234 output
235 }
236}
237
238impl<F: PrimeField> Neg for &LinearCombination<F> {
239 type Output = LinearCombination<F>;
240
241 #[inline]
242 fn neg(self) -> Self::Output {
243 -(self.clone())
244 }
245}
246
247impl<F: PrimeField> Add<Variable<F>> for LinearCombination<F> {
248 type Output = Self;
249
250 #[allow(clippy::op_ref)]
251 fn add(self, other: Variable<F>) -> Self::Output {
252 self + &other
253 }
254}
255
256impl<F: PrimeField> Add<&Variable<F>> for LinearCombination<F> {
257 type Output = Self;
258
259 fn add(self, other: &Variable<F>) -> Self::Output {
260 self + Self::from(other)
261 }
262}
263
264impl<F: PrimeField> Add<Variable<F>> for &LinearCombination<F> {
265 type Output = LinearCombination<F>;
266
267 #[allow(clippy::op_ref)]
268 fn add(self, other: Variable<F>) -> Self::Output {
269 self.clone() + &other
270 }
271}
272
273impl<F: PrimeField> Add<LinearCombination<F>> for LinearCombination<F> {
274 type Output = Self;
275
276 fn add(self, other: Self) -> Self::Output {
277 self + &other
278 }
279}
280
281impl<F: PrimeField> Add<&LinearCombination<F>> for LinearCombination<F> {
282 type Output = Self;
283
284 fn add(self, other: &Self) -> Self::Output {
285 &self + other
286 }
287}
288
289impl<F: PrimeField> Add<LinearCombination<F>> for &LinearCombination<F> {
290 type Output = LinearCombination<F>;
291
292 fn add(self, other: LinearCombination<F>) -> Self::Output {
293 self + &other
294 }
295}
296
297impl<F: PrimeField> Add<&LinearCombination<F>> for &LinearCombination<F> {
298 type Output = LinearCombination<F>;
299
300 fn add(self, other: &LinearCombination<F>) -> Self::Output {
301 if self.constant.is_zero() && self.terms.is_empty() {
302 other.clone()
303 } else if other.constant.is_zero() && other.terms.is_empty() {
304 self.clone()
305 } else if self.terms.len() > other.terms.len() {
306 let mut output = self.clone();
307 output += other;
308 output
309 } else {
310 let mut output = other.clone();
311 output += self;
312 output
313 }
314 }
315}
316
317impl<F: PrimeField> AddAssign<LinearCombination<F>> for LinearCombination<F> {
318 fn add_assign(&mut self, other: Self) {
319 *self += &other;
320 }
321}
322
323impl<F: PrimeField> AddAssign<&LinearCombination<F>> for LinearCombination<F> {
324 fn add_assign(&mut self, other: &Self) {
325 if other.constant.is_zero() && other.terms.is_empty() {
327 return;
328 }
329
330 if self.constant.is_zero() && self.terms.is_empty() {
331 *self = other.clone();
332 } else {
333 self.constant += other.constant;
335
336 for (variable, coefficient) in other.terms.iter() {
338 match variable.is_constant() {
339 true => panic!("Malformed linear combination found"),
340 false => {
341 match self.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
342 Ok(idx) => {
343 self.terms[idx].1 += *coefficient;
345 if self.terms[idx].1.is_zero() {
347 self.terms.remove(idx);
348 }
349 }
350 Err(idx) => {
351 self.terms.insert(idx, (variable.clone(), *coefficient));
353 }
354 }
355 }
356 }
357 }
358
359 self.value += other.value;
361 }
362 }
363}
364
365impl<F: PrimeField> Sub<Variable<F>> for LinearCombination<F> {
366 type Output = Self;
367
368 #[allow(clippy::op_ref)]
369 fn sub(self, other: Variable<F>) -> Self::Output {
370 self - &other
371 }
372}
373
374impl<F: PrimeField> Sub<&Variable<F>> for LinearCombination<F> {
375 type Output = Self;
376
377 fn sub(self, other: &Variable<F>) -> Self::Output {
378 self - Self::from(other)
379 }
380}
381
382impl<F: PrimeField> Sub<Variable<F>> for &LinearCombination<F> {
383 type Output = LinearCombination<F>;
384
385 #[allow(clippy::op_ref)]
386 fn sub(self, other: Variable<F>) -> Self::Output {
387 self.clone() - &other
388 }
389}
390
391impl<F: PrimeField> Sub<LinearCombination<F>> for LinearCombination<F> {
392 type Output = Self;
393
394 fn sub(self, other: Self) -> Self::Output {
395 self - &other
396 }
397}
398
399impl<F: PrimeField> Sub<&LinearCombination<F>> for LinearCombination<F> {
400 type Output = Self;
401
402 fn sub(self, other: &Self) -> Self::Output {
403 &self - other
404 }
405}
406
407impl<F: PrimeField> Sub<LinearCombination<F>> for &LinearCombination<F> {
408 type Output = LinearCombination<F>;
409
410 fn sub(self, other: LinearCombination<F>) -> Self::Output {
411 self - &other
412 }
413}
414
415impl<F: PrimeField> Sub<&LinearCombination<F>> for &LinearCombination<F> {
416 type Output = LinearCombination<F>;
417
418 fn sub(self, other: &LinearCombination<F>) -> Self::Output {
419 self + &(-other)
420 }
421}
422
423impl<F: PrimeField> Mul<F> for LinearCombination<F> {
424 type Output = Self;
425
426 #[allow(clippy::op_ref)]
427 fn mul(self, coefficient: F) -> Self::Output {
428 self * &coefficient
429 }
430}
431
432impl<F: PrimeField> Mul<&F> for LinearCombination<F> {
433 type Output = Self;
434
435 fn mul(self, coefficient: &F) -> Self::Output {
436 let mut output = self;
437 output.constant *= coefficient;
438 output.terms.retain_mut(|(_v, current_coefficient)| {
439 *current_coefficient *= coefficient;
440 !current_coefficient.is_zero()
441 });
442 output.value *= coefficient;
443 output
444 }
445}
446
447impl<F: PrimeField> Mul<F> for &LinearCombination<F> {
448 type Output = LinearCombination<F>;
449
450 #[allow(clippy::op_ref)]
451 fn mul(self, coefficient: F) -> Self::Output {
452 self * &coefficient
453 }
454}
455
456impl<F: PrimeField> Mul<&F> for &LinearCombination<F> {
457 type Output = LinearCombination<F>;
458
459 fn mul(self, coefficient: &F) -> Self::Output {
460 self.clone() * coefficient
461 }
462}
463
464impl<F: PrimeField> fmt::Debug for LinearCombination<F> {
465 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
466 let mut output = format!("Constant({})", self.constant);
467
468 for (variable, coefficient) in &self.terms {
469 output += &match (variable.mode(), coefficient.is_one()) {
470 (Mode::Constant, _) => panic!("Malformed linear combination at: ({coefficient} * {variable:?})"),
471 (_, true) => format!(" + {variable:?}"),
472 _ => format!(" + {coefficient} * {variable:?}"),
473 };
474 }
475 write!(f, "{output}")
476 }
477}
478
479impl<F: PrimeField> fmt::Display for LinearCombination<F> {
480 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
481 write!(f, "{}", self.value)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use snarkvm_fields::{One as O, Zero as Z};
489
490 use std::sync::Arc;
491
492 #[test]
493 fn test_zero() {
494 let zero = <Circuit as Environment>::BaseField::zero();
495
496 let candidate = LinearCombination::zero();
497 assert_eq!(zero, candidate.constant);
498 assert!(candidate.terms.is_empty());
499 assert_eq!(zero, candidate.value());
500 }
501
502 #[test]
503 fn test_one() {
504 let one = <Circuit as Environment>::BaseField::one();
505
506 let candidate = LinearCombination::one();
507 assert_eq!(one, candidate.constant);
508 assert!(candidate.terms.is_empty());
509 assert_eq!(one, candidate.value());
510 }
511
512 #[test]
513 fn test_two() {
514 let one = <Circuit as Environment>::BaseField::one();
515 let two = one + one;
516
517 let candidate = LinearCombination::one() + LinearCombination::one();
518 assert_eq!(two, candidate.constant);
519 assert!(candidate.terms.is_empty());
520 assert_eq!(two, candidate.value());
521 }
522
523 #[test]
524 fn test_is_constant() {
525 let zero = <Circuit as Environment>::BaseField::zero();
526 let one = <Circuit as Environment>::BaseField::one();
527
528 let candidate = LinearCombination::zero();
529 assert!(candidate.is_constant());
530 assert_eq!(zero, candidate.constant);
531 assert_eq!(zero, candidate.value());
532
533 let candidate = LinearCombination::one();
534 assert!(candidate.is_constant());
535 assert_eq!(one, candidate.constant);
536 assert_eq!(one, candidate.value());
537 }
538
539 #[test]
540 fn test_mul() {
541 let zero = <Circuit as Environment>::BaseField::zero();
542 let one = <Circuit as Environment>::BaseField::one();
543 let two = one + one;
544 let four = two + two;
545
546 let start = LinearCombination::from(Variable::Public(Arc::new((1, one))));
547 assert!(!start.is_constant());
548 assert_eq!(one, start.value());
549
550 let candidate = start * four;
552 assert_eq!(four, candidate.value());
553 assert_eq!(zero, candidate.constant);
554 assert_eq!(1, candidate.terms.len());
555
556 let (candidate_variable, candidate_coefficient) = candidate.terms.first().unwrap();
557 assert!(candidate_variable.is_public());
558 assert_eq!(one, candidate_variable.value());
559 assert_eq!(four, *candidate_coefficient);
560 }
561
562 #[test]
563 fn test_debug() {
564 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
565 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
566 {
567 let expected = "Constant(1) + Public(1, 1) + Private(0, 1)";
568
569 let candidate = LinearCombination::one() + one_public + one_private;
570 assert_eq!(expected, format!("{candidate:?}"));
571
572 let candidate = one_private + one_public + LinearCombination::one();
573 assert_eq!(expected, format!("{candidate:?}"));
574
575 let candidate = one_private + LinearCombination::one() + one_public;
576 assert_eq!(expected, format!("{candidate:?}"));
577
578 let candidate = one_public + LinearCombination::one() + one_private;
579 assert_eq!(expected, format!("{candidate:?}"));
580 }
581 {
582 let expected = "Constant(1) + 2 * Public(1, 1) + Private(0, 1)";
583
584 let candidate = LinearCombination::one() + one_public + one_public + one_private;
585 assert_eq!(expected, format!("{candidate:?}"));
586
587 let candidate = one_private + one_public + LinearCombination::one() + one_public;
588 assert_eq!(expected, format!("{candidate:?}"));
589
590 let candidate = one_public + one_private + LinearCombination::one() + one_public;
591 assert_eq!(expected, format!("{candidate:?}"));
592
593 let candidate = one_public + LinearCombination::one() + one_private + one_public;
594 assert_eq!(expected, format!("{candidate:?}"));
595 }
596 {
597 let expected = "Constant(1) + Public(1, 1) + 2 * Private(0, 1)";
598
599 let candidate = LinearCombination::one() + one_public + one_private + one_private;
600 assert_eq!(expected, format!("{candidate:?}"));
601
602 let candidate = one_private + one_public + LinearCombination::one() + one_private;
603 assert_eq!(expected, format!("{candidate:?}"));
604
605 let candidate = one_private + one_private + LinearCombination::one() + one_public;
606 assert_eq!(expected, format!("{candidate:?}"));
607
608 let candidate = one_public + LinearCombination::one() + one_private + one_private;
609 assert_eq!(expected, format!("{candidate:?}"));
610 }
611 {
612 let expected = "Constant(1) + Public(1, 1)";
613
614 let candidate = LinearCombination::one() + one_public + one_private - one_private;
615 assert_eq!(expected, format!("{candidate:?}"));
616
617 let candidate = one_private + one_public + LinearCombination::one() - one_private;
618 assert_eq!(expected, format!("{candidate:?}"));
619
620 let candidate = one_private - one_private + LinearCombination::one() + one_public;
621 assert_eq!(expected, format!("{candidate:?}"));
622
623 let candidate = one_public + LinearCombination::one() + one_private - one_private;
624 assert_eq!(expected, format!("{candidate:?}"));
625 }
626 }
627
628 #[rustfmt::skip]
629 #[test]
630 fn test_num_additions() {
631 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
632 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
633 let two_private = one_private + one_private;
634
635 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::zero();
636 assert_eq!(0, candidate.num_additions());
637
638 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::one();
639 assert_eq!(0, candidate.num_additions());
640
641 let candidate = LinearCombination::zero() + one_public;
642 assert_eq!(0, candidate.num_additions());
643
644 let candidate = LinearCombination::one() + one_public;
645 assert_eq!(1, candidate.num_additions());
646
647 let candidate = LinearCombination::zero() + one_public + one_public;
648 assert_eq!(0, candidate.num_additions());
649
650 let candidate = LinearCombination::one() + one_public + one_public;
651 assert_eq!(1, candidate.num_additions());
652
653 let candidate = LinearCombination::zero() + one_public + one_private;
654 assert_eq!(1, candidate.num_additions());
655
656 let candidate = LinearCombination::one() + one_public + one_private;
657 assert_eq!(2, candidate.num_additions());
658
659 let candidate = LinearCombination::zero() + one_public + one_private + one_public;
660 assert_eq!(1, candidate.num_additions());
661
662 let candidate = LinearCombination::one() + one_public + one_private + one_public;
663 assert_eq!(2, candidate.num_additions());
664
665 let candidate = LinearCombination::zero() + one_public + one_private + one_public + one_private;
666 assert_eq!(1, candidate.num_additions());
667
668 let candidate = LinearCombination::one() + one_public + one_private + one_public + one_private;
669 assert_eq!(2, candidate.num_additions());
670
671 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
672 assert_eq!(1, candidate.num_additions());
673
674 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
675 assert_eq!(2, candidate.num_additions());
676
677 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private;
678 assert_eq!(2, candidate.num_additions());
679
680 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
681 assert_eq!(1, candidate.num_additions());
682
683 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
684 assert_eq!(2, candidate.num_additions());
685
686 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private + &two_private;
687 assert_eq!(2, candidate.num_additions());
688
689 let candidate = LinearCombination::zero() - one_public;
692 assert_eq!(0, candidate.num_additions());
693
694 let candidate = LinearCombination::one() - one_public;
695 assert_eq!(1, candidate.num_additions());
696
697 let candidate = LinearCombination::zero() + one_public - one_public;
698 assert_eq!(0, candidate.num_additions());
699
700 let candidate = LinearCombination::one() + one_public - one_public;
701 assert_eq!(0, candidate.num_additions());
702
703 let candidate = LinearCombination::zero() + one_public - one_private;
704 assert_eq!(1, candidate.num_additions());
705
706 let candidate = LinearCombination::one() + one_public - one_private;
707 assert_eq!(2, candidate.num_additions());
708
709 let candidate = LinearCombination::zero() + one_public + one_private - one_public;
710 assert_eq!(0, candidate.num_additions());
711
712 let candidate = LinearCombination::one() + one_public + one_private - one_public;
713 assert_eq!(1, candidate.num_additions());
714
715 let candidate = LinearCombination::zero() + one_public + one_private + one_public - one_private;
716 assert_eq!(0, candidate.num_additions());
717
718 let candidate = LinearCombination::one() + one_public + one_private + one_public - one_private;
719 assert_eq!(1, candidate.num_additions());
720
721 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
722 assert_eq!(0, candidate.num_additions());
723
724 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
725 assert_eq!(1, candidate.num_additions());
726
727 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public - one_private;
728 assert_eq!(1, candidate.num_additions());
729
730 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
731 assert_eq!(0, candidate.num_additions());
732
733 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
734 assert_eq!(1, candidate.num_additions());
735
736 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private - &two_private;
737 assert_eq!(1, candidate.num_additions());
738 }
739}