Skip to main content

snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait, register_types_equivalent};
17use console::{
18    network::prelude::*,
19    program::{Register, RegisterType},
20};
21
22/// Asserts two operands are equal to each other.
23pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
24/// Asserts two operands are **not** equal to each other.
25pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
26
27enum Variant {
28    AssertEq,
29    AssertNeq,
30}
31
32/// Asserts an operation on two operands.
33#[derive(Clone, PartialEq, Eq, Hash)]
34pub struct AssertInstruction<N: Network, const VARIANT: u8> {
35    /// The operands.
36    operands: Vec<Operand<N>>,
37}
38
39impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
40    /// Initializes a new `assert` instruction.
41    #[inline]
42    pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
43        // Sanity check that the operands is exactly two inputs.
44        ensure!(operands.len() == 2, "Assert instructions must have two operands");
45        // Return the instruction.
46        Ok(Self { operands })
47    }
48
49    /// Returns the opcode.
50    #[inline]
51    pub const fn opcode() -> Opcode {
52        match VARIANT {
53            0 => Opcode::Assert("assert.eq"),
54            1 => Opcode::Assert("assert.neq"),
55            _ => panic!("Invalid 'assert' instruction opcode"),
56        }
57    }
58
59    /// Returns the operands in the operation.
60    #[inline]
61    pub fn operands(&self) -> &[Operand<N>] {
62        // Sanity check that the operands is exactly two inputs.
63        debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
64        // Return the operands.
65        &self.operands
66    }
67
68    /// Returns the destination register.
69    #[inline]
70    pub fn destinations(&self) -> Vec<Register<N>> {
71        vec![]
72    }
73
74    /// Returns whether this instruction refers to an external struct.
75    #[inline]
76    pub fn contains_external_struct(&self) -> bool {
77        false
78    }
79}
80
81impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
82    /// Evaluates the instruction.
83    pub fn evaluate(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
84        // Ensure the number of operands is correct.
85        if self.operands.len() != 2 {
86            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
87        }
88
89        // Retrieve the inputs.
90        let input_a = registers.load(stack, &self.operands[0])?;
91        let input_b = registers.load(stack, &self.operands[1])?;
92
93        // Assert the inputs.
94        match VARIANT {
95            0 => {
96                if input_a != input_b {
97                    bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
98                }
99            }
100            1 => {
101                if input_a == input_b {
102                    bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
103                }
104            }
105            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
106        }
107        Ok(())
108    }
109
110    /// Executes the instruction.
111    pub fn execute<A: circuit::Aleo<Network = N>>(
112        &self,
113        stack: &impl StackTrait<N>,
114        registers: &mut impl RegistersCircuit<N, A>,
115    ) -> Result<()> {
116        // Ensure the number of operands is correct.
117        if self.operands.len() != 2 {
118            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
119        }
120
121        // Retrieve the inputs.
122        let input_a = registers.load_circuit(stack, &self.operands[0])?;
123        let input_b = registers.load_circuit(stack, &self.operands[1])?;
124
125        // Assert the inputs.
126        match VARIANT {
127            0 => A::assert(input_a.is_equal(&input_b)),
128            1 => A::assert(input_a.is_not_equal(&input_b)),
129            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
130        }
131        Ok(())
132    }
133
134    /// Finalizes the instruction.
135    #[inline]
136    pub fn finalize(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
137        self.evaluate(stack, registers)
138    }
139
140    /// Returns the output type from the given program and input types.
141    pub fn output_types(
142        &self,
143        stack: &impl StackTrait<N>,
144        input_types: &[RegisterType<N>],
145    ) -> Result<Vec<RegisterType<N>>> {
146        // Ensure the number of input types is correct.
147        if input_types.len() != 2 {
148            bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
149        }
150        // Ensure the operands have equivalent types.
151        if !register_types_equivalent(stack, &input_types[0], stack, &input_types[1])? {
152            bail!(
153                "Instruction '{}' expects inputs of equivalent types. Found inputs of type '{}' and '{}'",
154                Self::opcode(),
155                input_types[0],
156                input_types[1]
157            )
158        }
159        // Ensure the number of operands is correct.
160        if self.operands.len() != 2 {
161            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
162        }
163
164        match VARIANT {
165            0 | 1 => Ok(vec![]),
166            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
167        }
168    }
169}
170
171impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
172    /// Parses a string into an operation.
173    fn parse(string: &str) -> ParserResult<Self> {
174        // Parse the opcode from the string.
175        let (string, _) = tag(*Self::opcode())(string)?;
176        // Parse the whitespace from the string.
177        let (string, _) = Sanitizer::parse_whitespaces(string)?;
178        // Parse the first operand from the string.
179        let (string, first) = Operand::parse(string)?;
180        // Parse the whitespace from the string.
181        let (string, _) = Sanitizer::parse_whitespaces(string)?;
182        // Parse the second operand from the string.
183        let (string, second) = Operand::parse(string)?;
184
185        Ok((string, Self { operands: vec![first, second] }))
186    }
187}
188
189impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
190    type Err = Error;
191
192    /// Parses a string into an operation.
193    fn from_str(string: &str) -> Result<Self> {
194        match Self::parse(string) {
195            Ok((remainder, object)) => {
196                // Ensure the remainder is empty.
197                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
198                // Return the object.
199                Ok(object)
200            }
201            Err(error) => bail!("Failed to parse string. {error}"),
202        }
203    }
204}
205
206impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
207    /// Prints the operation as a string.
208    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
209        Display::fmt(self, f)
210    }
211}
212
213impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
214    /// Prints the operation to a string.
215    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
216        // Ensure the number of operands is 2.
217        if self.operands.len() != 2 {
218            return Err(fmt::Error);
219        }
220        // Print the operation.
221        write!(f, "{}", Self::opcode())?;
222        self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
223    }
224}
225
226impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
227    /// Reads the operation from a buffer.
228    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
229        // Initialize the vector for the operands.
230        let mut operands = Vec::with_capacity(2);
231        // Read the operands.
232        for _ in 0..2 {
233            operands.push(Operand::read_le(&mut reader)?);
234        }
235
236        // Return the operation.
237        Ok(Self { operands })
238    }
239}
240
241impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
242    /// Writes the operation to a buffer.
243    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
244        // Ensure the number of operands is 2.
245        if self.operands.len() != 2 {
246            return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
247        }
248        // Write the operands.
249        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use console::network::MainnetV0;
257
258    type CurrentNetwork = MainnetV0;
259
260    #[test]
261    fn test_parse() {
262        let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
263        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
264        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
265        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
266        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
267
268        let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
269        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
270        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
271        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
272        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
273    }
274}