snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs

1// Copyright 2024-2025 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    Opcode,
18    Operand,
19    traits::{RegistersLoad, RegistersLoadCircuit, StackMatches, StackProgram},
20};
21use console::{
22    network::prelude::*,
23    program::{Register, RegisterType},
24};
25
26/// Asserts two operands are equal to each other.
27pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
28/// Asserts two operands are **not** equal to each other.
29pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
30
31enum Variant {
32    AssertEq,
33    AssertNeq,
34}
35
36/// Asserts an operation on two operands.
37#[derive(Clone, PartialEq, Eq, Hash)]
38pub struct AssertInstruction<N: Network, const VARIANT: u8> {
39    /// The operands.
40    operands: Vec<Operand<N>>,
41}
42
43impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
44    /// Initializes a new `assert` instruction.
45    #[inline]
46    pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
47        // Sanity check that the operands is exactly two inputs.
48        ensure!(operands.len() == 2, "Assert instructions must have two operands");
49        // Return the instruction.
50        Ok(Self { operands })
51    }
52
53    /// Returns the opcode.
54    #[inline]
55    pub const fn opcode() -> Opcode {
56        match VARIANT {
57            0 => Opcode::Assert("assert.eq"),
58            1 => Opcode::Assert("assert.neq"),
59            _ => panic!("Invalid 'assert' instruction opcode"),
60        }
61    }
62
63    /// Returns the operands in the operation.
64    #[inline]
65    pub fn operands(&self) -> &[Operand<N>] {
66        // Sanity check that the operands is exactly two inputs.
67        debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
68        // Return the operands.
69        &self.operands
70    }
71
72    /// Returns the destination register.
73    #[inline]
74    pub fn destinations(&self) -> Vec<Register<N>> {
75        vec![]
76    }
77}
78
79impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
80    /// Evaluates the instruction.
81    #[inline]
82    pub fn evaluate(
83        &self,
84        stack: &(impl StackMatches<N> + StackProgram<N>),
85        registers: &mut impl RegistersLoad<N>,
86    ) -> Result<()> {
87        // Ensure the number of operands is correct.
88        if self.operands.len() != 2 {
89            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
90        }
91
92        // Retrieve the inputs.
93        let input_a = registers.load(stack, &self.operands[0])?;
94        let input_b = registers.load(stack, &self.operands[1])?;
95
96        // Assert the inputs.
97        match VARIANT {
98            0 => {
99                if input_a != input_b {
100                    bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
101                }
102            }
103            1 => {
104                if input_a == input_b {
105                    bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
106                }
107            }
108            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
109        }
110        Ok(())
111    }
112
113    /// Executes the instruction.
114    #[inline]
115    pub fn execute<A: circuit::Aleo<Network = N>>(
116        &self,
117        stack: &(impl StackMatches<N> + StackProgram<N>),
118        registers: &mut impl RegistersLoadCircuit<N, A>,
119    ) -> Result<()> {
120        // Ensure the number of operands is correct.
121        if self.operands.len() != 2 {
122            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
123        }
124
125        // Retrieve the inputs.
126        let input_a = registers.load_circuit(stack, &self.operands[0])?;
127        let input_b = registers.load_circuit(stack, &self.operands[1])?;
128
129        // Assert the inputs.
130        match VARIANT {
131            0 => A::assert(input_a.is_equal(&input_b)),
132            1 => A::assert(input_a.is_not_equal(&input_b)),
133            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
134        }
135        Ok(())
136    }
137
138    /// Finalizes the instruction.
139    #[inline]
140    pub fn finalize(
141        &self,
142        stack: &(impl StackMatches<N> + StackProgram<N>),
143        registers: &mut impl RegistersLoad<N>,
144    ) -> Result<()> {
145        self.evaluate(stack, registers)
146    }
147
148    /// Returns the output type from the given program and input types.
149    #[inline]
150    pub fn output_types(
151        &self,
152        _stack: &impl StackProgram<N>,
153        input_types: &[RegisterType<N>],
154    ) -> Result<Vec<RegisterType<N>>> {
155        // Ensure the number of input types is correct.
156        if input_types.len() != 2 {
157            bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
158        }
159        // Ensure the operands are of the same type.
160        if input_types[0] != input_types[1] {
161            bail!(
162                "Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
163                Self::opcode(),
164                input_types[0],
165                input_types[1]
166            )
167        }
168        // Ensure the number of operands is correct.
169        if self.operands.len() != 2 {
170            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
171        }
172
173        match VARIANT {
174            0 | 1 => Ok(vec![]),
175            _ => bail!("Invalid 'assert' variant: {VARIANT}"),
176        }
177    }
178}
179
180impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
181    /// Parses a string into an operation.
182    #[inline]
183    fn parse(string: &str) -> ParserResult<Self> {
184        // Parse the opcode from the string.
185        let (string, _) = tag(*Self::opcode())(string)?;
186        // Parse the whitespace from the string.
187        let (string, _) = Sanitizer::parse_whitespaces(string)?;
188        // Parse the first operand from the string.
189        let (string, first) = Operand::parse(string)?;
190        // Parse the whitespace from the string.
191        let (string, _) = Sanitizer::parse_whitespaces(string)?;
192        // Parse the second operand from the string.
193        let (string, second) = Operand::parse(string)?;
194
195        Ok((string, Self { operands: vec![first, second] }))
196    }
197}
198
199impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
200    type Err = Error;
201
202    /// Parses a string into an operation.
203    #[inline]
204    fn from_str(string: &str) -> Result<Self> {
205        match Self::parse(string) {
206            Ok((remainder, object)) => {
207                // Ensure the remainder is empty.
208                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
209                // Return the object.
210                Ok(object)
211            }
212            Err(error) => bail!("Failed to parse string. {error}"),
213        }
214    }
215}
216
217impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
218    /// Prints the operation as a string.
219    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
220        Display::fmt(self, f)
221    }
222}
223
224impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
225    /// Prints the operation to a string.
226    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
227        // Ensure the number of operands is 2.
228        if self.operands.len() != 2 {
229            return Err(fmt::Error);
230        }
231        // Print the operation.
232        write!(f, "{} ", Self::opcode())?;
233        self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))
234    }
235}
236
237impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
238    /// Reads the operation from a buffer.
239    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
240        // Initialize the vector for the operands.
241        let mut operands = Vec::with_capacity(2);
242        // Read the operands.
243        for _ in 0..2 {
244            operands.push(Operand::read_le(&mut reader)?);
245        }
246
247        // Return the operation.
248        Ok(Self { operands })
249    }
250}
251
252impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
253    /// Writes the operation to a buffer.
254    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
255        // Ensure the number of operands is 2.
256        if self.operands.len() != 2 {
257            return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
258        }
259        // Write the operands.
260        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use console::network::MainnetV0;
268
269    type CurrentNetwork = MainnetV0;
270
271    #[test]
272    fn test_parse() {
273        let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
274        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
275        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
276        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
277        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
278
279        let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
280        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
281        assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
282        assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
283        assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
284    }
285}