Skip to main content

soroban_cli/signer/
validation.rs

1use crate::xdr::{HostFunction, SorobanAuthorizedFunction, SorobanAuthorizedInvocation};
2
3/// Classification of an `Address`-credential auth entry's relationship to the
4/// transaction's host function.
5///
6/// `SourceAccount` credential entries are out of scope here — they are signed
7/// implicitly via the transaction envelope and never reach this classifier.
8#[derive(Debug, PartialEq, Eq)]
9pub enum AuthStyle {
10    /// `root_invocation` matches the host function exactly. Safe to sign:
11    /// the entry is bound to the host function.
12    Strict,
13    /// `root_invocation` does not match the host function exactly. Any transaction
14    /// whose auth tree contains this entry could consume the resulting signature.
15    NonStrict,
16    /// `root_invocation` is not expected for the host function
17    Invalid,
18}
19
20/// Classify an auth invocation against the transaction's host function.
21///
22/// ### Arguments
23/// * `source_host_fn`- The transaction's host function
24/// * `auth_invocation` - The auth entry's root invocation
25pub fn classify_auth_invocation(
26    source_host_fn: &HostFunction,
27    auth_invocation: &SorobanAuthorizedInvocation,
28) -> AuthStyle {
29    // No auth entries are valid for `UploadContractWasm`.
30    if matches!(source_host_fn, HostFunction::UploadContractWasm(_)) {
31        return AuthStyle::Invalid;
32    }
33
34    // Check if the auth entry's root invocation matches the host function exactly.
35    // This is different than just a `root_auth` check, as contracts that authorize with
36    // `require_auth_for_args` at the root are not considered strict auth. This tradeoff is
37    // made to ensure that even a tampered auth entry can be flagged as non-strict.
38    let is_strict = match (source_host_fn, &auth_invocation.function) {
39        (HostFunction::InvokeContract(op), SorobanAuthorizedFunction::ContractFn(args)) => {
40            args == op
41        }
42        (
43            HostFunction::CreateContract(op),
44            SorobanAuthorizedFunction::CreateContractHostFn(args),
45        ) => args == op,
46        (
47            HostFunction::CreateContractV2(op),
48            SorobanAuthorizedFunction::CreateContractV2HostFn(args),
49        ) => args == op,
50        _ => false,
51    };
52
53    if is_strict {
54        AuthStyle::Strict
55    } else {
56        AuthStyle::NonStrict
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use crate::xdr::{
64        AccountId, BytesM, ContractExecutable, ContractIdPreimage, ContractIdPreimageFromAddress,
65        CreateContractArgsV2, Hash, InvokeContractArgs, PublicKey, ScAddress, ScSymbol, ScVal,
66        Uint256, VecM,
67    };
68    use stellar_strkey::ed25519;
69
70    const SOURCE_ACCOUNT: &str = "GBZXN7PIRZGNMHGA7MUUUF4GWPY5AYPV6LY4UV2GL6VJGIQRXFDNMADI";
71
72    fn source_bytes() -> [u8; 32] {
73        ed25519::PublicKey::from_string(SOURCE_ACCOUNT).unwrap().0
74    }
75
76    fn ed25519_address(bytes: [u8; 32]) -> ScAddress {
77        ScAddress::Account(AccountId(PublicKey::PublicKeyTypeEd25519(Uint256(bytes))))
78    }
79
80    fn host_fn_invoke(contract: [u8; 32], fn_name: &str, args: &[ScVal]) -> HostFunction {
81        HostFunction::InvokeContract(InvokeContractArgs {
82            contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(contract))),
83            function_name: ScSymbol(fn_name.try_into().unwrap()),
84            args: args.try_into().unwrap(),
85        })
86    }
87
88    fn host_fn_create(wasm_hash: [u8; 32], args: &[ScVal]) -> HostFunction {
89        HostFunction::CreateContractV2(CreateContractArgsV2 {
90            contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
91                address: ed25519_address(source_bytes()),
92                salt: Uint256([0u8; 32]),
93            }),
94            executable: ContractExecutable::Wasm(wasm_hash.into()),
95            constructor_args: args.try_into().unwrap(),
96        })
97    }
98
99    fn invocation_contract(
100        contract: [u8; 32],
101        fn_name: &str,
102        args: &[ScVal],
103    ) -> SorobanAuthorizedInvocation {
104        SorobanAuthorizedInvocation {
105            function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
106                contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(
107                    contract,
108                ))),
109                function_name: ScSymbol(fn_name.try_into().unwrap()),
110                args: args.to_vec().try_into().unwrap(),
111            }),
112            sub_invocations: VecM::default(),
113        }
114    }
115
116    fn invocation_create(wasm_hash: [u8; 32], args: &[ScVal]) -> SorobanAuthorizedInvocation {
117        SorobanAuthorizedInvocation {
118            function: SorobanAuthorizedFunction::CreateContractV2HostFn(CreateContractArgsV2 {
119                contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
120                    address: ed25519_address(source_bytes()),
121                    salt: Uint256([0u8; 32]),
122                }),
123                executable: ContractExecutable::Wasm(wasm_hash.into()),
124                constructor_args: args.try_into().unwrap(),
125            }),
126            sub_invocations: VecM::default(),
127        }
128    }
129
130    #[test]
131    fn test_matching_root_invocation_is_strict() {
132        let contract = [1u8; 32];
133        let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
134
135        let host_fn = host_fn_invoke(contract, "hello", args);
136        let invocation = invocation_contract(contract, "hello", args);
137
138        let style = classify_auth_invocation(&host_fn, &invocation);
139        assert_eq!(style, AuthStyle::Strict);
140    }
141
142    #[test]
143    fn test_subinvocations_dont_affect_root_match() {
144        let contract = [1u8; 32];
145        let other = [99u8; 32];
146        let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
147
148        let host_fn = host_fn_invoke(contract, "hello", args);
149        let mut invocation = invocation_contract(contract, "hello", args);
150        invocation.sub_invocations = [invocation_contract(other, "other", &[])]
151            .try_into()
152            .unwrap();
153
154        let style = classify_auth_invocation(&host_fn, &invocation);
155        assert_eq!(style, AuthStyle::Strict);
156    }
157
158    #[test]
159    fn test_different_root_contract_is_non_strict() {
160        let contract = [1u8; 32];
161        let other = [99u8; 32];
162
163        let host_fn = host_fn_invoke(contract, "hello", &[]);
164        let invocation = invocation_contract(other, "hello", &[]);
165
166        let style = classify_auth_invocation(&host_fn, &invocation);
167        assert_eq!(style, AuthStyle::NonStrict);
168    }
169
170    #[test]
171    fn test_different_function_same_contract_is_non_strict() {
172        let contract = [1u8; 32];
173
174        let host_fn = host_fn_invoke(contract, "hello", &[]);
175        let invocation = invocation_contract(contract, "transfer", &[]);
176
177        let style = classify_auth_invocation(&host_fn, &invocation);
178        assert_eq!(style, AuthStyle::NonStrict);
179    }
180
181    #[test]
182    fn test_different_args_is_non_strict() {
183        let contract = [1u8; 32];
184        let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
185        let wrong = &[ScVal::U32(43), ScVal::Symbol("hello".try_into().unwrap())];
186
187        let host_fn = host_fn_invoke(contract, "hello", args);
188        let invocation = invocation_contract(contract, "hello", wrong);
189
190        let style = classify_auth_invocation(&host_fn, &invocation);
191        assert_eq!(style, AuthStyle::NonStrict);
192    }
193
194    #[test]
195    fn test_upload_wasm_with_auth_entry_is_invalid() {
196        let contract = [1u8; 32];
197        let wasm_hash: BytesM = [42u8; 32].try_into().unwrap();
198
199        let host_fn = HostFunction::UploadContractWasm(wasm_hash);
200        let invocation = invocation_contract(contract, "hello", &[]);
201
202        let style = classify_auth_invocation(&host_fn, &invocation);
203        assert_eq!(style, AuthStyle::Invalid);
204    }
205
206    #[test]
207    fn test_matching_create_contract_root_is_strict() {
208        let contract = [1u8; 32];
209        let wasm_hash = [42u8; 32];
210        let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
211
212        let host_fn = host_fn_create(wasm_hash, args);
213        let mut invocation = invocation_create(wasm_hash, args);
214        invocation.sub_invocations = [invocation_contract(contract, "__constructor", args)]
215            .try_into()
216            .unwrap();
217
218        let style = classify_auth_invocation(&host_fn, &invocation);
219        assert_eq!(style, AuthStyle::Strict);
220    }
221}