snarkvm_synthesizer_program/logic/instruction/operation/
assert.rs1use crate::{
17 Opcode,
18 Operand,
19 traits::{RegistersLoad, RegistersLoadCircuit, StackMatches, StackProgram},
20};
21use console::{
22 network::prelude::*,
23 program::{Register, RegisterType},
24};
25
26pub type AssertEq<N> = AssertInstruction<N, { Variant::AssertEq as u8 }>;
28pub type AssertNeq<N> = AssertInstruction<N, { Variant::AssertNeq as u8 }>;
30
31enum Variant {
32 AssertEq,
33 AssertNeq,
34}
35
36#[derive(Clone, PartialEq, Eq, Hash)]
38pub struct AssertInstruction<N: Network, const VARIANT: u8> {
39 operands: Vec<Operand<N>>,
41}
42
43impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
44 #[inline]
46 pub fn new(operands: Vec<Operand<N>>) -> Result<Self> {
47 ensure!(operands.len() == 2, "Assert instructions must have two operands");
49 Ok(Self { operands })
51 }
52
53 #[inline]
55 pub const fn opcode() -> Opcode {
56 match VARIANT {
57 0 => Opcode::Assert("assert.eq"),
58 1 => Opcode::Assert("assert.neq"),
59 _ => panic!("Invalid 'assert' instruction opcode"),
60 }
61 }
62
63 #[inline]
65 pub fn operands(&self) -> &[Operand<N>] {
66 debug_assert!(self.operands.len() == 2, "Assert operations must have two operands");
68 &self.operands
70 }
71
72 #[inline]
74 pub fn destinations(&self) -> Vec<Register<N>> {
75 vec![]
76 }
77}
78
79impl<N: Network, const VARIANT: u8> AssertInstruction<N, VARIANT> {
80 #[inline]
82 pub fn evaluate(
83 &self,
84 stack: &(impl StackMatches<N> + StackProgram<N>),
85 registers: &mut impl RegistersLoad<N>,
86 ) -> Result<()> {
87 if self.operands.len() != 2 {
89 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
90 }
91
92 let input_a = registers.load(stack, &self.operands[0])?;
94 let input_b = registers.load(stack, &self.operands[1])?;
95
96 match VARIANT {
98 0 => {
99 if input_a != input_b {
100 bail!("'{}' failed: '{input_a}' is not equal to '{input_b}' (should be equal)", Self::opcode())
101 }
102 }
103 1 => {
104 if input_a == input_b {
105 bail!("'{}' failed: '{input_a}' is equal to '{input_b}' (should not be equal)", Self::opcode())
106 }
107 }
108 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
109 }
110 Ok(())
111 }
112
113 #[inline]
115 pub fn execute<A: circuit::Aleo<Network = N>>(
116 &self,
117 stack: &(impl StackMatches<N> + StackProgram<N>),
118 registers: &mut impl RegistersLoadCircuit<N, A>,
119 ) -> Result<()> {
120 if self.operands.len() != 2 {
122 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
123 }
124
125 let input_a = registers.load_circuit(stack, &self.operands[0])?;
127 let input_b = registers.load_circuit(stack, &self.operands[1])?;
128
129 match VARIANT {
131 0 => A::assert(input_a.is_equal(&input_b)),
132 1 => A::assert(input_a.is_not_equal(&input_b)),
133 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
134 }
135 Ok(())
136 }
137
138 #[inline]
140 pub fn finalize(
141 &self,
142 stack: &(impl StackMatches<N> + StackProgram<N>),
143 registers: &mut impl RegistersLoad<N>,
144 ) -> Result<()> {
145 self.evaluate(stack, registers)
146 }
147
148 #[inline]
150 pub fn output_types(
151 &self,
152 _stack: &impl StackProgram<N>,
153 input_types: &[RegisterType<N>],
154 ) -> Result<Vec<RegisterType<N>>> {
155 if input_types.len() != 2 {
157 bail!("Instruction '{}' expects 2 inputs, found {} inputs", Self::opcode(), input_types.len())
158 }
159 if input_types[0] != input_types[1] {
161 bail!(
162 "Instruction '{}' expects inputs of the same type. Found inputs of type '{}' and '{}'",
163 Self::opcode(),
164 input_types[0],
165 input_types[1]
166 )
167 }
168 if self.operands.len() != 2 {
170 bail!("Instruction '{}' expects 2 operands, found {} operands", Self::opcode(), self.operands.len())
171 }
172
173 match VARIANT {
174 0 | 1 => Ok(vec![]),
175 _ => bail!("Invalid 'assert' variant: {VARIANT}"),
176 }
177 }
178}
179
180impl<N: Network, const VARIANT: u8> Parser for AssertInstruction<N, VARIANT> {
181 #[inline]
183 fn parse(string: &str) -> ParserResult<Self> {
184 let (string, _) = tag(*Self::opcode())(string)?;
186 let (string, _) = Sanitizer::parse_whitespaces(string)?;
188 let (string, first) = Operand::parse(string)?;
190 let (string, _) = Sanitizer::parse_whitespaces(string)?;
192 let (string, second) = Operand::parse(string)?;
194
195 Ok((string, Self { operands: vec![first, second] }))
196 }
197}
198
199impl<N: Network, const VARIANT: u8> FromStr for AssertInstruction<N, VARIANT> {
200 type Err = Error;
201
202 #[inline]
204 fn from_str(string: &str) -> Result<Self> {
205 match Self::parse(string) {
206 Ok((remainder, object)) => {
207 ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
209 Ok(object)
211 }
212 Err(error) => bail!("Failed to parse string. {error}"),
213 }
214 }
215}
216
217impl<N: Network, const VARIANT: u8> Debug for AssertInstruction<N, VARIANT> {
218 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
220 Display::fmt(self, f)
221 }
222}
223
224impl<N: Network, const VARIANT: u8> Display for AssertInstruction<N, VARIANT> {
225 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
227 if self.operands.len() != 2 {
229 return Err(fmt::Error);
230 }
231 write!(f, "{} ", Self::opcode())?;
233 self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))
234 }
235}
236
237impl<N: Network, const VARIANT: u8> FromBytes for AssertInstruction<N, VARIANT> {
238 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
240 let mut operands = Vec::with_capacity(2);
242 for _ in 0..2 {
244 operands.push(Operand::read_le(&mut reader)?);
245 }
246
247 Ok(Self { operands })
249 }
250}
251
252impl<N: Network, const VARIANT: u8> ToBytes for AssertInstruction<N, VARIANT> {
253 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
255 if self.operands.len() != 2 {
257 return Err(error(format!("The number of operands must be 2, found {}", self.operands.len())));
258 }
259 self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use console::network::MainnetV0;
268
269 type CurrentNetwork = MainnetV0;
270
271 #[test]
272 fn test_parse() {
273 let (string, assert) = AssertEq::<CurrentNetwork>::parse("assert.eq r0 r1").unwrap();
274 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
275 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
276 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
277 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
278
279 let (string, assert) = AssertNeq::<CurrentNetwork>::parse("assert.neq r0 r1").unwrap();
280 assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
281 assert_eq!(assert.operands.len(), 2, "The number of operands is incorrect");
282 assert_eq!(assert.operands[0], Operand::Register(Register::Locator(0)), "The first operand is incorrect");
283 assert_eq!(assert.operands[1], Operand::Register(Register::Locator(1)), "The second operand is incorrect");
284 }
285}