snarkvm_circuit_program/data/plaintext/
equal.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<A: Aleo> Equal<Self> for Plaintext<A> {
19    type Output = Boolean<A>;
20
21    /// Returns `true` if `self` and `other` are equal.
22    fn is_equal(&self, other: &Self) -> Self::Output {
23        match (self, other) {
24            (Self::Literal(a, _), Self::Literal(b, _)) => a.is_equal(b),
25            (Self::Struct(a, _), Self::Struct(b, _)) => match a.len() == b.len() {
26                true => {
27                    // Recursively check each member for equality.
28                    let mut equal = Boolean::constant(true);
29                    for ((name_a, plaintext_a), (name_b, plaintext_b)) in a.iter().zip_eq(b.iter()) {
30                        equal = equal & name_a.is_equal(name_b) & plaintext_a.is_equal(plaintext_b);
31                    }
32                    equal
33                }
34                false => Boolean::constant(false),
35            },
36            (Self::Array(a, _), Self::Array(b, _)) => match a.len() == b.len() {
37                true => {
38                    // Recursively check each element for equality.
39                    let mut equal = Boolean::constant(true);
40                    for (plaintext_a, plaintext_b) in a.iter().zip_eq(b.iter()) {
41                        equal &= plaintext_a.is_equal(plaintext_b);
42                    }
43                    equal
44                }
45                false => Boolean::constant(false),
46            },
47            (Self::Literal(..), _) | (Self::Struct(..), _) | (Self::Array(..), _) => Boolean::constant(false),
48        }
49    }
50
51    /// Returns `true` if `self` and `other` are *not* equal.
52    fn is_not_equal(&self, other: &Self) -> Self::Output {
53        match (self, other) {
54            (Self::Literal(a, _), Self::Literal(b, _)) => a.is_not_equal(b),
55            (Self::Struct(a, _), Self::Struct(b, _)) => match a.len() == b.len() {
56                true => {
57                    // Recursively check each member for inequality.
58                    let mut not_equal = Boolean::constant(false);
59                    for ((name_a, plaintext_a), (name_b, plaintext_b)) in a.iter().zip_eq(b.iter()) {
60                        not_equal = not_equal | name_a.is_not_equal(name_b) | plaintext_a.is_not_equal(plaintext_b);
61                    }
62                    not_equal
63                }
64                false => Boolean::constant(true),
65            },
66            (Self::Array(a, _), Self::Array(b, _)) => match a.len() == b.len() {
67                true => {
68                    // Recursively check each element for inequality.
69                    let mut not_equal = Boolean::constant(false);
70                    for (plaintext_a, plaintext_b) in a.iter().zip_eq(b.iter()) {
71                        not_equal |= plaintext_a.is_not_equal(plaintext_b);
72                    }
73                    not_equal
74                }
75                false => Boolean::constant(true),
76            },
77            (Self::Literal(..), _) | (Self::Struct(..), _) | (Self::Array(..), _) => Boolean::constant(true),
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::Circuit;
86
87    fn sample_plaintext(mode: Mode) -> Plaintext<Circuit> {
88        let plaintext = console::Plaintext::<<Circuit as Environment>::Network>::from_str(
89            r"{
90    a: true,
91    b: 123456789field,
92    c: 0group,
93    d: {
94        e: true,
95        f: 123456789field,
96        g: 0group
97    }
98}",
99        )
100        .unwrap();
101        Plaintext::new(mode, plaintext)
102    }
103
104    fn sample_mismatched_plaintext(mode: Mode) -> Plaintext<Circuit> {
105        let plaintext = console::Plaintext::<<Circuit as Environment>::Network>::from_str(
106            r"{
107    a: false,
108    b: 123456789field,
109    c: 0group,
110    d: {
111        e: true,
112        f: 123456789field,
113        g: 0group
114    }
115}",
116        )
117        .unwrap();
118        Plaintext::new(mode, plaintext)
119    }
120
121    fn check_is_equal(
122        mode: Mode,
123        num_constants: u64,
124        num_public: u64,
125        num_private: u64,
126        num_constraints: u64,
127    ) -> Result<()> {
128        // Sample the plaintext.
129        let plaintext = sample_plaintext(mode);
130        let mismatched_plaintext = sample_mismatched_plaintext(mode);
131
132        Circuit::scope(format!("{mode}"), || {
133            let candidate = plaintext.is_equal(&plaintext);
134            assert!(candidate.eject_value());
135            assert_scope!(<=num_constants, <=num_public, <=num_private, <=num_constraints);
136        });
137
138        Circuit::scope(format!("{mode}"), || {
139            let candidate = plaintext.is_equal(&mismatched_plaintext);
140            assert!(!candidate.eject_value());
141            assert_scope!(<=num_constants, <=num_public, <=num_private, <=num_constraints);
142        });
143
144        Circuit::reset();
145        Ok(())
146    }
147
148    fn check_is_not_equal(
149        mode: Mode,
150        num_constants: u64,
151        num_public: u64,
152        num_private: u64,
153        num_constraints: u64,
154    ) -> Result<()> {
155        // Sample the plaintext.
156        let plaintext = sample_plaintext(mode);
157        let mismatched_plaintext = sample_mismatched_plaintext(mode);
158
159        Circuit::scope(format!("{mode}"), || {
160            let candidate = plaintext.is_not_equal(&mismatched_plaintext);
161            assert!(candidate.eject_value());
162            assert_scope!(<=num_constants, <=num_public, <=num_private, <=num_constraints);
163        });
164
165        Circuit::scope(format!("{mode}"), || {
166            let candidate = plaintext.is_not_equal(&plaintext);
167            assert!(!candidate.eject_value());
168            assert_scope!(<=num_constants, <=num_public, <=num_private, <=num_constraints);
169        });
170
171        Circuit::reset();
172        Ok(())
173    }
174
175    #[test]
176    fn test_is_equal_constant() -> Result<()> {
177        check_is_equal(Mode::Constant, 13, 0, 0, 0)
178    }
179
180    #[test]
181    fn test_is_equal_public() -> Result<()> {
182        check_is_equal(Mode::Public, 13, 0, 21, 21)
183    }
184
185    #[test]
186    fn test_is_equal_private() -> Result<()> {
187        check_is_equal(Mode::Private, 13, 0, 21, 21)
188    }
189
190    #[test]
191    fn test_is_not_equal_constant() -> Result<()> {
192        check_is_not_equal(Mode::Constant, 13, 0, 0, 0)
193    }
194
195    #[test]
196    fn test_is_not_equal_public() -> Result<()> {
197        check_is_not_equal(Mode::Public, 13, 0, 21, 21)
198    }
199
200    #[test]
201    fn test_is_not_equal_private() -> Result<()> {
202        check_is_not_equal(Mode::Private, 13, 0, 21, 21)
203    }
204}