snarkvm_synthesizer_program/logic/instruction/operation/
call.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Opcode, Operand, RegistersCircuit, RegistersTrait, StackTrait};
17use console::{
18    network::prelude::*,
19    program::{Identifier, Locator, Register, RegisterType, ValueType},
20};
21
22/// The operator references a function name or closure name.
23#[derive(Clone, PartialEq, Eq, Hash)]
24pub enum CallOperator<N: Network> {
25    /// The reference to a non-local function or closure.
26    Locator(Locator<N>),
27    /// The reference to a local function or closure.
28    Resource(Identifier<N>),
29}
30
31impl<N: Network> Parser for CallOperator<N> {
32    /// Parses a string into an operator.
33    #[inline]
34    fn parse(string: &str) -> ParserResult<Self> {
35        alt((map(Locator::parse, CallOperator::Locator), map(Identifier::parse, CallOperator::Resource)))(string)
36    }
37}
38
39impl<N: Network> FromStr for CallOperator<N> {
40    type Err = Error;
41
42    /// Parses a string into an operator.
43    #[inline]
44    fn from_str(string: &str) -> Result<Self> {
45        match Self::parse(string) {
46            Ok((remainder, object)) => {
47                // Ensure the remainder is empty.
48                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
49                // Return the object.
50                Ok(object)
51            }
52            Err(error) => bail!("Failed to parse string. {error}"),
53        }
54    }
55}
56
57impl<N: Network> Debug for CallOperator<N> {
58    /// Prints the operator as a string.
59    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
60        Display::fmt(self, f)
61    }
62}
63
64impl<N: Network> Display for CallOperator<N> {
65    /// Prints the operator to a string.
66    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
67        match self {
68            CallOperator::Locator(locator) => Display::fmt(locator, f),
69            CallOperator::Resource(resource) => Display::fmt(resource, f),
70        }
71    }
72}
73
74impl<N: Network> FromBytes for CallOperator<N> {
75    /// Reads the operation from a buffer.
76    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
77        // Read the variant.
78        let variant = u8::read_le(&mut reader)?;
79        // Match the variant.
80        match variant {
81            0 => Ok(CallOperator::Locator(Locator::read_le(&mut reader)?)),
82            1 => Ok(CallOperator::Resource(Identifier::read_le(&mut reader)?)),
83            _ => Err(error("Failed to read CallOperator. Invalid variant.")),
84        }
85    }
86}
87
88impl<N: Network> ToBytes for CallOperator<N> {
89    /// Writes the operation to a buffer.
90    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
91        match self {
92            CallOperator::Locator(locator) => {
93                // Write the variant.
94                0u8.write_le(&mut writer)?;
95                // Write the locator.
96                locator.write_le(&mut writer)
97            }
98            CallOperator::Resource(resource) => {
99                // Write the variant.
100                1u8.write_le(&mut writer)?;
101                // Write the resource.
102                resource.write_le(&mut writer)
103            }
104        }
105    }
106}
107
108/// Calls the operands into the declared type.
109/// i.e. `call transfer r0.owner 0u64 r1.amount into r1 r2;`
110#[derive(Clone, PartialEq, Eq, Hash)]
111pub struct Call<N: Network> {
112    /// The reference.
113    operator: CallOperator<N>,
114    /// The operands.
115    operands: Vec<Operand<N>>,
116    /// The destination registers.
117    destinations: Vec<Register<N>>,
118}
119
120impl<N: Network> Call<N> {
121    /// Returns the opcode.
122    #[inline]
123    pub const fn opcode() -> Opcode {
124        Opcode::Call
125    }
126
127    /// Return the operator.
128    #[inline]
129    pub const fn operator(&self) -> &CallOperator<N> {
130        &self.operator
131    }
132
133    /// Returns the operands in the operation.
134    #[inline]
135    pub fn operands(&self) -> &[Operand<N>] {
136        &self.operands
137    }
138
139    /// Returns the destination registers.
140    #[inline]
141    pub fn destinations(&self) -> Vec<Register<N>> {
142        self.destinations.clone()
143    }
144
145    /// Returns `true` if the instruction is a function call.
146    #[inline]
147    pub fn is_function_call(&self, stack: &impl StackTrait<N>) -> Result<bool> {
148        match self.operator() {
149            // Check if the locator is for a function.
150            CallOperator::Locator(locator) => {
151                // Get the external stack.
152                let external_stack = stack.get_external_stack(locator.program_id())?;
153                // Retrieve the program.
154                let program = external_stack.program();
155                // Check if the resource is a function.
156                Ok(program.contains_function(locator.resource()))
157            }
158            // Check if the resource is a function.
159            CallOperator::Resource(resource) => Ok(stack.program().contains_function(resource)),
160        }
161    }
162
163    /// Evaluates the instruction.
164    pub fn evaluate(&self, _stack: &impl StackTrait<N>, _registers: &mut impl RegistersTrait<N>) -> Result<()> {
165        bail!("Forbidden operation: Evaluate cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
166    }
167
168    /// Executes the instruction.
169    pub fn execute<A: circuit::Aleo<Network = N>>(
170        &self,
171        _stack: &impl StackTrait<N>,
172        _registers: &mut impl RegistersCircuit<N, A>,
173    ) -> Result<()> {
174        bail!("Forbidden operation: Execute cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
175    }
176
177    /// Finalizes the instruction.
178    #[inline]
179    pub fn finalize(&self, _stack: &impl StackTrait<N>, _registers: &mut impl RegistersTrait<N>) -> Result<()> {
180        bail!("Forbidden operation: Finalize cannot invoke a 'call' directly. Use 'call' in 'Stack' instead.")
181    }
182
183    /// Returns the output type from the given program and input types.
184    #[inline]
185    pub fn output_types(
186        &self,
187        stack: &impl StackTrait<N>,
188        input_types: &[RegisterType<N>],
189    ) -> Result<Vec<RegisterType<N>>> {
190        // Retrieve the external stack, if needed, and the resource.
191        let (external_stack, resource) = match &self.operator {
192            CallOperator::Locator(locator) => {
193                (Some(stack.get_external_stack(locator.program_id())?), locator.resource())
194            }
195            CallOperator::Resource(resource) => {
196                // TODO (howardwu): Revisit this decision to forbid calling internal functions. A record cannot be spent again.
197                //  But there are legitimate uses for passing a record through to an internal function.
198                //  We could invoke the internal function without a state transition, but need to match visibility.
199                if stack.program().contains_function(resource) {
200                    bail!("Cannot call '{resource}'. Use a closure ('closure {resource}:') instead.")
201                }
202                (None, resource)
203            }
204        };
205        // Retrieve the program.
206        let (is_external, program) = match &external_stack {
207            Some(external_stack) => (true, external_stack.program()),
208            None => (false, stack.program()),
209        };
210        // If the operator is a closure, retrieve the closure and compute the output types.
211        if let Ok(closure) = program.get_closure(resource) {
212            // Ensure the number of operands matches the number of input statements.
213            if closure.inputs().len() != self.operands.len() {
214                bail!("Expected {} inputs, found {}", closure.inputs().len(), self.operands.len())
215            }
216            // Ensure the number of inputs matches the number of input statements.
217            if closure.inputs().len() != input_types.len() {
218                bail!("Expected {} input types, found {}", closure.inputs().len(), input_types.len())
219            }
220            // Ensure the number of destinations matches the number of output statements.
221            if closure.outputs().len() != self.destinations.len() {
222                bail!("Expected {} outputs, found {}", closure.outputs().len(), self.destinations.len())
223            }
224            // Return the output register types.
225            Ok(closure.outputs().iter().map(|output| output.register_type()).cloned().collect())
226        }
227        // If the operator is a function, retrieve the function and compute the output types.
228        else if let Ok(function) = program.get_function(resource) {
229            // Ensure the number of operands matches the number of input statements.
230            if function.inputs().len() != self.operands.len() {
231                bail!("Expected {} inputs, found {}", function.inputs().len(), self.operands.len())
232            }
233            // Ensure the number of inputs matches the number of input statements.
234            if function.inputs().len() != input_types.len() {
235                bail!("Expected {} input types, found {}", function.inputs().len(), input_types.len())
236            }
237            // Ensure the number of destinations matches the number of output statements.
238            if function.outputs().len() != self.destinations.len() {
239                bail!("Expected {} outputs, found {}", function.outputs().len(), self.destinations.len())
240            }
241            // Return the output register types.
242            function
243                .output_types()
244                .into_iter()
245                .map(|output_type| match (is_external, output_type) {
246                    // If the output is a record and the function is external, return the external record type.
247                    (true, ValueType::Record(record_name)) => Ok(RegisterType::ExternalRecord(Locator::from_str(
248                        &format!("{}/{}", program.id(), record_name),
249                    )?)),
250                    // Else, return the register type.
251                    (_, output_type) => Ok(RegisterType::from(output_type)),
252                })
253                .collect::<Result<Vec<_>>>()
254        }
255        // Else, throw an error.
256        else {
257            bail!("Call operator '{}' is invalid or unsupported.", self.operator)
258        }
259    }
260}
261
262impl<N: Network> Parser for Call<N> {
263    /// Parses a string into an operation.
264    #[inline]
265    fn parse(string: &str) -> ParserResult<Self> {
266        /// Parses an operand from the string.
267        fn parse_operand<N: Network>(string: &str) -> ParserResult<Operand<N>> {
268            // Parse the whitespace from the string.
269            let (string, _) = Sanitizer::parse_whitespaces(string)?;
270            // Parse the operand from the string.
271            Operand::parse(string)
272        }
273
274        /// Parses a destination register from the string.
275        fn parse_destination<N: Network>(string: &str) -> ParserResult<Register<N>> {
276            // Parse the whitespace from the string.
277            let (string, _) = Sanitizer::parse_whitespaces(string)?;
278            // Parse the destination from the string.
279            Register::parse(string)
280        }
281
282        // Parse the opcode from the string.
283        let (string, _) = tag(*Self::opcode())(string)?;
284        // Parse the whitespace from the string.
285        let (string, _) = Sanitizer::parse_whitespaces(string)?;
286        // Parse the name of the call from the string.
287        let (string, operator) = CallOperator::parse(string)?;
288        // Parse the whitespace from the string.
289        let (string, _) = Sanitizer::parse_whitespaces(string)?;
290        // Parse the operands from the string.
291        let (string, operands) = map_res(many0(complete(parse_operand)), |operands: Vec<Operand<N>>| {
292            // Ensure the number of operands is within the bounds.
293            match operands.len() <= N::MAX_OPERANDS {
294                true => Ok(operands),
295                false => Err(error("Failed to parse 'call' opcode: too many operands")),
296            }
297        })(string)?;
298        // Parse the whitespace from the string.
299        let (string, _) = Sanitizer::parse_whitespaces(string)?;
300
301        // Optionally parse the "into" from the string.
302        let (string, destinations) = match opt(tag("into"))(string)? {
303            // If the "into" was not parsed, return the string and an empty vector of destinations.
304            (string, None) => (string, vec![]),
305            // If the "into" was parsed, parse the destinations from the string.
306            (string, Some(_)) => {
307                // Parse the whitespace from the string.
308                let (string, _) = Sanitizer::parse_whitespaces(string)?;
309                // Parse the destinations from the string.
310                let (string, destinations) =
311                    map_res(many1(complete(parse_destination)), |destinations: Vec<Register<N>>| {
312                        // Ensure the number of destinations is within the bounds.
313                        match destinations.len() <= N::MAX_OPERANDS {
314                            true => Ok(destinations),
315                            false => Err(error("Failed to parse 'call' opcode: too many destinations")),
316                        }
317                    })(string)?;
318                // Return the string and the destinations.
319                (string, destinations)
320            }
321        };
322
323        Ok((string, Self { operator, operands, destinations }))
324    }
325}
326
327impl<N: Network> FromStr for Call<N> {
328    type Err = Error;
329
330    /// Parses a string into an operation.
331    #[inline]
332    fn from_str(string: &str) -> Result<Self> {
333        match Self::parse(string) {
334            Ok((remainder, object)) => {
335                // Ensure the remainder is empty.
336                ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
337                // Return the object.
338                Ok(object)
339            }
340            Err(error) => bail!("Failed to parse string. {error}"),
341        }
342    }
343}
344
345impl<N: Network> Debug for Call<N> {
346    /// Prints the operation as a string.
347    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
348        Display::fmt(self, f)
349    }
350}
351
352impl<N: Network> Display for Call<N> {
353    /// Prints the operation to a string.
354    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
355        // Ensure the number of operands is within the bounds.
356        if self.operands.len() > N::MAX_OPERANDS {
357            return Err(fmt::Error);
358        }
359        // Ensure the number of destinations is within the bounds.
360        if self.destinations.len() > N::MAX_OPERANDS {
361            return Err(fmt::Error);
362        }
363        // Print the operation.
364        write!(f, "{} {}", Self::opcode(), self.operator)?;
365        self.operands.iter().try_for_each(|operand| write!(f, " {operand}"))?;
366        if !self.destinations.is_empty() {
367            write!(f, " into")?;
368            self.destinations.iter().try_for_each(|destination| write!(f, " {destination}"))?;
369        }
370        Ok(())
371    }
372}
373
374impl<N: Network> FromBytes for Call<N> {
375    /// Reads the operation from a buffer.
376    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
377        // Read the operator of the call.
378        let operator = CallOperator::read_le(&mut reader)?;
379
380        // Read the number of operands.
381        let num_operands = u8::read_le(&mut reader)? as usize;
382        // Ensure the number of operands is within the bounds.
383        if num_operands > N::MAX_OPERANDS {
384            return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
385        }
386
387        // Initialize the vector for the operands.
388        let mut operands = Vec::with_capacity(num_operands);
389        // Read the operands.
390        for _ in 0..num_operands {
391            operands.push(Operand::read_le(&mut reader)?);
392        }
393
394        // Read the number of destination registers.
395        let num_destinations = u8::read_le(&mut reader)? as usize;
396        // Ensure the number of destinations is within the bounds.
397        if num_destinations > N::MAX_OPERANDS {
398            return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
399        }
400
401        // Initialize the vector for the destinations.
402        let mut destinations = Vec::with_capacity(num_destinations);
403        // Read the destination registers.
404        for _ in 0..num_destinations {
405            destinations.push(Register::read_le(&mut reader)?);
406        }
407
408        // Return the operation.
409        Ok(Self { operator, operands, destinations })
410    }
411}
412
413impl<N: Network> ToBytes for Call<N> {
414    /// Writes the operation to a buffer.
415    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
416        // Ensure the number of operands is within the bounds.
417        if self.operands.len() > N::MAX_OPERANDS {
418            return Err(error(format!("The number of operands must be <= {}", N::MAX_OPERANDS)));
419        }
420        // Ensure the number of destinations is within the bounds.
421        if self.destinations.len() > N::MAX_OPERANDS {
422            return Err(error(format!("The number of destinations must be <= {}", N::MAX_OPERANDS)));
423        }
424
425        // Write the name of the call.
426        self.operator.write_le(&mut writer)?;
427        // Write the number of operands.
428        u8::try_from(self.operands.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
429        // Write the operands.
430        self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
431        // Write the number of destination register.
432        u8::try_from(self.destinations.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?;
433        // Write the destination registers.
434        self.destinations.iter().try_for_each(|destination| destination.write_le(&mut writer))
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use console::{
442        network::MainnetV0,
443        program::{Access, Address, Identifier, Literal, U64},
444    };
445
446    type CurrentNetwork = MainnetV0;
447
448    const TEST_CASES: &[&str] = &[
449        "call foo",
450        "call foo r0",
451        "call foo r0.owner",
452        "call foo r0 r1",
453        "call foo into r0",
454        "call foo into r0 r1",
455        "call foo into r0 r1 r2",
456        "call foo r0 into r1",
457        "call foo r0 r1 into r2",
458        "call foo r0 r1 into r2 r3",
459        "call foo r0 r1 r2 into r3 r4",
460        "call foo r0 r1 r2 into r3 r4 r5",
461    ];
462
463    fn check_parser(
464        string: &str,
465        expected_operator: CallOperator<CurrentNetwork>,
466        expected_operands: Vec<Operand<CurrentNetwork>>,
467        expected_destinations: Vec<Register<CurrentNetwork>>,
468    ) {
469        // Check that the parser works.
470        let (string, call) = Call::<CurrentNetwork>::parse(string).unwrap();
471
472        // Check that the entire string was consumed.
473        assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
474
475        // Check that the operator is correct.
476        assert_eq!(call.operator, expected_operator, "The call operator is incorrect");
477
478        // Check that the operands are correct.
479        assert_eq!(call.operands.len(), expected_operands.len(), "The number of operands is incorrect");
480        for (i, (given, expected)) in call.operands.iter().zip(expected_operands.iter()).enumerate() {
481            assert_eq!(given, expected, "The {i}-th operand is incorrect");
482        }
483
484        // Check that the destinations are correct.
485        assert_eq!(call.destinations.len(), expected_destinations.len(), "The number of destinations is incorrect");
486        for (i, (given, expected)) in call.destinations.iter().zip(expected_destinations.iter()).enumerate() {
487            assert_eq!(given, expected, "The {i}-th destination is incorrect");
488        }
489    }
490
491    #[test]
492    fn test_parse() {
493        check_parser(
494            "call transfer r0.owner r0.token_amount into r1 r2 r3",
495            CallOperator::from_str("transfer").unwrap(),
496            vec![
497                Operand::Register(Register::Access(0, vec![Access::from(Identifier::from_str("owner").unwrap())])),
498                Operand::Register(Register::Access(0, vec![Access::from(
499                    Identifier::from_str("token_amount").unwrap(),
500                )])),
501            ],
502            vec![Register::Locator(1), Register::Locator(2), Register::Locator(3)],
503        );
504
505        check_parser(
506            "call mint_public aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw 100u64",
507            CallOperator::from_str("mint_public").unwrap(),
508            vec![
509                Operand::Literal(Literal::Address(
510                    Address::from_str("aleo1wfyyj2uvwuqw0c0dqa5x70wrawnlkkvuepn4y08xyaqfqqwweqys39jayw").unwrap(),
511                )),
512                Operand::Literal(Literal::U64(U64::from_str("100u64").unwrap())),
513            ],
514            vec![],
515        );
516
517        check_parser(
518            "call get_magic_number into r0",
519            CallOperator::from_str("get_magic_number").unwrap(),
520            vec![],
521            vec![Register::Locator(0)],
522        );
523
524        check_parser("call noop", CallOperator::from_str("noop").unwrap(), vec![], vec![])
525    }
526
527    #[test]
528    fn test_display() {
529        for expected in TEST_CASES {
530            assert_eq!(Call::<CurrentNetwork>::from_str(expected).unwrap().to_string(), *expected);
531        }
532    }
533
534    #[test]
535    fn test_bytes() {
536        for case in TEST_CASES {
537            let expected = Call::<CurrentNetwork>::from_str(case).unwrap();
538
539            // Check the byte representation.
540            let expected_bytes = expected.to_bytes_le().unwrap();
541            assert_eq!(expected, Call::read_le(&expected_bytes[..]).unwrap());
542        }
543    }
544}