starknet_devnet_core/
contract_class_choice.rs

1use std::str::FromStr;
2
3use starknet_rs_core::types::Felt;
4use starknet_rs_core::utils::get_selector_from_name;
5use starknet_types::contract_class::deprecated::json_contract_class::Cairo0Json;
6use starknet_types::contract_class::{Cairo0ContractClass, ContractClass};
7use starknet_types::traits::HashProducer;
8
9use crate::constants::{CAIRO_0_ACCOUNT_CONTRACT, CAIRO_1_ACCOUNT_CONTRACT_SIERRA};
10use crate::error::DevnetResult;
11
12#[derive(clap::ValueEnum, Debug, Clone)]
13pub enum AccountContractClassChoice {
14    Cairo0,
15    Cairo1,
16}
17
18impl AccountContractClassChoice {
19    pub fn get_class_wrapper(&self) -> DevnetResult<AccountClassWrapper> {
20        Ok(match self {
21            AccountContractClassChoice::Cairo0 => {
22                let contract_json = Cairo0Json::raw_json_from_json_str(CAIRO_0_ACCOUNT_CONTRACT)?;
23                let contract_class = Cairo0ContractClass::RawJson(contract_json);
24
25                AccountClassWrapper {
26                    class_hash: contract_class.generate_hash()?,
27                    contract_class: ContractClass::Cairo0(contract_class),
28                    class_metadata: "OpenZeppelin 0.5.1",
29                }
30            }
31            AccountContractClassChoice::Cairo1 => {
32                let contract_class = ContractClass::Cairo1(
33                    ContractClass::cairo_1_from_sierra_json_str(CAIRO_1_ACCOUNT_CONTRACT_SIERRA)?,
34                );
35                AccountClassWrapper {
36                    class_hash: contract_class.generate_hash()?,
37                    contract_class,
38                    class_metadata: "OpenZeppelin 1.0.0",
39                }
40            }
41        })
42    }
43}
44#[derive(Clone, Debug)]
45pub struct AccountClassWrapper {
46    pub contract_class: ContractClass,
47    pub class_hash: Felt,
48    pub class_metadata: &'static str,
49}
50
51impl FromStr for AccountClassWrapper {
52    type Err = crate::error::Error;
53
54    fn from_str(path_candidate: &str) -> Result<Self, Self::Err> {
55        // load artifact
56        let contract_class = ContractClass::cairo_1_from_sierra_json_str(
57            std::fs::read_to_string(path_candidate)?.as_str(),
58        )?;
59
60        // check that artifact is really account
61        let execute_selector = get_selector_from_name("__execute__")
62            .map_err(|err| crate::error::Error::UnexpectedInternalError { msg: err.to_string() })?;
63        let validate_selector = get_selector_from_name("__validate__")
64            .map_err(|err| crate::error::Error::UnexpectedInternalError { msg: err.to_string() })?;
65        let mut has_execute = false;
66        let mut has_validate = false;
67        for entry_point in &contract_class.entry_points_by_type.external {
68            let selector: Felt = (&entry_point.selector).into();
69            has_execute |= selector == execute_selector;
70            has_validate |= selector == validate_selector;
71        }
72
73        if !has_execute || !has_validate {
74            let msg = format!(
75                "Not a valid Sierra account artifact; has __execute__: {has_execute}; has \
76                 __validate__: {has_validate}"
77            );
78            return Err(crate::error::Error::ContractClassLoadError(msg));
79        }
80
81        // generate the hash and return
82        let contract_class = ContractClass::Cairo1(contract_class);
83        let class_hash = contract_class.generate_hash()?;
84        Ok(Self { contract_class, class_hash, class_metadata: "Custom" })
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::str::FromStr;
91
92    use clap::ValueEnum;
93    use starknet_types::felt::felt_from_prefixed_hex;
94    use starknet_types::traits::HashProducer;
95
96    use super::AccountContractClassChoice;
97    use crate::constants::{
98        CAIRO_0_ACCOUNT_CONTRACT_HASH, CAIRO_1_ACCOUNT_CONTRACT_SIERRA_HASH,
99        CAIRO_1_ACCOUNT_CONTRACT_SIERRA_PATH,
100    };
101    use crate::contract_class_choice::AccountClassWrapper;
102
103    #[test]
104    fn all_methods_work_with_all_options() {
105        for implementation in AccountContractClassChoice::value_variants().iter() {
106            let AccountClassWrapper { contract_class, class_hash, class_metadata } =
107                implementation.get_class_wrapper().unwrap();
108            let generated_hash = contract_class.generate_hash().unwrap();
109            assert_eq!(generated_hash, class_hash);
110            assert!(class_metadata.starts_with("OpenZeppelin"));
111        }
112    }
113
114    #[test]
115    fn correct_hash_calculated() {
116        assert_eq!(
117            AccountContractClassChoice::Cairo0.get_class_wrapper().unwrap().class_hash,
118            felt_from_prefixed_hex(CAIRO_0_ACCOUNT_CONTRACT_HASH).unwrap()
119        );
120
121        assert_eq!(
122            AccountContractClassChoice::Cairo1.get_class_wrapper().unwrap().class_hash,
123            felt_from_prefixed_hex(CAIRO_1_ACCOUNT_CONTRACT_SIERRA_HASH).unwrap()
124        )
125    }
126
127    #[test]
128    fn correct_metadata() {
129        assert_eq!(
130            AccountContractClassChoice::Cairo0.get_class_wrapper().unwrap().class_metadata,
131            "OpenZeppelin 0.5.1"
132        );
133        assert_eq!(
134            AccountContractClassChoice::Cairo1.get_class_wrapper().unwrap().class_metadata,
135            "OpenZeppelin 1.0.0"
136        );
137
138        let custom_class =
139            AccountClassWrapper::from_str(CAIRO_1_ACCOUNT_CONTRACT_SIERRA_PATH).unwrap();
140        assert_eq!(custom_class.class_metadata, "Custom");
141    }
142}