1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::{Add, Mul, Neg, Sub};
4use ff::Field;
5use group::Group;
6
7use super::{GroupVar, ScalarTerm, ScalarVar, Sum, Term, Weighted};
8
9mod add {
10 use super::*;
11
12 macro_rules! impl_add_term {
13 ($($type:ty),+) => {
14 $(
15 impl<G> Add<$type> for $type {
16 type Output = Sum<$type>;
17
18 fn add(self, rhs: $type) -> Self::Output {
19 Sum(vec![self, rhs])
20 }
21 }
22 )+
23 };
24 }
25
26 impl_add_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
27
28 impl<T> Add<T> for Sum<T> {
29 type Output = Sum<T>;
30
31 fn add(mut self, rhs: T) -> Self::Output {
32 self.0.push(rhs);
33 self
34 }
35 }
36
37 macro_rules! impl_add_sum_term {
38 ($($type:ty),+) => {
39 $(
40 impl<G> Add<Sum<$type>> for $type {
41 type Output = Sum<$type>;
42
43 fn add(self, rhs: Sum<$type>) -> Self::Output {
44 rhs + self
45 }
46 }
47 )+
48 };
49 }
50
51 impl_add_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
52
53 impl<T> Add<Sum<T>> for Sum<T> {
54 type Output = Sum<T>;
55
56 fn add(mut self, rhs: Sum<T>) -> Self::Output {
57 self.0.extend(rhs.0);
58 self
59 }
60 }
61
62 impl<T, F> Add<Weighted<T, F>> for Weighted<T, F> {
63 type Output = Sum<Weighted<T, F>>;
64
65 fn add(self, rhs: Weighted<T, F>) -> Self::Output {
66 Sum(vec![self, rhs])
67 }
68 }
69
70 impl<T, F: Field> Add<T> for Weighted<T, F> {
71 type Output = Sum<Weighted<T, F>>;
72
73 fn add(self, rhs: T) -> Self::Output {
74 Sum(vec![self, rhs.into()])
75 }
76 }
77
78 macro_rules! impl_add_weighted_term {
79 ($($type:ty),+) => {
80 $(
81 impl<G: Group> Add<Weighted<$type, G::Scalar>> for $type {
82 type Output = Sum<Weighted<$type, G::Scalar>>;
83
84 fn add(self, rhs: Weighted<$type, G::Scalar>) -> Self::Output {
85 rhs + self
86 }
87 }
88 )+
89 };
90 }
91
92 impl_add_weighted_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
93
94 impl<T, F: Field> Add<T> for Sum<Weighted<T, F>> {
95 type Output = Sum<Weighted<T, F>>;
96
97 fn add(mut self, rhs: T) -> Self::Output {
98 self.0.push(rhs.into());
99 self
100 }
101 }
102
103 macro_rules! impl_add_weighted_sum_term {
104 ($($type:ty),+) => {
105 $(
106 impl<G: Group> Add<Sum<Weighted<$type, G::Scalar>>> for $type {
107 type Output = Sum<Weighted<$type, G::Scalar>>;
108
109 fn add(self, rhs: Sum<Weighted<$type, G::Scalar>>) -> Self::Output {
110 rhs + self
111 }
112 }
113 )+
114 };
115 }
116
117 impl_add_weighted_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
118
119 impl<T, F: Field> Add<Sum<T>> for Sum<Weighted<T, F>> {
120 type Output = Sum<Weighted<T, F>>;
121
122 fn add(self, rhs: Sum<T>) -> Self::Output {
123 self + Self::from(rhs)
124 }
125 }
126
127 impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Sum<T> {
128 type Output = Sum<Weighted<T, F>>;
129
130 fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
131 rhs + self
132 }
133 }
134
135 impl<T, F: Field> Add<Weighted<T, F>> for Sum<T> {
136 type Output = Sum<Weighted<T, F>>;
137
138 fn add(self, rhs: Weighted<T, F>) -> Self::Output {
139 Self::Output::from(self) + rhs
140 }
141 }
142
143 impl<T, F: Field> Add<Sum<T>> for Weighted<T, F> {
144 type Output = Sum<Weighted<T, F>>;
145
146 fn add(self, rhs: Sum<T>) -> Self::Output {
147 rhs + self
148 }
149 }
150
151 impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Weighted<T, F> {
152 type Output = Sum<Weighted<T, F>>;
153
154 fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
155 rhs + self
156 }
157 }
158
159 impl<G> Add<ScalarVar<G>> for ScalarTerm<G> {
160 type Output = Sum<ScalarTerm<G>>;
161
162 fn add(self, rhs: ScalarVar<G>) -> Self::Output {
163 self + ScalarTerm::from(rhs)
164 }
165 }
166
167 impl<G> Add<ScalarTerm<G>> for ScalarVar<G> {
168 type Output = Sum<ScalarTerm<G>>;
169
170 fn add(self, rhs: ScalarTerm<G>) -> Self::Output {
171 rhs + self
172 }
173 }
174
175 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarTerm<G>, G::Scalar> {
176 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
177
178 fn add(self, rhs: T) -> Self::Output {
179 self + Self::from(rhs.into())
180 }
181 }
182
183 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarVar<G>, G::Scalar> {
184 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
185
186 fn add(self, rhs: T) -> Self::Output {
187 <Weighted<ScalarTerm<G>, G::Scalar>>::from(self) + rhs.into()
188 }
189 }
190
191 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for ScalarVar<G> {
192 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
193
194 fn add(self, rhs: T) -> Self::Output {
195 Weighted::from(ScalarTerm::from(self)) + rhs.into()
196 }
197 }
198
199 impl<G: Group> Add<GroupVar<G>> for Sum<Weighted<Term<G>, G::Scalar>> {
200 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
201
202 fn add(self, rhs: GroupVar<G>) -> Self::Output {
203 self + Self::from(rhs)
204 }
205 }
206
207 impl<G: Group> Add<GroupVar<G>> for Sum<Term<G>> {
208 type Output = Sum<Term<G>>;
209
210 fn add(self, rhs: GroupVar<G>) -> Self::Output {
211 self + Self::from(rhs)
212 }
213 }
214
215 impl<G: Group> Add<GroupVar<G>> for Weighted<Term<G>, G::Scalar> {
216 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
217
218 fn add(self, rhs: GroupVar<G>) -> Self::Output {
219 self + Self::from(rhs)
220 }
221 }
222
223 impl<G: Group> Add<GroupVar<G>> for Term<G> {
224 type Output = Sum<Term<G>>;
225
226 fn add(self, rhs: GroupVar<G>) -> Self::Output {
227 self + Self::from(rhs)
228 }
229 }
230
231 impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Term<G> {
232 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
233
234 fn add(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
235 Sum(vec![
236 Weighted {
237 term: self,
238 weight: G::Scalar::ONE,
239 },
240 Weighted {
241 term: Term {
242 scalar: super::ScalarTerm::Unit,
243 elem: rhs.term,
244 },
245 weight: rhs.weight,
246 },
247 ])
248 }
249 }
250
251 impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Sum<Term<G>> {
252 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
253
254 fn add(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
255 Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
256 }
257 }
258
259 impl<G: Group> Add<Sum<Term<G>>> for Weighted<GroupVar<G>, G::Scalar> {
260 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
261
262 fn add(self, rhs: Sum<Term<G>>) -> Self::Output {
263 rhs + self
264 }
265 }
266
267 impl<G: Group> Add<Sum<Weighted<Term<G>, G::Scalar>>> for Weighted<GroupVar<G>, G::Scalar> {
268 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
269
270 fn add(self, rhs: Sum<Weighted<Term<G>, G::Scalar>>) -> Self::Output {
271 let weighted_term = Weighted {
272 term: Term {
273 scalar: super::ScalarTerm::Unit,
274 elem: self.term,
275 },
276 weight: self.weight,
277 };
278 rhs + weighted_term
279 }
280 }
281
282 impl<G: Group> Add<Term<G>> for Sum<Weighted<GroupVar<G>, G::Scalar>> {
283 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
284
285 fn add(self, rhs: Term<G>) -> Self::Output {
286 Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
287 }
288 }
289
290 impl<G: Group> Add<Sum<Term<G>>> for Sum<Weighted<GroupVar<G>, G::Scalar>> {
291 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
292
293 fn add(self, rhs: Sum<Term<G>>) -> Self::Output {
294 Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
295 }
296 }
297
298 impl<G: Group> Add<Sum<Weighted<GroupVar<G>, G::Scalar>>> for Term<G> {
299 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
300
301 fn add(self, rhs: Sum<Weighted<GroupVar<G>, G::Scalar>>) -> Self::Output {
302 Sum::<Weighted<Term<G>, G::Scalar>>::from(rhs) + self
303 }
304 }
305
306 impl<G: Group> Add<Sum<Weighted<GroupVar<G>, G::Scalar>>> for Sum<Term<G>> {
307 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
308
309 fn add(self, rhs: Sum<Weighted<GroupVar<G>, G::Scalar>>) -> Self::Output {
310 self + Sum::<Weighted<Term<G>, G::Scalar>>::from(rhs)
311 }
312 }
313
314 impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Sum<Weighted<Term<G>, G::Scalar>> {
315 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
316
317 fn add(mut self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
318 self.0.push(Weighted {
319 term: Term {
320 scalar: super::ScalarTerm::Unit,
321 elem: rhs.term,
322 },
323 weight: rhs.weight,
324 });
325 self
326 }
327 }
328
329 impl<G: Group> Add<Term<G>> for Weighted<GroupVar<G>, G::Scalar> {
330 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
331
332 fn add(self, rhs: Term<G>) -> Self::Output {
333 rhs + self
334 }
335 }
336
337 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Sum<ScalarVar<G>> {
338 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
339
340 fn add(self, rhs: T) -> Self::Output {
341 let mut weighted_terms = Vec::new();
343 for var in self.0 {
344 weighted_terms.push(Weighted {
345 term: ScalarTerm::from(var),
346 weight: G::Scalar::ONE,
347 });
348 }
349 let weighted_sum: Sum<Weighted<ScalarTerm<G>, G::Scalar>> = Sum(weighted_terms);
350
351 let weighted_scalar = Weighted {
353 term: ScalarTerm::Unit,
354 weight: rhs.into(),
355 };
356
357 weighted_sum + weighted_scalar
358 }
359 }
360}
361
362mod mul {
363 use super::*;
364
365 impl<G> Mul<ScalarVar<G>> for GroupVar<G> {
366 type Output = Term<G>;
367
368 fn mul(self, rhs: ScalarVar<G>) -> Term<G> {
370 Term {
371 elem: self,
372 scalar: rhs.into(),
373 }
374 }
375 }
376
377 impl<G> Mul<GroupVar<G>> for ScalarVar<G> {
378 type Output = Term<G>;
379
380 fn mul(self, rhs: GroupVar<G>) -> Term<G> {
382 rhs * self
383 }
384 }
385
386 impl<G> Mul<ScalarTerm<G>> for GroupVar<G> {
387 type Output = Term<G>;
388
389 fn mul(self, rhs: ScalarTerm<G>) -> Term<G> {
390 Term {
391 elem: self,
392 scalar: rhs,
393 }
394 }
395 }
396
397 impl<G> Mul<GroupVar<G>> for ScalarTerm<G> {
398 type Output = Term<G>;
399
400 fn mul(self, rhs: GroupVar<G>) -> Term<G> {
401 rhs * self
402 }
403 }
404
405 impl<Rhs: Clone, Lhs: Mul<Rhs>> Mul<Rhs> for Sum<Lhs> {
406 type Output = Sum<<Lhs as Mul<Rhs>>::Output>;
407
408 fn mul(self, rhs: Rhs) -> Self::Output {
410 Sum(self.0.into_iter().map(|x| x * rhs.clone()).collect())
411 }
412 }
413
414 macro_rules! impl_scalar_mul_term {
419 ($($type:ty),+) => {
420 $(
421 impl<F: Field + Into<G::Scalar>, G: Group> Mul<F> for $type {
423 type Output = Weighted<$type, G::Scalar>;
424
425 fn mul(self, rhs: F) -> Self::Output {
426 Weighted {
427 term: self,
428 weight: rhs.into(),
429 }
430 }
431 }
432 )+
433 };
434 }
435
436 impl_scalar_mul_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
437
438 impl<T, F: Field> Mul<F> for Weighted<T, F> {
439 type Output = Weighted<T, F>;
440
441 fn mul(self, rhs: F) -> Self::Output {
442 Weighted {
443 term: self.term,
444 weight: self.weight * rhs,
445 }
446 }
447 }
448
449 impl<G: Group> Mul<ScalarVar<G>> for Weighted<GroupVar<G>, G::Scalar> {
450 type Output = Weighted<Term<G>, G::Scalar>;
451
452 fn mul(self, rhs: ScalarVar<G>) -> Self::Output {
453 Weighted {
454 term: self.term * rhs,
455 weight: self.weight,
456 }
457 }
458 }
459
460 impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarVar<G> {
461 type Output = Weighted<Term<G>, G::Scalar>;
462
463 fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
464 rhs * self
465 }
466 }
467
468 impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarVar<G>, G::Scalar> {
469 type Output = Weighted<Term<G>, G::Scalar>;
470
471 fn mul(self, rhs: GroupVar<G>) -> Self::Output {
472 Weighted {
473 term: self.term * rhs,
474 weight: self.weight,
475 }
476 }
477 }
478
479 impl<G: Group> Mul<Weighted<ScalarVar<G>, G::Scalar>> for GroupVar<G> {
480 type Output = Weighted<Term<G>, G::Scalar>;
481
482 fn mul(self, rhs: Weighted<ScalarVar<G>, G::Scalar>) -> Self::Output {
483 rhs * self
484 }
485 }
486
487 impl<G: Group> Mul<ScalarTerm<G>> for Weighted<GroupVar<G>, G::Scalar> {
488 type Output = Weighted<Term<G>, G::Scalar>;
489
490 fn mul(self, rhs: ScalarTerm<G>) -> Self::Output {
491 Weighted {
492 term: self.term * rhs,
493 weight: self.weight,
494 }
495 }
496 }
497
498 impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarTerm<G> {
499 type Output = Weighted<Term<G>, G::Scalar>;
500
501 fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
502 rhs * self
503 }
504 }
505
506 impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarTerm<G>, G::Scalar> {
507 type Output = Weighted<Term<G>, G::Scalar>;
508
509 fn mul(self, rhs: GroupVar<G>) -> Self::Output {
510 Weighted {
511 term: self.term * rhs,
512 weight: self.weight,
513 }
514 }
515 }
516
517 impl<G: Group> Mul<Weighted<ScalarTerm<G>, G::Scalar>> for GroupVar<G> {
518 type Output = Weighted<Term<G>, G::Scalar>;
519
520 fn mul(self, rhs: Weighted<ScalarTerm<G>, G::Scalar>) -> Self::Output {
521 rhs * self
522 }
523 }
524}
525
526mod neg {
527 use super::*;
528
529 impl<T: Neg> Neg for Sum<T> {
530 type Output = Sum<<T as Neg>::Output>;
531
532 fn neg(self) -> Self::Output {
534 Sum(self.0.into_iter().map(|x| x.neg()).collect())
535 }
536 }
537
538 impl<T, F: Field> Neg for Weighted<T, F> {
539 type Output = Weighted<T, F>;
540
541 fn neg(self) -> Self::Output {
543 Weighted {
544 term: self.term,
545 weight: -self.weight,
546 }
547 }
548 }
549
550 macro_rules! impl_neg_term {
551 ($($type:ty),+) => {
552 $(
553 impl<G: Group> Neg for $type {
554 type Output = Weighted<$type, G::Scalar>;
555
556 fn neg(self) -> Self::Output {
557 Weighted {
558 term: self,
559 weight: -G::Scalar::ONE,
560 }
561 }
562 }
563 )+
564 };
565 }
566
567 impl_neg_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
568}
569
570mod sub {
571 use super::*;
572
573 impl<T, Rhs> Sub<Rhs> for Sum<T>
574 where
575 Rhs: Neg,
576 <Rhs as Neg>::Output: Add<Self>,
577 {
578 type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
579
580 #[allow(clippy::suspicious_arithmetic_impl)]
581 fn sub(self, rhs: Rhs) -> Self::Output {
582 rhs.neg() + self
583 }
584 }
585
586 impl<T, F, Rhs> Sub<Rhs> for Weighted<T, F>
587 where
588 Rhs: Neg,
589 <Rhs as Neg>::Output: Add<Self>,
590 {
591 type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
592
593 #[allow(clippy::suspicious_arithmetic_impl)]
594 fn sub(self, rhs: Rhs) -> Self::Output {
595 rhs.neg() + self
596 }
597 }
598
599 macro_rules! impl_sub_as_neg_add {
600 ($($type:ty),+) => {
601 $(
602 impl<G, Rhs> Sub<Rhs> for $type
603 where
604 Rhs: Neg,
605 Self: Add<<Rhs as Neg>::Output>,
606 {
607 type Output = <Self as Add<<Rhs as Neg>::Output>>::Output;
608
609 #[allow(clippy::suspicious_arithmetic_impl)]
610 fn sub(self, rhs: Rhs) -> Self::Output {
611 self + rhs.neg()
612 }
613 }
614 )+
615 };
616 }
617
618 impl_sub_as_neg_add!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
619}
620
621#[cfg(test)]
622mod tests {
623 use crate::linear_relation::{GroupVar, ScalarTerm, ScalarVar, Term};
624 use core::marker::PhantomData;
625 use curve25519_dalek::RistrettoPoint as G;
626 use curve25519_dalek::Scalar;
627
628 fn scalar_var(i: usize) -> ScalarVar<G> {
629 ScalarVar(i, PhantomData)
630 }
631
632 fn group_var(i: usize) -> GroupVar<G> {
633 GroupVar(i, PhantomData)
634 }
635
636 #[test]
637 fn test_scalar_var_addition() {
638 let x = scalar_var(0);
639 let y = scalar_var(1);
640
641 let sum = x + y;
642 assert_eq!(sum.terms().len(), 2);
643 assert_eq!(sum.terms()[0], x);
644 assert_eq!(sum.terms()[1], y);
645 }
646
647 #[test]
648 fn test_scalar_var_scalar_addition() {
649 let x = scalar_var(0);
650
651 let sum = x + Scalar::from(5u64);
652 assert_eq!(sum.terms().len(), 2);
653 assert_eq!(sum.terms()[0].term, x.into());
654 assert_eq!(sum.terms()[0].weight, Scalar::ONE);
655 assert_eq!(sum.terms()[1].term, ScalarTerm::Unit);
656 assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
657 }
658
659 #[test]
660 fn test_scalar_var_scalar_addition_mul_group() {
661 let x = scalar_var(0);
662 let g = group_var(0);
663
664 let res = (x + Scalar::from(5u64)) * g;
665
666 assert_eq!(res.terms().len(), 2);
667 assert_eq!(
668 res.terms()[0].term,
669 Term {
670 scalar: x.into(),
671 elem: g
672 }
673 );
674 assert_eq!(res.terms()[0].weight, Scalar::ONE);
675 assert_eq!(
676 res.terms()[1].term,
677 Term {
678 scalar: ScalarTerm::Unit,
679 elem: g
680 }
681 );
682 assert_eq!(res.terms()[1].weight, Scalar::from(5u64));
683 }
684
685 #[test]
686 fn test_group_var_addition() {
687 let g = group_var(0);
688 let h = group_var(1);
689
690 let sum = g + h;
691 assert_eq!(sum.terms().len(), 2);
692 assert_eq!(sum.terms()[0], g);
693 assert_eq!(sum.terms()[1], h);
694 }
695
696 #[test]
697 fn test_term_addition() {
698 let x = scalar_var(0);
699 let g = group_var(0);
700 let y = scalar_var(1);
701 let h = group_var(1);
702
703 let term1 = Term {
704 scalar: x.into(),
705 elem: g,
706 };
707 let term2 = Term {
708 scalar: y.into(),
709 elem: h,
710 };
711
712 let sum = term1 + term2;
713 assert_eq!(sum.terms().len(), 2);
714 assert_eq!(sum.terms()[0], term1);
715 assert_eq!(sum.terms()[1], term2);
716 }
717
718 #[test]
719 fn test_term_group_var_addition() {
720 let x = scalar_var(0);
721 let g = group_var(0);
722
723 let res = (x * g) + g;
724
725 assert_eq!(res.terms().len(), 2);
726 assert_eq!(
727 res.terms()[0],
728 Term {
729 scalar: x.into(),
730 elem: g
731 }
732 );
733 assert_eq!(
734 res.terms()[1],
735 Term {
736 scalar: ScalarTerm::Unit,
737 elem: g
738 }
739 );
740 }
741
742 #[test]
743 fn test_scalar_group_multiplication() {
744 let x = scalar_var(0);
745 let g = group_var(0);
746
747 let term1 = x * g;
748 let term2 = g * x;
749
750 assert_eq!(term1.scalar, x.into());
751 assert_eq!(term1.elem, g);
752 assert_eq!(term2.scalar, x.into());
753 assert_eq!(term2.elem, g);
754 }
755
756 #[test]
757 fn test_scalar_coefficient_multiplication() {
758 let x = scalar_var(0);
759 let weighted = x * Scalar::from(5u64);
760
761 assert_eq!(weighted.term, x);
762 assert_eq!(weighted.weight, Scalar::from(5u64));
763 }
764
765 #[test]
766 fn test_group_coefficient_multiplication() {
767 let g = group_var(0);
768 let weighted = g * Scalar::from(3u64);
769
770 assert_eq!(weighted.term, g);
771 assert_eq!(weighted.weight, Scalar::from(3u64));
772 }
773
774 #[test]
775 fn test_term_coefficient_multiplication() {
776 let x = scalar_var(0);
777 let g = group_var(0);
778 let term = Term {
779 scalar: x.into(),
780 elem: g,
781 };
782 let weighted = term * Scalar::from(7u64);
783
784 assert_eq!(weighted.term, term);
785 assert_eq!(weighted.weight, Scalar::from(7u64));
786 }
787
788 #[test]
789 fn test_scalar_var_negation() {
790 let x = scalar_var(0);
791 let neg_x = -x;
792
793 assert_eq!(neg_x.term, x);
794 assert_eq!(neg_x.weight, -Scalar::ONE);
795 }
796
797 #[test]
798 fn test_group_var_negation() {
799 let g = group_var(0);
800 let neg_g = -g;
801
802 assert_eq!(neg_g.term, g);
803 assert_eq!(neg_g.weight, -Scalar::ONE);
804 }
805
806 #[test]
807 fn test_term_negation() {
808 let x = scalar_var(0);
809 let g = group_var(0);
810 let term = Term {
811 scalar: x.into(),
812 elem: g,
813 };
814 let neg_term = -term;
815
816 assert_eq!(neg_term.term, term);
817 assert_eq!(neg_term.weight, -Scalar::ONE);
818 }
819
820 #[test]
821 fn test_weighted_negation() {
822 let x = scalar_var(0);
823 let weighted = x * Scalar::from(5u64);
824 let neg_weighted = -weighted;
825
826 assert_eq!(neg_weighted.term, x);
827 assert_eq!(neg_weighted.weight, -Scalar::from(5u64));
828 }
829
830 #[test]
831 fn test_scalar_var_subtraction() {
832 let x = scalar_var(0);
833 let y = scalar_var(1);
834
835 let diff = x - y;
836 assert_eq!(diff.terms().len(), 2);
837 assert_eq!(diff.terms()[0].term, y);
838 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
839 assert_eq!(diff.terms()[1].term, x);
840 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
841 }
842
843 #[test]
844 fn test_scalar_var_subtraction_by_scalar() {
845 let x = scalar_var(0);
846
847 let diff = x - Scalar::ONE;
848 assert_eq!(diff.terms().len(), 2);
849 assert_eq!(diff.terms()[0].term, ScalarTerm::Var(x));
850 assert_eq!(diff.terms()[0].weight, Scalar::ONE);
851 assert_eq!(diff.terms()[1].term, ScalarTerm::Unit);
852 assert_eq!(diff.terms()[1].weight, -Scalar::ONE);
853 }
854
855 #[test]
856 fn test_group_var_subtraction() {
857 let g = group_var(0);
858 let h = group_var(1);
859
860 let diff = g - h;
861 assert_eq!(diff.terms().len(), 2);
862 assert_eq!(diff.terms()[0].term, h);
863 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
864 assert_eq!(diff.terms()[1].term, g);
865 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
866 }
867
868 #[test]
869 fn test_term_subtraction() {
870 let x = scalar_var(0);
871 let g = group_var(0);
872 let y = scalar_var(1);
873 let h = group_var(1);
874
875 let term1 = Term {
876 scalar: x.into(),
877 elem: g,
878 };
879 let term2 = Term {
880 scalar: y.into(),
881 elem: h,
882 };
883
884 let diff = term1 - term2;
885 assert_eq!(diff.terms().len(), 2);
886 assert_eq!(diff.terms()[0].term, term2);
887 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
888 assert_eq!(diff.terms()[1].term, term1);
889 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
890 }
891
892 #[test]
893 fn test_sum_addition_chaining() {
894 let x = scalar_var(0);
895 let y = scalar_var(1);
896 let z = scalar_var(2);
897
898 let sum = x + y + z;
899 assert_eq!(sum.terms().len(), 3);
900 assert_eq!(sum.terms()[0], x);
901 assert_eq!(sum.terms()[1], y);
902 assert_eq!(sum.terms()[2], z);
903 }
904
905 #[test]
906 fn test_sum_plus_scalar_var() {
907 let x = scalar_var(0);
908 let y = scalar_var(1);
909 let z = scalar_var(2);
910
911 let sum = x + y;
912 let result = z + sum;
913 assert_eq!(result.terms().len(), 3);
914 assert_eq!(result.terms()[0], x);
915 assert_eq!(result.terms()[1], y);
916 assert_eq!(result.terms()[2], z);
917 }
918
919 #[test]
920 fn test_sum_plus_sum() {
921 let x = scalar_var(0);
922 let y = scalar_var(1);
923 let z = scalar_var(2);
924 let w = scalar_var(3);
925
926 let sum1 = x + y;
927 let sum2 = z + w;
928 let result = sum1 + sum2;
929
930 assert_eq!(result.terms().len(), 4);
931 assert_eq!(result.terms()[0], x);
932 assert_eq!(result.terms()[1], y);
933 assert_eq!(result.terms()[2], z);
934 assert_eq!(result.terms()[3], w);
935 }
936
937 #[test]
938 fn test_sum_negation() {
939 let x = scalar_var(0);
940 let y = scalar_var(1);
941
942 let sum = x + y;
943 let neg_sum = -sum;
944
945 assert_eq!(neg_sum.terms().len(), 2);
946 assert_eq!(neg_sum.terms()[0].term, x);
947 assert_eq!(neg_sum.terms()[0].weight, -Scalar::ONE);
948 assert_eq!(neg_sum.terms()[1].term, y);
949 assert_eq!(neg_sum.terms()[1].weight, -Scalar::ONE);
950 }
951
952 #[test]
953 fn test_weighted_addition() {
954 let x = scalar_var(0);
955 let y = scalar_var(1);
956
957 let weighted1 = x * Scalar::from(3u64);
958 let weighted2 = y * Scalar::from(5u64);
959 let sum = weighted1 + weighted2;
960
961 assert_eq!(sum.terms().len(), 2);
962 assert_eq!(sum.terms()[0].term, x);
963 assert_eq!(sum.terms()[0].weight, Scalar::from(3u64));
964 assert_eq!(sum.terms()[1].term, y);
965 assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
966 }
967
968 #[test]
969 fn test_weighted_plus_term() {
970 let x = scalar_var(0);
971 let y = scalar_var(1);
972
973 let weighted = x * Scalar::from(2u64);
974 let sum = weighted + y;
975
976 assert_eq!(sum.terms().len(), 2);
977 assert_eq!(sum.terms()[0].term, x);
978 assert_eq!(sum.terms()[0].weight, Scalar::from(2u64));
979 assert_eq!(sum.terms()[1].term, y);
980 assert_eq!(sum.terms()[1].weight, Scalar::ONE);
981 }
982
983 #[test]
984 fn test_weighted_scalar_multiplication() {
985 let x = scalar_var(0);
986 let weighted = x * Scalar::from(2u64);
987 let result = weighted * Scalar::from(3u64);
988
989 assert_eq!(result.term, x);
990 assert_eq!(result.weight, Scalar::from(6u64));
991 }
992
993 #[test]
994 fn test_weighted_group_var_times_scalar_var() {
995 let x = scalar_var(0);
996 let g = group_var(0);
997
998 let weighted_g = g * Scalar::from(5u64);
999 let result = x * weighted_g;
1000
1001 assert_eq!(result.term.scalar, x.into());
1002 assert_eq!(result.term.elem, g);
1003 assert_eq!(result.weight, Scalar::from(5u64));
1004 }
1005
1006 #[test]
1007 fn test_weighted_scalar_var_times_group_var() {
1008 let x = scalar_var(0);
1009 let g = group_var(0);
1010
1011 let weighted_x = x * Scalar::from(3u64);
1012 let result = weighted_x * g;
1013
1014 assert_eq!(result.term.scalar, x.into());
1015 assert_eq!(result.term.elem, g);
1016 assert_eq!(result.weight, Scalar::from(3u64));
1017 }
1018
1019 #[test]
1020 fn test_sum_scalar_multiplication_distributive() {
1021 let x = scalar_var(0);
1022 let y = scalar_var(1);
1023
1024 let sum = x + y;
1025 let result = sum * Scalar::from(2u64);
1026
1027 assert_eq!(result.terms().len(), 2);
1028 assert_eq!(result.terms()[0].term, x);
1029 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1030 assert_eq!(result.terms()[1].term, y);
1031 assert_eq!(result.terms()[1].weight, Scalar::from(2u64));
1032 }
1033
1034 #[test]
1035 fn test_sum_subtraction_distributive() {
1036 let x = scalar_var(0);
1037 let y = scalar_var(1);
1038 let z = scalar_var(2);
1039
1040 let sum1 = x + y;
1041 let result = sum1 - z;
1042
1043 assert_eq!(result.terms().len(), 3);
1044 assert_eq!(result.terms()[0].term, x);
1045 assert_eq!(result.terms()[0].weight, Scalar::ONE);
1046 assert_eq!(result.terms()[1].term, y);
1047 assert_eq!(result.terms()[1].weight, Scalar::ONE);
1048 assert_eq!(result.terms()[2].term, z);
1049 assert_eq!(result.terms()[2].weight, -Scalar::ONE);
1050 }
1051
1052 #[test]
1053 fn test_weighted_sum_scalar_multiplication() {
1054 let x = scalar_var(0);
1055 let y = scalar_var(1);
1056
1057 let weighted1 = x * Scalar::from(2u64);
1058 let weighted2 = y * Scalar::from(3u64);
1059 let sum = weighted1 + weighted2;
1060 let result = sum * Scalar::from(4u64);
1061
1062 assert_eq!(result.terms().len(), 2);
1063 assert_eq!(result.terms()[0].term, x);
1064 assert_eq!(result.terms()[0].weight, Scalar::from(8u64));
1065 assert_eq!(result.terms()[1].term, y);
1066 assert_eq!(result.terms()[1].weight, Scalar::from(12u64));
1067 }
1068
1069 #[test]
1070 fn test_pedersen_commitment_expression() {
1071 let x = scalar_var(0);
1072 let r = scalar_var(1);
1073 let g = group_var(0);
1074 let h = group_var(1);
1075
1076 let commitment = x * g + r * h;
1077 assert_eq!(commitment.terms().len(), 2);
1078 assert_eq!(commitment.terms()[0].scalar, x.into());
1079 assert_eq!(commitment.terms()[0].elem, g);
1080 assert_eq!(commitment.terms()[1].scalar, r.into());
1081 assert_eq!(commitment.terms()[1].elem, h);
1082 }
1083
1084 #[test]
1085 fn test_weighted_pedersen_commitment() {
1086 let x = scalar_var(0);
1087 let r = scalar_var(1);
1088 let g = group_var(0);
1089 let h = group_var(1);
1090
1091 let commitment = x * g * Scalar::from(3u64) + r * h * Scalar::from(2u64);
1092 assert_eq!(commitment.terms().len(), 2);
1093 assert_eq!(commitment.terms()[0].term.scalar, x.into());
1094 assert_eq!(commitment.terms()[0].term.elem, g);
1095 assert_eq!(commitment.terms()[0].weight, Scalar::from(3u64));
1096 assert_eq!(commitment.terms()[1].term.scalar, r.into());
1097 assert_eq!(commitment.terms()[1].term.elem, h);
1098 assert_eq!(commitment.terms()[1].weight, Scalar::from(2u64));
1099 }
1100
1101 #[test]
1102 fn test_complex_multi_term_expression() {
1103 let scalars = [scalar_var(0), scalar_var(1), scalar_var(2), scalar_var(3)];
1104 let groups = [group_var(0), group_var(1), group_var(2), group_var(3)];
1105
1106 let expr = scalars[0] * groups[0] + scalars[1] * groups[1] + scalars[2] * groups[2]
1107 - scalars[3] * groups[3];
1108
1109 assert_eq!(expr.terms().len(), 4);
1110
1111 for i in 0..3 {
1112 assert_eq!(expr.terms()[i].term.scalar, scalars[i].into());
1113 assert_eq!(expr.terms()[i].term.elem, groups[i]);
1114 assert_eq!(expr.terms()[i].weight, Scalar::ONE);
1115 }
1116
1117 assert_eq!(expr.terms()[3].term.scalar, scalars[3].into());
1118 assert_eq!(expr.terms()[3].term.elem, groups[3]);
1119 assert_eq!(expr.terms()[3].weight, -Scalar::ONE);
1120 }
1121
1122 #[test]
1123 fn test_chained_addition_with_coefficients() {
1124 let x = scalar_var(0);
1125 let y = scalar_var(1);
1126 let z = scalar_var(2);
1127 let g = group_var(0);
1128 let h = group_var(1);
1129 let k = group_var(2);
1130
1131 let expr =
1132 x * g * Scalar::from(2u64) + y * h * Scalar::from(3u64) + z * k * Scalar::from(5u64);
1133 assert_eq!(expr.terms().len(), 3);
1134
1135 let expected_coeffs = [2u64, 3u64, 5u64];
1136 let expected_scalars = [x, y, z];
1137 let expected_groups = [g, h, k];
1138
1139 for i in 0..3 {
1140 assert_eq!(expr.terms()[i].term.scalar, expected_scalars[i].into());
1141 assert_eq!(expr.terms()[i].term.elem, expected_groups[i]);
1142 assert_eq!(expr.terms()[i].weight, Scalar::from(expected_coeffs[i]));
1143 }
1144 }
1145
1146 #[test]
1147 fn test_mixing_sum_term_and_sum_weighted() {
1148 let x = scalar_var(0);
1149 let y = scalar_var(1);
1150 let z = scalar_var(2);
1151 let g = group_var(0);
1152 let h = group_var(1);
1153 let k = group_var(2);
1154
1155 let basic_sum = x * g + y * h; let weighted_term = z * k * Scalar::from(3u64); let mixed = basic_sum + weighted_term;
1158
1159 assert_eq!(mixed.terms().len(), 3);
1160 assert_eq!(mixed.terms()[0].term.scalar, x.into());
1161 assert_eq!(mixed.terms()[0].term.elem, g);
1162 assert_eq!(mixed.terms()[0].weight, Scalar::ONE);
1163 assert_eq!(mixed.terms()[1].term.scalar, y.into());
1164 assert_eq!(mixed.terms()[1].term.elem, h);
1165 assert_eq!(mixed.terms()[1].weight, Scalar::ONE);
1166 assert_eq!(mixed.terms()[2].term.scalar, z.into());
1167 assert_eq!(mixed.terms()[2].term.elem, k);
1168 assert_eq!(mixed.terms()[2].weight, Scalar::from(3u64));
1169 }
1170
1171 #[test]
1172 fn test_sum_term_plus_weighted_group_var() {
1173 let x = scalar_var(0);
1174 let y = scalar_var(1);
1175 let g = group_var(0);
1176 let h = group_var(1);
1177
1178 let sum_term = x * g + y * h;
1179 let weighted = h * Scalar::from(3u64);
1180 let result = sum_term + weighted;
1181
1182 assert_eq!(result.terms().len(), 3);
1183 assert_eq!(result.terms()[0].term.scalar, x.into());
1184 assert_eq!(result.terms()[0].term.elem, g);
1185 assert_eq!(result.terms()[0].weight, Scalar::ONE);
1186 assert_eq!(result.terms()[1].term.scalar, y.into());
1187 assert_eq!(result.terms()[1].term.elem, h);
1188 assert_eq!(result.terms()[1].weight, Scalar::ONE);
1189 assert_eq!(result.terms()[2].term.scalar, ScalarTerm::Unit);
1190 assert_eq!(result.terms()[2].term.elem, h);
1191 assert_eq!(result.terms()[2].weight, Scalar::from(3u64));
1192 }
1193
1194 #[test]
1195 fn test_weighted_group_var_plus_sum_term() {
1196 let x = scalar_var(0);
1197 let y = scalar_var(1);
1198 let g = group_var(0);
1199 let h = group_var(1);
1200
1201 let sum_term = x * g + y * h;
1202 let weighted = h * Scalar::from(5u64);
1203 let result = weighted + sum_term;
1204
1205 assert_eq!(result.terms().len(), 3);
1206 assert_eq!(result.terms()[0].term.scalar, x.into());
1207 assert_eq!(result.terms()[0].term.elem, g);
1208 assert_eq!(result.terms()[0].weight, Scalar::ONE);
1209 assert_eq!(result.terms()[1].term.scalar, y.into());
1210 assert_eq!(result.terms()[1].term.elem, h);
1211 assert_eq!(result.terms()[1].weight, Scalar::ONE);
1212 assert_eq!(result.terms()[2].term.scalar, ScalarTerm::Unit);
1213 assert_eq!(result.terms()[2].term.elem, h);
1214 assert_eq!(result.terms()[2].weight, Scalar::from(5u64));
1215 }
1216
1217 #[test]
1218 fn test_sum_weighted_group_var_plus_term() {
1219 let x = scalar_var(0);
1220 let g = group_var(0);
1221 let h = group_var(1);
1222
1223 let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1224 let term = x * g;
1225 let result = sum_weighted + term;
1226
1227 assert_eq!(result.terms().len(), 3);
1228 assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1229 assert_eq!(result.terms()[0].term.elem, g);
1230 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1231 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1232 assert_eq!(result.terms()[1].term.elem, h);
1233 assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1234 assert_eq!(result.terms()[2].term.scalar, x.into());
1235 assert_eq!(result.terms()[2].term.elem, g);
1236 assert_eq!(result.terms()[2].weight, Scalar::ONE);
1237 }
1238
1239 #[test]
1240 fn test_sum_weighted_group_var_plus_sum_term() {
1241 let x = scalar_var(0);
1242 let y = scalar_var(1);
1243 let g = group_var(0);
1244 let h = group_var(1);
1245
1246 let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1247 let sum_term = x * g + y * h;
1248 let result = sum_weighted + sum_term;
1249
1250 assert_eq!(result.terms().len(), 4);
1251 assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1252 assert_eq!(result.terms()[0].term.elem, g);
1253 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1254 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1255 assert_eq!(result.terms()[1].term.elem, h);
1256 assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1257 assert_eq!(result.terms()[2].term.scalar, x.into());
1258 assert_eq!(result.terms()[2].term.elem, g);
1259 assert_eq!(result.terms()[2].weight, Scalar::ONE);
1260 assert_eq!(result.terms()[3].term.scalar, y.into());
1261 assert_eq!(result.terms()[3].term.elem, h);
1262 assert_eq!(result.terms()[3].weight, Scalar::ONE);
1263 }
1264
1265 #[test]
1266 fn test_term_plus_sum_weighted_group_var() {
1267 let x = scalar_var(0);
1268 let g = group_var(0);
1269 let h = group_var(1);
1270
1271 let term = x * g;
1272 let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1273 let result = term + sum_weighted;
1274
1275 assert_eq!(result.terms().len(), 3);
1276 assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1277 assert_eq!(result.terms()[0].term.elem, g);
1278 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1279 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1280 assert_eq!(result.terms()[1].term.elem, h);
1281 assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1282 assert_eq!(result.terms()[2].term.scalar, x.into());
1283 assert_eq!(result.terms()[2].term.elem, g);
1284 assert_eq!(result.terms()[2].weight, Scalar::ONE);
1285 }
1286
1287 #[test]
1288 fn test_sum_term_plus_sum_weighted_group_var() {
1289 let x = scalar_var(0);
1290 let y = scalar_var(1);
1291 let g = group_var(0);
1292 let h = group_var(1);
1293
1294 let sum_term = x * g + y * h;
1295 let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1296 let result = sum_term + sum_weighted;
1297
1298 assert_eq!(result.terms().len(), 4);
1299 assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1300 assert_eq!(result.terms()[0].term.elem, g);
1301 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1302 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1303 assert_eq!(result.terms()[1].term.elem, h);
1304 assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1305 assert_eq!(result.terms()[2].term.scalar, x.into());
1306 assert_eq!(result.terms()[2].term.elem, g);
1307 assert_eq!(result.terms()[2].weight, Scalar::ONE);
1308 assert_eq!(result.terms()[3].term.scalar, y.into());
1309 assert_eq!(result.terms()[3].term.elem, h);
1310 assert_eq!(result.terms()[3].weight, Scalar::ONE);
1311 }
1312
1313 #[test]
1314 fn test_scalar_var_minus_scalar_times_group() {
1315 let x = scalar_var(0);
1316 let b = group_var(0);
1317
1318 let result = x * b + b * (-Scalar::ONE);
1321
1322 assert_eq!(result.terms().len(), 2);
1323 assert_eq!(result.terms()[0].term.scalar, x.into());
1324 assert_eq!(result.terms()[0].term.elem, b);
1325 assert_eq!(result.terms()[0].weight, Scalar::ONE);
1326 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1327 assert_eq!(result.terms()[1].term.elem, b);
1328 assert_eq!(result.terms()[1].weight, -Scalar::ONE);
1329 }
1330
1331 #[test]
1332 fn test_group_var_times_scalar_plus_scalar_times_group() {
1333 let gen__disj1_x_r = scalar_var(0);
1334 let a = group_var(0);
1335 let b = group_var(1);
1336
1337 let result = a * Scalar::ONE + gen__disj1_x_r * b;
1339
1340 assert_eq!(result.terms().len(), 2);
1341 assert_eq!(result.terms()[0].term.scalar, gen__disj1_x_r.into());
1343 assert_eq!(result.terms()[0].term.elem, b);
1344 assert_eq!(result.terms()[0].weight, Scalar::ONE);
1345 assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1346 assert_eq!(result.terms()[1].term.elem, a);
1347 assert_eq!(result.terms()[1].weight, Scalar::ONE);
1348 }
1349}