Skip to main content

sigma_proofs/linear_relation/
ops.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::ops::{Add, Mul, Neg, Sub};
4use ff::Field;
5use group::Group;
6
7use super::{GroupVar, ScalarTerm, ScalarVar, Sum, Term, Weighted};
8
9mod add {
10    use super::*;
11
12    macro_rules! impl_add_term {
13        ($($type:ty),+) => {
14            $(
15            impl<G> Add<$type> for $type {
16                type Output = Sum<$type>;
17
18                fn add(self, rhs: $type) -> Self::Output {
19                    Sum(vec![self, rhs])
20                }
21            }
22            )+
23        };
24    }
25
26    impl_add_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
27
28    impl<T> Add<T> for Sum<T> {
29        type Output = Sum<T>;
30
31        fn add(mut self, rhs: T) -> Self::Output {
32            self.0.push(rhs);
33            self
34        }
35    }
36
37    macro_rules! impl_add_sum_term {
38        ($($type:ty),+) => {
39            $(
40            impl<G> Add<Sum<$type>> for $type {
41                type Output = Sum<$type>;
42
43                fn add(self, rhs: Sum<$type>) -> Self::Output {
44                    rhs + self
45                }
46            }
47            )+
48        };
49    }
50
51    impl_add_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
52
53    impl<T> Add<Sum<T>> for Sum<T> {
54        type Output = Sum<T>;
55
56        fn add(mut self, rhs: Sum<T>) -> Self::Output {
57            self.0.extend(rhs.0);
58            self
59        }
60    }
61
62    impl<T, F> Add<Weighted<T, F>> for Weighted<T, F> {
63        type Output = Sum<Weighted<T, F>>;
64
65        fn add(self, rhs: Weighted<T, F>) -> Self::Output {
66            Sum(vec![self, rhs])
67        }
68    }
69
70    impl<T, F: Field> Add<T> for Weighted<T, F> {
71        type Output = Sum<Weighted<T, F>>;
72
73        fn add(self, rhs: T) -> Self::Output {
74            Sum(vec![self, rhs.into()])
75        }
76    }
77
78    macro_rules! impl_add_weighted_term {
79        ($($type:ty),+) => {
80            $(
81            impl<G: Group> Add<Weighted<$type, G::Scalar>> for $type {
82                type Output = Sum<Weighted<$type, G::Scalar>>;
83
84                fn add(self, rhs: Weighted<$type, G::Scalar>) -> Self::Output {
85                    rhs + self
86                }
87            }
88            )+
89        };
90    }
91
92    impl_add_weighted_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
93
94    impl<T, F: Field> Add<T> for Sum<Weighted<T, F>> {
95        type Output = Sum<Weighted<T, F>>;
96
97        fn add(mut self, rhs: T) -> Self::Output {
98            self.0.push(rhs.into());
99            self
100        }
101    }
102
103    macro_rules! impl_add_weighted_sum_term {
104        ($($type:ty),+) => {
105            $(
106            impl<G: Group> Add<Sum<Weighted<$type, G::Scalar>>> for $type {
107                type Output = Sum<Weighted<$type, G::Scalar>>;
108
109                fn add(self, rhs: Sum<Weighted<$type, G::Scalar>>) -> Self::Output {
110                    rhs + self
111                }
112            }
113            )+
114        };
115    }
116
117    impl_add_weighted_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
118
119    impl<T, F: Field> Add<Sum<T>> for Sum<Weighted<T, F>> {
120        type Output = Sum<Weighted<T, F>>;
121
122        fn add(self, rhs: Sum<T>) -> Self::Output {
123            self + Self::from(rhs)
124        }
125    }
126
127    impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Sum<T> {
128        type Output = Sum<Weighted<T, F>>;
129
130        fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
131            rhs + self
132        }
133    }
134
135    impl<T, F: Field> Add<Weighted<T, F>> for Sum<T> {
136        type Output = Sum<Weighted<T, F>>;
137
138        fn add(self, rhs: Weighted<T, F>) -> Self::Output {
139            Self::Output::from(self) + rhs
140        }
141    }
142
143    impl<T, F: Field> Add<Sum<T>> for Weighted<T, F> {
144        type Output = Sum<Weighted<T, F>>;
145
146        fn add(self, rhs: Sum<T>) -> Self::Output {
147            rhs + self
148        }
149    }
150
151    impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Weighted<T, F> {
152        type Output = Sum<Weighted<T, F>>;
153
154        fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
155            rhs + self
156        }
157    }
158
159    impl<G> Add<ScalarVar<G>> for ScalarTerm<G> {
160        type Output = Sum<ScalarTerm<G>>;
161
162        fn add(self, rhs: ScalarVar<G>) -> Self::Output {
163            self + ScalarTerm::from(rhs)
164        }
165    }
166
167    impl<G> Add<ScalarTerm<G>> for ScalarVar<G> {
168        type Output = Sum<ScalarTerm<G>>;
169
170        fn add(self, rhs: ScalarTerm<G>) -> Self::Output {
171            rhs + self
172        }
173    }
174
175    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarTerm<G>, G::Scalar> {
176        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
177
178        fn add(self, rhs: T) -> Self::Output {
179            self + Self::from(rhs.into())
180        }
181    }
182
183    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarVar<G>, G::Scalar> {
184        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
185
186        fn add(self, rhs: T) -> Self::Output {
187            <Weighted<ScalarTerm<G>, G::Scalar>>::from(self) + rhs.into()
188        }
189    }
190
191    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for ScalarVar<G> {
192        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
193
194        fn add(self, rhs: T) -> Self::Output {
195            Weighted::from(ScalarTerm::from(self)) + rhs.into()
196        }
197    }
198
199    impl<G: Group> Add<GroupVar<G>> for Sum<Weighted<Term<G>, G::Scalar>> {
200        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
201
202        fn add(self, rhs: GroupVar<G>) -> Self::Output {
203            self + Self::from(rhs)
204        }
205    }
206
207    impl<G: Group> Add<GroupVar<G>> for Sum<Term<G>> {
208        type Output = Sum<Term<G>>;
209
210        fn add(self, rhs: GroupVar<G>) -> Self::Output {
211            self + Self::from(rhs)
212        }
213    }
214
215    impl<G: Group> Add<GroupVar<G>> for Weighted<Term<G>, G::Scalar> {
216        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
217
218        fn add(self, rhs: GroupVar<G>) -> Self::Output {
219            self + Self::from(rhs)
220        }
221    }
222
223    impl<G: Group> Add<GroupVar<G>> for Term<G> {
224        type Output = Sum<Term<G>>;
225
226        fn add(self, rhs: GroupVar<G>) -> Self::Output {
227            self + Self::from(rhs)
228        }
229    }
230
231    impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Term<G> {
232        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
233
234        fn add(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
235            Sum(vec![
236                Weighted {
237                    term: self,
238                    weight: G::Scalar::ONE,
239                },
240                Weighted {
241                    term: Term {
242                        scalar: super::ScalarTerm::Unit,
243                        elem: rhs.term,
244                    },
245                    weight: rhs.weight,
246                },
247            ])
248        }
249    }
250
251    impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Sum<Term<G>> {
252        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
253
254        fn add(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
255            Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
256        }
257    }
258
259    impl<G: Group> Add<Sum<Term<G>>> for Weighted<GroupVar<G>, G::Scalar> {
260        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
261
262        fn add(self, rhs: Sum<Term<G>>) -> Self::Output {
263            rhs + self
264        }
265    }
266
267    impl<G: Group> Add<Sum<Weighted<Term<G>, G::Scalar>>> for Weighted<GroupVar<G>, G::Scalar> {
268        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
269
270        fn add(self, rhs: Sum<Weighted<Term<G>, G::Scalar>>) -> Self::Output {
271            let weighted_term = Weighted {
272                term: Term {
273                    scalar: super::ScalarTerm::Unit,
274                    elem: self.term,
275                },
276                weight: self.weight,
277            };
278            rhs + weighted_term
279        }
280    }
281
282    impl<G: Group> Add<Term<G>> for Sum<Weighted<GroupVar<G>, G::Scalar>> {
283        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
284
285        fn add(self, rhs: Term<G>) -> Self::Output {
286            Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
287        }
288    }
289
290    impl<G: Group> Add<Sum<Term<G>>> for Sum<Weighted<GroupVar<G>, G::Scalar>> {
291        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
292
293        fn add(self, rhs: Sum<Term<G>>) -> Self::Output {
294            Sum::<Weighted<Term<G>, G::Scalar>>::from(self) + rhs
295        }
296    }
297
298    impl<G: Group> Add<Sum<Weighted<GroupVar<G>, G::Scalar>>> for Term<G> {
299        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
300
301        fn add(self, rhs: Sum<Weighted<GroupVar<G>, G::Scalar>>) -> Self::Output {
302            Sum::<Weighted<Term<G>, G::Scalar>>::from(rhs) + self
303        }
304    }
305
306    impl<G: Group> Add<Sum<Weighted<GroupVar<G>, G::Scalar>>> for Sum<Term<G>> {
307        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
308
309        fn add(self, rhs: Sum<Weighted<GroupVar<G>, G::Scalar>>) -> Self::Output {
310            self + Sum::<Weighted<Term<G>, G::Scalar>>::from(rhs)
311        }
312    }
313
314    impl<G: Group> Add<Weighted<GroupVar<G>, G::Scalar>> for Sum<Weighted<Term<G>, G::Scalar>> {
315        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
316
317        fn add(mut self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
318            self.0.push(Weighted {
319                term: Term {
320                    scalar: super::ScalarTerm::Unit,
321                    elem: rhs.term,
322                },
323                weight: rhs.weight,
324            });
325            self
326        }
327    }
328
329    impl<G: Group> Add<Term<G>> for Weighted<GroupVar<G>, G::Scalar> {
330        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
331
332        fn add(self, rhs: Term<G>) -> Self::Output {
333            rhs + self
334        }
335    }
336
337    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Sum<ScalarVar<G>> {
338        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
339
340        fn add(self, rhs: T) -> Self::Output {
341            // Convert Sum<ScalarVar<G>> to Sum<Weighted<ScalarTerm<G>, G::Scalar>>
342            let mut weighted_terms = Vec::new();
343            for var in self.0 {
344                weighted_terms.push(Weighted {
345                    term: ScalarTerm::from(var),
346                    weight: G::Scalar::ONE,
347                });
348            }
349            let weighted_sum: Sum<Weighted<ScalarTerm<G>, G::Scalar>> = Sum(weighted_terms);
350
351            // Convert the scalar to a weighted term
352            let weighted_scalar = Weighted {
353                term: ScalarTerm::Unit,
354                weight: rhs.into(),
355            };
356
357            weighted_sum + weighted_scalar
358        }
359    }
360}
361
362mod mul {
363    use super::*;
364
365    impl<G> Mul<ScalarVar<G>> for GroupVar<G> {
366        type Output = Term<G>;
367
368        /// Multiply a [ScalarVar] by a [GroupVar] to form a new [Term].
369        fn mul(self, rhs: ScalarVar<G>) -> Term<G> {
370            Term {
371                elem: self,
372                scalar: rhs.into(),
373            }
374        }
375    }
376
377    impl<G> Mul<GroupVar<G>> for ScalarVar<G> {
378        type Output = Term<G>;
379
380        /// Multiply a [ScalarVar] by a [GroupVar] to form a new [Term].
381        fn mul(self, rhs: GroupVar<G>) -> Term<G> {
382            rhs * self
383        }
384    }
385
386    impl<G> Mul<ScalarTerm<G>> for GroupVar<G> {
387        type Output = Term<G>;
388
389        fn mul(self, rhs: ScalarTerm<G>) -> Term<G> {
390            Term {
391                elem: self,
392                scalar: rhs,
393            }
394        }
395    }
396
397    impl<G> Mul<GroupVar<G>> for ScalarTerm<G> {
398        type Output = Term<G>;
399
400        fn mul(self, rhs: GroupVar<G>) -> Term<G> {
401            rhs * self
402        }
403    }
404
405    impl<Rhs: Clone, Lhs: Mul<Rhs>> Mul<Rhs> for Sum<Lhs> {
406        type Output = Sum<<Lhs as Mul<Rhs>>::Output>;
407
408        /// Multiplication of the sum by a term, implemented as a general distributive property.
409        fn mul(self, rhs: Rhs) -> Self::Output {
410            Sum(self.0.into_iter().map(|x| x * rhs.clone()).collect())
411        }
412    }
413
414    // NOTE: Rust forbids implementation of foreign traits (e.g. Mul) over bare generic types (e.g. F:
415    // Field). It can be implemented over specific types (e.g. curve25519_dalek::Scalar or u64). As a
416    // result, this generic implements `var * scalar`, but not `scalar * var`.
417
418    macro_rules! impl_scalar_mul_term {
419        ($($type:ty),+) => {
420            $(
421            // NOTE: Rust does not like this impl when F is replaced by G::Scalar.
422            impl<F: Field + Into<G::Scalar>, G: Group> Mul<F> for $type {
423                type Output = Weighted<$type, G::Scalar>;
424
425                fn mul(self, rhs: F) -> Self::Output {
426                    Weighted {
427                        term: self,
428                        weight: rhs.into(),
429                    }
430                }
431            }
432            )+
433        };
434    }
435
436    impl_scalar_mul_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
437
438    impl<T, F: Field> Mul<F> for Weighted<T, F> {
439        type Output = Weighted<T, F>;
440
441        fn mul(self, rhs: F) -> Self::Output {
442            Weighted {
443                term: self.term,
444                weight: self.weight * rhs,
445            }
446        }
447    }
448
449    impl<G: Group> Mul<ScalarVar<G>> for Weighted<GroupVar<G>, G::Scalar> {
450        type Output = Weighted<Term<G>, G::Scalar>;
451
452        fn mul(self, rhs: ScalarVar<G>) -> Self::Output {
453            Weighted {
454                term: self.term * rhs,
455                weight: self.weight,
456            }
457        }
458    }
459
460    impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarVar<G> {
461        type Output = Weighted<Term<G>, G::Scalar>;
462
463        fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
464            rhs * self
465        }
466    }
467
468    impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarVar<G>, G::Scalar> {
469        type Output = Weighted<Term<G>, G::Scalar>;
470
471        fn mul(self, rhs: GroupVar<G>) -> Self::Output {
472            Weighted {
473                term: self.term * rhs,
474                weight: self.weight,
475            }
476        }
477    }
478
479    impl<G: Group> Mul<Weighted<ScalarVar<G>, G::Scalar>> for GroupVar<G> {
480        type Output = Weighted<Term<G>, G::Scalar>;
481
482        fn mul(self, rhs: Weighted<ScalarVar<G>, G::Scalar>) -> Self::Output {
483            rhs * self
484        }
485    }
486
487    impl<G: Group> Mul<ScalarTerm<G>> for Weighted<GroupVar<G>, G::Scalar> {
488        type Output = Weighted<Term<G>, G::Scalar>;
489
490        fn mul(self, rhs: ScalarTerm<G>) -> Self::Output {
491            Weighted {
492                term: self.term * rhs,
493                weight: self.weight,
494            }
495        }
496    }
497
498    impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarTerm<G> {
499        type Output = Weighted<Term<G>, G::Scalar>;
500
501        fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
502            rhs * self
503        }
504    }
505
506    impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarTerm<G>, G::Scalar> {
507        type Output = Weighted<Term<G>, G::Scalar>;
508
509        fn mul(self, rhs: GroupVar<G>) -> Self::Output {
510            Weighted {
511                term: self.term * rhs,
512                weight: self.weight,
513            }
514        }
515    }
516
517    impl<G: Group> Mul<Weighted<ScalarTerm<G>, G::Scalar>> for GroupVar<G> {
518        type Output = Weighted<Term<G>, G::Scalar>;
519
520        fn mul(self, rhs: Weighted<ScalarTerm<G>, G::Scalar>) -> Self::Output {
521            rhs * self
522        }
523    }
524}
525
526mod neg {
527    use super::*;
528
529    impl<T: Neg> Neg for Sum<T> {
530        type Output = Sum<<T as Neg>::Output>;
531
532        /// Negation a sum, implemented as a general distributive property.
533        fn neg(self) -> Self::Output {
534            Sum(self.0.into_iter().map(|x| x.neg()).collect())
535        }
536    }
537
538    impl<T, F: Field> Neg for Weighted<T, F> {
539        type Output = Weighted<T, F>;
540
541        /// Negation of a weighted term, implemented as negation of its weight.
542        fn neg(self) -> Self::Output {
543            Weighted {
544                term: self.term,
545                weight: -self.weight,
546            }
547        }
548    }
549
550    macro_rules! impl_neg_term {
551        ($($type:ty),+) => {
552            $(
553            impl<G: Group> Neg for $type {
554                type Output = Weighted<$type, G::Scalar>;
555
556                fn neg(self) -> Self::Output {
557                    Weighted {
558                        term: self,
559                        weight: -G::Scalar::ONE,
560                    }
561                }
562            }
563            )+
564        };
565    }
566
567    impl_neg_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
568}
569
570mod sub {
571    use super::*;
572
573    impl<T, Rhs> Sub<Rhs> for Sum<T>
574    where
575        Rhs: Neg,
576        <Rhs as Neg>::Output: Add<Self>,
577    {
578        type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
579
580        #[allow(clippy::suspicious_arithmetic_impl)]
581        fn sub(self, rhs: Rhs) -> Self::Output {
582            rhs.neg() + self
583        }
584    }
585
586    impl<T, F, Rhs> Sub<Rhs> for Weighted<T, F>
587    where
588        Rhs: Neg,
589        <Rhs as Neg>::Output: Add<Self>,
590    {
591        type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
592
593        #[allow(clippy::suspicious_arithmetic_impl)]
594        fn sub(self, rhs: Rhs) -> Self::Output {
595            rhs.neg() + self
596        }
597    }
598
599    macro_rules! impl_sub_as_neg_add {
600        ($($type:ty),+) => {
601            $(
602            impl<G, Rhs> Sub<Rhs> for $type
603            where
604                Rhs: Neg,
605                Self: Add<<Rhs as Neg>::Output>,
606            {
607                type Output = <Self as Add<<Rhs as Neg>::Output>>::Output;
608
609                #[allow(clippy::suspicious_arithmetic_impl)]
610                fn sub(self, rhs: Rhs) -> Self::Output {
611                    self + rhs.neg()
612                }
613            }
614            )+
615        };
616    }
617
618    impl_sub_as_neg_add!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
619}
620
621#[cfg(test)]
622mod tests {
623    use crate::linear_relation::{GroupVar, ScalarTerm, ScalarVar, Term};
624    use core::marker::PhantomData;
625    use curve25519_dalek::RistrettoPoint as G;
626    use curve25519_dalek::Scalar;
627
628    fn scalar_var(i: usize) -> ScalarVar<G> {
629        ScalarVar(i, PhantomData)
630    }
631
632    fn group_var(i: usize) -> GroupVar<G> {
633        GroupVar(i, PhantomData)
634    }
635
636    #[test]
637    fn test_scalar_var_addition() {
638        let x = scalar_var(0);
639        let y = scalar_var(1);
640
641        let sum = x + y;
642        assert_eq!(sum.terms().len(), 2);
643        assert_eq!(sum.terms()[0], x);
644        assert_eq!(sum.terms()[1], y);
645    }
646
647    #[test]
648    fn test_scalar_var_scalar_addition() {
649        let x = scalar_var(0);
650
651        let sum = x + Scalar::from(5u64);
652        assert_eq!(sum.terms().len(), 2);
653        assert_eq!(sum.terms()[0].term, x.into());
654        assert_eq!(sum.terms()[0].weight, Scalar::ONE);
655        assert_eq!(sum.terms()[1].term, ScalarTerm::Unit);
656        assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
657    }
658
659    #[test]
660    fn test_scalar_var_scalar_addition_mul_group() {
661        let x = scalar_var(0);
662        let g = group_var(0);
663
664        let res = (x + Scalar::from(5u64)) * g;
665
666        assert_eq!(res.terms().len(), 2);
667        assert_eq!(
668            res.terms()[0].term,
669            Term {
670                scalar: x.into(),
671                elem: g
672            }
673        );
674        assert_eq!(res.terms()[0].weight, Scalar::ONE);
675        assert_eq!(
676            res.terms()[1].term,
677            Term {
678                scalar: ScalarTerm::Unit,
679                elem: g
680            }
681        );
682        assert_eq!(res.terms()[1].weight, Scalar::from(5u64));
683    }
684
685    #[test]
686    fn test_group_var_addition() {
687        let g = group_var(0);
688        let h = group_var(1);
689
690        let sum = g + h;
691        assert_eq!(sum.terms().len(), 2);
692        assert_eq!(sum.terms()[0], g);
693        assert_eq!(sum.terms()[1], h);
694    }
695
696    #[test]
697    fn test_term_addition() {
698        let x = scalar_var(0);
699        let g = group_var(0);
700        let y = scalar_var(1);
701        let h = group_var(1);
702
703        let term1 = Term {
704            scalar: x.into(),
705            elem: g,
706        };
707        let term2 = Term {
708            scalar: y.into(),
709            elem: h,
710        };
711
712        let sum = term1 + term2;
713        assert_eq!(sum.terms().len(), 2);
714        assert_eq!(sum.terms()[0], term1);
715        assert_eq!(sum.terms()[1], term2);
716    }
717
718    #[test]
719    fn test_term_group_var_addition() {
720        let x = scalar_var(0);
721        let g = group_var(0);
722
723        let res = (x * g) + g;
724
725        assert_eq!(res.terms().len(), 2);
726        assert_eq!(
727            res.terms()[0],
728            Term {
729                scalar: x.into(),
730                elem: g
731            }
732        );
733        assert_eq!(
734            res.terms()[1],
735            Term {
736                scalar: ScalarTerm::Unit,
737                elem: g
738            }
739        );
740    }
741
742    #[test]
743    fn test_scalar_group_multiplication() {
744        let x = scalar_var(0);
745        let g = group_var(0);
746
747        let term1 = x * g;
748        let term2 = g * x;
749
750        assert_eq!(term1.scalar, x.into());
751        assert_eq!(term1.elem, g);
752        assert_eq!(term2.scalar, x.into());
753        assert_eq!(term2.elem, g);
754    }
755
756    #[test]
757    fn test_scalar_coefficient_multiplication() {
758        let x = scalar_var(0);
759        let weighted = x * Scalar::from(5u64);
760
761        assert_eq!(weighted.term, x);
762        assert_eq!(weighted.weight, Scalar::from(5u64));
763    }
764
765    #[test]
766    fn test_group_coefficient_multiplication() {
767        let g = group_var(0);
768        let weighted = g * Scalar::from(3u64);
769
770        assert_eq!(weighted.term, g);
771        assert_eq!(weighted.weight, Scalar::from(3u64));
772    }
773
774    #[test]
775    fn test_term_coefficient_multiplication() {
776        let x = scalar_var(0);
777        let g = group_var(0);
778        let term = Term {
779            scalar: x.into(),
780            elem: g,
781        };
782        let weighted = term * Scalar::from(7u64);
783
784        assert_eq!(weighted.term, term);
785        assert_eq!(weighted.weight, Scalar::from(7u64));
786    }
787
788    #[test]
789    fn test_scalar_var_negation() {
790        let x = scalar_var(0);
791        let neg_x = -x;
792
793        assert_eq!(neg_x.term, x);
794        assert_eq!(neg_x.weight, -Scalar::ONE);
795    }
796
797    #[test]
798    fn test_group_var_negation() {
799        let g = group_var(0);
800        let neg_g = -g;
801
802        assert_eq!(neg_g.term, g);
803        assert_eq!(neg_g.weight, -Scalar::ONE);
804    }
805
806    #[test]
807    fn test_term_negation() {
808        let x = scalar_var(0);
809        let g = group_var(0);
810        let term = Term {
811            scalar: x.into(),
812            elem: g,
813        };
814        let neg_term = -term;
815
816        assert_eq!(neg_term.term, term);
817        assert_eq!(neg_term.weight, -Scalar::ONE);
818    }
819
820    #[test]
821    fn test_weighted_negation() {
822        let x = scalar_var(0);
823        let weighted = x * Scalar::from(5u64);
824        let neg_weighted = -weighted;
825
826        assert_eq!(neg_weighted.term, x);
827        assert_eq!(neg_weighted.weight, -Scalar::from(5u64));
828    }
829
830    #[test]
831    fn test_scalar_var_subtraction() {
832        let x = scalar_var(0);
833        let y = scalar_var(1);
834
835        let diff = x - y;
836        assert_eq!(diff.terms().len(), 2);
837        assert_eq!(diff.terms()[0].term, y);
838        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
839        assert_eq!(diff.terms()[1].term, x);
840        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
841    }
842
843    #[test]
844    fn test_scalar_var_subtraction_by_scalar() {
845        let x = scalar_var(0);
846
847        let diff = x - Scalar::ONE;
848        assert_eq!(diff.terms().len(), 2);
849        assert_eq!(diff.terms()[0].term, ScalarTerm::Var(x));
850        assert_eq!(diff.terms()[0].weight, Scalar::ONE);
851        assert_eq!(diff.terms()[1].term, ScalarTerm::Unit);
852        assert_eq!(diff.terms()[1].weight, -Scalar::ONE);
853    }
854
855    #[test]
856    fn test_group_var_subtraction() {
857        let g = group_var(0);
858        let h = group_var(1);
859
860        let diff = g - h;
861        assert_eq!(diff.terms().len(), 2);
862        assert_eq!(diff.terms()[0].term, h);
863        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
864        assert_eq!(diff.terms()[1].term, g);
865        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
866    }
867
868    #[test]
869    fn test_term_subtraction() {
870        let x = scalar_var(0);
871        let g = group_var(0);
872        let y = scalar_var(1);
873        let h = group_var(1);
874
875        let term1 = Term {
876            scalar: x.into(),
877            elem: g,
878        };
879        let term2 = Term {
880            scalar: y.into(),
881            elem: h,
882        };
883
884        let diff = term1 - term2;
885        assert_eq!(diff.terms().len(), 2);
886        assert_eq!(diff.terms()[0].term, term2);
887        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
888        assert_eq!(diff.terms()[1].term, term1);
889        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
890    }
891
892    #[test]
893    fn test_sum_addition_chaining() {
894        let x = scalar_var(0);
895        let y = scalar_var(1);
896        let z = scalar_var(2);
897
898        let sum = x + y + z;
899        assert_eq!(sum.terms().len(), 3);
900        assert_eq!(sum.terms()[0], x);
901        assert_eq!(sum.terms()[1], y);
902        assert_eq!(sum.terms()[2], z);
903    }
904
905    #[test]
906    fn test_sum_plus_scalar_var() {
907        let x = scalar_var(0);
908        let y = scalar_var(1);
909        let z = scalar_var(2);
910
911        let sum = x + y;
912        let result = z + sum;
913        assert_eq!(result.terms().len(), 3);
914        assert_eq!(result.terms()[0], x);
915        assert_eq!(result.terms()[1], y);
916        assert_eq!(result.terms()[2], z);
917    }
918
919    #[test]
920    fn test_sum_plus_sum() {
921        let x = scalar_var(0);
922        let y = scalar_var(1);
923        let z = scalar_var(2);
924        let w = scalar_var(3);
925
926        let sum1 = x + y;
927        let sum2 = z + w;
928        let result = sum1 + sum2;
929
930        assert_eq!(result.terms().len(), 4);
931        assert_eq!(result.terms()[0], x);
932        assert_eq!(result.terms()[1], y);
933        assert_eq!(result.terms()[2], z);
934        assert_eq!(result.terms()[3], w);
935    }
936
937    #[test]
938    fn test_sum_negation() {
939        let x = scalar_var(0);
940        let y = scalar_var(1);
941
942        let sum = x + y;
943        let neg_sum = -sum;
944
945        assert_eq!(neg_sum.terms().len(), 2);
946        assert_eq!(neg_sum.terms()[0].term, x);
947        assert_eq!(neg_sum.terms()[0].weight, -Scalar::ONE);
948        assert_eq!(neg_sum.terms()[1].term, y);
949        assert_eq!(neg_sum.terms()[1].weight, -Scalar::ONE);
950    }
951
952    #[test]
953    fn test_weighted_addition() {
954        let x = scalar_var(0);
955        let y = scalar_var(1);
956
957        let weighted1 = x * Scalar::from(3u64);
958        let weighted2 = y * Scalar::from(5u64);
959        let sum = weighted1 + weighted2;
960
961        assert_eq!(sum.terms().len(), 2);
962        assert_eq!(sum.terms()[0].term, x);
963        assert_eq!(sum.terms()[0].weight, Scalar::from(3u64));
964        assert_eq!(sum.terms()[1].term, y);
965        assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
966    }
967
968    #[test]
969    fn test_weighted_plus_term() {
970        let x = scalar_var(0);
971        let y = scalar_var(1);
972
973        let weighted = x * Scalar::from(2u64);
974        let sum = weighted + y;
975
976        assert_eq!(sum.terms().len(), 2);
977        assert_eq!(sum.terms()[0].term, x);
978        assert_eq!(sum.terms()[0].weight, Scalar::from(2u64));
979        assert_eq!(sum.terms()[1].term, y);
980        assert_eq!(sum.terms()[1].weight, Scalar::ONE);
981    }
982
983    #[test]
984    fn test_weighted_scalar_multiplication() {
985        let x = scalar_var(0);
986        let weighted = x * Scalar::from(2u64);
987        let result = weighted * Scalar::from(3u64);
988
989        assert_eq!(result.term, x);
990        assert_eq!(result.weight, Scalar::from(6u64));
991    }
992
993    #[test]
994    fn test_weighted_group_var_times_scalar_var() {
995        let x = scalar_var(0);
996        let g = group_var(0);
997
998        let weighted_g = g * Scalar::from(5u64);
999        let result = x * weighted_g;
1000
1001        assert_eq!(result.term.scalar, x.into());
1002        assert_eq!(result.term.elem, g);
1003        assert_eq!(result.weight, Scalar::from(5u64));
1004    }
1005
1006    #[test]
1007    fn test_weighted_scalar_var_times_group_var() {
1008        let x = scalar_var(0);
1009        let g = group_var(0);
1010
1011        let weighted_x = x * Scalar::from(3u64);
1012        let result = weighted_x * g;
1013
1014        assert_eq!(result.term.scalar, x.into());
1015        assert_eq!(result.term.elem, g);
1016        assert_eq!(result.weight, Scalar::from(3u64));
1017    }
1018
1019    #[test]
1020    fn test_sum_scalar_multiplication_distributive() {
1021        let x = scalar_var(0);
1022        let y = scalar_var(1);
1023
1024        let sum = x + y;
1025        let result = sum * Scalar::from(2u64);
1026
1027        assert_eq!(result.terms().len(), 2);
1028        assert_eq!(result.terms()[0].term, x);
1029        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1030        assert_eq!(result.terms()[1].term, y);
1031        assert_eq!(result.terms()[1].weight, Scalar::from(2u64));
1032    }
1033
1034    #[test]
1035    fn test_sum_subtraction_distributive() {
1036        let x = scalar_var(0);
1037        let y = scalar_var(1);
1038        let z = scalar_var(2);
1039
1040        let sum1 = x + y;
1041        let result = sum1 - z;
1042
1043        assert_eq!(result.terms().len(), 3);
1044        assert_eq!(result.terms()[0].term, x);
1045        assert_eq!(result.terms()[0].weight, Scalar::ONE);
1046        assert_eq!(result.terms()[1].term, y);
1047        assert_eq!(result.terms()[1].weight, Scalar::ONE);
1048        assert_eq!(result.terms()[2].term, z);
1049        assert_eq!(result.terms()[2].weight, -Scalar::ONE);
1050    }
1051
1052    #[test]
1053    fn test_weighted_sum_scalar_multiplication() {
1054        let x = scalar_var(0);
1055        let y = scalar_var(1);
1056
1057        let weighted1 = x * Scalar::from(2u64);
1058        let weighted2 = y * Scalar::from(3u64);
1059        let sum = weighted1 + weighted2;
1060        let result = sum * Scalar::from(4u64);
1061
1062        assert_eq!(result.terms().len(), 2);
1063        assert_eq!(result.terms()[0].term, x);
1064        assert_eq!(result.terms()[0].weight, Scalar::from(8u64));
1065        assert_eq!(result.terms()[1].term, y);
1066        assert_eq!(result.terms()[1].weight, Scalar::from(12u64));
1067    }
1068
1069    #[test]
1070    fn test_pedersen_commitment_expression() {
1071        let x = scalar_var(0);
1072        let r = scalar_var(1);
1073        let g = group_var(0);
1074        let h = group_var(1);
1075
1076        let commitment = x * g + r * h;
1077        assert_eq!(commitment.terms().len(), 2);
1078        assert_eq!(commitment.terms()[0].scalar, x.into());
1079        assert_eq!(commitment.terms()[0].elem, g);
1080        assert_eq!(commitment.terms()[1].scalar, r.into());
1081        assert_eq!(commitment.terms()[1].elem, h);
1082    }
1083
1084    #[test]
1085    fn test_weighted_pedersen_commitment() {
1086        let x = scalar_var(0);
1087        let r = scalar_var(1);
1088        let g = group_var(0);
1089        let h = group_var(1);
1090
1091        let commitment = x * g * Scalar::from(3u64) + r * h * Scalar::from(2u64);
1092        assert_eq!(commitment.terms().len(), 2);
1093        assert_eq!(commitment.terms()[0].term.scalar, x.into());
1094        assert_eq!(commitment.terms()[0].term.elem, g);
1095        assert_eq!(commitment.terms()[0].weight, Scalar::from(3u64));
1096        assert_eq!(commitment.terms()[1].term.scalar, r.into());
1097        assert_eq!(commitment.terms()[1].term.elem, h);
1098        assert_eq!(commitment.terms()[1].weight, Scalar::from(2u64));
1099    }
1100
1101    #[test]
1102    fn test_complex_multi_term_expression() {
1103        let scalars = [scalar_var(0), scalar_var(1), scalar_var(2), scalar_var(3)];
1104        let groups = [group_var(0), group_var(1), group_var(2), group_var(3)];
1105
1106        let expr = scalars[0] * groups[0] + scalars[1] * groups[1] + scalars[2] * groups[2]
1107            - scalars[3] * groups[3];
1108
1109        assert_eq!(expr.terms().len(), 4);
1110
1111        for i in 0..3 {
1112            assert_eq!(expr.terms()[i].term.scalar, scalars[i].into());
1113            assert_eq!(expr.terms()[i].term.elem, groups[i]);
1114            assert_eq!(expr.terms()[i].weight, Scalar::ONE);
1115        }
1116
1117        assert_eq!(expr.terms()[3].term.scalar, scalars[3].into());
1118        assert_eq!(expr.terms()[3].term.elem, groups[3]);
1119        assert_eq!(expr.terms()[3].weight, -Scalar::ONE);
1120    }
1121
1122    #[test]
1123    fn test_chained_addition_with_coefficients() {
1124        let x = scalar_var(0);
1125        let y = scalar_var(1);
1126        let z = scalar_var(2);
1127        let g = group_var(0);
1128        let h = group_var(1);
1129        let k = group_var(2);
1130
1131        let expr =
1132            x * g * Scalar::from(2u64) + y * h * Scalar::from(3u64) + z * k * Scalar::from(5u64);
1133        assert_eq!(expr.terms().len(), 3);
1134
1135        let expected_coeffs = [2u64, 3u64, 5u64];
1136        let expected_scalars = [x, y, z];
1137        let expected_groups = [g, h, k];
1138
1139        for i in 0..3 {
1140            assert_eq!(expr.terms()[i].term.scalar, expected_scalars[i].into());
1141            assert_eq!(expr.terms()[i].term.elem, expected_groups[i]);
1142            assert_eq!(expr.terms()[i].weight, Scalar::from(expected_coeffs[i]));
1143        }
1144    }
1145
1146    #[test]
1147    fn test_mixing_sum_term_and_sum_weighted() {
1148        let x = scalar_var(0);
1149        let y = scalar_var(1);
1150        let z = scalar_var(2);
1151        let g = group_var(0);
1152        let h = group_var(1);
1153        let k = group_var(2);
1154
1155        let basic_sum = x * g + y * h; // Sum<Term>
1156        let weighted_term = z * k * Scalar::from(3u64); // Weighted<Term>
1157        let mixed = basic_sum + weighted_term;
1158
1159        assert_eq!(mixed.terms().len(), 3);
1160        assert_eq!(mixed.terms()[0].term.scalar, x.into());
1161        assert_eq!(mixed.terms()[0].term.elem, g);
1162        assert_eq!(mixed.terms()[0].weight, Scalar::ONE);
1163        assert_eq!(mixed.terms()[1].term.scalar, y.into());
1164        assert_eq!(mixed.terms()[1].term.elem, h);
1165        assert_eq!(mixed.terms()[1].weight, Scalar::ONE);
1166        assert_eq!(mixed.terms()[2].term.scalar, z.into());
1167        assert_eq!(mixed.terms()[2].term.elem, k);
1168        assert_eq!(mixed.terms()[2].weight, Scalar::from(3u64));
1169    }
1170
1171    #[test]
1172    fn test_sum_term_plus_weighted_group_var() {
1173        let x = scalar_var(0);
1174        let y = scalar_var(1);
1175        let g = group_var(0);
1176        let h = group_var(1);
1177
1178        let sum_term = x * g + y * h;
1179        let weighted = h * Scalar::from(3u64);
1180        let result = sum_term + weighted;
1181
1182        assert_eq!(result.terms().len(), 3);
1183        assert_eq!(result.terms()[0].term.scalar, x.into());
1184        assert_eq!(result.terms()[0].term.elem, g);
1185        assert_eq!(result.terms()[0].weight, Scalar::ONE);
1186        assert_eq!(result.terms()[1].term.scalar, y.into());
1187        assert_eq!(result.terms()[1].term.elem, h);
1188        assert_eq!(result.terms()[1].weight, Scalar::ONE);
1189        assert_eq!(result.terms()[2].term.scalar, ScalarTerm::Unit);
1190        assert_eq!(result.terms()[2].term.elem, h);
1191        assert_eq!(result.terms()[2].weight, Scalar::from(3u64));
1192    }
1193
1194    #[test]
1195    fn test_weighted_group_var_plus_sum_term() {
1196        let x = scalar_var(0);
1197        let y = scalar_var(1);
1198        let g = group_var(0);
1199        let h = group_var(1);
1200
1201        let sum_term = x * g + y * h;
1202        let weighted = h * Scalar::from(5u64);
1203        let result = weighted + sum_term;
1204
1205        assert_eq!(result.terms().len(), 3);
1206        assert_eq!(result.terms()[0].term.scalar, x.into());
1207        assert_eq!(result.terms()[0].term.elem, g);
1208        assert_eq!(result.terms()[0].weight, Scalar::ONE);
1209        assert_eq!(result.terms()[1].term.scalar, y.into());
1210        assert_eq!(result.terms()[1].term.elem, h);
1211        assert_eq!(result.terms()[1].weight, Scalar::ONE);
1212        assert_eq!(result.terms()[2].term.scalar, ScalarTerm::Unit);
1213        assert_eq!(result.terms()[2].term.elem, h);
1214        assert_eq!(result.terms()[2].weight, Scalar::from(5u64));
1215    }
1216
1217    #[test]
1218    fn test_sum_weighted_group_var_plus_term() {
1219        let x = scalar_var(0);
1220        let g = group_var(0);
1221        let h = group_var(1);
1222
1223        let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1224        let term = x * g;
1225        let result = sum_weighted + term;
1226
1227        assert_eq!(result.terms().len(), 3);
1228        assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1229        assert_eq!(result.terms()[0].term.elem, g);
1230        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1231        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1232        assert_eq!(result.terms()[1].term.elem, h);
1233        assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1234        assert_eq!(result.terms()[2].term.scalar, x.into());
1235        assert_eq!(result.terms()[2].term.elem, g);
1236        assert_eq!(result.terms()[2].weight, Scalar::ONE);
1237    }
1238
1239    #[test]
1240    fn test_sum_weighted_group_var_plus_sum_term() {
1241        let x = scalar_var(0);
1242        let y = scalar_var(1);
1243        let g = group_var(0);
1244        let h = group_var(1);
1245
1246        let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1247        let sum_term = x * g + y * h;
1248        let result = sum_weighted + sum_term;
1249
1250        assert_eq!(result.terms().len(), 4);
1251        assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1252        assert_eq!(result.terms()[0].term.elem, g);
1253        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1254        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1255        assert_eq!(result.terms()[1].term.elem, h);
1256        assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1257        assert_eq!(result.terms()[2].term.scalar, x.into());
1258        assert_eq!(result.terms()[2].term.elem, g);
1259        assert_eq!(result.terms()[2].weight, Scalar::ONE);
1260        assert_eq!(result.terms()[3].term.scalar, y.into());
1261        assert_eq!(result.terms()[3].term.elem, h);
1262        assert_eq!(result.terms()[3].weight, Scalar::ONE);
1263    }
1264
1265    #[test]
1266    fn test_term_plus_sum_weighted_group_var() {
1267        let x = scalar_var(0);
1268        let g = group_var(0);
1269        let h = group_var(1);
1270
1271        let term = x * g;
1272        let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1273        let result = term + sum_weighted;
1274
1275        assert_eq!(result.terms().len(), 3);
1276        assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1277        assert_eq!(result.terms()[0].term.elem, g);
1278        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1279        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1280        assert_eq!(result.terms()[1].term.elem, h);
1281        assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1282        assert_eq!(result.terms()[2].term.scalar, x.into());
1283        assert_eq!(result.terms()[2].term.elem, g);
1284        assert_eq!(result.terms()[2].weight, Scalar::ONE);
1285    }
1286
1287    #[test]
1288    fn test_sum_term_plus_sum_weighted_group_var() {
1289        let x = scalar_var(0);
1290        let y = scalar_var(1);
1291        let g = group_var(0);
1292        let h = group_var(1);
1293
1294        let sum_term = x * g + y * h;
1295        let sum_weighted = g * Scalar::from(2u64) + h * Scalar::from(3u64);
1296        let result = sum_term + sum_weighted;
1297
1298        assert_eq!(result.terms().len(), 4);
1299        assert_eq!(result.terms()[0].term.scalar, ScalarTerm::Unit);
1300        assert_eq!(result.terms()[0].term.elem, g);
1301        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
1302        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1303        assert_eq!(result.terms()[1].term.elem, h);
1304        assert_eq!(result.terms()[1].weight, Scalar::from(3u64));
1305        assert_eq!(result.terms()[2].term.scalar, x.into());
1306        assert_eq!(result.terms()[2].term.elem, g);
1307        assert_eq!(result.terms()[2].weight, Scalar::ONE);
1308        assert_eq!(result.terms()[3].term.scalar, y.into());
1309        assert_eq!(result.terms()[3].term.elem, h);
1310        assert_eq!(result.terms()[3].weight, Scalar::ONE);
1311    }
1312
1313    #[test]
1314    fn test_scalar_var_minus_scalar_times_group() {
1315        let x = scalar_var(0);
1316        let b = group_var(0);
1317
1318        // Test the user's example: (x - Scalar::from_u128(1u128)) * B
1319        // For now, demonstrate the equivalent: x * B + b * (-1)
1320        let result = x * b + b * (-Scalar::ONE);
1321
1322        assert_eq!(result.terms().len(), 2);
1323        assert_eq!(result.terms()[0].term.scalar, x.into());
1324        assert_eq!(result.terms()[0].term.elem, b);
1325        assert_eq!(result.terms()[0].weight, Scalar::ONE);
1326        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1327        assert_eq!(result.terms()[1].term.elem, b);
1328        assert_eq!(result.terms()[1].weight, -Scalar::ONE);
1329    }
1330
1331    #[test]
1332    fn test_group_var_times_scalar_plus_scalar_times_group() {
1333        let gen__disj1_x_r = scalar_var(0);
1334        let a = group_var(0);
1335        let b = group_var(1);
1336
1337        // Test the user's example: A * Scalar::from_u128(1u128) + gen__disj1_x_r * B
1338        let result = a * Scalar::ONE + gen__disj1_x_r * b;
1339
1340        assert_eq!(result.terms().len(), 2);
1341        // The order is reversed from what we expected due to implementation details
1342        assert_eq!(result.terms()[0].term.scalar, gen__disj1_x_r.into());
1343        assert_eq!(result.terms()[0].term.elem, b);
1344        assert_eq!(result.terms()[0].weight, Scalar::ONE);
1345        assert_eq!(result.terms()[1].term.scalar, ScalarTerm::Unit);
1346        assert_eq!(result.terms()[1].term.elem, a);
1347        assert_eq!(result.terms()[1].weight, Scalar::ONE);
1348    }
1349}