1use stellar_strkey::Contract as ContractId;
18use stellar_xdr::curr::{
19 OperationBody, SorobanAuthorizedFunction, SorobanAuthorizedInvocation, Transaction,
20};
21
22use crate::SorobanHelperError;
23
24#[derive(Clone)]
25pub enum Guard {
26 NumberOfAllowedCalls(u16),
29 AuthorizedCallsFor(AuthorizedCallsForContract),
30 }
32
33impl Guard {
34 pub fn check(&self, transaction: &Transaction) -> Result<bool, SorobanHelperError> {
40 match self {
41 Guard::NumberOfAllowedCalls(remaining) => Ok(*remaining > 0),
42 Guard::AuthorizedCallsFor(calls_for_contract) => {
43 Ok(calls_for_contract.check(transaction))
44 } }
46 }
47
48 pub fn update(&mut self, transaction: &Transaction) -> Result<(), SorobanHelperError> {
53 match self {
54 Guard::NumberOfAllowedCalls(remaining) => {
55 if *remaining > 0 {
56 *remaining -= 1;
57 }
58 Ok(())
59 }
60 Guard::AuthorizedCallsFor(calls_for_contract) => {
61 calls_for_contract.update(transaction);
62 Ok(())
63 } }
65 }
66}
67
68#[derive(Clone)]
69pub struct AuthorizedCallsForContract {
70 pub contract_id: ContractId,
71 pub remaining: u16,
72}
73
74impl AuthorizedCallsForContract {
75 fn count_authorized_calls(&self, invocation: &SorobanAuthorizedInvocation) -> u16 {
76 let mut count = 0;
77 if let SorobanAuthorizedFunction::ContractFn(args) = &invocation.function {
78 if args.contract_address.to_string() == self.contract_id.to_string() {
79 count += 1;
80 }
81 }
82 for sub_invocation in invocation.sub_invocations.iter() {
84 count += self.count_authorized_calls(sub_invocation);
85 }
86 count
87 }
88
89 fn extract_contract_calls(&self, tx: &Transaction) -> u16 {
90 let mut calls = 0;
91 for op in tx.operations.iter() {
92 if let OperationBody::InvokeHostFunction(invoke_op) = &op.body {
93 for auth_entry in invoke_op.auth.iter() {
94 calls += self.count_authorized_calls(&auth_entry.root_invocation);
95 }
96 }
97 }
98 calls
99 }
100
101 pub fn check(&self, transaction: &Transaction) -> bool {
102 let calls = self.extract_contract_calls(transaction);
103 self.remaining >= calls && calls > 0
104 }
105
106 pub fn update(&mut self, transaction: &Transaction) {
107 let calls = self.extract_contract_calls(transaction);
108 if calls > 0 && self.remaining >= calls {
109 self.remaining -= calls;
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use ed25519_dalek::SigningKey;
117 use stellar_strkey::{ed25519::PublicKey, Contract as ContractId};
118 use stellar_xdr::curr::{
119 AccountId, Hash, HostFunction, InvokeContractArgs, InvokeHostFunctionOp, Operation,
120 OperationBody, ScAddress, ScSymbol, SorobanAuthorizationEntry, SorobanAuthorizedFunction,
121 SorobanAuthorizedInvocation, SorobanCredentials, VecM,
122 };
123
124 use crate::{
125 mock::{mock_contract_id, mock_env, mock_transaction},
126 Account, AuthorizedCallsForContract, Signer,
127 };
128
129 fn create_invocation(
130 target_address: &ContractId,
131 sub_invocations: Vec<SorobanAuthorizedInvocation>,
132 ) -> SorobanAuthorizedInvocation {
133 SorobanAuthorizedInvocation {
134 function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
135 contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(
136 target_address.0,
137 ))),
138 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
139 args: VecM::default(),
140 }),
141 sub_invocations: sub_invocations.try_into().unwrap(),
142 }
143 }
144
145 #[test]
146 fn test_authorized_calls_check_and_update_success() {
147 let signing_key = SigningKey::from_bytes(&[42; 32]);
148 let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
149 let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
150 public_key.0.into(),
151 ));
152 let signer = Signer::new(signing_key);
153 let account = Account::single(signer);
154 let env = mock_env(None, None, None);
155
156 let contract_id = mock_contract_id(account, &env);
157
158 let sub_invocation = create_invocation(&contract_id, vec![]);
159 let root_invocation = create_invocation(
160 &contract_id,
161 vec![sub_invocation.clone(), sub_invocation.clone()],
162 );
163
164 let auth_entry = SorobanAuthorizationEntry {
165 credentials: SorobanCredentials::SourceAccount,
166 root_invocation,
167 };
168 let invoke_op = InvokeHostFunctionOp {
169 host_function: HostFunction::InvokeContract(InvokeContractArgs {
170 contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(
171 contract_id.0,
172 ))),
173 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
174 args: VecM::default(),
175 }),
176 auth: vec![auth_entry].try_into().unwrap(),
177 };
178
179 let op = Operation {
180 source_account: None,
181 body: OperationBody::InvokeHostFunction(invoke_op),
182 };
183
184 let transaction = mock_transaction(account_id.clone(), vec![op]);
185
186 let mut guard = AuthorizedCallsForContract {
187 contract_id,
188 remaining: 3,
189 };
190
191 assert_eq!(guard.extract_contract_calls(&transaction), 3);
192 assert!(guard.check(&transaction));
193 guard.update(&transaction);
194 assert_eq!(guard.remaining, 0);
195 }
196
197 #[test]
198 fn test_authorized_calls_for_contract_check_and_update_fail() {
199 let signing_key = SigningKey::from_bytes(&[42; 32]);
200 let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
201 let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
202 public_key.0.into(),
203 ));
204 let signer = Signer::new(signing_key);
205 let account = Account::single(signer);
206 let env = mock_env(None, None, None);
207
208 let contract_id = mock_contract_id(account, &env);
209
210 let sub_invocation = create_invocation(&contract_id, vec![]);
211 let root_invocation = create_invocation(&contract_id, vec![sub_invocation.clone()]);
212
213 let auth_entry = SorobanAuthorizationEntry {
214 credentials: SorobanCredentials::SourceAccount,
215 root_invocation,
216 };
217 let invoke_op = InvokeHostFunctionOp {
218 host_function: HostFunction::InvokeContract(InvokeContractArgs {
219 contract_address: ScAddress::Contract(stellar_xdr::curr::ContractId(Hash(
220 contract_id.0,
221 ))),
222 function_name: ScSymbol("dummy_fn".try_into().unwrap()),
223 args: VecM::default(),
224 }),
225 auth: vec![auth_entry].try_into().unwrap(),
226 };
227
228 let op = Operation {
229 source_account: None,
230 body: OperationBody::InvokeHostFunction(invoke_op),
231 };
232
233 let transaction = mock_transaction(account_id.clone(), vec![op]);
234
235 let guard = AuthorizedCallsForContract {
236 contract_id,
237 remaining: 1,
238 };
239
240 assert_eq!(guard.extract_contract_calls(&transaction), 2);
241 assert!(!guard.check(&transaction));
242 }
243}