snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs1use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait, register_types_equivalent};
17use console::{
18 network::prelude::*,
19 program::{Register, RegisterType},
20};
21
22pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
24pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
26
27enum Variant {
28 AssertEq,
29 AssertNeq,
30}
31
32#[derive(Clone, PartialEq, Eq, Hash)]
34pub struct AssertInstruction<N: Network, const VARIANT: u8> {
35 operands: Vec<Operand<N>>,
37}
38
39impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
40 #[inline]
42 pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
43 ensure!(operands.len() == 2, "Assert instructions must have two operands");
45 Ok(Self { operands })
47 }
48
49 #[inline]
51 pub const fn opcode() -> Opcode {
52 match VARIANT {
53 0 => Opcode::Assert("assert.eq"),
54 1 => Opcode::Assert("assert.neq"),
55 _ => panic!("Invalid 'assert' instruction opcode"),
56 }
57 }
58
59 #[inline]
61 pub fn operands(&self) -> &[Operand<N>] {
62 debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
64 &self.operands
66 }
67
68 #[inline]
70 pub fn destinations(&self) -> Vec<Register<N>> {
71 vec![]
72 }
73
74 #[inline]
76 pub fn contains_external_struct(&self) -> bool {
77 false
78 }
79}
80
81impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
82 pub fn evaluate(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
84 if self.operands.len() != 2 {
86 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
87 }
88
89 let input_a = registers.load(stack, &self.operands[0])?;
91 let input_b = registers.load(stack, &self.operands[1])?;
92
93 match VARIANT {
95 0 => {
96 if input_a != input_b {
97 bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
98 }
99 }
100 1 => {
101 if input_a == input_b {
102 bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
103 }
104 }
105 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
106 }
107 Ok(())
108 }
109
110 pub fn execute<A: circuit::Aleo<Network = N>>(
112 &self,
113 stack: &impl StackTrait<N>,
114 registers: &mut impl RegistersCircuit<N, A>,
115 ) -> Result<()> {
116 if self.operands.len() != 2 {
118 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
119 }
120
121 let input_a = registers.load_circuit(stack, &self.operands[0])?;
123 let input_b = registers.load_circuit(stack, &self.operands[1])?;
124
125 match VARIANT {
127 0 => A::assert(input_a.is_equal(&input_b)),
128 1 => A::assert(input_a.is_not_equal(&input_b)),
129 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
130 }
131 Ok(())
132 }
133
134 #[inline]
136 pub fn finalize(&self, stack: &impl StackTrait<N>, registers: &mut impl RegistersTrait<N>) -> Result<()> {
137 self.evaluate(stack, registers)
138 }
139
140 pub fn output_types(
142 &self,
143 stack: &impl StackTrait<N>,
144 input_types: &[RegisterType<N>],
145 ) -> Result<Vec<RegisterType<N>>> {
146 if input_types.len() != 2 {
148 bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
149 }
150 if !register_types_equivalent(stack, &input_types[0], stack, &input_types[1])? {
152 bail!(
153 "Instruction '{}' expects inputs of equivalent types. Found inputs of type '{}' and '{}'",
154 Self::opcode(),
155 input_types[0],
156 input_types[1]
157 )
158 }
159 if self.operands.len() != 2 {
161 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
162 }
163
164 match VARIANT {
165 0 | 1 => Ok(vec![]),
166 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
167 }
168 }
169}
170
171impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
172 fn parse(string: &str) -> ParserResult<Self> {
174 let (string, _) = tag(*Self::opcode())(string)?;
176 let (string, _) = Sanitizer::parse_whitespaces(string)?;
178 let (string, first) = Operand::parse(string)?;
180 let (string, _) = Sanitizer::parse_whitespaces(string)?;
182 let (string, second) = Operand::parse(string)?;
184
185 Ok((string, Self { operands: vec![first, second] }))
186 }
187}
188
189impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
190 type Err = Error;
191
192 fn from_str(string: &str) -> Result<Self> {
194 match Self::parse(string) {
195 Ok((remainder, object)) => {
196 ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
198 Ok(object)
200 }
201 Err(error) => bail!("Failed to parse string. {error}"),
202 }
203 }
204}
205
206impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
207 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
209 Display::fmt(self, f)
210 }
211}
212
213impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
214 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
216 if self.operands.len() != 2 {
218 return Err(fmt::Error);
219 }
220 write!(f, "{}", Self::opcode())?;
222 self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
223 }
224}
225
226impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
227 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
229 let mut operands = Vec::with_capacity(2);
231 for _ in 0..2 {
233 operands.push(Operand::read_le(&mut reader)?);
234 }
235
236 Ok(Self { operands })
238 }
239}
240
241impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
242 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
244 if self.operands.len() != 2 {
246 return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
247 }
248 self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use console::network::MainnetV0;
257
258 type CurrentNetwork = MainnetV0;
259
260 #[test]
261 fn test_parse() {
262 let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
263 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
264 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
265 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
266 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
267
268 let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
269 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
270 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
271 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
272 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
273 }
274}