1use stellar_strkey::Contract as ContractId;
18use stellar_xdr::curr::{
19 OperationBody, SorobanAuthorizedFunction, SorobanAuthorizedInvocation, Transaction,
20};
21
22use crate::SorobanHelperError;
23
24#[derive(Clone)]
25pub enum Guard {
26 NumberOfAllowedCalls(u16),
29 AuthorizedCallsFor(AuthorizedCallsForContract),
30 }
32
33impl Guard {
34 pub fn check(&self, transaction: &Transaction) -> Result<bool, SorobanHelperError> {
40 match self {
41 Guard::NumberOfAllowedCalls(remaining) => Ok(*remaining > 0),
42 Guard::AuthorizedCallsFor(calls_for_contract) => {
43 Ok(calls_for_contract.check(transaction))
44 } }
46 }
47
48 pub fn update(&mut self, transaction: &Transaction) -> Result<(), SorobanHelperError> {
53 match self {
54 Guard::NumberOfAllowedCalls(remaining) => {
55 if *remaining > 0 {
56 *remaining -= 1;
57 }
58 Ok(())
59 }
60 Guard::AuthorizedCallsFor(calls_for_contract) => {
61 calls_for_contract.update(transaction);
62 Ok(())
63 } }
65 }
66}
67
68#[derive(Clone)]
69pub struct AuthorizedCallsForContract {
70 pub contract_id: ContractId,
71 pub remaining: u16,
72}
73
74impl AuthorizedCallsForContract {
75 fn count_authorized_calls(&self, invocation: &SorobanAuthorizedInvocation) -> u16 {
76 let mut count = 0;
77 if let SorobanAuthorizedFunction::ContractFn(args) = &invocation.function {
78 if args.contract_address.to_string() == self.contract_id.to_string() {
79 count += 1;
80 }
81 }
82 for sub_invocation in invocation.sub_invocations.iter() {
84 count += self.count_authorized_calls(sub_invocation);
85 }
86 count
87 }
88
89 fn extract_contract_calls(&self, tx: &Transaction) -> u16 {
90 let mut calls = 0;
91 for op in tx.operations.iter() {
92 if let OperationBody::InvokeHostFunction(invoke_op) = &op.body {
93 for auth_entry in invoke_op.auth.iter() {
94 calls += self.count_authorized_calls(&auth_entry.root_invocation);
95 }
96 }
97 }
98 calls
99 }
100
101 pub fn check(&self, transaction: &Transaction) -> bool {
102 let calls = self.extract_contract_calls(transaction);
103 self.remaining >= calls && calls > 0
104 }
105
106 pub fn update(&mut self, transaction: &Transaction) {
107 let calls = self.extract_contract_calls(transaction);
108 if calls > 0 && self.remaining >= calls {
109 self.remaining -= calls;
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use ed25519_dalek::SigningKey;
117 use stellar_strkey::{Contract as ContractId, ed25519::PublicKey};
118 use stellar_xdr::curr::{
119 AccountId, Hash, HostFunction, InvokeContractArgs, InvokeHostFunctionOp, Operation,
120 OperationBody, ScAddress, ScSymbol, SorobanAuthorizationEntry, SorobanAuthorizedFunction,
121 SorobanAuthorizedInvocation, SorobanCredentials, VecM,
122 };
123
124 use crate::{
125 Account, AuthorizedCallsForContract, Signer,
126 mock::{mock_contract_id, mock_env, mock_transaction},
127 };
128
129 fn create_invocation(
130 target_address: &ContractId,
131 sub_invocations: Vec<SorobanAuthorizedInvocation>,
132 ) -> SorobanAuthorizedInvocation {
133 SorobanAuthorizedInvocation {
134 function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
135 contract_address: ScAddress::Contract(Hash(target_address.0)),
136 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
137 args: VecM::default(),
138 }),
139 sub_invocations: sub_invocations.try_into().unwrap(),
140 }
141 }
142
143 #[test]
144 fn test_authorized_calls_check_and_update_success() {
145 let signing_key = SigningKey::from_bytes(&[42; 32]);
146 let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
147 let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
148 public_key.0.into(),
149 ));
150 let signer = Signer::new(signing_key);
151 let account = Account::single(signer);
152 let env = mock_env(None, None, None);
153
154 let contract_id = mock_contract_id(account, &env);
155
156 let sub_invocation = create_invocation(&contract_id, vec![]);
157 let root_invocation = create_invocation(
158 &contract_id,
159 vec![sub_invocation.clone(), sub_invocation.clone()],
160 );
161
162 let auth_entry = SorobanAuthorizationEntry {
163 credentials: SorobanCredentials::SourceAccount,
164 root_invocation,
165 };
166 let invoke_op = InvokeHostFunctionOp {
167 host_function: HostFunction::InvokeContract(InvokeContractArgs {
168 contract_address: ScAddress::Contract(Hash(contract_id.0)),
169 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
170 args: VecM::default(),
171 }),
172 auth: vec![auth_entry].try_into().unwrap(),
173 };
174
175 let op = Operation {
176 source_account: None,
177 body: OperationBody::InvokeHostFunction(invoke_op),
178 };
179
180 let transaction = mock_transaction(account_id.clone(), vec![op]);
181
182 let mut guard = AuthorizedCallsForContract {
183 contract_id,
184 remaining: 3,
185 };
186
187 assert_eq!(guard.extract_contract_calls(&transaction), 3);
188 assert!(guard.check(&transaction));
189 guard.update(&transaction);
190 assert_eq!(guard.remaining, 0);
191 }
192
193 #[test]
194 fn test_authorized_calls_for_contract_check_and_update_fail() {
195 let signing_key = SigningKey::from_bytes(&[42; 32]);
196 let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
197 let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
198 public_key.0.into(),
199 ));
200 let signer = Signer::new(signing_key);
201 let account = Account::single(signer);
202 let env = mock_env(None, None, None);
203
204 let contract_id = mock_contract_id(account, &env);
205
206 let sub_invocation = create_invocation(&contract_id, vec![]);
207 let root_invocation = create_invocation(&contract_id, vec![sub_invocation.clone()]);
208
209 let auth_entry = SorobanAuthorizationEntry {
210 credentials: SorobanCredentials::SourceAccount,
211 root_invocation,
212 };
213 let invoke_op = InvokeHostFunctionOp {
214 host_function: HostFunction::InvokeContract(InvokeContractArgs {
215 contract_address: ScAddress::Contract(Hash(contract_id.0)),
216 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
217 args: VecM::default(),
218 }),
219 auth: vec![auth_entry].try_into().unwrap(),
220 };
221
222 let op = Operation {
223 source_account: None,
224 body: OperationBody::InvokeHostFunction(invoke_op),
225 };
226
227 let transaction = mock_transaction(account_id.clone(), vec![op]);
228
229 let guard = AuthorizedCallsForContract {
230 contract_id,
231 remaining: 1,
232 };
233
234 assert_eq!(guard.extract_contract_calls(&transaction), 2);
235 assert!(!guard.check(&transaction));
236 }
237}