1use crate::xdr::{HostFunction, SorobanAuthorizedFunction, SorobanAuthorizedInvocation};
2
3#[derive(Debug, PartialEq, Eq)]
9pub enum AuthStyle {
10 Strict,
13 NonStrict,
16 Invalid,
18}
19
20pub fn classify_auth_invocation(
26 source_host_fn: &HostFunction,
27 auth_invocation: &SorobanAuthorizedInvocation,
28) -> AuthStyle {
29 if matches!(source_host_fn, HostFunction::UploadContractWasm(_)) {
31 return AuthStyle::Invalid;
32 }
33
34 let is_strict = match (source_host_fn, &auth_invocation.function) {
39 (HostFunction::InvokeContract(op), SorobanAuthorizedFunction::ContractFn(args)) => {
40 args == op
41 }
42 (
43 HostFunction::CreateContract(op),
44 SorobanAuthorizedFunction::CreateContractHostFn(args),
45 ) => args == op,
46 (
47 HostFunction::CreateContractV2(op),
48 SorobanAuthorizedFunction::CreateContractV2HostFn(args),
49 ) => args == op,
50 _ => false,
51 };
52
53 if is_strict {
54 AuthStyle::Strict
55 } else {
56 AuthStyle::NonStrict
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use crate::xdr::{
64 AccountId, BytesM, ContractExecutable, ContractIdPreimage, ContractIdPreimageFromAddress,
65 CreateContractArgsV2, Hash, InvokeContractArgs, PublicKey, ScAddress, ScSymbol, ScVal,
66 Uint256, VecM,
67 };
68 use stellar_strkey::ed25519;
69
70 const SOURCE_ACCOUNT: &str = "GBZXN7PIRZGNMHGA7MUUUF4GWPY5AYPV6LY4UV2GL6VJGIQRXFDNMADI";
71
72 fn source_bytes() -> [u8; 32] {
73 ed25519::PublicKey::from_string(SOURCE_ACCOUNT).unwrap().0
74 }
75
76 fn ed25519_address(bytes: [u8; 32]) -> ScAddress {
77 ScAddress::Account(AccountId(PublicKey::PublicKeyTypeEd25519(Uint256(bytes))))
78 }
79
80 fn host_fn_invoke(contract: [u8; 32], fn_name: &str, args: &[ScVal]) -> HostFunction {
81 HostFunction::InvokeContract(InvokeContractArgs {
82 contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(contract))),
83 function_name: ScSymbol(fn_name.try_into().unwrap()),
84 args: args.try_into().unwrap(),
85 })
86 }
87
88 fn host_fn_create(wasm_hash: [u8; 32], args: &[ScVal]) -> HostFunction {
89 HostFunction::CreateContractV2(CreateContractArgsV2 {
90 contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
91 address: ed25519_address(source_bytes()),
92 salt: Uint256([0u8; 32]),
93 }),
94 executable: ContractExecutable::Wasm(wasm_hash.into()),
95 constructor_args: args.try_into().unwrap(),
96 })
97 }
98
99 fn invocation_contract(
100 contract: [u8; 32],
101 fn_name: &str,
102 args: &[ScVal],
103 ) -> SorobanAuthorizedInvocation {
104 SorobanAuthorizedInvocation {
105 function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
106 contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(
107 contract,
108 ))),
109 function_name: ScSymbol(fn_name.try_into().unwrap()),
110 args: args.to_vec().try_into().unwrap(),
111 }),
112 sub_invocations: VecM::default(),
113 }
114 }
115
116 fn invocation_create(wasm_hash: [u8; 32], args: &[ScVal]) -> SorobanAuthorizedInvocation {
117 SorobanAuthorizedInvocation {
118 function: SorobanAuthorizedFunction::CreateContractV2HostFn(CreateContractArgsV2 {
119 contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
120 address: ed25519_address(source_bytes()),
121 salt: Uint256([0u8; 32]),
122 }),
123 executable: ContractExecutable::Wasm(wasm_hash.into()),
124 constructor_args: args.try_into().unwrap(),
125 }),
126 sub_invocations: VecM::default(),
127 }
128 }
129
130 #[test]
131 fn test_matching_root_invocation_is_strict() {
132 let contract = [1u8; 32];
133 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
134
135 let host_fn = host_fn_invoke(contract, "hello", args);
136 let invocation = invocation_contract(contract, "hello", args);
137
138 let style = classify_auth_invocation(&host_fn, &invocation);
139 assert_eq!(style, AuthStyle::Strict);
140 }
141
142 #[test]
143 fn test_subinvocations_dont_affect_root_match() {
144 let contract = [1u8; 32];
145 let other = [99u8; 32];
146 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
147
148 let host_fn = host_fn_invoke(contract, "hello", args);
149 let mut invocation = invocation_contract(contract, "hello", args);
150 invocation.sub_invocations = [invocation_contract(other, "other", &[])]
151 .try_into()
152 .unwrap();
153
154 let style = classify_auth_invocation(&host_fn, &invocation);
155 assert_eq!(style, AuthStyle::Strict);
156 }
157
158 #[test]
159 fn test_different_root_contract_is_non_strict() {
160 let contract = [1u8; 32];
161 let other = [99u8; 32];
162
163 let host_fn = host_fn_invoke(contract, "hello", &[]);
164 let invocation = invocation_contract(other, "hello", &[]);
165
166 let style = classify_auth_invocation(&host_fn, &invocation);
167 assert_eq!(style, AuthStyle::NonStrict);
168 }
169
170 #[test]
171 fn test_different_function_same_contract_is_non_strict() {
172 let contract = [1u8; 32];
173
174 let host_fn = host_fn_invoke(contract, "hello", &[]);
175 let invocation = invocation_contract(contract, "transfer", &[]);
176
177 let style = classify_auth_invocation(&host_fn, &invocation);
178 assert_eq!(style, AuthStyle::NonStrict);
179 }
180
181 #[test]
182 fn test_different_args_is_non_strict() {
183 let contract = [1u8; 32];
184 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
185 let wrong = &[ScVal::U32(43), ScVal::Symbol("hello".try_into().unwrap())];
186
187 let host_fn = host_fn_invoke(contract, "hello", args);
188 let invocation = invocation_contract(contract, "hello", wrong);
189
190 let style = classify_auth_invocation(&host_fn, &invocation);
191 assert_eq!(style, AuthStyle::NonStrict);
192 }
193
194 #[test]
195 fn test_upload_wasm_with_auth_entry_is_invalid() {
196 let contract = [1u8; 32];
197 let wasm_hash: BytesM = [42u8; 32].try_into().unwrap();
198
199 let host_fn = HostFunction::UploadContractWasm(wasm_hash);
200 let invocation = invocation_contract(contract, "hello", &[]);
201
202 let style = classify_auth_invocation(&host_fn, &invocation);
203 assert_eq!(style, AuthStyle::Invalid);
204 }
205
206 #[test]
207 fn test_matching_create_contract_root_is_strict() {
208 let contract = [1u8; 32];
209 let wasm_hash = [42u8; 32];
210 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
211
212 let host_fn = host_fn_create(wasm_hash, args);
213 let mut invocation = invocation_create(wasm_hash, args);
214 invocation.sub_invocations = [invocation_contract(contract, "__constructor", args)]
215 .try_into()
216 .unwrap();
217
218 let style = classify_auth_invocation(&host_fn, &invocation);
219 assert_eq!(style, AuthStyle::Strict);
220 }
221}