soroban_rs/
guard.rs

1//! # Soroban Account Guards
2//! Represents a guard mechanism that can
3//! be used to control and limit operations.
4//!
5//! ## Example
6//!
7//! ```rust,no_run
8//! use soroban_rs::{Env, Signer, Account, Guard};
9//! use ed25519_dalek::SigningKey;
10//!
11//! async fn example(signing_key: SigningKey) {
12//!     let mut account = Account::single(Signer::new(signing_key));
13//!     let guard = Guard::NumberOfAllowedCalls(3);
14//!     account.add_guard(guard);
15//! }
16//! ```
17use stellar_strkey::Contract as ContractId;
18use stellar_xdr::curr::{
19    OperationBody, SorobanAuthorizedFunction, SorobanAuthorizedInvocation, Transaction,
20};
21
22use crate::SorobanHelperError;
23
24#[derive(Clone)]
25pub enum Guard {
26    /// Limits the number of allowed calls to a specific operation.
27    /// The u16 value represents the remaining number of calls allowed.
28    NumberOfAllowedCalls(u16),
29    AuthorizedCallsFor(AuthorizedCallsForContract),
30    // ... other variants
31}
32
33impl Guard {
34    /// Checks if the guard condition is satisfied.
35    ///
36    /// # Returns
37    /// * `true` if the operation is allowed to proceed
38    /// * `false` if the operation should be blocked
39    pub fn check(&self, transaction: &Transaction) -> Result<bool, SorobanHelperError> {
40        match self {
41            Guard::NumberOfAllowedCalls(remaining) => Ok(*remaining > 0),
42            Guard::AuthorizedCallsFor(calls_for_contract) => {
43                Ok(calls_for_contract.check(transaction))
44            } // handle other variants
45        }
46    }
47
48    /// Updates the guard state after an operation has been performed.
49    ///
50    /// This method should be called after a successful operation to update
51    /// the internal state of the guard (e.g., decrement remaining allowed calls).
52    pub fn update(&mut self, transaction: &Transaction) -> Result<(), SorobanHelperError> {
53        match self {
54            Guard::NumberOfAllowedCalls(remaining) => {
55                if *remaining > 0 {
56                    *remaining -= 1;
57                }
58                Ok(())
59            }
60            Guard::AuthorizedCallsFor(calls_for_contract) => {
61                calls_for_contract.update(transaction);
62                Ok(())
63            } // handle other variants
64        }
65    }
66}
67
68#[derive(Clone)]
69pub struct AuthorizedCallsForContract {
70    pub contract_id: ContractId,
71    pub remaining: u16,
72}
73
74impl AuthorizedCallsForContract {
75    fn count_authorized_calls(&self, invocation: &SorobanAuthorizedInvocation) -> u16 {
76        let mut count = 0;
77        if let SorobanAuthorizedFunction::ContractFn(args) = &invocation.function {
78            if args.contract_address.to_string() == self.contract_id.to_string() {
79                count += 1;
80            }
81        }
82        // visit all nodes in the tree of invocations.
83        for sub_invocation in invocation.sub_invocations.iter() {
84            count += self.count_authorized_calls(sub_invocation);
85        }
86        count
87    }
88
89    fn extract_contract_calls(&self, tx: &Transaction) -> u16 {
90        let mut calls = 0;
91        for op in tx.operations.iter() {
92            if let OperationBody::InvokeHostFunction(invoke_op) = &op.body {
93                for auth_entry in invoke_op.auth.iter() {
94                    calls += self.count_authorized_calls(&auth_entry.root_invocation);
95                }
96            }
97        }
98        calls
99    }
100
101    pub fn check(&self, transaction: &Transaction) -> bool {
102        let calls = self.extract_contract_calls(transaction);
103        self.remaining >= calls && calls > 0
104    }
105
106    pub fn update(&mut self, transaction: &Transaction) {
107        let calls = self.extract_contract_calls(transaction);
108        if calls > 0 && self.remaining >= calls {
109            self.remaining -= calls;
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use ed25519_dalek::SigningKey;
117    use stellar_strkey::{ed25519::PublicKey, Contract as ContractId};
118    use stellar_xdr::curr::{
119        AccountId, Hash, HostFunction, InvokeContractArgs, InvokeHostFunctionOp, Operation,
120        OperationBody, ScAddress, ScSymbol, SorobanAuthorizationEntry, SorobanAuthorizedFunction,
121        SorobanAuthorizedInvocation, SorobanCredentials, VecM,
122    };
123
124    use crate::{
125        mock::{mock_contract_id, mock_env, mock_transaction},
126        Account, AuthorizedCallsForContract, Signer,
127    };
128
129    fn create_invocation(
130        target_address: &ContractId,
131        sub_invocations: Vec<SorobanAuthorizedInvocation>,
132    ) -> SorobanAuthorizedInvocation {
133        SorobanAuthorizedInvocation {
134            function: SorobanAuthorizedFunction::ContractFn(InvokeContractArgs {
135                contract_address: ScAddress::Contract(Hash(target_address.0)),
136                function_name: ScSymbol("dummy_fn".try_into().unwrap()),
137                args: VecM::default(),
138            }),
139            sub_invocations: sub_invocations.try_into().unwrap(),
140        }
141    }
142
143    #[test]
144    fn test_authorized_calls_check_and_update_success() {
145        let signing_key = SigningKey::from_bytes(&[42; 32]);
146        let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
147        let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
148            public_key.0.into(),
149        ));
150        let signer = Signer::new(signing_key);
151        let account = Account::single(signer);
152        let env = mock_env(None, None, None);
153
154        let contract_id = mock_contract_id(account, &env);
155
156        let sub_invocation = create_invocation(&contract_id, vec![]);
157        let root_invocation = create_invocation(
158            &contract_id,
159            vec![sub_invocation.clone(), sub_invocation.clone()],
160        );
161
162        let auth_entry = SorobanAuthorizationEntry {
163            credentials: SorobanCredentials::SourceAccount,
164            root_invocation,
165        };
166        let invoke_op = InvokeHostFunctionOp {
167            host_function: HostFunction::InvokeContract(InvokeContractArgs {
168                contract_address: ScAddress::Contract(Hash(contract_id.0)),
169                function_name: ScSymbol("dummy_fn".try_into().unwrap()),
170                args: VecM::default(),
171            }),
172            auth: vec![auth_entry].try_into().unwrap(),
173        };
174
175        let op = Operation {
176            source_account: None,
177            body: OperationBody::InvokeHostFunction(invoke_op),
178        };
179
180        let transaction = mock_transaction(account_id.clone(), vec![op]);
181
182        let mut guard = AuthorizedCallsForContract {
183            contract_id,
184            remaining: 3,
185        };
186
187        assert_eq!(guard.extract_contract_calls(&transaction), 3);
188        assert!(guard.check(&transaction));
189        guard.update(&transaction);
190        assert_eq!(guard.remaining, 0);
191    }
192
193    #[test]
194    fn test_authorized_calls_for_contract_check_and_update_fail() {
195        let signing_key = SigningKey::from_bytes(&[42; 32]);
196        let public_key = PublicKey(*signing_key.verifying_key().as_bytes());
197        let account_id = AccountId(stellar_xdr::curr::PublicKey::PublicKeyTypeEd25519(
198            public_key.0.into(),
199        ));
200        let signer = Signer::new(signing_key);
201        let account = Account::single(signer);
202        let env = mock_env(None, None, None);
203
204        let contract_id = mock_contract_id(account, &env);
205
206        let sub_invocation = create_invocation(&contract_id, vec![]);
207        let root_invocation = create_invocation(&contract_id, vec![sub_invocation.clone()]);
208
209        let auth_entry = SorobanAuthorizationEntry {
210            credentials: SorobanCredentials::SourceAccount,
211            root_invocation,
212        };
213        let invoke_op = InvokeHostFunctionOp {
214            host_function: HostFunction::InvokeContract(InvokeContractArgs {
215                contract_address: ScAddress::Contract(Hash(contract_id.0)),
216                function_name: ScSymbol("dummy_fn".try_into().unwrap()),
217                args: VecM::default(),
218            }),
219            auth: vec![auth_entry].try_into().unwrap(),
220        };
221
222        let op = Operation {
223            source_account: None,
224            body: OperationBody::InvokeHostFunction(invoke_op),
225        };
226
227        let transaction = mock_transaction(account_id.clone(), vec![op]);
228
229        let guard = AuthorizedCallsForContract {
230            contract_id,
231            remaining: 1,
232        };
233
234        assert_eq!(guard.extract_contract_calls(&transaction), 2);
235        assert!(!guard.check(&transaction));
236    }
237}