snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait};
17use console::{
18    network::prelude::*,
19    program::{Register, RegisterType},
20};
21
22/// Asserts two operands are equal to each other.
23pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
24/// Asserts two operands are **not** equal to each other.
25pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
26
27enum Variant {
28    AssertEq,
29    AssertNeq,
30}
31
32/// Asserts an operation on two operands.
33#[derive(Clone, PartialEq, Eq, Hash)]
34pub struct AssertInstruction<N: Network, const VARIANT: u8> {
35    /// The operands.
36    operands: Vec<Operand<N>>,
37}
38
39impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
40    /// Initializes a new `assert` instruction.
41    #[inline]
42    pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
43        // Sanity check that the operands is exactly two inputs.
44        ensure!(operands.len() == 2, "Assert instructions must have two operands");
45        // Return the instruction.
46        Ok(Self { operands })
47    }
48
49    /// Returns the opcode.
50    #[inline]
51    pub const fn opcode() -> Opcode {
52        match VARIANT {
53            0 => Opcode::Assert("assert.eq"),
54            1 => Opcode::Assert("assert.neq"),
55            _ => panic!("Invalid 'assert' instruction opcode"),
56        }
57    }
58
59    /// Returns the operands in the operation.
60    #[inline]
61    pub fn operands(&self) -> &[Operand<N>] {
62        // Sanity check that the operands is exactly two inputs.
63        debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
64        // Return the operands.
65        &self.operands
66    }
67
68    /// Returns the destination register.
69    #[inline]
70    pub fn destinations(&self) -> Vec<Register<N>> {
71        vec![]
72    }
73}
74
75impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
76    /// Evaluates the instruction.
77    pub fn evaluate(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
78        // Ensure the number of operands is correct.
79        if self.operands.len() != 2 {
80            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
81        }
82
83        // Retrieve the inputs.
84        let input_a = registers.load(stack, &self.operands[0])?;
85        let input_b = registers.load(stack, &self.operands[1])?;
86
87        // Assert the inputs.
88        match VARIANT {
89            0 => {
90                if input_a != input_b {
91                    bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
92                }
93            }
94            1 => {
95                if input_a == input_b {
96                    bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
97                }
98            }
99            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
100        }
101        Ok(())
102    }
103
104    /// Executes the instruction.
105    pub fn execute<A: circuit::Aleo<Network = N>>(
106        &self,
107        stack: &impl StackTrait<N>,
108        registers: &mut impl RegistersCircuit<N, A>,
109    ) -> Result<()> {
110        // Ensure the number of operands is correct.
111        if self.operands.len() != 2 {
112            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
113        }
114
115        // Retrieve the inputs.
116        let input_a = registers.load_circuit(stack, &self.operands[0])?;
117        let input_b = registers.load_circuit(stack, &self.operands[1])?;
118
119        // Assert the inputs.
120        match VARIANT {
121            0 => A::assert(input_a.is_equal(&input_b)),
122            1 => A::assert(input_a.is_not_equal(&input_b)),
123            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
124        }
125        Ok(())
126    }
127
128    /// Finalizes the instruction.
129    #[inline]
130    pub fn finalize(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
131        self.evaluate(stack, registers)
132    }
133
134    /// Returns the output type from the given program and input types.
135    pub fn output_types(
136        &self,
137        _stack: &impl StackTrait<N>,
138        input_types: &[RegisterType<N>],
139    ) -> Result<Vec<RegisterType<N>>> {
140        // Ensure the number of input types is correct.
141        if input_types.len() != 2 {
142            bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
143        }
144        // Ensure the operands are of the same type.
145        if input_types[0] != input_types[1] {
146            bail!(
147                "Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
148                Self::opcode(),
149                input_types[0],
150                input_types[1]
151            )
152        }
153        // Ensure the number of operands is correct.
154        if self.operands.len() != 2 {
155            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
156        }
157
158        match VARIANT {
159            0 | 1 => Ok(vec![]),
160            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
161        }
162    }
163}
164
165impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
166    /// Parses a string into an operation.
167    fn parse(string: &str) -> ParserResult<Self> {
168        // Parse the opcode from the string.
169        let (string, _) = tag(*Self::opcode())(string)?;
170        // Parse the whitespace from the string.
171        let (string, _) = Sanitizer::parse_whitespaces(string)?;
172        // Parse the first operand from the string.
173        let (string, first) = Operand::parse(string)?;
174        // Parse the whitespace from the string.
175        let (string, _) = Sanitizer::parse_whitespaces(string)?;
176        // Parse the second operand from the string.
177        let (string, second) = Operand::parse(string)?;
178
179        Ok((string, Self { operands: vec![first, second] }))
180    }
181}
182
183impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
184    type Err = Error;
185
186    /// Parses a string into an operation.
187    fn from_str(string: &str) -> Result<Self> {
188        match Self::parse(string) {
189            Ok((remainder, object)) => {
190                // Ensure the remainder is empty.
191                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
192                // Return the object.
193                Ok(object)
194            }
195            Err(error) => bail!("Failed to parse string. {error}"),
196        }
197    }
198}
199
200impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
201    /// Prints the operation as a string.
202    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
203        Display::fmt(self, f)
204    }
205}
206
207impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
208    /// Prints the operation to a string.
209    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
210        // Ensure the number of operands is 2.
211        if self.operands.len() != 2 {
212            return Err(fmt::Error);
213        }
214        // Print the operation.
215        write!(f, "{}", Self::opcode())?;
216        self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
217    }
218}
219
220impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
221    /// Reads the operation from a buffer.
222    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
223        // Initialize the vector for the operands.
224        let mut operands = Vec::with_capacity(2);
225        // Read the operands.
226        for _ in 0..2 {
227            operands.push(Operand::read_le(&mut reader)?);
228        }
229
230        // Return the operation.
231        Ok(Self { operands })
232    }
233}
234
235impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
236    /// Writes the operation to a buffer.
237    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
238        // Ensure the number of operands is 2.
239        if self.operands.len() != 2 {
240            return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
241        }
242        // Write the operands.
243        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use console::network::MainnetV0;
251
252    type CurrentNetwork = MainnetV0;
253
254    #[test]
255    fn test_parse() {
256        let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
257        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
258        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
259        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
260        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
261
262        let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
263        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
264        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
265        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
266        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
267    }
268}