snarkvm_circuit_environment/helpers/
linear_combination.rs1use crate::{Mode, *};
17use snarkvm_fields::PrimeField;
18
19use core::{
20 fmt,
21 ops::{Add, AddAssign, Mul, Neg, Sub},
22};
23
24#[derive(Clone)]
35pub struct LinearCombination<F: PrimeField> {
36 constant: F,
37 terms: Vec<(Variable<F>, F)>,
39 value: F,
41}
42
43impl<F: PrimeField> LinearCombination<F> {
44 pub(crate) fn zero() -> Self {
46 Self { constant: F::zero(), terms: Default::default(), value: Default::default() }
47 }
48
49 pub(crate) fn one() -> Self {
51 Self { constant: F::one(), terms: Default::default(), value: F::one() }
52 }
53
54 pub fn is_constant(&self) -> bool {
56 self.terms.is_empty()
57 }
58
59 pub fn is_public(&self) -> bool {
62 self.constant.is_zero()
63 && self.terms.len() == 1
64 && match self.terms.first() {
65 Some((Variable::Public(..), coefficient)) => *coefficient == F::one(),
66 _ => false,
67 }
68 }
69
70 pub fn is_private(&self) -> bool {
72 !self.is_constant() && !self.is_public()
73 }
74
75 pub fn mode(&self) -> Mode {
77 if self.is_constant() {
78 Mode::Constant
79 } else if self.is_public() {
80 Mode::Public
81 } else {
82 Mode::Private
83 }
84 }
85
86 pub fn value(&self) -> F {
88 self.value
89 }
90
91 pub fn is_boolean_type(&self) -> bool {
100 if self.terms.is_empty() {
102 self.constant.is_zero() || self.constant.is_one()
103 }
104 else if self.constant.is_zero() {
106 if self.terms.iter().any(|(v, _)| !(v.value().is_zero() || v.value().is_one())) {
109 eprintln!("Property 2 of the `Boolean` type was violated in {self}");
110 return false;
111 }
112
113 if !(self.value.is_zero() || self.value.is_one()) {
115 eprintln!("Property 3 of the `Boolean` type was violated");
116 return false;
117 }
118
119 true
120 } else {
121 eprintln!("Both LC::constant and LC::terms contain elements, which is a violation");
124 false
125 }
126 }
127
128 pub(super) fn to_constant(&self) -> F {
130 self.constant
131 }
132
133 pub(super) fn to_terms(&self) -> &[(Variable<F>, F)] {
135 &self.terms
136 }
137
138 pub(super) fn num_nonzeros(&self) -> u64 {
140 match self.constant.is_zero() {
142 true => self.terms.len() as u64,
143 false => (self.terms.len() as u64).saturating_add(1),
144 }
145 }
146
147 #[cfg(test)]
149 pub(super) fn num_additions(&self) -> u64 {
150 match !self.constant.is_zero() && !self.terms.is_empty() {
152 true => self.terms.len() as u64,
153 false => (self.terms.len() as u64).saturating_sub(1),
154 }
155 }
156}
157
158impl<F: PrimeField> From<Variable<F>> for LinearCombination<F> {
159 fn from(variable: Variable<F>) -> Self {
160 Self::from(&variable)
161 }
162}
163
164impl<F: PrimeField> From<&Variable<F>> for LinearCombination<F> {
165 fn from(variable: &Variable<F>) -> Self {
166 Self::from(&[variable.clone()])
167 }
168}
169
170impl<F: PrimeField, const N: usize> From<[Variable<F>; N]> for LinearCombination<F> {
171 fn from(variables: [Variable<F>; N]) -> Self {
172 Self::from(&variables[..])
173 }
174}
175
176impl<F: PrimeField, const N: usize> From<&[Variable<F>; N]> for LinearCombination<F> {
177 fn from(variables: &[Variable<F>; N]) -> Self {
178 Self::from(&variables[..])
179 }
180}
181
182impl<F: PrimeField> From<Vec<Variable<F>>> for LinearCombination<F> {
183 fn from(variables: Vec<Variable<F>>) -> Self {
184 Self::from(variables.as_slice())
185 }
186}
187
188impl<F: PrimeField> From<&Vec<Variable<F>>> for LinearCombination<F> {
189 fn from(variables: &Vec<Variable<F>>) -> Self {
190 Self::from(variables.as_slice())
191 }
192}
193
194impl<F: PrimeField> From<&[Variable<F>]> for LinearCombination<F> {
195 fn from(variables: &[Variable<F>]) -> Self {
196 let mut output = Self::zero();
197 for variable in variables {
198 match variable.is_constant() {
199 true => output.constant += variable.value(),
200 false => {
201 match output.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
202 Ok(idx) => {
203 output.terms[idx].1 += F::one();
205 if output.terms[idx].1.is_zero() {
207 output.terms.remove(idx);
208 }
209 }
210 Err(idx) => {
211 output.terms.insert(idx, (variable.clone(), F::one()));
213 }
214 }
215 }
216 }
217 output.value += variable.value();
219 }
220 output
221 }
222}
223
224impl<F: PrimeField> Neg for LinearCombination<F> {
225 type Output = Self;
226
227 #[inline]
228 fn neg(self) -> Self::Output {
229 let mut output = self;
230 output.constant = -output.constant;
231 output.terms.iter_mut().for_each(|(_, coefficient)| *coefficient = -(*coefficient));
232 output.value = -output.value;
233 output
234 }
235}
236
237impl<F: PrimeField> Neg for &LinearCombination<F> {
238 type Output = LinearCombination<F>;
239
240 #[inline]
241 fn neg(self) -> Self::Output {
242 -(self.clone())
243 }
244}
245
246impl<F: PrimeField> Add<Variable<F>> for LinearCombination<F> {
247 type Output = Self;
248
249 #[allow(clippy::op_ref)]
250 fn add(self, other: Variable<F>) -> Self::Output {
251 self + &other
252 }
253}
254
255impl<F: PrimeField> Add<&Variable<F>> for LinearCombination<F> {
256 type Output = Self;
257
258 fn add(self, other: &Variable<F>) -> Self::Output {
259 self + Self::from(other)
260 }
261}
262
263impl<F: PrimeField> Add<Variable<F>> for &LinearCombination<F> {
264 type Output = LinearCombination<F>;
265
266 #[allow(clippy::op_ref)]
267 fn add(self, other: Variable<F>) -> Self::Output {
268 self.clone() + &other
269 }
270}
271
272impl<F: PrimeField> Add<LinearCombination<F>> for LinearCombination<F> {
273 type Output = Self;
274
275 fn add(self, other: Self) -> Self::Output {
276 self + &other
277 }
278}
279
280impl<F: PrimeField> Add<&LinearCombination<F>> for LinearCombination<F> {
281 type Output = Self;
282
283 fn add(self, other: &Self) -> Self::Output {
284 &self + other
285 }
286}
287
288impl<F: PrimeField> Add<LinearCombination<F>> for &LinearCombination<F> {
289 type Output = LinearCombination<F>;
290
291 fn add(self, other: LinearCombination<F>) -> Self::Output {
292 self + &other
293 }
294}
295
296impl<F: PrimeField> Add<&LinearCombination<F>> for &LinearCombination<F> {
297 type Output = LinearCombination<F>;
298
299 fn add(self, other: &LinearCombination<F>) -> Self::Output {
300 if self.constant.is_zero() && self.terms.is_empty() {
301 other.clone()
302 } else if other.constant.is_zero() && other.terms.is_empty() {
303 self.clone()
304 } else if self.terms.len() > other.terms.len() {
305 let mut output = self.clone();
306 output += other;
307 output
308 } else {
309 let mut output = other.clone();
310 output += self;
311 output
312 }
313 }
314}
315
316impl<F: PrimeField> AddAssign<LinearCombination<F>> for LinearCombination<F> {
317 fn add_assign(&mut self, other: Self) {
318 *self += &other;
319 }
320}
321
322impl<F: PrimeField> AddAssign<&LinearCombination<F>> for LinearCombination<F> {
323 fn add_assign(&mut self, other: &Self) {
324 if other.constant.is_zero() && other.terms.is_empty() {
326 return;
327 }
328
329 if self.constant.is_zero() && self.terms.is_empty() {
330 *self = other.clone();
331 } else {
332 self.constant += other.constant;
334
335 for (variable, coefficient) in other.terms.iter() {
337 match variable.is_constant() {
338 true => panic!("Malformed linear combination found"),
339 false => {
340 match self.terms.binary_search_by(|(v, _)| v.cmp(variable)) {
341 Ok(idx) => {
342 self.terms[idx].1 += *coefficient;
344 if self.terms[idx].1.is_zero() {
346 self.terms.remove(idx);
347 }
348 }
349 Err(idx) => {
350 self.terms.insert(idx, (variable.clone(), *coefficient));
352 }
353 }
354 }
355 }
356 }
357
358 self.value += other.value;
360 }
361 }
362}
363
364impl<F: PrimeField> Sub<Variable<F>> for LinearCombination<F> {
365 type Output = Self;
366
367 #[allow(clippy::op_ref)]
368 fn sub(self, other: Variable<F>) -> Self::Output {
369 self - &other
370 }
371}
372
373impl<F: PrimeField> Sub<&Variable<F>> for LinearCombination<F> {
374 type Output = Self;
375
376 fn sub(self, other: &Variable<F>) -> Self::Output {
377 self - Self::from(other)
378 }
379}
380
381impl<F: PrimeField> Sub<Variable<F>> for &LinearCombination<F> {
382 type Output = LinearCombination<F>;
383
384 #[allow(clippy::op_ref)]
385 fn sub(self, other: Variable<F>) -> Self::Output {
386 self.clone() - &other
387 }
388}
389
390impl<F: PrimeField> Sub<LinearCombination<F>> for LinearCombination<F> {
391 type Output = Self;
392
393 fn sub(self, other: Self) -> Self::Output {
394 self - &other
395 }
396}
397
398impl<F: PrimeField> Sub<&LinearCombination<F>> for LinearCombination<F> {
399 type Output = Self;
400
401 fn sub(self, other: &Self) -> Self::Output {
402 &self - other
403 }
404}
405
406impl<F: PrimeField> Sub<LinearCombination<F>> for &LinearCombination<F> {
407 type Output = LinearCombination<F>;
408
409 fn sub(self, other: LinearCombination<F>) -> Self::Output {
410 self - &other
411 }
412}
413
414impl<F: PrimeField> Sub<&LinearCombination<F>> for &LinearCombination<F> {
415 type Output = LinearCombination<F>;
416
417 fn sub(self, other: &LinearCombination<F>) -> Self::Output {
418 self + &(-other)
419 }
420}
421
422impl<F: PrimeField> Mul<F> for LinearCombination<F> {
423 type Output = Self;
424
425 #[allow(clippy::op_ref)]
426 fn mul(self, coefficient: F) -> Self::Output {
427 self * &coefficient
428 }
429}
430
431impl<F: PrimeField> Mul<&F> for LinearCombination<F> {
432 type Output = Self;
433
434 fn mul(self, coefficient: &F) -> Self::Output {
435 let mut output = self;
436 output.constant *= coefficient;
437 output.terms = output
438 .terms
439 .into_iter()
440 .filter_map(|(v, current_coefficient)| {
441 let res = current_coefficient * coefficient;
442 (!res.is_zero()).then_some((v, res))
443 })
444 .collect();
445 output.value *= coefficient;
446 output
447 }
448}
449
450impl<F: PrimeField> Mul<F> for &LinearCombination<F> {
451 type Output = LinearCombination<F>;
452
453 #[allow(clippy::op_ref)]
454 fn mul(self, coefficient: F) -> Self::Output {
455 self * &coefficient
456 }
457}
458
459impl<F: PrimeField> Mul<&F> for &LinearCombination<F> {
460 type Output = LinearCombination<F>;
461
462 fn mul(self, coefficient: &F) -> Self::Output {
463 self.clone() * coefficient
464 }
465}
466
467impl<F: PrimeField> fmt::Debug for LinearCombination<F> {
468 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
469 let mut output = format!("Constant({})", self.constant);
470
471 for (variable, coefficient) in &self.terms {
472 output += &match (variable.mode(), coefficient.is_one()) {
473 (Mode::Constant, _) => panic!("Malformed linear combination at: ({coefficient} * {variable:?})"),
474 (_, true) => format!(" + {variable:?}"),
475 _ => format!(" + {coefficient} * {variable:?}"),
476 };
477 }
478 write!(f, "{output}")
479 }
480}
481
482impl<F: PrimeField> fmt::Display for LinearCombination<F> {
483 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
484 write!(f, "{}", self.value)
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use snarkvm_fields::{One as O, Zero as Z};
492
493 use std::rc::Rc;
494
495 #[test]
496 fn test_zero() {
497 let zero = <Circuit as Environment>::BaseField::zero();
498
499 let candidate = LinearCombination::zero();
500 assert_eq!(zero, candidate.constant);
501 assert!(candidate.terms.is_empty());
502 assert_eq!(zero, candidate.value());
503 }
504
505 #[test]
506 fn test_one() {
507 let one = <Circuit as Environment>::BaseField::one();
508
509 let candidate = LinearCombination::one();
510 assert_eq!(one, candidate.constant);
511 assert!(candidate.terms.is_empty());
512 assert_eq!(one, candidate.value());
513 }
514
515 #[test]
516 fn test_two() {
517 let one = <Circuit as Environment>::BaseField::one();
518 let two = one + one;
519
520 let candidate = LinearCombination::one() + LinearCombination::one();
521 assert_eq!(two, candidate.constant);
522 assert!(candidate.terms.is_empty());
523 assert_eq!(two, candidate.value());
524 }
525
526 #[test]
527 fn test_is_constant() {
528 let zero = <Circuit as Environment>::BaseField::zero();
529 let one = <Circuit as Environment>::BaseField::one();
530
531 let candidate = LinearCombination::zero();
532 assert!(candidate.is_constant());
533 assert_eq!(zero, candidate.constant);
534 assert_eq!(zero, candidate.value());
535
536 let candidate = LinearCombination::one();
537 assert!(candidate.is_constant());
538 assert_eq!(one, candidate.constant);
539 assert_eq!(one, candidate.value());
540 }
541
542 #[test]
543 fn test_mul() {
544 let zero = <Circuit as Environment>::BaseField::zero();
545 let one = <Circuit as Environment>::BaseField::one();
546 let two = one + one;
547 let four = two + two;
548
549 let start = LinearCombination::from(Variable::Public(Rc::new((1, one))));
550 assert!(!start.is_constant());
551 assert_eq!(one, start.value());
552
553 let candidate = start * four;
555 assert_eq!(four, candidate.value());
556 assert_eq!(zero, candidate.constant);
557 assert_eq!(1, candidate.terms.len());
558
559 let (candidate_variable, candidate_coefficient) = candidate.terms.first().unwrap();
560 assert!(candidate_variable.is_public());
561 assert_eq!(one, candidate_variable.value());
562 assert_eq!(four, *candidate_coefficient);
563 }
564
565 #[test]
566 fn test_debug() {
567 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
568 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
569 {
570 let expected = "Constant(1) + Public(1, 1) + Private(0, 1)";
571
572 let candidate = LinearCombination::one() + one_public + one_private;
573 assert_eq!(expected, format!("{candidate:?}"));
574
575 let candidate = one_private + one_public + LinearCombination::one();
576 assert_eq!(expected, format!("{candidate:?}"));
577
578 let candidate = one_private + LinearCombination::one() + one_public;
579 assert_eq!(expected, format!("{candidate:?}"));
580
581 let candidate = one_public + LinearCombination::one() + one_private;
582 assert_eq!(expected, format!("{candidate:?}"));
583 }
584 {
585 let expected = "Constant(1) + 2 * Public(1, 1) + Private(0, 1)";
586
587 let candidate = LinearCombination::one() + one_public + one_public + one_private;
588 assert_eq!(expected, format!("{candidate:?}"));
589
590 let candidate = one_private + one_public + LinearCombination::one() + one_public;
591 assert_eq!(expected, format!("{candidate:?}"));
592
593 let candidate = one_public + one_private + LinearCombination::one() + one_public;
594 assert_eq!(expected, format!("{candidate:?}"));
595
596 let candidate = one_public + LinearCombination::one() + one_private + one_public;
597 assert_eq!(expected, format!("{candidate:?}"));
598 }
599 {
600 let expected = "Constant(1) + Public(1, 1) + 2 * Private(0, 1)";
601
602 let candidate = LinearCombination::one() + one_public + one_private + one_private;
603 assert_eq!(expected, format!("{candidate:?}"));
604
605 let candidate = one_private + one_public + LinearCombination::one() + one_private;
606 assert_eq!(expected, format!("{candidate:?}"));
607
608 let candidate = one_private + one_private + LinearCombination::one() + one_public;
609 assert_eq!(expected, format!("{candidate:?}"));
610
611 let candidate = one_public + LinearCombination::one() + one_private + one_private;
612 assert_eq!(expected, format!("{candidate:?}"));
613 }
614 {
615 let expected = "Constant(1) + Public(1, 1)";
616
617 let candidate = LinearCombination::one() + one_public + one_private - one_private;
618 assert_eq!(expected, format!("{candidate:?}"));
619
620 let candidate = one_private + one_public + LinearCombination::one() - one_private;
621 assert_eq!(expected, format!("{candidate:?}"));
622
623 let candidate = one_private - one_private + LinearCombination::one() + one_public;
624 assert_eq!(expected, format!("{candidate:?}"));
625
626 let candidate = one_public + LinearCombination::one() + one_private - one_private;
627 assert_eq!(expected, format!("{candidate:?}"));
628 }
629 }
630
631 #[rustfmt::skip]
632 #[test]
633 fn test_num_additions() {
634 let one_public = &Circuit::new_variable(Mode::Public, <Circuit as Environment>::BaseField::one());
635 let one_private = &Circuit::new_variable(Mode::Private, <Circuit as Environment>::BaseField::one());
636 let two_private = one_private + one_private;
637
638 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::zero();
639 assert_eq!(0, candidate.num_additions());
640
641 let candidate = LinearCombination::<<Circuit as Environment>::BaseField>::one();
642 assert_eq!(0, candidate.num_additions());
643
644 let candidate = LinearCombination::zero() + one_public;
645 assert_eq!(0, candidate.num_additions());
646
647 let candidate = LinearCombination::one() + one_public;
648 assert_eq!(1, candidate.num_additions());
649
650 let candidate = LinearCombination::zero() + one_public + one_public;
651 assert_eq!(0, candidate.num_additions());
652
653 let candidate = LinearCombination::one() + one_public + one_public;
654 assert_eq!(1, candidate.num_additions());
655
656 let candidate = LinearCombination::zero() + one_public + one_private;
657 assert_eq!(1, candidate.num_additions());
658
659 let candidate = LinearCombination::one() + one_public + one_private;
660 assert_eq!(2, candidate.num_additions());
661
662 let candidate = LinearCombination::zero() + one_public + one_private + one_public;
663 assert_eq!(1, candidate.num_additions());
664
665 let candidate = LinearCombination::one() + one_public + one_private + one_public;
666 assert_eq!(2, candidate.num_additions());
667
668 let candidate = LinearCombination::zero() + one_public + one_private + one_public + one_private;
669 assert_eq!(1, candidate.num_additions());
670
671 let candidate = LinearCombination::one() + one_public + one_private + one_public + one_private;
672 assert_eq!(2, candidate.num_additions());
673
674 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
675 assert_eq!(1, candidate.num_additions());
676
677 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private;
678 assert_eq!(2, candidate.num_additions());
679
680 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private;
681 assert_eq!(2, candidate.num_additions());
682
683 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
684 assert_eq!(1, candidate.num_additions());
685
686 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private + &two_private;
687 assert_eq!(2, candidate.num_additions());
688
689 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private + &two_private;
690 assert_eq!(2, candidate.num_additions());
691
692 let candidate = LinearCombination::zero() - one_public;
695 assert_eq!(0, candidate.num_additions());
696
697 let candidate = LinearCombination::one() - one_public;
698 assert_eq!(1, candidate.num_additions());
699
700 let candidate = LinearCombination::zero() + one_public - one_public;
701 assert_eq!(0, candidate.num_additions());
702
703 let candidate = LinearCombination::one() + one_public - one_public;
704 assert_eq!(0, candidate.num_additions());
705
706 let candidate = LinearCombination::zero() + one_public - one_private;
707 assert_eq!(1, candidate.num_additions());
708
709 let candidate = LinearCombination::one() + one_public - one_private;
710 assert_eq!(2, candidate.num_additions());
711
712 let candidate = LinearCombination::zero() + one_public + one_private - one_public;
713 assert_eq!(0, candidate.num_additions());
714
715 let candidate = LinearCombination::one() + one_public + one_private - one_public;
716 assert_eq!(1, candidate.num_additions());
717
718 let candidate = LinearCombination::zero() + one_public + one_private + one_public - one_private;
719 assert_eq!(0, candidate.num_additions());
720
721 let candidate = LinearCombination::one() + one_public + one_private + one_public - one_private;
722 assert_eq!(1, candidate.num_additions());
723
724 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
725 assert_eq!(0, candidate.num_additions());
726
727 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public - one_private;
728 assert_eq!(1, candidate.num_additions());
729
730 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public - one_private;
731 assert_eq!(1, candidate.num_additions());
732
733 let candidate = LinearCombination::zero() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
734 assert_eq!(0, candidate.num_additions());
735
736 let candidate = LinearCombination::one() + LinearCombination::zero() + one_public + one_private + one_public + one_private - &two_private;
737 assert_eq!(1, candidate.num_additions());
738
739 let candidate = LinearCombination::one() + LinearCombination::zero() + LinearCombination::one() + one_public + one_private + one_public + one_private - &two_private;
740 assert_eq!(1, candidate.num_additions());
741 }
742}