snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs1use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait};
17use console::{
18 network::prelude::*,
19 program::{Register, RegisterType},
20};
21
22pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
24pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
26
27enum Variant {
28 AssertEq,
29 AssertNeq,
30}
31
32#[derive(Clone, PartialEq, Eq, Hash)]
34pub struct AssertInstruction<N: Network, const VARIANT: u8> {
35 operands: Vec<Operand<N>>,
37}
38
39impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
40 #[inline]
42 pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
43 ensure!(operands.len() == 2, "Assert instructions must have two operands");
45 Ok(Self { operands })
47 }
48
49 #[inline]
51 pub const fn opcode() -> Opcode {
52 match VARIANT {
53 0 => Opcode::Assert("assert.eq"),
54 1 => Opcode::Assert("assert.neq"),
55 _ => panic!("Invalid 'assert' instruction opcode"),
56 }
57 }
58
59 #[inline]
61 pub fn operands(&self) -> &[Operand<N>] {
62 debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
64 &self.operands
66 }
67
68 #[inline]
70 pub fn destinations(&self) -> Vec<Register<N>> {
71 vec![]
72 }
73}
74
75impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
76 pub fn evaluate(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
78 if self.operands.len() != 2 {
80 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
81 }
82
83 let input_a = registers.load(stack, &self.operands[0])?;
85 let input_b = registers.load(stack, &self.operands[1])?;
86
87 match VARIANT {
89 0 => {
90 if input_a != input_b {
91 bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
92 }
93 }
94 1 => {
95 if input_a == input_b {
96 bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
97 }
98 }
99 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
100 }
101 Ok(())
102 }
103
104 pub fn execute<A: circuit::Aleo<Network = N>>(
106 &self,
107 stack: &impl StackTrait<N>,
108 registers: &mut impl RegistersCircuit<N, A>,
109 ) -> Result<()> {
110 if self.operands.len() != 2 {
112 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
113 }
114
115 let input_a = registers.load_circuit(stack, &self.operands[0])?;
117 let input_b = registers.load_circuit(stack, &self.operands[1])?;
118
119 match VARIANT {
121 0 => A::assert(input_a.is_equal(&input_b)),
122 1 => A::assert(input_a.is_not_equal(&input_b)),
123 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
124 }
125 Ok(())
126 }
127
128 #[inline]
130 pub fn finalize(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
131 self.evaluate(stack, registers)
132 }
133
134 pub fn output_types(
136 &self,
137 _stack: &impl StackTrait<N>,
138 input_types: &[RegisterType<N>],
139 ) -> Result<Vec<RegisterType<N>>> {
140 if input_types.len() != 2 {
142 bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
143 }
144 if input_types[0] != input_types[1] {
146 bail!(
147 "Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
148 Self::opcode(),
149 input_types[0],
150 input_types[1]
151 )
152 }
153 if self.operands.len() != 2 {
155 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
156 }
157
158 match VARIANT {
159 0 | 1 => Ok(vec![]),
160 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
161 }
162 }
163}
164
165impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
166 fn parse(string: &str) -> ParserResult<Self> {
168 let (string, _) = tag(*Self::opcode())(string)?;
170 let (string, _) = Sanitizer::parse_whitespaces(string)?;
172 let (string, first) = Operand::parse(string)?;
174 let (string, _) = Sanitizer::parse_whitespaces(string)?;
176 let (string, second) = Operand::parse(string)?;
178
179 Ok((string, Self { operands: vec![first, second] }))
180 }
181}
182
183impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
184 type Err = Error;
185
186 fn from_str(string: &str) -> Result<Self> {
188 match Self::parse(string) {
189 Ok((remainder, object)) => {
190 ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
192 Ok(object)
194 }
195 Err(error) => bail!("Failed to parse string. {error}"),
196 }
197 }
198}
199
200impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
201 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
203 Display::fmt(self, f)
204 }
205}
206
207impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
208 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
210 if self.operands.len() != 2 {
212 return Err(fmt::Error);
213 }
214 write!(f, "{}", Self::opcode())?;
216 self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
217 }
218}
219
220impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
221 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
223 let mut operands = Vec::with_capacity(2);
225 for _ in 0..2 {
227 operands.push(Operand::read_le(&mut reader)?);
228 }
229
230 Ok(Self { operands })
232 }
233}
234
235impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
236 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
238 if self.operands.len() != 2 {
240 return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
241 }
242 self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use console::network::MainnetV0;
251
252 type CurrentNetwork = MainnetV0;
253
254 #[test]
255 fn test_parse() {
256 let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
257 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
258 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
259 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
260 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
261
262 let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
263 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
264 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
265 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
266 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
267 }
268}