snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs1use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait, register_types_equivalent};
17use console::{
18 network::prelude::*,
19 program::{Register, RegisterType},
20};
21use snarkvm_synthesizer_error::*;
22
23pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
25pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
27
28enum Variant {
29 AssertEq,
30 AssertNeq,
31}
32
33#[derive(Clone, PartialEq, Eq, Hash)]
35pub struct AssertInstruction<N: Network, const VARIANT: u8> {
36 operands: Vec<Operand<N>>,
38}
39
40impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
41 #[inline]
43 pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
44 ensure!(operands.len() == 2, "Assert instructions must have two operands");
46 Ok(Self { operands })
48 }
49
50 #[inline]
52 pub const fn opcode() -> Opcode {
53 match VARIANT {
54 0 => Opcode::Assert("assert.eq"),
55 1 => Opcode::Assert("assert.neq"),
56 _ => panic!("Invalid 'assert' instruction opcode"),
57 }
58 }
59
60 #[inline]
62 pub fn operands(&self) -> &[Operand<N>] {
63 debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
65 &self.operands
67 }
68
69 #[inline]
71 pub fn destinations(&self) -> Vec<Register<N>> {
72 vec![]
73 }
74
75 #[inline]
77 pub fn contains_external_struct(&self) -> bool {
78 false
79 }
80}
81
82impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
83 pub fn evaluate(
85 &self,
86 stack: &impl StackTrait<N>,
87 registers: &mut impl RegistersTrait<N>,
88 ) -> Result<(), EvalError> {
89 if self.operands.len() != 2 {
91 return Err(anyhow!(
92 "Instruction '{}' expects 2 operands, found {} operands",
93 Self::opcode(),
94 self.operands.len()
95 )
96 .into());
97 }
98
99 let input_a = registers.load(stack, &self.operands[0])?;
101 let input_b = registers.load(stack, &self.operands[1])?;
102
103 match VARIANT {
105 0 => {
106 if input_a != input_b {
107 return Err(AssertError::Eq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
108 }
109 }
110 1 => {
111 if input_a == input_b {
112 return Err(AssertError::Neq { lhs: format!("{input_a}"), rhs: format!("{input_b}") }.into());
113 }
114 }
115 _ => return Err(AssertError::Invalid { variant: VARIANT }.into()),
116 }
117 Ok(())
118 }
119
120 pub fn execute<A: circuit::Aleo<Network = N>>(
122 &self,
123 stack: &impl StackTrait<N>,
124 registers: &mut impl RegistersCircuit<N, A>,
125 ) -> Result<(), ExecError> {
126 if self.operands.len() != 2 {
128 return Err(anyhow!(
129 "Instruction '{}' expects 2 operands, found {} operands",
130 Self::opcode(),
131 self.operands.len()
132 )
133 .into());
134 }
135
136 let input_a = registers.load_circuit(stack, &self.operands[0])?;
138 let input_b = registers.load_circuit(stack, &self.operands[1])?;
139
140 match VARIANT {
142 0 => A::assert(input_a.is_equal(&input_b))?,
143 1 => A::assert(input_a.is_not_equal(&input_b))?,
144 _ => return Err(anyhow!("Invalid 'assert' variant: {VARIANT}").into()),
145 }
146 Ok(())
147 }
148
149 #[inline]
151 pub fn finalize(
152 &self,
153 stack: &impl StackTrait<N>,
154 registers: &mut impl RegistersTrait<N>,
155 ) -> Result<(), FinalizeError> {
156 self.evaluate(stack, registers)?;
157 Ok(())
158 }
159
160 pub fn output_types(
162 &self,
163 stack: &impl StackTrait<N>,
164 input_types: &[RegisterType<N>],
165 ) -> Result<Vec<RegisterType<N>>> {
166 if input_types.len() != 2 {
168 bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
169 }
170 if !register_types_equivalent(stack, &input_types[0], stack, &input_types[1])? {
172 bail!(
173 "Instruction '{}' expects inputs of equivalent types. Found inputs of type '{}' and '{}'",
174 Self::opcode(),
175 input_types[0],
176 input_types[1]
177 )
178 }
179 if self.operands.len() != 2 {
181 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
182 }
183
184 match VARIANT {
185 0 | 1 => Ok(vec![]),
186 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
187 }
188 }
189}
190
191impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
192 fn parse(string: &str) -> ParserResult<Self> {
194 let (string, _) = tag(*Self::opcode())(string)?;
196 let (string, _) = Sanitizer::parse_whitespaces(string)?;
198 let (string, first) = Operand::parse(string)?;
200 let (string, _) = Sanitizer::parse_whitespaces(string)?;
202 let (string, second) = Operand::parse(string)?;
204
205 Ok((string, Self { operands: vec![first, second] }))
206 }
207}
208
209impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
210 type Err = Error;
211
212 fn from_str(string: &str) -> Result<Self> {
214 match Self::parse(string) {
215 Ok((remainder, object)) => {
216 ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
218 Ok(object)
220 }
221 Err(error) => bail!("Failed to parse string. {error}"),
222 }
223 }
224}
225
226impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
227 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
229 Display::fmt(self, f)
230 }
231}
232
233impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
234 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
236 if self.operands.len() != 2 {
238 return Err(fmt::Error);
239 }
240 write!(f, "{}", Self::opcode())?;
242 self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))
243 }
244}
245
246impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
247 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
249 let mut operands = Vec::with_capacity(2);
251 for _ in 0..2 {
253 operands.push(Operand::read_le(&mut reader)?);
254 }
255
256 Ok(Self { operands })
258 }
259}
260
261impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
262 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
264 if self.operands.len() != 2 {
266 return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
267 }
268 self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use console::network::MainnetV0;
277
278 type CurrentNetwork = MainnetV0;
279
280 #[test]
281 fn test_parse() {
282 let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
283 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
284 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
285 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
286 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
287
288 let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
289 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
290 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
291 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
292 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
293 }
294}