snarkvm_synthesizer_process/stack/helpers/
stack_trait.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<N: Network> StackTrait<N> for Stack<N> {
19    /// Checks that the given value matches the layout of the value type.
20    fn matches_value_type(&self, value: &Value<N>, value_type: &ValueType<N>) -> Result<()> {
21        // Ensure the value matches the declared value type in the register.
22        match (value, value_type) {
23            (Value::Plaintext(plaintext), ValueType::Constant(plaintext_type))
24            | (Value::Plaintext(plaintext), ValueType::Public(plaintext_type))
25            | (Value::Plaintext(plaintext), ValueType::Private(plaintext_type)) => {
26                self.matches_plaintext(plaintext, plaintext_type)
27            }
28            (Value::Record(record), ValueType::Record(record_name)) => self.matches_record(record, record_name),
29            (Value::Record(record), ValueType::ExternalRecord(locator)) => {
30                self.matches_external_record(record, locator)
31            }
32            (Value::Future(future), ValueType::Future(locator)) => self.matches_future(future, locator),
33            _ => bail!("A value does not match its declared value type '{value_type}'"),
34        }
35    }
36
37    /// Checks that the given stack value matches the layout of the register type.
38    fn matches_register_type(&self, stack_value: &Value<N>, register_type: &RegisterType<N>) -> Result<()> {
39        match (stack_value, register_type) {
40            (Value::Plaintext(plaintext), RegisterType::Plaintext(plaintext_type)) => {
41                self.matches_plaintext(plaintext, plaintext_type)
42            }
43            (Value::Record(record), RegisterType::Record(record_name)) => self.matches_record(record, record_name),
44            (Value::Record(record), RegisterType::ExternalRecord(locator)) => {
45                self.matches_external_record(record, locator)
46            }
47            (Value::Future(future), RegisterType::Future(locator)) => self.matches_future(future, locator),
48            _ => bail!("A value does not match its declared register type '{register_type}'"),
49        }
50    }
51
52    /// Checks that the given record matches the layout of the external record type.
53    fn matches_external_record(&self, record: &Record<N, Plaintext<N>>, locator: &Locator<N>) -> Result<()> {
54        // Retrieve the record name.
55        let record_name = locator.resource();
56
57        // Ensure the record name is valid.
58        ensure!(!Program::is_reserved_keyword(record_name), "Record name '{record_name}' is reserved");
59
60        // Retrieve the external stack.
61        let external_stack = self.get_external_stack(locator.program_id())?;
62        // Retrieve the record type from the program.
63        let Ok(record_type) = external_stack.program().get_record(locator.resource()) else {
64            bail!("External '{locator}' is not defined in the program")
65        };
66
67        // Ensure the record name matches.
68        if record_type.name() != record_name {
69            bail!("Expected external record '{record_name}', found external record '{}'", record_type.name())
70        }
71
72        self.matches_record_internal(record, record_type, 0)
73    }
74
75    /// Checks that the given record matches the layout of the record type.
76    fn matches_record(&self, record: &Record<N, Plaintext<N>>, record_name: &Identifier<N>) -> Result<()> {
77        // Ensure the record name is valid.
78        ensure!(!Program::is_reserved_keyword(record_name), "Record name '{record_name}' is reserved");
79
80        // Retrieve the record type from the program.
81        let Ok(record_type) = self.program().get_record(record_name) else {
82            bail!("Record '{record_name}' is not defined in the program")
83        };
84
85        // Ensure the record name matches.
86        if record_type.name() != record_name {
87            bail!("Expected record '{record_name}', found record '{}'", record_type.name())
88        }
89
90        self.matches_record_internal(record, record_type, 0)
91    }
92
93    /// Checks that the given plaintext matches the layout of the plaintext type.
94    fn matches_plaintext(&self, plaintext: &Plaintext<N>, plaintext_type: &PlaintextType<N>) -> Result<()> {
95        self.matches_plaintext_internal(plaintext, plaintext_type, 0)
96    }
97
98    /// Checks that the given future matches the layout of the future type.
99    fn matches_future(&self, future: &Future<N>, locator: &Locator<N>) -> Result<()> {
100        self.matches_future_internal(future, locator, 0)
101    }
102
103    /// Returns `true` if the proving key for the given function name exists.
104    fn contains_proving_key(&self, function_name: &Identifier<N>) -> bool {
105        self.proving_keys.read().contains_key(function_name)
106    }
107
108    /// Returns the proving key for the given function name.
109    fn get_proving_key(&self, function_name: &Identifier<N>) -> Result<ProvingKey<N>> {
110        // If the program is 'credits.aleo', try to load the proving key, if it does not exist.
111        self.try_insert_credits_function_proving_key(function_name)?;
112        // Return the proving key, if it exists.
113        match self.proving_keys.read().get(function_name) {
114            Some(pk) => Ok(pk.clone()),
115            None => bail!("Proving key not found for: {}/{}", self.program.id(), function_name),
116        }
117    }
118
119    /// Inserts the given proving key for the given function name.
120    fn insert_proving_key(&self, function_name: &Identifier<N>, proving_key: ProvingKey<N>) -> Result<()> {
121        // Ensure the function name exists in the program.
122        ensure!(
123            self.program.contains_function(function_name),
124            "Function '{function_name}' does not exist in program '{}'.",
125            self.program.id()
126        );
127        // Insert the proving key.
128        self.proving_keys.write().insert(*function_name, proving_key);
129        Ok(())
130    }
131
132    /// Removes the proving key for the given function name.
133    fn remove_proving_key(&self, function_name: &Identifier<N>) {
134        self.proving_keys.write().shift_remove(function_name);
135    }
136
137    /// Returns `true` if the verifying key for the given function name exists.
138    fn contains_verifying_key(&self, function_name: &Identifier<N>) -> bool {
139        self.verifying_keys.read().contains_key(function_name)
140    }
141
142    /// Returns the verifying key for the given function name.
143    fn get_verifying_key(&self, function_name: &Identifier<N>) -> Result<VerifyingKey<N>> {
144        // Return the verifying key, if it exists.
145        match self.verifying_keys.read().get(function_name) {
146            Some(vk) => Ok(vk.clone()),
147            None => bail!("Verifying key not found for: {}/{}", self.program.id(), function_name),
148        }
149    }
150
151    /// Inserts the given verifying key for the given function name.
152    fn insert_verifying_key(&self, function_name: &Identifier<N>, verifying_key: VerifyingKey<N>) -> Result<()> {
153        // Ensure the function name exists in the program.
154        ensure!(
155            self.program.contains_function(function_name),
156            "Function '{function_name}' does not exist in program '{}'.",
157            self.program.id()
158        );
159        // Insert the verifying key.
160        self.verifying_keys.write().insert(*function_name, verifying_key);
161        Ok(())
162    }
163
164    /// Removes the verifying key for the given function name.
165    fn remove_verifying_key(&self, function_name: &Identifier<N>) {
166        self.verifying_keys.write().shift_remove(function_name);
167    }
168
169    /// Returns the program.
170    fn program(&self) -> &Program<N> {
171        &self.program
172    }
173
174    /// Returns the program ID.
175    fn program_id(&self) -> &ProgramID<N> {
176        self.program.id()
177    }
178
179    /// Returns the program address.
180    fn program_address(&self) -> &Address<N> {
181        &self.program_address
182    }
183
184    /// Returns the program checksum.
185    fn program_checksum(&self) -> &[U8<N>; 32] {
186        &self.program_checksum
187    }
188
189    /// Returns the program checksum as a field element.
190    #[inline]
191    fn program_checksum_as_field(&self) -> Result<Field<N>> {
192        // Get the bits of the program checksum, truncated to the field size.
193        let bits = self
194            .program_checksum
195            .iter()
196            .flat_map(|byte| byte.to_bits_le())
197            .take(Field::<N>::SIZE_IN_DATA_BITS)
198            .collect::<Vec<_>>();
199        // Return the field element from the bits.
200        Field::from_bits_le(&bits)
201    }
202
203    /// Returns the program edition.
204    #[inline]
205    fn program_edition(&self) -> U16<N> {
206        self.program_edition
207    }
208
209    /// Returns the program owner.
210    #[inline]
211    fn program_owner(&self) -> &Option<Address<N>> {
212        &self.program_owner
213    }
214
215    /// Sets the program owner.
216    /// The program owner should only be set for programs that are deployed after `ConsensusVersion::V9` is active.
217    fn set_program_owner(&mut self, program_owner: Option<Address<N>>) {
218        self.program_owner = program_owner;
219    }
220
221    /// Returns the external stack for the given program ID.
222    ///
223    /// Attention - this function is used to check the existence of the external program.
224    /// Developers should explicitly handle the error case so as to not default to the main program.
225    fn get_external_stack(&self, program_id: &ProgramID<N>) -> Result<Arc<Stack<N>>> {
226        // Check that the program ID is not itself.
227        ensure!(
228            program_id != self.program.id(),
229            "Attempted to get the main program '{program_id}' as an external program."
230        );
231        // Check that the program ID is imported by the program.
232        ensure!(self.program.contains_import(program_id), "External program '{program_id}' is not imported.");
233        // Upgrade the weak reference to the process-level stack map and retrieve the external stack.
234        self.stacks
235            .upgrade()
236            .ok_or_else(|| anyhow!("Process-level stack map does not exist"))?
237            .read()
238            .get(program_id)
239            .cloned()
240            .ok_or_else(|| anyhow!("External stack for '{program_id}' does not exist"))
241    }
242
243    /// Returns the function with the given function name.
244    fn get_function(&self, function_name: &Identifier<N>) -> Result<Function<N>> {
245        self.program.get_function(function_name)
246    }
247
248    /// Returns a reference to the function with the given function name.
249    fn get_function_ref(&self, function_name: &Identifier<N>) -> Result<&Function<N>> {
250        self.program.get_function_ref(function_name)
251    }
252
253    /// Returns the expected number of calls for the given function name.
254    fn get_number_of_calls(&self, function_name: &Identifier<N>) -> Result<usize> {
255        // Initialize the base number of calls.
256        let mut num_calls = 1;
257        // Initialize a queue of functions to check.
258        let mut queue = vec![(StackRef::Internal(self), *function_name)];
259        // Iterate over the queue.
260        while let Some((stack_ref, function_name)) = queue.pop() {
261            // Ensure that the number of calls does not exceed the maximum.
262            // Note that one transition is reserved for the fee.
263            ensure!(
264                num_calls < Transaction::<N>::MAX_TRANSITIONS,
265                "Number of calls must be less than '{}'",
266                Transaction::<N>::MAX_TRANSITIONS
267            );
268            // Determine the number of calls for the function.
269            for instruction in stack_ref.get_function_ref(&function_name)?.instructions() {
270                if let Instruction::Call(call) = instruction {
271                    // Determine if this is a function call.
272                    if call.is_function_call(&*stack_ref)? {
273                        // Increment by the number of calls.
274                        num_calls += 1;
275                        // Add the function to the queue.
276                        match call.operator() {
277                            CallOperator::Locator(locator) => {
278                                // If the locator matches the program ID of the provided stack, use it directly.
279                                // Otherwise, retrieve the external stack.
280                                let stack = if locator.program_id() == self.program().id() {
281                                    StackRef::Internal(self)
282                                } else {
283                                    StackRef::External(stack_ref.get_external_stack(locator.program_id())?)
284                                };
285                                queue.push((stack, *locator.resource()));
286                            }
287                            CallOperator::Resource(resource) => {
288                                queue.push((stack_ref.clone(), *resource));
289                            }
290                        }
291                    }
292                }
293            }
294        }
295        // Return the number of calls.
296        Ok(num_calls)
297    }
298
299    /// Returns a value for the given register type.
300    fn sample_value<R: Rng + CryptoRng>(
301        &self,
302        burner_address: &Address<N>,
303        register_type: &RegisterType<N>,
304        rng: &mut R,
305    ) -> Result<Value<N>> {
306        match register_type {
307            RegisterType::Plaintext(plaintext_type) => {
308                Ok(Value::Plaintext(self.sample_plaintext(plaintext_type, rng)?))
309            }
310            RegisterType::Record(record_name) => {
311                Ok(Value::Record(self.sample_record(burner_address, record_name, Group::rand(rng), rng)?))
312            }
313            RegisterType::ExternalRecord(locator) => {
314                // Retrieve the external stack.
315                let stack = self.get_external_stack(locator.program_id())?;
316                // Sample the output.
317                Ok(Value::Record(stack.sample_record(burner_address, locator.resource(), Group::rand(rng), rng)?))
318            }
319            RegisterType::Future(locator) => Ok(Value::Future(self.sample_future(locator, rng)?)),
320        }
321    }
322
323    /// Returns a record for the given record name, with the given burner address and nonce.
324    fn sample_record<R: Rng + CryptoRng>(
325        &self,
326        burner_address: &Address<N>,
327        record_name: &Identifier<N>,
328        nonce: Group<N>,
329        rng: &mut R,
330    ) -> Result<Record<N, Plaintext<N>>> {
331        // Sample a record.
332        let record = self.sample_record_internal(burner_address, record_name, nonce, 0, rng)?;
333        // Ensure the record matches the value type.
334        self.matches_record(&record, record_name)?;
335        // Return the record.
336        Ok(record)
337    }
338
339    /// Returns a record for the given record name, deriving the nonce from tvk and index.
340    fn sample_record_using_tvk<R: Rng + CryptoRng>(
341        &self,
342        burner_address: &Address<N>,
343        record_name: &Identifier<N>,
344        tvk: Field<N>,
345        index: Field<N>,
346        rng: &mut R,
347    ) -> Result<Record<N, Plaintext<N>>> {
348        // Compute the randomizer.
349        let randomizer = N::hash_to_scalar_psd2(&[tvk, index])?;
350        // Construct the record nonce from that randomizer.
351        let record_nonce = N::g_scalar_multiply(&randomizer);
352        // Sample the record with that nonce.
353        self.sample_record(burner_address, record_name, record_nonce, rng)
354    }
355}
356
357impl<N: Network> Stack<N> {
358    /// Checks that the given record matches the layout of the record type.
359    fn matches_record_internal(
360        &self,
361        record: &Record<N, Plaintext<N>>,
362        record_type: &RecordType<N>,
363        depth: usize,
364    ) -> Result<()> {
365        // If the depth exceeds the maximum depth, then the plaintext type is invalid.
366        ensure!(depth <= N::MAX_DATA_DEPTH, "Plaintext exceeded maximum depth of {}", N::MAX_DATA_DEPTH);
367
368        // Retrieve the record name.
369        let record_name = record_type.name();
370        // Ensure the record name is valid.
371        ensure!(!Program::is_reserved_keyword(record_name), "Record name '{record_name}' is reserved");
372
373        // Ensure the visibility of the record owner matches the visibility in the record type.
374        ensure!(
375            record.owner().is_public() == record_type.owner().is_public(),
376            "Visibility of record entry 'owner' does not match"
377        );
378        ensure!(
379            record.owner().is_private() == record_type.owner().is_private(),
380            "Visibility of record entry 'owner' does not match"
381        );
382
383        // Ensure the number of record entries does not exceed the maximum.
384        let num_entries = record.data().len();
385        ensure!(num_entries <= N::MAX_DATA_ENTRIES, "'{record_name}' cannot exceed {} entries", N::MAX_DATA_ENTRIES);
386
387        // Ensure the number of record entries match.
388        let expected_num_entries = record_type.entries().len();
389        if expected_num_entries != num_entries {
390            bail!("'{record_name}' expected {expected_num_entries} entries, found {num_entries} entries")
391        }
392
393        // Ensure the record data match, in the same order.
394        for (i, ((expected_name, expected_type), (entry_name, entry))) in
395            record_type.entries().iter().zip_eq(record.data().iter()).enumerate()
396        {
397            // Ensure the entry name matches.
398            if expected_name != entry_name {
399                bail!("Entry '{i}' in '{record_name}' is incorrect: expected '{expected_name}', found '{entry_name}'")
400            }
401            // Ensure the entry name is valid.
402            ensure!(!Program::is_reserved_keyword(entry_name), "Entry name '{entry_name}' is reserved");
403            // Ensure the entry matches (recursive call).
404            self.matches_entry_internal(record_name, entry_name, entry, expected_type, depth + 1)?;
405        }
406
407        Ok(())
408    }
409
410    /// Checks that the given entry matches the layout of the entry type.
411    fn matches_entry_internal(
412        &self,
413        record_name: &Identifier<N>,
414        entry_name: &Identifier<N>,
415        entry: &Entry<N, Plaintext<N>>,
416        entry_type: &EntryType<N>,
417        depth: usize,
418    ) -> Result<()> {
419        match (entry, entry_type) {
420            (Entry::Constant(plaintext), EntryType::Constant(plaintext_type))
421            | (Entry::Public(plaintext), EntryType::Public(plaintext_type))
422            | (Entry::Private(plaintext), EntryType::Private(plaintext_type)) => {
423                match self.matches_plaintext_internal(plaintext, plaintext_type, depth) {
424                    Ok(()) => Ok(()),
425                    Err(error) => bail!("Invalid record entry '{record_name}.{entry_name}': {error}"),
426                }
427            }
428            _ => bail!(
429                "Type mismatch in record entry '{record_name}.{entry_name}':\n'{entry}'\n does not match\n'{entry_type}'"
430            ),
431        }
432    }
433
434    /// Checks that the given plaintext matches the layout of the plaintext type.
435    fn matches_plaintext_internal(
436        &self,
437        plaintext: &Plaintext<N>,
438        plaintext_type: &PlaintextType<N>,
439        depth: usize,
440    ) -> Result<()> {
441        // If the depth exceeds the maximum depth, then the plaintext type is invalid.
442        ensure!(depth <= N::MAX_DATA_DEPTH, "Plaintext exceeded maximum depth of {}", N::MAX_DATA_DEPTH);
443
444        // Ensure the plaintext matches the plaintext definition in the program.
445        match plaintext_type {
446            PlaintextType::Literal(literal_type) => match plaintext {
447                // If `plaintext` is a literal, it must match the literal type.
448                Plaintext::Literal(literal, ..) => {
449                    // Ensure the literal type matches.
450                    match literal.to_type() == *literal_type {
451                        true => Ok(()),
452                        false => bail!("'{literal}' is invalid: expected {literal_type}"),
453                    }
454                }
455                // If `plaintext` is a struct, this is a mismatch.
456                Plaintext::Struct(..) => bail!("'{plaintext_type}' is invalid: expected literal, found struct"),
457                // If `plaintext` is an array, this is a mismatch.
458                Plaintext::Array(..) => bail!("'{plaintext_type}' is invalid: expected literal, found array"),
459            },
460            PlaintextType::Struct(struct_name) => {
461                // Ensure the struct name is valid.
462                ensure!(!Program::is_reserved_keyword(struct_name), "Struct '{struct_name}' is reserved");
463
464                // Retrieve the struct from the program.
465                let Ok(struct_) = self.program().get_struct(struct_name) else {
466                    bail!("Struct '{struct_name}' is not defined in the program")
467                };
468
469                // Ensure the struct name matches.
470                if struct_.name() != struct_name {
471                    bail!("Expected struct '{struct_name}', found struct '{}'", struct_.name())
472                }
473
474                // Retrieve the struct members.
475                let members = match plaintext {
476                    Plaintext::Literal(..) => bail!("'{struct_name}' is invalid: expected struct, found literal"),
477                    Plaintext::Struct(members, ..) => members,
478                    Plaintext::Array(..) => bail!("'{struct_name}' is invalid: expected struct, found array"),
479                };
480
481                let num_members = members.len();
482                // Ensure the number of struct members does not go below the minimum.
483                ensure!(
484                    num_members >= N::MIN_STRUCT_ENTRIES,
485                    "'{struct_name}' cannot be less than {} entries",
486                    N::MIN_STRUCT_ENTRIES
487                );
488                // Ensure the number of struct members does not exceed the maximum.
489                ensure!(
490                    num_members <= N::MAX_STRUCT_ENTRIES,
491                    "'{struct_name}' cannot exceed {} entries",
492                    N::MAX_STRUCT_ENTRIES
493                );
494
495                // Ensure the number of struct members match.
496                let expected_num_members = struct_.members().len();
497                if expected_num_members != num_members {
498                    bail!("'{struct_name}' expected {expected_num_members} members, found {num_members} members")
499                }
500
501                // Ensure the struct members match, in the same order.
502                for (i, ((expected_name, expected_type), (member_name, member))) in
503                    struct_.members().iter().zip_eq(members.iter()).enumerate()
504                {
505                    // Ensure the member name matches.
506                    if expected_name != member_name {
507                        bail!(
508                            "Member '{i}' in '{struct_name}' is incorrect: expected '{expected_name}', found '{member_name}'"
509                        )
510                    }
511                    // Ensure the member name is valid.
512                    ensure!(!Program::is_reserved_keyword(member_name), "Member name '{member_name}' is reserved");
513                    // Ensure the member plaintext matches (recursive call).
514                    self.matches_plaintext_internal(member, expected_type, depth + 1)?;
515                }
516
517                Ok(())
518            }
519            PlaintextType::Array(array_type) => match plaintext {
520                // If `plaintext` is a literal, this is a mismatch.
521                Plaintext::Literal(..) => bail!("'{plaintext_type}' is invalid: expected array, found literal"),
522                // If `plaintext` is a struct, this is a mismatch.
523                Plaintext::Struct(..) => bail!("'{plaintext_type}' is invalid: expected array, found struct"),
524                // If `plaintext` is an array, it must match the array type.
525                Plaintext::Array(array, ..) => {
526                    // Ensure the array length matches.
527                    let (actual_length, expected_length) = (array.len(), array_type.length());
528                    if **expected_length as usize != actual_length {
529                        bail!(
530                            "'{plaintext_type}' is invalid: expected {expected_length} elements, found {actual_length} elements"
531                        )
532                    }
533                    // Ensure the array elements match.
534                    for element in array.iter() {
535                        self.matches_plaintext_internal(element, array_type.next_element_type(), depth + 1)?;
536                    }
537                    Ok(())
538                }
539            },
540        }
541    }
542
543    /// Checks that the given future matches the layout of the future type.
544    fn matches_future_internal(&self, future: &Future<N>, locator: &Locator<N>, depth: usize) -> Result<()> {
545        // If the depth exceeds the maximum depth, then the future type is invalid.
546        ensure!(depth <= N::MAX_DATA_DEPTH, "Future exceeded maximum depth of {}", N::MAX_DATA_DEPTH);
547
548        // Ensure that the program IDs match.
549        ensure!(future.program_id() == locator.program_id(), "Future program ID does not match");
550
551        // Ensure that the function names match.
552        ensure!(future.function_name() == locator.resource(), "Future name does not match");
553
554        // Retrieve the external stack, if needed.
555        let external_stack = match locator.program_id() == self.program_id() {
556            true => None,
557            // Attention - This method must fail here and early return if the external program is missing.
558            // Otherwise, this method will proceed to look for the requested function in its own program.
559            false => Some(self.get_external_stack(locator.program_id())?),
560        };
561        // Retrieve the associated function.
562        let function = match &external_stack {
563            Some(external_stack) => external_stack.get_function_ref(locator.resource())?,
564            None => self.get_function_ref(locator.resource())?,
565        };
566        // Retrieve the finalize inputs.
567        let inputs = match function.finalize_logic() {
568            Some(finalize_logic) => finalize_logic.inputs(),
569            None => bail!("Function '{locator}' does not have a finalize block"),
570        };
571
572        // Ensure the number of arguments matches the number of inputs.
573        ensure!(future.arguments().len() == inputs.len(), "Future arguments do not match");
574
575        // Check that the arguments match the inputs.
576        for (argument, input) in future.arguments().iter().zip_eq(inputs.iter()) {
577            match (argument, input.finalize_type()) {
578                (Argument::Plaintext(plaintext), FinalizeType::Plaintext(plaintext_type)) => {
579                    self.matches_plaintext_internal(plaintext, plaintext_type, depth + 1)?
580                }
581                (Argument::Future(future), FinalizeType::Future(locator)) => {
582                    self.matches_future_internal(future, locator, depth + 1)?
583                }
584                (_, input_type) => {
585                    bail!("Argument type does not match input type: expected '{input_type}'")
586                }
587            }
588        }
589
590        Ok(())
591    }
592}