snarkvm_synthesizer_program/logic/instruction/operation/
is.rs

1// Copyright 2024-2025 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    Opcode,
18    Operand,
19    traits::{RegistersLoad, RegistersLoadCircuit, RegistersStore, RegistersStoreCircuit, StackMatches, StackProgram},
20};
21use console::{
22    network::prelude::*,
23    program::{Literal, LiteralType, Plaintext, PlaintextType, Register, RegisterType, Value},
24    types::Boolean,
25};
26
27/// Computes whether `first` equals `second` as a boolean, storing the outcome in `destination`.
28pub type IsEq<N> = IsInstruction<N, { Variant::IsEq as u8 }>;
29/// Computes whether `first` does **not** equals `second` as a boolean, storing the outcome in `destination`.
30pub type IsNeq<N> = IsInstruction<N, { Variant::IsNeq as u8 }>;
31
32enum Variant {
33    IsEq,
34    IsNeq,
35}
36
37/// Computes an equality operation on two operands, and stores the outcome in `destination`.
38#[derive(Clone, PartialEq, Eq, Hash)]
39pub struct IsInstruction<N: Network, const VARIANT: u8> {
40    /// The operands.
41    operands: Vec<Operand<N>>,
42    /// The destination register.
43    destination: Register<N>,
44}
45
46impl<N: Network, const VARIANT: u8> IsInstruction<N, VARIANT> {
47    /// Initializes a new `is` instruction.
48    #[inline]
49    pub fn new(operands: Vec<Operand<N>>, destination: Register<N>) -> Result<Self> {
50        // Sanity check the number of operands.
51        ensure!(operands.len() == 2, "Instruction '{}' must have two operands", Self::opcode());
52        // Return the instruction.
53        Ok(Self { operands, destination })
54    }
55
56    /// Returns the opcode.
57    #[inline]
58    pub const fn opcode() -> Opcode {
59        match VARIANT {
60            0 => Opcode::Is("is.eq"),
61            1 => Opcode::Is("is.neq"),
62            _ => panic!("Invalid 'is' instruction opcode"),
63        }
64    }
65
66    /// Returns the operands in the operation.
67    #[inline]
68    pub fn operands(&self) -> &[Operand<N>] {
69        // Sanity check that the operands is exactly two inputs.
70        debug_assert!(self.operands.len() == 2, "Instruction '{}' must have two operands", Self::opcode());
71        // Return the operands.
72        &self.operands
73    }
74
75    /// Returns the destination register.
76    #[inline]
77    pub fn destinations(&self) -> Vec<Register<N>> {
78        vec![self.destination.clone()]
79    }
80}
81
82impl<N: Network, const VARIANT: u8> IsInstruction<N, VARIANT> {
83    /// Evaluates the instruction.
84    #[inline]
85    pub fn evaluate(
86        &self,
87        stack: &(impl StackMatches<N> + StackProgram<N>),
88        registers: &mut (impl RegistersLoad<N> + RegistersStore<N>),
89    ) -> Result<()> {
90        // Ensure the number of operands is correct.
91        if self.operands.len() != 2 {
92            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
93        }
94
95        // Retrieve the inputs.
96        let input_a = registers.load(stack, &self.operands[0])?;
97        let input_b = registers.load(stack, &self.operands[1])?;
98
99        // Check the inputs.
100        let output = match VARIANT {
101            0 => Literal::Boolean(Boolean::new(input_a == input_b)),
102            1 => Literal::Boolean(Boolean::new(input_a != input_b)),
103            _ => bail!("Invalid 'is' variant: {VARIANT}"),
104        };
105        // Store the output.
106        registers.store(stack, &self.destination, Value::Plaintext(Plaintext::from(output)))
107    }
108
109    /// Executes the instruction.
110    #[inline]
111    pub fn execute<A: circuit::Aleo<Network = N>>(
112        &self,
113        stack: &(impl StackMatches<N> + StackProgram<N>),
114        registers: &mut (impl RegistersLoadCircuit<N, A> + RegistersStoreCircuit<N, A>),
115    ) -> Result<()> {
116        // Ensure the number of operands is correct.
117        if self.operands.len() != 2 {
118            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
119        }
120
121        // Retrieve the inputs.
122        let input_a = registers.load_circuit(stack, &self.operands[0])?;
123        let input_b = registers.load_circuit(stack, &self.operands[1])?;
124
125        // Check the inputs.
126        let output = match VARIANT {
127            0 => circuit::Literal::Boolean(input_a.is_equal(&input_b)),
128            1 => circuit::Literal::Boolean(input_a.is_not_equal(&input_b)),
129            _ => bail!("Invalid 'is' variant: {VARIANT}"),
130        };
131        // Convert the output to a stack value.
132        let output = circuit::Value::Plaintext(circuit::Plaintext::Literal(output, Default::default()));
133        // Store the output.
134        registers.store_circuit(stack, &self.destination, output)
135    }
136
137    /// Finalizes the instruction.
138    #[inline]
139    pub fn finalize(
140        &self,
141        stack: &(impl StackMatches<N> + StackProgram<N>),
142        registers: &mut (impl RegistersLoad<N> + RegistersStore<N>),
143    ) -> Result<()> {
144        self.evaluate(stack, registers)
145    }
146
147    /// Returns the output type from the given program and input types.
148    #[inline]
149    pub fn output_types(
150        &self,
151        _stack: &impl StackProgram<N>,
152        input_types: &[RegisterType<N>],
153    ) -> Result<Vec<RegisterType<N>>> {
154        // Ensure the number of input types is correct.
155        if input_types.len() != 2 {
156            bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
157        }
158        // Ensure the operands are of the same type.
159        if input_types[0] != input_types[1] {
160            bail!(
161                "Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
162                Self::opcode(),
163                input_types[0],
164                input_types[1]
165            )
166        }
167        // Ensure the number of operands is correct.
168        if self.operands.len() != 2 {
169            bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
170        }
171
172        match VARIANT {
173            0 | 1 => Ok(vec![RegisterType::Plaintext(PlaintextType::Literal(LiteralType::Boolean))]),
174            _ => bail!("Invalid 'is' variant: {VARIANT}"),
175        }
176    }
177}
178
179impl<N: Network, const VARIANT: u8> Parser for IsInstruction<N, VARIANT> {
180    /// Parses a string into an operation.
181    #[inline]
182    fn parse(string: &str) -> ParserResult<Self> {
183        // Parse the opcode from the string.
184        let (string, _) = tag(*Self::opcode())(string)?;
185        // Parse the whitespace from the string.
186        let (string, _) = Sanitizer::parse_whitespaces(string)?;
187        // Parse the first operand from the string.
188        let (string, first) = Operand::parse(string)?;
189        // Parse the whitespace from the string.
190        let (string, _) = Sanitizer::parse_whitespaces(string)?;
191        // Parse the second operand from the string.
192        let (string, second) = Operand::parse(string)?;
193        // Parse the whitespace from the string.
194        let (string, _) = Sanitizer::parse_whitespaces(string)?;
195        // Parse the "into" from the string.
196        let (string, _) = tag("into")(string)?;
197        // Parse the whitespace from the string.
198        let (string, _) = Sanitizer::parse_whitespaces(string)?;
199        // Parse the destination register from the string.
200        let (string, destination) = Register::parse(string)?;
201
202        Ok((string, Self { operands: vec![first, second], destination }))
203    }
204}
205
206impl<N: Network, const VARIANT: u8> FromStr for IsInstruction<N, VARIANT> {
207    type Err = Error;
208
209    /// Parses a string into an operation.
210    #[inline]
211    fn from_str(string: &str) -> Result<Self> {
212        match Self::parse(string) {
213            Ok((remainder, object)) => {
214                // Ensure the remainder is empty.
215                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
216                // Return the object.
217                Ok(object)
218            }
219            Err(error) => bail!("Failed to parse string. {error}"),
220        }
221    }
222}
223
224impl<N: Network, const VARIANT: u8> Debug for IsInstruction<N, VARIANT> {
225    /// Prints the operation as a string.
226    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
227        Display::fmt(self, f)
228    }
229}
230
231impl<N: Network, const VARIANT: u8> Display for IsInstruction<N, VARIANT> {
232    /// Prints the operation to a string.
233    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
234        // Ensure the number of operands is 2.
235        if self.operands.len() != 2 {
236            return Err(fmt::Error);
237        }
238        // Print the operation.
239        write!(f, "{} ", Self::opcode())?;
240        self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))?;
241        write!(f, "into {}", self.destination)
242    }
243}
244
245impl<N: Network, const VARIANT: u8> FromBytes for IsInstruction<N, VARIANT> {
246    /// Reads the operation from a buffer.
247    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
248        // Initialize the vector for the operands.
249        let mut operands = Vec::with_capacity(2);
250        // Read the operands.
251        for _ in 0..2 {
252            operands.push(Operand::read_le(&mut reader)?);
253        }
254        // Read the destination register.
255        let destination = Register::read_le(&mut reader)?;
256
257        // Return the operation.
258        Ok(Self { operands, destination })
259    }
260}
261
262impl<N: Network, const VARIANT: u8> ToBytes for IsInstruction<N, VARIANT> {
263    /// Writes the operation to a buffer.
264    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
265        // Ensure the number of operands is 2.
266        if self.operands.len() != 2 {
267            return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
268        }
269        // Write the operands.
270        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
271        // Write the destination register.
272        self.destination.write_le(&mut writer)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use console::network::MainnetV0;
280
281    type CurrentNetwork = MainnetV0;
282
283    #[test]
284    fn test_parse() {
285        let (string, is) = IsEq::<CurrentNetwork>::parse("is.eq r0 r1 into r2").unwrap();
286        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
287        assert_eq!(is.operands.len(), 2, "The number of operands is incorrect");
288        assert_eq!(is.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
289        assert_eq!(is.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
290        assert_eq!(is.destination, Register::Locator(2), "The destination register is incorrect");
291
292        let (string, is) = IsNeq::<CurrentNetwork>::parse("is.neq r0 r1 into r2").unwrap();
293        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
294        assert_eq!(is.operands.len(), 2, "The number of operands is incorrect");
295        assert_eq!(is.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
296        assert_eq!(is.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
297        assert_eq!(is.destination, Register::Locator(2), "The destination register is incorrect");
298    }
299}