Skip to main content

snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait, register_types_equivalent};
17use console::{
18    network::prelude::*,
19    program::{Register, RegisterType},
20};
21use snarkvm_synthesizer_error::*;
22
23/// Asserts two operands are equal to each other.
24pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
25/// Asserts two operands are **not** equal to each other.
26pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
27
28enum Variant {
29    AssertEq,
30    AssertNeq,
31}
32
33/// Asserts an operation on two operands.
34#[derive(Clone, PartialEq, Eq, Hash)]
35pub struct AssertInstruction<N: Network, const VARIANT: u8> {
36    /// The operands.
37    operands: Vec<Operand<N>>,
38}
39
40impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
41    /// Initializes a new `assert` instruction.
42    #[inline]
43    pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
44        // Sanity check that the operands is exactly two inputs.
45        ensure!(operands.len() == 2, "Assert instructions must have two operands");
46        // Return the instruction.
47        Ok(Self { operands })
48    }
49
50    /// Returns the opcode.
51    #[inline]
52    pub const fn opcode() -> Opcode {
53        match VARIANT {
54            0 => Opcode::Assert("assert.eq"),
55            1 => Opcode::Assert("assert.neq"),
56            _ => panic!("Invalid 'assert' instruction opcode"),
57        }
58    }
59
60    /// Returns the operands in the operation.
61    #[inline]
62    pub fn operands(&self) -> &[Operand<N>] {
63        // Sanity check that the operands is exactly two inputs.
64        debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
65        // Return the operands.
66        &self.operands
67    }
68
69    /// Returns the destination register.
70    #[inline]
71    pub fn destinations(&self) -> Vec<Register<N>> {
72        vec![]
73    }
74
75    /// Returns whether this instruction refers to an external struct.
76    #[inline]
77    pub fn contains_external_struct(&self) -> bool {
78        false
79    }
80}
81
82impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
83    /// Evaluates the instruction.
84    pub fn evaluate(
85        &self,
86        stack: &impl StackTrait<N>,
87        registers: &mut impl RegistersTrait<N>,
88    ) -> Result<(), EvalError> {
89        // Ensure the number of operands is correct.
90        if self.operands.len() != 2 {
91            return Err(anyhow!(
92                "Instruction '{}' expects 2 operands, found {} operands",
93                Self::opcode(),
94                self.operands.len()
95            )
96            .into());
97        }
98
99        // Retrieve the inputs.
100        let input_a = registers.load(stack, &self.operands[0])?;
101        let input_b = registers.load(stack, &self.operands[1])?;
102
103        // Assert the inputs.
104        match VARIANT {
105            0 => {
106                if input_a != input_b {
107                    return Err(AssertError::Eq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
108                }
109            }
110            1 => {
111                if input_a == input_b {
112                    return Err(AssertError::Neq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
113                }
114            }
115            _ => return Err(AssertError::Invalid { variant: VARIANT }.into()),
116        }
117        Ok(())
118    }
119
120    /// Executes the instruction.
121    pub fn execute<A: circuit::Aleo<Network = N>>(
122        &self,
123        stack: &impl StackTrait<N>,
124        registers: &mut impl RegistersCircuit<N, A>,
125    ) -> Result<(), ExecError> {
126        // Ensure the number of operands is correct.
127        if self.operands.len() != 2 {
128            return Err(anyhow!(
129                "Instruction '{}' expects 2 operands, found {} operands",
130                Self::opcode(),
131                self.operands.len()
132            )
133            .into());
134        }
135
136        // Retrieve the inputs.
137        let input_a = registers.load_circuit(stack, &self.operands[0])?;
138        let input_b = registers.load_circuit(stack, &self.operands[1])?;
139
140        // Assert the inputs.
141        match VARIANT {
142            0 => A::assert(input_a.is_equal(&input_b))?,
143            1 => A::assert(input_a.is_not_equal(&input_b))?,
144            _ => return Err(anyhow!("Invalid 'assert' variant: {VARIANT}").into()),
145        }
146        Ok(())
147    }
148
149    /// Finalizes the instruction.
150    #[inline]
151    pub fn finalize(
152        &self,
153        stack: &impl StackTrait<N>,
154        registers: &mut impl RegistersTrait<N>,
155    ) -> Result<(), FinalizeError> {
156        self.evaluate(stack, registers)?;
157        Ok(())
158    }
159
160    /// Returns the output type from the given program and input types.
161    pub fn output_types(
162        &self,
163        stack: &impl StackTrait<N>,
164        input_types: &[RegisterType<N>],
165    ) -> Result<Vec<RegisterType<N>>> {
166        // Ensure the number of input types is correct.
167        if input_types.len() != 2 {
168            bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
169        }
170        // Ensure the operands have equivalent types.
171        if !register_types_equivalent(stack, &input_types[0], stack, &input_types[1])? {
172            bail!(
173                "Instruction '{}' expects inputs of equivalent types. Found inputs of type '{}' and '{}'",
174                Self::opcode(),
175                input_types[0],
176                input_types[1]
177            )
178        }
179        // Ensure the number of operands is correct.
180        if self.operands.len() != 2 {
181            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
182        }
183
184        match VARIANT {
185            0 | 1 => Ok(vec![]),
186            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
187        }
188    }
189}
190
191impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
192    /// Parses a string into an operation.
193    fn parse(string: &str) -> ParserResult<Self> {
194        // Parse the opcode from the string.
195        let (string, _) = tag(*Self::opcode())(string)?;
196        // Parse the whitespace from the string.
197        let (string, _) = Sanitizer::parse_whitespaces(string)?;
198        // Parse the first operand from the string.
199        let (string, first) = Operand::parse(string)?;
200        // Parse the whitespace from the string.
201        let (string, _) = Sanitizer::parse_whitespaces(string)?;
202        // Parse the second operand from the string.
203        let (string, second) = Operand::parse(string)?;
204
205        Ok((string, Self { operands: vec![first, second] }))
206    }
207}
208
209impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
210    type Err = Error;
211
212    /// Parses a string into an operation.
213    fn from_str(string: &str) -> Result<Self> {
214        match Self::parse(string) {
215            Ok((remainder, object)) => {
216                // Ensure the remainder is empty.
217                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
218                // Return the object.
219                Ok(object)
220            }
221            Err(error) => bail!("Failed to parse string. {error}"),
222        }
223    }
224}
225
226impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
227    /// Prints the operation as a string.
228    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
229        Display::fmt(self, f)
230    }
231}
232
233impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
234    /// Prints the operation to a string.
235    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
236        // Ensure the number of operands is 2.
237        if self.operands.len() != 2 {
238            return Err(fmt::Error);
239        }
240        // Print the operation.
241        write!(f, "{}", Self::opcode())?;
242        self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
243    }
244}
245
246impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
247    /// Reads the operation from a buffer.
248    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
249        // Initialize the vector for the operands.
250        let mut operands = Vec::with_capacity(2);
251        // Read the operands.
252        for _ in 0..2 {
253            operands.push(Operand::read_le(&mut reader)?);
254        }
255
256        // Return the operation.
257        Ok(Self { operands })
258    }
259}
260
261impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
262    /// Writes the operation to a buffer.
263    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
264        // Ensure the number of operands is 2.
265        if self.operands.len() != 2 {
266            return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
267        }
268        // Write the operands.
269        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use console::network::MainnetV0;
277
278    type CurrentNetwork = MainnetV0;
279
280    #[test]
281    fn test_parse() {
282        let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
283        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
284        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
285        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
286        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
287
288        let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
289        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
290        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
291        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
292        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
293    }
294}