Skip to main content

safe_rs/
wallet.rs

1//! Generic Wallet type for Safe and EOA accounts
2//!
3//! This module provides a `Wallet<A>` type that wraps any account implementing
4//! the `Account` trait, enabling generic code that works with both Safe and EOA.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use safe_rs::{WalletBuilder, WalletConfig, Account};
10//!
11//! // Connect to a Safe wallet using the fluent builder API
12//! let safe_wallet = WalletBuilder::new(provider, signer)
13//!     .connect(safe_address)
14//!     .await?;
15//!
16//! // Connect to an EOA wallet
17//! let eoa_wallet = WalletBuilder::new(provider, signer)
18//!     .connect_eoa()
19//!     .await?;
20//!
21//! // Deploy a new Safe and connect to it
22//! let builder = WalletBuilder::new(provider, signer);
23//! let address = builder.deploy(rpc_url, config.clone()).await?;
24//! let wallet = builder.connect(address).await?;
25//!
26//! // Generic function that works with any account type
27//! async fn do_something<A: Account>(wallet: &Wallet<A>) -> Result<()> {
28//!     wallet.batch()
29//!         .add_typed(token, IERC20::transferCall { to: recipient, amount })
30//!         .simulate().await?
31//!         .execute().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use alloy::network::{AnyNetwork, EthereumWallet};
37use alloy::primitives::{Address, Bytes, U256};
38use alloy::providers::{Provider, ProviderBuilder};
39use alloy::signers::local::PrivateKeySigner;
40use url::Url;
41
42use crate::account::Account;
43use crate::chain::{ChainAddresses, ChainConfig};
44use crate::create2::{compute_create2_address, encode_setup_call};
45use crate::eoa::Eoa;
46use crate::error::{Error, Result};
47use crate::safe::{is_safe, ExecutionResult, Safe};
48use crate::types::Operation;
49use crate::ISafeProxyFactory;
50
51/// Configuration for Safe address computation and deployment
52#[derive(Debug, Clone)]
53pub struct WalletConfig {
54    /// Salt nonce for CREATE2 address computation (default: 0)
55    pub salt_nonce: U256,
56    /// Additional owners beyond the signer (default: empty)
57    pub additional_owners: Vec<Address>,
58    /// Threshold for the Safe (default: 1)
59    pub threshold: u64,
60    /// Fallback handler address (default: v1.4.1 fallback handler)
61    pub fallback_handler: Option<Address>,
62}
63
64impl Default for WalletConfig {
65    fn default() -> Self {
66        Self {
67            salt_nonce: U256::ZERO,
68            additional_owners: Vec::new(),
69            threshold: 1,
70            fallback_handler: None,
71        }
72    }
73}
74
75impl WalletConfig {
76    /// Creates a new WalletConfig with default values
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Sets the salt nonce for CREATE2 address computation
82    pub fn with_salt_nonce(mut self, salt_nonce: U256) -> Self {
83        self.salt_nonce = salt_nonce;
84        self
85    }
86
87    /// Sets additional owners beyond the signer
88    pub fn with_additional_owners(mut self, owners: Vec<Address>) -> Self {
89        self.additional_owners = owners;
90        self
91    }
92
93    /// Sets the threshold for the Safe
94    pub fn with_threshold(mut self, threshold: u64) -> Self {
95        self.threshold = threshold;
96        self
97    }
98
99    /// Sets a custom fallback handler
100    pub fn with_fallback_handler(mut self, handler: Address) -> Self {
101        self.fallback_handler = Some(handler);
102        self
103    }
104
105    /// Builds the owners array (signer + additional owners)
106    fn build_owners(&self, signer_address: Address) -> Vec<Address> {
107        let mut owners = vec![signer_address];
108        for owner in &self.additional_owners {
109            if !owners.contains(owner) {
110                owners.push(*owner);
111            }
112        }
113        owners
114    }
115
116    /// Gets the fallback handler, using the v1.4.1 default if not specified
117    fn get_fallback_handler(&self) -> Address {
118        self.fallback_handler
119            .unwrap_or_else(|| ChainAddresses::v1_4_1().fallback_handler)
120    }
121}
122
123// =============================================================================
124// WalletBuilder
125// =============================================================================
126
127/// Builder for creating wallets with a fluent API.
128///
129/// This builder holds the provider and signer, allowing you to:
130/// - Connect to an existing Safe at a known address
131/// - Connect to a Safe at a computed CREATE2 address
132/// - Connect as an EOA (no Safe)
133/// - Deploy a new Safe and then connect to it
134///
135/// # Example
136///
137/// ```rust,ignore
138/// // Connect to existing Safe
139/// let wallet = WalletBuilder::new(provider, signer)
140///     .connect(address)
141///     .await?;
142///
143/// // Deploy then connect (builder not consumed by deploy)
144/// let builder = WalletBuilder::new(provider, signer);
145/// let address = builder.deploy(rpc_url, config.clone()).await?;
146/// let wallet = builder.connect(address).await?;
147/// ```
148pub struct WalletBuilder<P> {
149    provider: P,
150    signer: PrivateKeySigner,
151}
152
153impl<P> WalletBuilder<P> {
154    /// Creates a new WalletBuilder with the given provider and signer.
155    pub fn new(provider: P, signer: PrivateKeySigner) -> Self {
156        Self { provider, signer }
157    }
158
159    /// Returns a reference to the signer.
160    pub fn signer(&self) -> &PrivateKeySigner {
161        &self.signer
162    }
163
164    /// Returns the signer's address.
165    pub fn signer_address(&self) -> Address {
166        self.signer.address()
167    }
168
169    /// Returns a reference to the provider.
170    pub fn provider(&self) -> &P {
171        &self.provider
172    }
173}
174
175impl<P> WalletBuilder<P>
176where
177    P: Provider<AnyNetwork> + Clone + 'static,
178{
179    /// Connects to an existing Safe at the given address.
180    ///
181    /// # Arguments
182    /// * `address` - The Safe contract address
183    ///
184    /// # Example
185    ///
186    /// ```rust,ignore
187    /// let wallet = WalletBuilder::new(provider, signer)
188    ///     .connect(safe_address)
189    ///     .await?;
190    /// ```
191    pub async fn connect(self, address: Address) -> Result<Wallet<Safe<P>>> {
192        let safe = Safe::connect(self.provider, self.signer, address).await?;
193        Ok(Wallet::from_account(safe))
194    }
195
196    /// Connects to a Safe at the computed CREATE2 address for the given config.
197    ///
198    /// This computes the deterministic Safe address based on the signer and config,
199    /// then connects to it. Returns an error if no Safe is deployed at that address.
200    ///
201    /// # Arguments
202    /// * `config` - Configuration for Safe address computation
203    ///
204    /// # Example
205    ///
206    /// ```rust,ignore
207    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
208    /// let wallet = WalletBuilder::new(provider, signer)
209    ///     .connect_with_config(config)
210    ///     .await?;
211    /// ```
212    pub async fn connect_with_config(self, config: WalletConfig) -> Result<Wallet<Safe<P>>> {
213        let safe_address = self.compute_address(&config).await?;
214
215        // Check if Safe is deployed
216        if !is_safe(&self.provider, safe_address).await? {
217            return Err(Error::InvalidConfig(format!(
218                "No Safe deployed at computed address {}",
219                safe_address
220            )));
221        }
222
223        let safe = Safe::connect(self.provider, self.signer, safe_address).await?;
224        Ok(Wallet::from_account(safe))
225    }
226
227    /// Connects as an EOA (no Safe).
228    ///
229    /// # Example
230    ///
231    /// ```rust,ignore
232    /// let wallet = WalletBuilder::new(provider, signer)
233    ///     .connect_eoa()
234    ///     .await?;
235    /// ```
236    pub async fn connect_eoa(self) -> Result<Wallet<Eoa<P>>> {
237        let eoa = Eoa::connect(self.provider, self.signer).await?;
238        Ok(Wallet::from_account(eoa))
239    }
240
241    /// Computes the Safe address that would be used for the given config.
242    ///
243    /// This is useful for checking what Safe address would be computed without
244    /// actually connecting or deploying.
245    ///
246    /// # Arguments
247    /// * `config` - Configuration for Safe address computation
248    ///
249    /// # Example
250    ///
251    /// ```rust,ignore
252    /// let builder = WalletBuilder::new(provider, signer);
253    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
254    /// let address = builder.compute_address(&config).await?;
255    /// ```
256    pub async fn compute_address(&self, config: &WalletConfig) -> Result<Address> {
257        let addresses = ChainAddresses::v1_4_1();
258        let signer_address = self.signer.address();
259
260        // Build owners array
261        let owners = config.build_owners(signer_address);
262
263        // Get fallback handler
264        let fallback_handler = config.get_fallback_handler();
265
266        // Encode initializer
267        let initializer = encode_setup_call(&owners, config.threshold, fallback_handler);
268
269        // Get proxy creation code
270        let factory = ISafeProxyFactory::new(addresses.proxy_factory, &self.provider);
271        let creation_code = factory
272            .proxyCreationCode()
273            .call()
274            .await
275            .map_err(|e| Error::Fetch {
276                what: "proxy creation code",
277                reason: e.to_string(),
278            })?;
279
280        // Compute deterministic address
281        let safe_address = compute_create2_address(
282            addresses.proxy_factory,
283            addresses.safe_singleton,
284            &initializer,
285            config.salt_nonce,
286            &creation_code,
287        );
288
289        Ok(safe_address)
290    }
291
292    /// Deploys a Safe with the given configuration. Idempotent.
293    ///
294    /// If a Safe already exists at the computed address, returns that address
295    /// without deploying. Otherwise, deploys a new Safe.
296    ///
297    /// Uses `&self` so the builder can be reused for `connect()` afterward.
298    ///
299    /// # Arguments
300    /// * `rpc_url` - The RPC URL for sending the deployment transaction
301    /// * `config` - Configuration for Safe deployment
302    ///
303    /// # Example
304    ///
305    /// ```rust,ignore
306    /// let builder = WalletBuilder::new(provider, signer);
307    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
308    /// let address = builder.deploy(rpc_url, config.clone()).await?;
309    /// let wallet = builder.connect(address).await?;
310    /// ```
311    pub async fn deploy(&self, rpc_url: Url, config: WalletConfig) -> Result<Address> {
312        let addresses = ChainAddresses::v1_4_1();
313        let signer_address = self.signer.address();
314
315        // Build owners array
316        let owners = config.build_owners(signer_address);
317
318        // Validate threshold
319        if config.threshold == 0 || config.threshold as usize > owners.len() {
320            return Err(Error::InvalidConfig(format!(
321                "Invalid threshold: {} (must be 1-{})",
322                config.threshold,
323                owners.len()
324            )));
325        }
326
327        // Get fallback handler
328        let fallback_handler = config.get_fallback_handler();
329
330        // Encode initializer
331        let initializer = encode_setup_call(&owners, config.threshold, fallback_handler);
332
333        // Get proxy creation code
334        let factory = ISafeProxyFactory::new(addresses.proxy_factory, &self.provider);
335        let creation_code = factory
336            .proxyCreationCode()
337            .call()
338            .await
339            .map_err(|e| Error::Fetch {
340                what: "proxy creation code",
341                reason: e.to_string(),
342            })?;
343
344        // Compute deterministic address
345        let safe_address = compute_create2_address(
346            addresses.proxy_factory,
347            addresses.safe_singleton,
348            &initializer,
349            config.salt_nonce,
350            &creation_code,
351        );
352
353        // Check if Safe is already deployed
354        if is_safe(&self.provider, safe_address).await? {
355            return Ok(safe_address);
356        }
357
358        // Deploy the Safe
359        let wallet_provider = ProviderBuilder::new()
360            .network::<AnyNetwork>()
361            .wallet(EthereumWallet::from(self.signer.clone()))
362            .connect_http(rpc_url);
363
364        let factory_with_wallet = ISafeProxyFactory::new(addresses.proxy_factory, &wallet_provider);
365
366        let pending_tx = factory_with_wallet
367            .createProxyWithNonce(addresses.safe_singleton, initializer, config.salt_nonce)
368            .send()
369            .await
370            .map_err(|e| Error::ExecutionFailed {
371                reason: format!("Failed to send deployment transaction: {}", e),
372            })?;
373
374        let _receipt = pending_tx.get_receipt().await.map_err(|e| Error::ExecutionFailed {
375            reason: format!("Failed to get deployment receipt: {}", e),
376        })?;
377
378        // Verify deployment
379        if !is_safe(&self.provider, safe_address).await? {
380            return Err(Error::ExecutionFailed {
381                reason: format!("Deployment failed: no Safe at expected address {}", safe_address),
382            });
383        }
384
385        Ok(safe_address)
386    }
387}
388
389/// A wallet that wraps any account type implementing the `Account` trait.
390///
391/// This provides a unified interface for both Safe and EOA wallets with
392/// compile-time polymorphism.
393///
394/// # Type Parameters
395///
396/// * `A` - The account type (e.g., `Safe<P>` or `Eoa<P>`)
397///
398/// # Example
399///
400/// ```rust,ignore
401/// // Connect to a Safe using the fluent builder API
402/// let wallet = WalletBuilder::new(provider, signer)
403///     .connect(safe_address)
404///     .await?;
405///
406/// // Use the unified batch API
407/// wallet.batch()
408///     .add_typed(token, call)
409///     .execute().await?;
410/// ```
411pub struct Wallet<A: Account> {
412    account: A,
413}
414
415impl<A: Account> Wallet<A> {
416    /// Creates a new wallet wrapping the given account.
417    pub fn from_account(account: A) -> Self {
418        Self { account }
419    }
420
421    /// Returns the wallet's address.
422    ///
423    /// For Safe wallets, returns the Safe contract address.
424    /// For EOA wallets, returns the signer address.
425    pub fn address(&self) -> Address {
426        self.account.address()
427    }
428
429    /// Returns the underlying signer address.
430    ///
431    /// For Safe wallets, this is the owner/signer address.
432    /// For EOA wallets, this is the same as `address()`.
433    pub fn signer_address(&self) -> Address {
434        self.account.signer_address()
435    }
436
437    /// Returns a reference to the provider.
438    pub fn provider(&self) -> &A::Provider {
439        self.account.provider()
440    }
441
442    /// Returns the chain configuration.
443    pub fn config(&self) -> &ChainConfig {
444        self.account.config()
445    }
446
447    /// Gets the current nonce for the account.
448    ///
449    /// For Safe wallets, this is the Safe's internal nonce.
450    /// For EOA wallets, this is the account's transaction count.
451    pub async fn nonce(&self) -> Result<U256> {
452        self.account.nonce().await
453    }
454
455    /// Creates a new builder for batching transactions.
456    ///
457    /// Returns `A::Builder<'_>` which implements `CallBuilder`.
458    ///
459    /// # Example
460    ///
461    /// ```rust,ignore
462    /// wallet.batch()
463    ///     .add_typed(token, IERC20::transferCall { to: recipient, amount })
464    ///     .simulate().await?
465    ///     .execute().await?;
466    /// ```
467    pub fn batch(&self) -> A::Builder<'_> {
468        self.account.batch()
469    }
470
471    /// Executes a single transaction.
472    ///
473    /// This is a convenience method for executing a single call without
474    /// the batch builder. For multiple calls, use `batch()` instead.
475    ///
476    /// # Errors
477    /// Returns `Error::UnsupportedEoaOperation` if `operation` is `DelegateCall`
478    /// and the wallet is an EOA.
479    pub async fn execute_single(
480        &self,
481        to: Address,
482        value: U256,
483        data: Bytes,
484        operation: Operation,
485    ) -> Result<ExecutionResult> {
486        self.account.execute_single(to, value, data, operation).await
487    }
488
489    /// Returns a reference to the underlying account.
490    pub fn inner(&self) -> &A {
491        &self.account
492    }
493
494    /// Consumes the wallet and returns the underlying account.
495    pub fn into_inner(self) -> A {
496        self.account
497    }
498}
499
500// =============================================================================
501// Safe-specific implementation
502// =============================================================================
503
504impl<P> Wallet<Safe<P>>
505where
506    P: Provider<AnyNetwork> + Clone + 'static,
507{
508    /// Returns true (this is a Safe wallet).
509    pub fn is_safe(&self) -> bool {
510        true
511    }
512
513    /// Returns false (this is not an EOA wallet).
514    pub fn is_eoa(&self) -> bool {
515        false
516    }
517
518    /// Returns a reference to the underlying Safe.
519    pub fn safe(&self) -> &Safe<P> {
520        self.inner()
521    }
522}
523
524// =============================================================================
525// EOA-specific implementation
526// =============================================================================
527
528impl<P> Wallet<Eoa<P>>
529where
530    P: Provider<AnyNetwork> + Clone + 'static,
531{
532    /// Returns false (this is not a Safe wallet).
533    pub fn is_safe(&self) -> bool {
534        false
535    }
536
537    /// Returns true (this is an EOA wallet).
538    pub fn is_eoa(&self) -> bool {
539        true
540    }
541
542    /// Returns a reference to the underlying Eoa.
543    pub fn eoa(&self) -> &Eoa<P> {
544        self.inner()
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_wallet_config_default() {
554        let config = WalletConfig::default();
555        assert_eq!(config.salt_nonce, U256::ZERO);
556        assert!(config.additional_owners.is_empty());
557        assert_eq!(config.threshold, 1);
558        assert!(config.fallback_handler.is_none());
559    }
560
561    #[test]
562    fn test_wallet_config_builder() {
563        use alloy::primitives::address;
564
565        let owner2 = address!("2222222222222222222222222222222222222222");
566        let handler = address!("fd0732Dc9E303f09fCEf3a7388Ad10A83459Ec99");
567
568        let config = WalletConfig::new()
569            .with_salt_nonce(U256::from(42))
570            .with_additional_owners(vec![owner2])
571            .with_threshold(2)
572            .with_fallback_handler(handler);
573
574        assert_eq!(config.salt_nonce, U256::from(42));
575        assert_eq!(config.additional_owners, vec![owner2]);
576        assert_eq!(config.threshold, 2);
577        assert_eq!(config.fallback_handler, Some(handler));
578    }
579
580    #[test]
581    fn test_wallet_config_build_owners() {
582        use alloy::primitives::address;
583
584        let signer = address!("1111111111111111111111111111111111111111");
585        let owner2 = address!("2222222222222222222222222222222222222222");
586        let owner3 = address!("3333333333333333333333333333333333333333");
587
588        let config = WalletConfig::new().with_additional_owners(vec![owner2, owner3]);
589        let owners = config.build_owners(signer);
590
591        assert_eq!(owners.len(), 3);
592        assert_eq!(owners[0], signer);
593        assert_eq!(owners[1], owner2);
594        assert_eq!(owners[2], owner3);
595    }
596
597    #[test]
598    fn test_wallet_config_build_owners_no_duplicates() {
599        use alloy::primitives::address;
600
601        let signer = address!("1111111111111111111111111111111111111111");
602        // Include signer in additional owners (should not duplicate)
603        let config = WalletConfig::new().with_additional_owners(vec![signer]);
604        let owners = config.build_owners(signer);
605
606        assert_eq!(owners.len(), 1);
607        assert_eq!(owners[0], signer);
608    }
609
610    #[test]
611    fn test_wallet_config_get_fallback_handler_default() {
612        let config = WalletConfig::default();
613        let handler = config.get_fallback_handler();
614        assert_eq!(handler, ChainAddresses::v1_4_1().fallback_handler);
615    }
616
617    #[test]
618    fn test_wallet_config_get_fallback_handler_custom() {
619        use alloy::primitives::address;
620
621        let custom_handler = address!("dead000000000000000000000000000000000000");
622        let config = WalletConfig::new().with_fallback_handler(custom_handler);
623        let handler = config.get_fallback_handler();
624        assert_eq!(handler, custom_handler);
625    }
626}