snarkvm_synthesizer_program/logic/instruction/operation/
call.rs

1// Copyright 2024-2025 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    Opcode,
18    Operand,
19    traits::{RegistersLoad, RegistersLoadCircuit, StackMatches, StackProgram},
20};
21use console::{
22    network::prelude::*,
23    program::{Identifier, Locator, Register, RegisterType, ValueType},
24};
25
26/// The operator references a function name or closure name.
27#[derive(Clone, PartialEq, Eq, Hash)]
28pub enum CallOperator<N: Network> {
29    /// The reference to a non-local function or closure.
30    Locator(Locator<N>),
31    /// The reference to a local function or closure.
32    Resource(Identifier<N>),
33}
34
35impl<N: Network> Parser for CallOperator<N> {
36    /// Parses a string into an operator.
37    #[inline]
38    fn parse(string: &str) -> ParserResult<Self> {
39        alt((map(Locator::parse, CallOperator::Locator), map(Identifier::parse, CallOperator::Resource)))(string)
40    }
41}
42
43impl<N: Network> FromStr for CallOperator<N> {
44    type Err = Error;
45
46    /// Parses a string into an operator.
47    #[inline]
48    fn from_str(string: &str) -> Result<Self> {
49        match Self::parse(string) {
50            Ok((remainder, object)) => {
51                // Ensure the remainder is empty.
52                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
53                // Return the object.
54                Ok(object)
55            }
56            Err(error) => bail!("Failed to parse string. {error}"),
57        }
58    }
59}
60
61impl<N: Network> Debug for CallOperator<N> {
62    /// Prints the operator as a string.
63    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
64        Display::fmt(self, f)
65    }
66}
67
68impl<N: Network> Display for CallOperator<N> {
69    /// Prints the operator to a string.
70    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
71        match self {
72            CallOperator::Locator(locator) => Display::fmt(locator, f),
73            CallOperator::Resource(resource) => Display::fmt(resource, f),
74        }
75    }
76}
77
78impl<N: Network> FromBytes for CallOperator<N> {
79    /// Reads the operation from a buffer.
80    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
81        // Read the variant.
82        let variant = u8::read_le(&mut reader)?;
83        // Match the variant.
84        match variant {
85            0 => Ok(CallOperator::Locator(Locator::read_le(&mut reader)?)),
86            1 => Ok(CallOperator::Resource(Identifier::read_le(&mut reader)?)),
87            _ => Err(error("Failed to read CallOperator. Invalid variant.")),
88        }
89    }
90}
91
92impl<N: Network> ToBytes for CallOperator<N> {
93    /// Writes the operation to a buffer.
94    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
95        match self {
96            CallOperator::Locator(locator) => {
97                // Write the variant.
98                0u8.write_le(&mut writer)?;
99                // Write the locator.
100                locator.write_le(&mut writer)
101            }
102            CallOperator::Resource(resource) => {
103                // Write the variant.
104                1u8.write_le(&mut writer)?;
105                // Write the resource.
106                resource.write_le(&mut writer)
107            }
108        }
109    }
110}
111
112/// Calls the operands into the declared type.
113/// i.e. `call transfer r0.owner 0u64 r1.amount into r1 r2;`
114#[derive(Clone, PartialEq, Eq, Hash)]
115pub struct Call<N: Network> {
116    /// The reference.
117    operator: CallOperator<N>,
118    /// The operands.
119    operands: Vec<Operand<N>>,
120    /// The destination registers.
121    destinations: Vec<Register<N>>,
122}
123
124impl<N: Network> Call<N> {
125    /// Returns the opcode.
126    #[inline]
127    pub const fn opcode() -> Opcode {
128        Opcode::Call
129    }
130
131    /// Return the operator.
132    #[inline]
133    pub const fn operator(&self) -> &CallOperator<N> {
134        &self.operator
135    }
136
137    /// Returns the operands in the operation.
138    #[inline]
139    pub fn operands(&self) -> &[Operand<N>] {
140        &self.operands
141    }
142
143    /// Returns the destination registers.
144    #[inline]
145    pub fn destinations(&self) -> Vec<Register<N>> {
146        self.destinations.clone()
147    }
148}
149
150impl<N: Network> Call<N> {
151    /// Returns `true` if the instruction is a function call.
152    #[inline]
153    pub fn is_function_call(&self, stack: &impl StackProgram<N>) -> Result<bool> {
154        match self.operator() {
155            // Check if the locator is for a function.
156            CallOperator::Locator(locator) => {
157                // Get the external stack.
158                let external_stack = stack.get_external_stack(locator.program_id())?;
159                // Retrieve the program.
160                let program = external_stack.program();
161                // Check if the resource is a function.
162                Ok(program.contains_function(locator.resource()))
163            }
164            // Check if the resource is a function.
165            CallOperator::Resource(resource) => Ok(stack.program().contains_function(resource)),
166        }
167    }
168
169    /// Evaluates the instruction.
170    pub fn evaluate(&self, _stack: &impl StackProgram<N>, _registers: &mut impl RegistersLoad<N>) -> Result<()> {
171        bail!("Forbidden operation: Evaluate cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
172    }
173
174    /// Executes the instruction.
175    pub fn execute<A: circuit::Aleo<Network = N>>(
176        &self,
177        _stack: &impl StackProgram<N>,
178        _registers: &mut impl RegistersLoadCircuit<N, A>,
179    ) -> Result<()> {
180        bail!("Forbidden operation: Execute cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
181    }
182
183    /// Finalizes the instruction.
184    #[inline]
185    pub fn finalize(
186        &self,
187        _stack: &(impl StackMatches<N> + StackProgram<N>),
188        _registers: &mut impl RegistersLoad<N>,
189    ) -> Result<()> {
190        bail!("Forbidden operation: Finalize cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
191    }
192
193    /// Returns the output type from the given program and input types.
194    #[inline]
195    pub fn output_types(
196        &self,
197        stack: &impl StackProgram<N>,
198        input_types: &[RegisterType<N>],
199    ) -> Result<Vec<RegisterType<N>>> {
200        // Retrieve the external stack, if needed, and the resource.
201        let (external_stack, resource) = match &self.operator {
202            CallOperator::Locator(locator) => {
203                (Some(stack.get_external_stack(locator.program_id())?), locator.resource())
204            }
205            CallOperator::Resource(resource) => {
206                // TODO (howardwu): Revisit this decision to forbid calling internal functions. A record cannot be spent again.
207                //  But there are legitimate uses for passing a record through to an internal function.
208                //  We could invoke the internal function without a state transition, but need to match visibility.
209                if stack.program().contains_function(resource) {
210                    bail!("Cannot call '{resource}'. Use a closure ('closure {resource}:') instead.")
211                }
212                (None, resource)
213            }
214        };
215        // Retrieve the program.
216        let (is_external, program) = match &external_stack {
217            Some(external_stack) => (true, external_stack.program()),
218            None => (false, stack.program()),
219        };
220        // If the operator is a closure, retrieve the closure and compute the output types.
221        if let Ok(closure) = program.get_closure(resource) {
222            // Ensure the number of operands matches the number of input statements.
223            if closure.inputs().len() != self.operands.len() {
224                bail!("Expected {} inputs, found {}", closure.inputs().len(), self.operands.len())
225            }
226            // Ensure the number of inputs matches the number of input statements.
227            if closure.inputs().len() != input_types.len() {
228                bail!("Expected {} input types, found {}", closure.inputs().len(), input_types.len())
229            }
230            // Ensure the number of destinations matches the number of output statements.
231            if closure.outputs().len() != self.destinations.len() {
232                bail!("Expected {} outputs, found {}", closure.outputs().len(), self.destinations.len())
233            }
234            // Return the output register types.
235            Ok(closure.outputs().iter().map(|output| output.register_type()).cloned().collect())
236        }
237        // If the operator is a function, retrieve the function and compute the output types.
238        else if let Ok(function) = program.get_function(resource) {
239            // Ensure the number of operands matches the number of input statements.
240            if function.inputs().len() != self.operands.len() {
241                bail!("Expected {} inputs, found {}", function.inputs().len(), self.operands.len())
242            }
243            // Ensure the number of inputs matches the number of input statements.
244            if function.inputs().len() != input_types.len() {
245                bail!("Expected {} input types, found {}", function.inputs().len(), input_types.len())
246            }
247            // Ensure the number of destinations matches the number of output statements.
248            if function.outputs().len() != self.destinations.len() {
249                bail!("Expected {} outputs, found {}", function.outputs().len(), self.destinations.len())
250            }
251            // Return the output register types.
252            function
253                .output_types()
254                .into_iter()
255                .map(|output_type| match (is_external, output_type) {
256                    // If the output is a record and the function is external, return the external record type.
257                    (true, ValueType::Record(record_name)) => Ok(RegisterType::ExternalRecord(Locator::from_str(
258                        &format!("{}/{}", program.id(), record_name),
259                    )?)),
260                    // Else, return the register type.
261                    (_, output_type) => Ok(RegisterType::from(output_type)),
262                })
263                .collect::<Result<Vec<_>>>()
264        }
265        // Else, throw an error.
266        else {
267            bail!("Call operator '{}' is invalid or unsupported.", self.operator)
268        }
269    }
270}
271
272impl<N: Network> Parser for Call<N> {
273    /// Parses a string into an operation.
274    #[inline]
275    fn parse(string: &str) -> ParserResult<Self> {
276        /// Parses an operand from the string.
277        fn parse_operand<N: Network>(string: &str) -> ParserResult<Operand<N>> {
278            // Parse the whitespace from the string.
279            let (string, _) = Sanitizer::parse_whitespaces(string)?;
280            // Parse the operand from the string.
281            Operand::parse(string)
282        }
283
284        /// Parses a destination register from the string.
285        fn parse_destination<N: Network>(string: &str) -> ParserResult<Register<N>> {
286            // Parse the whitespace from the string.
287            let (string, _) = Sanitizer::parse_whitespaces(string)?;
288            // Parse the destination from the string.
289            Register::parse(string)
290        }
291
292        // Parse the opcode from the string.
293        let (string, _) = tag(*Self::opcode())(string)?;
294        // Parse the whitespace from the string.
295        let (string, _) = Sanitizer::parse_whitespaces(string)?;
296        // Parse the name of the call from the string.
297        let (string, operator) = CallOperator::parse(string)?;
298        // Parse the whitespace from the string.
299        let (string, _) = Sanitizer::parse_whitespaces(string)?;
300        // Parse the operands from the string.
301        let (string, operands) = map_res(many0(complete(parse_operand)), |operands: Vec<Operand<N>>| {
302            // Ensure the number of operands is within the bounds.
303            match operands.len() <= N::MAX_OPERANDS {
304                true => Ok(operands),
305                false => Err(error("Failed to parse 'call' opcode: too many operands")),
306            }
307        })(string)?;
308        // Parse the whitespace from the string.
309        let (string, _) = Sanitizer::parse_whitespaces(string)?;
310
311        // Optionally parse the "into" from the string.
312        let (string, destinations) = match opt(tag("into"))(string)? {
313            // If the "into" was not parsed, return the string and an empty vector of destinations.
314            (string, None) => (string, vec![]),
315            // If the "into" was parsed, parse the destinations from the string.
316            (string, Some(_)) => {
317                // Parse the whitespace from the string.
318                let (string, _) = Sanitizer::parse_whitespaces(string)?;
319                // Parse the destinations from the string.
320                let (string, destinations) =
321                    map_res(many1(complete(parse_destination)), |destinations: Vec<Register<N>>| {
322                        // Ensure the number of destinations is within the bounds.
323                        match destinations.len() <= N::MAX_OPERANDS {
324                            true => Ok(destinations),
325                            false => Err(error("Failed to parse 'call' opcode: too many destinations")),
326                        }
327                    })(string)?;
328                // Return the string and the destinations.
329                (string, destinations)
330            }
331        };
332
333        Ok((string, Self { operator, operands, destinations }))
334    }
335}
336
337impl<N: Network> FromStr for Call<N> {
338    type Err = Error;
339
340    /// Parses a string into an operation.
341    #[inline]
342    fn from_str(string: &str) -> Result<Self> {
343        match Self::parse(string) {
344            Ok((remainder, object)) => {
345                // Ensure the remainder is empty.
346                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
347                // Return the object.
348                Ok(object)
349            }
350            Err(error) => bail!("Failed to parse string. {error}"),
351        }
352    }
353}
354
355impl<N: Network> Debug for Call<N> {
356    /// Prints the operation as a string.
357    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
358        Display::fmt(self, f)
359    }
360}
361
362impl<N: Network> Display for Call<N> {
363    /// Prints the operation to a string.
364    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
365        // Ensure the number of operands is within the bounds.
366        if self.operands.len() > N::MAX_OPERANDS {
367            return Err(fmt::Error);
368        }
369        // Ensure the number of destinations is within the bounds.
370        if self.destinations.len() > N::MAX_OPERANDS {
371            return Err(fmt::Error);
372        }
373        // Print the operation.
374        write!(f, "{} {}", Self::opcode(), self.operator)?;
375        self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))?;
376        if !self.destinations.is_empty() {
377            write!(f, " into")?;
378            self.destinations.iter().try_for_each(|destination| write!(f, " {destination}"))?;
379        }
380        Ok(())
381    }
382}
383
384impl<N: Network> FromBytes for Call<N> {
385    /// Reads the operation from a buffer.
386    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
387        // Read the operator of the call.
388        let operator = CallOperator::read_le(&mut reader)?;
389
390        // Read the number of operands.
391        let num_operands = u8::read_le(&mut reader)? as usize;
392        // Ensure the number of operands is within the bounds.
393        if num_operands > N::MAX_OPERANDS {
394            return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
395        }
396
397        // Initialize the vector for the operands.
398        let mut operands = Vec::with_capacity(num_operands);
399        // Read the operands.
400        for _ in 0..num_operands {
401            operands.push(Operand::read_le(&mut reader)?);
402        }
403
404        // Read the number of destination registers.
405        let num_destinations = u8::read_le(&mut reader)? as usize;
406        // Ensure the number of destinations is within the bounds.
407        if num_destinations > N::MAX_OPERANDS {
408            return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
409        }
410
411        // Initialize the vector for the destinations.
412        let mut destinations = Vec::with_capacity(num_destinations);
413        // Read the destination registers.
414        for _ in 0..num_destinations {
415            destinations.push(Register::read_le(&mut reader)?);
416        }
417
418        // Return the operation.
419        Ok(Self { operator, operands, destinations })
420    }
421}
422
423impl<N: Network> ToBytes for Call<N> {
424    /// Writes the operation to a buffer.
425    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
426        // Ensure the number of operands is within the bounds.
427        if self.operands.len() > N::MAX_OPERANDS {
428            return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
429        }
430        // Ensure the number of destinations is within the bounds.
431        if self.destinations.len() > N::MAX_OPERANDS {
432            return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
433        }
434
435        // Write the name of the call.
436        self.operator.write_le(&mut writer)?;
437        // Write the number of operands.
438        u8::try_from(self.operands.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
439        // Write the operands.
440        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
441        // Write the number of destination register.
442        u8::try_from(self.destinations.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
443        // Write the destination registers.
444        self.destinations.iter().try_for_each(|destination| destination.write_le(&mut writer))
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use console::{
452        network::MainnetV0,
453        program::{Access, Address, Identifier, Literal, U64},
454    };
455
456    type CurrentNetwork = MainnetV0;
457
458    const TEST_CASES: &[&str] = &[
459        "call foo",
460        "call foo r0",
461        "call foo r0.owner",
462        "call foo r0 r1",
463        "call foo into r0",
464        "call foo into r0 r1",
465        "call foo into r0 r1 r2",
466        "call foo r0 into r1",
467        "call foo r0 r1 into r2",
468        "call foo r0 r1 into r2 r3",
469        "call foo r0 r1 r2 into r3 r4",
470        "call foo r0 r1 r2 into r3 r4 r5",
471    ];
472
473    fn check_parser(
474        string: &str,
475        expected_operator: CallOperator<CurrentNetwork>,
476        expected_operands: Vec<Operand<CurrentNetwork>>,
477        expected_destinations: Vec<Register<CurrentNetwork>>,
478    ) {
479        // Check that the parser works.
480        let (string, call) = Call::<CurrentNetwork>::parse(string).unwrap();
481
482        // Check that the entire string was consumed.
483        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
484
485        // Check that the operator is correct.
486        assert_eq!(call.operator, expected_operator, "The call operator is incorrect");
487
488        // Check that the operands are correct.
489        assert_eq!(call.operands.len(), expected_operands.len(), "The number of operands is incorrect");
490        for (i, (given, expected)) in call.operands.iter().zip(expected_operands.iter()).enumerate() {
491            assert_eq!(given, expected, "The {i}-th operand is incorrect");
492        }
493
494        // Check that the destinations are correct.
495        assert_eq!(call.destinations.len(), expected_destinations.len(), "The number of destinations is incorrect");
496        for (i, (given, expected)) in call.destinations.iter().zip(expected_destinations.iter()).enumerate() {
497            assert_eq!(given, expected, "The {i}-th destination is incorrect");
498        }
499    }
500
501    #[test]
502    fn test_parse() {
503        check_parser(
504            "call transfer r0.owner r0.token_amount into r1 r2 r3",
505            CallOperator::from_str("transfer").unwrap(),
506            vec![
507                Operand::Register(Register::Access(0, vec![Access::from(Identifier::from_str("owner").unwrap())])),
508                Operand::Register(Register::Access(0, vec![Access::from(
509                    Identifier::from_str("token_amount").unwrap(),
510                )])),
511            ],
512            vec![Register::Locator(1), Register::Locator(2), Register::Locator(3)],
513        );
514
515        check_parser(
516            "call mint_public aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw 100u64",
517            CallOperator::from_str("mint_public").unwrap(),
518            vec![
519                Operand::Literal(Literal::Address(
520                    Address::from_str("aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw").unwrap(),
521                )),
522                Operand::Literal(Literal::U64(U64::from_str("100u64").unwrap())),
523            ],
524            vec![],
525        );
526
527        check_parser(
528            "call get_magic_number into r0",
529            CallOperator::from_str("get_magic_number").unwrap(),
530            vec![],
531            vec![Register::Locator(0)],
532        );
533
534        check_parser("call noop", CallOperator::from_str("noop").unwrap(), vec![], vec![])
535    }
536
537    #[test]
538    fn test_display() {
539        for expected in TEST_CASES {
540            assert_eq!(Call::<CurrentNetwork>::from_str(expected).unwrap().to_string(), *expected);
541        }
542    }
543
544    #[test]
545    fn test_bytes() {
546        for case in TEST_CASES {
547            let expected = Call::<CurrentNetwork>::from_str(case).unwrap();
548
549            // Check the byte representation.
550            let expected_bytes = expected.to_bytes_le().unwrap();
551            assert_eq!(expected, Call::read_le(&expected_bytes[..]).unwrap());
552        }
553    }
554}