1use crate::xdr::{HostFunction, SorobanAuthorizedFunction, SorobanAuthorizedInvocation};
2
3#[derive(Debug, PartialEq, Eq)]
9pub enum AuthStyle {
10 Strict,
13 NonStrict,
16 Invalid,
18}
19
20pub fn classify_auth_invocation(
26 source_host_fn: &HostFunction,
27 auth_invocation: &SorobanAuthorizedInvocation,
28) -> AuthStyle {
29 if matches!(source_host_fn, HostFunction::UploadContractWasm(_)) {
31 return AuthStyle::Invalid;
32 }
33
34 let is_strict = match (source_host_fn, &auth_invocation.function) {
39 (HostFunction::InvokeContract(op), SorobanAuthorizedFunction::ContractFn(args)) => {
40 args == op
41 }
42 (
43 HostFunction::CreateContract(op),
44 SorobanAuthorizedFunction::CreateContractHostFn(args),
45 ) => args == op,
46 (
47 HostFunction::CreateContractV2(op),
48 SorobanAuthorizedFunction::CreateContractV2HostFn(args),
49 ) => args == op,
50 _ => false,
51 };
52
53 if is_strict {
54 AuthStyle::Strict
55 } else {
56 AuthStyle::NonStrict
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use crate::xdr::{
64 AccountId, BytesM, ContractExecutable, ContractIdPreimage, ContractIdPreimageFromAddress,
65 CreateContractArgsV2, Hash, InvokeContractArgs, PublicKey, ScAddress, ScSymbol, ScVal,
66 Uint256, VecM,
67 };
68 use stellar_strkey::ed25519;
69
70 const SOURCE_ACCOUNT: &str = "GBZXN7PIRZGNMHGA7MUUUF4GWPY5AYPV6LY4UV2GL6VJGIQRXFDNMADI";
71
72 fn source_bytes() -> [u8; 32] {
73 ed25519::PublicKey::from_string(SOURCE_ACCOUNT).unwrap().0
74 }
75
76 fn ed25519_address(bytes: [u8; 32]) -> ScAddress {
77 ScAddress::Account(AccountId(PublicKey::PublicKeyTypeEd25519(Uint256(bytes))))
78 }
79
80 fn host_fn_invoke(contract: [u8; 32], fn_name: &str, args: &[ScVal]) -> HostFunction {
81 HostFunction::InvokeContract(InvokeContractArgs {
82 contract_address: ScAddress::Contract(stellar_xdr::ContractId(Hash(contract))),
83 function_name: ScSymbol(fn_name.try_into().unwrap()),
84 args: args.try_into().unwrap(),
85 })
86 }
87
88 fn host_fn_create(wasm_hash: [u8; 32], args: &[ScVal]) -> HostFunction {
89 HostFunction::CreateContractV2(CreateContractArgsV2 {
90 contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
91 address: ed25519_address(source_bytes()),
92 salt: Uint256([0u8; 32]),
93 }),
94 executable: ContractExecutable::Wasm(wasm_hash.into()),
95 constructor_args: args.try_into().unwrap(),
96 })
97 }
98
99 fn invocation_contract(
100 contract: [u8; 32],
101 fn_name: &str,
102 args: &[ScVal],
103 ) -> SorobanAuthorizedInvocation {
104 SorobanAuthorizedInvocation {
105 function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
106 contract_address: ScAddress::Contract(stellar_xdr::ContractId(Hash(contract))),
107 function_name: ScSymbol(fn_name.try_into().unwrap()),
108 args: args.to_vec().try_into().unwrap(),
109 }),
110 sub_invocations: VecM::default(),
111 }
112 }
113
114 fn invocation_create(wasm_hash: [u8; 32], args: &[ScVal]) -> SorobanAuthorizedInvocation {
115 SorobanAuthorizedInvocation {
116 function: SorobanAuthorizedFunction::CreateContractV2HostFn(CreateContractArgsV2 {
117 contract_id_preimage: ContractIdPreimage::Address(ContractIdPreimageFromAddress {
118 address: ed25519_address(source_bytes()),
119 salt: Uint256([0u8; 32]),
120 }),
121 executable: ContractExecutable::Wasm(wasm_hash.into()),
122 constructor_args: args.try_into().unwrap(),
123 }),
124 sub_invocations: VecM::default(),
125 }
126 }
127
128 #[test]
129 fn test_matching_root_invocation_is_strict() {
130 let contract = [1u8; 32];
131 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
132
133 let host_fn = host_fn_invoke(contract, "hello", args);
134 let invocation = invocation_contract(contract, "hello", args);
135
136 let style = classify_auth_invocation(&host_fn, &invocation);
137 assert_eq!(style, AuthStyle::Strict);
138 }
139
140 #[test]
141 fn test_subinvocations_dont_affect_root_match() {
142 let contract = [1u8; 32];
143 let other = [99u8; 32];
144 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
145
146 let host_fn = host_fn_invoke(contract, "hello", args);
147 let mut invocation = invocation_contract(contract, "hello", args);
148 invocation.sub_invocations = [invocation_contract(other, "other", &[])]
149 .try_into()
150 .unwrap();
151
152 let style = classify_auth_invocation(&host_fn, &invocation);
153 assert_eq!(style, AuthStyle::Strict);
154 }
155
156 #[test]
157 fn test_different_root_contract_is_non_strict() {
158 let contract = [1u8; 32];
159 let other = [99u8; 32];
160
161 let host_fn = host_fn_invoke(contract, "hello", &[]);
162 let invocation = invocation_contract(other, "hello", &[]);
163
164 let style = classify_auth_invocation(&host_fn, &invocation);
165 assert_eq!(style, AuthStyle::NonStrict);
166 }
167
168 #[test]
169 fn test_different_function_same_contract_is_non_strict() {
170 let contract = [1u8; 32];
171
172 let host_fn = host_fn_invoke(contract, "hello", &[]);
173 let invocation = invocation_contract(contract, "transfer", &[]);
174
175 let style = classify_auth_invocation(&host_fn, &invocation);
176 assert_eq!(style, AuthStyle::NonStrict);
177 }
178
179 #[test]
180 fn test_different_args_is_non_strict() {
181 let contract = [1u8; 32];
182 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
183 let wrong = &[ScVal::U32(43), ScVal::Symbol("hello".try_into().unwrap())];
184
185 let host_fn = host_fn_invoke(contract, "hello", args);
186 let invocation = invocation_contract(contract, "hello", wrong);
187
188 let style = classify_auth_invocation(&host_fn, &invocation);
189 assert_eq!(style, AuthStyle::NonStrict);
190 }
191
192 #[test]
193 fn test_upload_wasm_with_auth_entry_is_invalid() {
194 let contract = [1u8; 32];
195 let wasm_hash: BytesM = [42u8; 32].try_into().unwrap();
196
197 let host_fn = HostFunction::UploadContractWasm(wasm_hash);
198 let invocation = invocation_contract(contract, "hello", &[]);
199
200 let style = classify_auth_invocation(&host_fn, &invocation);
201 assert_eq!(style, AuthStyle::Invalid);
202 }
203
204 #[test]
205 fn test_matching_create_contract_root_is_strict() {
206 let contract = [1u8; 32];
207 let wasm_hash = [42u8; 32];
208 let args = &[ScVal::U32(42), ScVal::Symbol("hello".try_into().unwrap())];
209
210 let host_fn = host_fn_create(wasm_hash, args);
211 let mut invocation = invocation_create(wasm_hash, args);
212 invocation.sub_invocations = [invocation_contract(contract, "__constructor", args)]
213 .try_into()
214 .unwrap();
215
216 let style = classify_auth_invocation(&host_fn, &invocation);
217 assert_eq!(style, AuthStyle::Strict);
218 }
219}