safe_rs/simulation/
fork.rs

1//! Fork database and revm simulation
2
3use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use alloy::network::AnyNetwork;
7use alloy::primitives::{Address, Bytes, Log, TxKind, B256, U256};
8use alloy::providers::Provider;
9use alloy::rpc::types::trace::geth::pre_state::{AccountState, DiffMode};
10use foundry_fork_db::{cache::BlockchainDbMeta, BlockchainDb, SharedBackend};
11use revm::context::TxEnv;
12use revm::database::CacheDB;
13use revm::primitives::hardfork::SpecId;
14use revm::state::{AccountInfo, EvmState};
15use revm::{Context, ExecuteEvm, MainBuilder, MainContext};
16
17use crate::error::{Error, Result};
18use crate::types::Operation;
19
20/// Result of a simulated transaction
21#[derive(Debug, Clone)]
22pub struct SimulationResult {
23    /// Whether the simulation succeeded
24    pub success: bool,
25    /// Gas used during simulation
26    pub gas_used: u64,
27    /// Return data from the call
28    pub return_data: Bytes,
29    /// Logs emitted during simulation
30    pub logs: Vec<Log>,
31    /// Revert reason if the call reverted
32    pub revert_reason: Option<String>,
33    /// State changes from simulation (pre/post state for touched accounts)
34    pub state_diff: DiffMode,
35}
36
37impl SimulationResult {
38    /// Returns true if the simulation was successful
39    pub fn is_success(&self) -> bool {
40        self.success
41    }
42
43    /// Returns the revert reason if available
44    pub fn error_message(&self) -> Option<&str> {
45        self.revert_reason.as_deref()
46    }
47}
48
49/// Builds a state diff from REVM's execution state
50///
51/// REVM tracks original values in `Account.original_info` and `EvmStorageSlot.original_value`,
52/// so we can reconstruct both pre and post state from the final state.
53fn build_state_diff(state: &EvmState) -> DiffMode {
54    let mut pre = BTreeMap::new();
55    let mut post = BTreeMap::new();
56
57    for (address, account) in state.iter() {
58        // Skip if account wasn't touched
59        if !account.is_touched() {
60            continue;
61        }
62
63        // Build storage diffs - only include changed slots
64        let mut pre_storage = BTreeMap::new();
65        let mut post_storage = BTreeMap::new();
66
67        for (key, slot) in account.storage.iter() {
68            if slot.is_changed() {
69                pre_storage.insert(B256::from(*key), B256::from(slot.original_value));
70                post_storage.insert(B256::from(*key), B256::from(slot.present_value));
71            }
72        }
73
74        // Build pre-state from original_info
75        let pre_state = AccountState {
76            balance: Some(account.original_info.balance),
77            nonce: Some(account.original_info.nonce),
78            code: account
79                .original_info
80                .code
81                .as_ref()
82                .map(|c| Bytes::from(c.original_bytes().to_vec())),
83            storage: pre_storage,
84        };
85
86        // Build post-state from current info
87        let post_state = AccountState {
88            balance: Some(account.info.balance),
89            nonce: Some(account.info.nonce),
90            code: account
91                .info
92                .code
93                .as_ref()
94                .map(|c| Bytes::from(c.original_bytes().to_vec())),
95            storage: post_storage,
96        };
97
98        pre.insert(*address, pre_state);
99        post.insert(*address, post_state);
100    }
101
102    DiffMode { pre, post }
103}
104
105/// Fork simulator for executing transactions against a forked state
106pub struct ForkSimulator<P> {
107    provider: P,
108    chain_id: u64,
109    block_number: Option<u64>,
110}
111
112impl<P> ForkSimulator<P>
113where
114    P: Provider<AnyNetwork> + Clone + 'static,
115{
116    /// Creates a new fork simulator
117    pub fn new(provider: P, chain_id: u64) -> Self {
118        Self {
119            provider,
120            chain_id,
121            block_number: None,
122        }
123    }
124
125    /// Sets the block number to fork from
126    pub fn at_block(mut self, block: u64) -> Self {
127        self.block_number = Some(block);
128        self
129    }
130
131    /// Creates a forked database from the current provider state
132    pub async fn create_fork_db(&self) -> Result<CacheDB<SharedBackend>> {
133        let block = match self.block_number {
134            Some(b) => b,
135            None => self
136                .provider
137                .get_block_number()
138                .await
139                .map_err(|e| Error::ForkDb(e.to_string()))?,
140        };
141
142        let meta = BlockchainDbMeta::new(
143            Default::default(), // empty known contracts
144            format!("fork-{}", self.chain_id),
145        );
146
147        let db = BlockchainDb::new(meta, None);
148        let backend = SharedBackend::spawn_backend_thread(
149            Arc::new(self.provider.clone()),
150            db,
151            Some(block.into()),
152        );
153
154        Ok(CacheDB::new(backend))
155    }
156
157    /// Simulates a call from the Safe
158    pub async fn simulate_call(
159        &self,
160        from: Address,
161        to: Address,
162        value: U256,
163        data: Bytes,
164        operation: Operation,
165    ) -> Result<SimulationResult> {
166        let mut db = self.create_fork_db().await?;
167
168        // Set a high balance for the caller to ensure the call can proceed
169        let caller_info = AccountInfo::default();
170        db.insert_account_info(from, caller_info);
171
172        // Update the balance separately
173        if let Some(account) = db.cache.accounts.get_mut(&from) {
174            account.info.balance = U256::from(1_000_000_000_000_000_000_000u128); // 1000 ETH
175        }
176
177        // Determine the actual call target and calldata
178        let (call_to, call_data) = match operation {
179            Operation::Call => (to, data.to_vec()),
180            Operation::DelegateCall => {
181                // For delegatecall simulation, we execute directly from the Safe
182                // This is a simplification - in reality the Safe would delegatecall
183                (to, data.to_vec())
184            }
185        };
186
187        let tx = TxEnv {
188            caller: from,
189            gas_limit: 30_000_000,
190            gas_price: 0,
191            kind: TxKind::Call(call_to),
192            value,
193            data: call_data.into(),
194            nonce: 0,
195            chain_id: Some(self.chain_id),
196            ..Default::default()
197        };
198
199        // Build the EVM context
200        let ctx = Context::mainnet()
201            .with_db(db)
202            .modify_cfg_chained(|cfg| {
203                cfg.spec = SpecId::CANCUN;
204                cfg.chain_id = self.chain_id;
205            })
206            .modify_block_chained(|block| {
207                block.basefee = 0;
208            })
209            .with_tx(tx.clone());
210
211        // Create and run the EVM
212        let mut evm = ctx.build_mainnet();
213        let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
214
215        Ok(self.process_result(result))
216    }
217
218    /// Estimates gas for a Safe internal call
219    ///
220    /// Runs the simulation and returns gas used + 10% buffer
221    pub async fn estimate_safe_tx_gas(
222        &self,
223        from: Address,
224        to: Address,
225        value: U256,
226        data: Bytes,
227        operation: Operation,
228    ) -> Result<U256> {
229        let result = self.simulate_call(from, to, value, data, operation).await?;
230
231        if !result.success {
232            return Err(Error::GasEstimation(format!(
233                "Simulation failed: {}",
234                result.revert_reason.unwrap_or_else(|| "unknown".to_string())
235            )));
236        }
237
238        // Add 10% buffer to the gas used
239        let gas_with_buffer = result.gas_used + (result.gas_used / 10);
240        Ok(U256::from(gas_with_buffer))
241    }
242
243    fn process_result<H>(
244        &self,
245        result: revm::context::result::ExecResultAndState<revm::context::result::ExecutionResult<H>>,
246    ) -> SimulationResult
247    where
248        H: std::fmt::Debug,
249    {
250        use revm::context::result::{ExecutionResult, Output};
251
252        // Build state diff from the execution state
253        let state_diff = build_state_diff(&result.state);
254
255        match result.result {
256            ExecutionResult::Success {
257                gas_used,
258                output,
259                logs,
260                ..
261            } => {
262                let return_data = match output {
263                    Output::Call(data) => Bytes::from(data.to_vec()),
264                    Output::Create(_, _) => Bytes::new(),
265                };
266
267                let logs = logs
268                    .into_iter()
269                    .filter_map(|log| {
270                        Log::new(log.address, log.topics().to_vec(), log.data.data.clone())
271                    })
272                    .collect();
273
274                SimulationResult {
275                    success: true,
276                    gas_used,
277                    return_data,
278                    logs,
279                    revert_reason: None,
280                    state_diff,
281                }
282            }
283            ExecutionResult::Revert { gas_used, output } => {
284                let revert_reason = Self::decode_revert_reason(&output);
285                SimulationResult {
286                    success: false,
287                    gas_used,
288                    return_data: Bytes::from(output.to_vec()),
289                    logs: vec![],
290                    revert_reason: Some(revert_reason),
291                    state_diff,
292                }
293            }
294            ExecutionResult::Halt { gas_used, reason } => SimulationResult {
295                success: false,
296                gas_used,
297                return_data: Bytes::new(),
298                logs: vec![],
299                revert_reason: Some(format!("Halted: {:?}", reason)),
300                state_diff,
301            },
302        }
303    }
304
305    fn decode_revert_reason(output: &revm::primitives::Bytes) -> String {
306        if output.len() < 4 {
307            return "Unknown revert".to_string();
308        }
309
310        // Check for Error(string) selector: 0x08c379a0
311        if output[0..4] == [0x08, 0xc3, 0x79, 0xa0] && output.len() >= 68 {
312            // Skip selector (4) + offset (32) + length position
313            let offset = 4 + 32;
314            if output.len() > offset + 32 {
315                let len = u32::from_be_bytes([
316                    output[offset + 28],
317                    output[offset + 29],
318                    output[offset + 30],
319                    output[offset + 31],
320                ]) as usize;
321
322                let str_start = offset + 32;
323                if output.len() >= str_start + len {
324                    if let Ok(s) = String::from_utf8(output[str_start..str_start + len].to_vec()) {
325                        return s;
326                    }
327                }
328            }
329        }
330
331        // Check for Panic(uint256) selector: 0x4e487b71
332        if output[0..4] == [0x4e, 0x48, 0x7b, 0x71] && output.len() >= 36 {
333            let panic_code =
334                u32::from_be_bytes([output[32], output[33], output[34], output[35]]) as usize;
335            return match panic_code {
336                0x00 => "Panic: generic/compiler panic",
337                0x01 => "Panic: assertion failed",
338                0x11 => "Panic: arithmetic overflow/underflow",
339                0x12 => "Panic: division by zero",
340                0x21 => "Panic: invalid enum value",
341                0x22 => "Panic: access to incorrectly encoded storage",
342                0x31 => "Panic: pop on empty array",
343                0x32 => "Panic: array out of bounds",
344                0x41 => "Panic: memory overflow",
345                0x51 => "Panic: call to zero-initialized function",
346                _ => "Panic: unknown code",
347            }
348            .to_string();
349        }
350
351        format!("Revert: 0x{}", alloy::primitives::hex::encode(output))
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_simulation_result() {
361        let result = SimulationResult {
362            success: true,
363            gas_used: 21000,
364            return_data: Bytes::new(),
365            logs: vec![],
366            revert_reason: None,
367            state_diff: DiffMode::default(),
368        };
369
370        assert!(result.is_success());
371        assert!(result.error_message().is_none());
372    }
373
374    #[test]
375    fn test_simulation_result_revert() {
376        let result = SimulationResult {
377            success: false,
378            gas_used: 21000,
379            return_data: Bytes::new(),
380            logs: vec![],
381            revert_reason: Some("ERC20: insufficient balance".to_string()),
382            state_diff: DiffMode::default(),
383        };
384
385        assert!(!result.is_success());
386        assert_eq!(result.error_message(), Some("ERC20: insufficient balance"));
387    }
388
389    #[test]
390    fn test_state_diff_with_balance_change() {
391        let mut pre = BTreeMap::new();
392        let mut post = BTreeMap::new();
393
394        let addr = Address::ZERO;
395
396        pre.insert(
397            addr,
398            AccountState {
399                balance: Some(U256::from(1000)),
400                nonce: Some(0),
401                code: None,
402                storage: BTreeMap::new(),
403            },
404        );
405
406        post.insert(
407            addr,
408            AccountState {
409                balance: Some(U256::from(500)),
410                nonce: Some(1),
411                code: None,
412                storage: BTreeMap::new(),
413            },
414        );
415
416        let state_diff = DiffMode { pre, post };
417
418        let result = SimulationResult {
419            success: true,
420            gas_used: 21000,
421            return_data: Bytes::new(),
422            logs: vec![],
423            revert_reason: None,
424            state_diff,
425        };
426
427        assert!(result.is_success());
428        assert_eq!(result.state_diff.pre.len(), 1);
429        assert_eq!(result.state_diff.post.len(), 1);
430
431        let pre_account = result.state_diff.pre.get(&addr).unwrap();
432        let post_account = result.state_diff.post.get(&addr).unwrap();
433
434        assert_eq!(pre_account.balance, Some(U256::from(1000)));
435        assert_eq!(post_account.balance, Some(U256::from(500)));
436        assert_eq!(pre_account.nonce, Some(0));
437        assert_eq!(post_account.nonce, Some(1));
438    }
439
440    #[test]
441    fn test_state_diff_with_storage_change() {
442        let mut pre = BTreeMap::new();
443        let mut post = BTreeMap::new();
444
445        let addr = Address::ZERO;
446        let storage_key = B256::ZERO;
447
448        // Storage values in AccountState are B256, not U256
449        let pre_value = B256::from(U256::from(100));
450        let post_value = B256::from(U256::from(200));
451
452        let mut pre_storage = BTreeMap::new();
453        pre_storage.insert(storage_key, pre_value);
454
455        let mut post_storage = BTreeMap::new();
456        post_storage.insert(storage_key, post_value);
457
458        pre.insert(
459            addr,
460            AccountState {
461                balance: Some(U256::ZERO),
462                nonce: Some(0),
463                code: None,
464                storage: pre_storage,
465            },
466        );
467
468        post.insert(
469            addr,
470            AccountState {
471                balance: Some(U256::ZERO),
472                nonce: Some(0),
473                code: None,
474                storage: post_storage,
475            },
476        );
477
478        let state_diff = DiffMode { pre, post };
479
480        let result = SimulationResult {
481            success: true,
482            gas_used: 50000,
483            return_data: Bytes::new(),
484            logs: vec![],
485            revert_reason: None,
486            state_diff,
487        };
488
489        let pre_account = result.state_diff.pre.get(&addr).unwrap();
490        let post_account = result.state_diff.post.get(&addr).unwrap();
491
492        assert_eq!(pre_account.storage.get(&storage_key), Some(&pre_value));
493        assert_eq!(post_account.storage.get(&storage_key), Some(&post_value));
494    }
495}