Skip to main content

safe_rs/simulation/
fork.rs

1//! Fork database and revm simulation
2
3use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use alloy::network::AnyNetwork;
7use alloy::primitives::{Address, Bytes, Log, TxKind, B256, U256};
8use alloy::providers::Provider;
9use alloy::rpc::types::trace::geth::pre_state::{AccountState, DiffMode};
10use foundry_fork_db::{cache::BlockchainDbMeta, BlockchainDb, SharedBackend};
11use revm::context::TxEnv;
12use revm::database::CacheDB;
13use revm::primitives::hardfork::SpecId;
14use revm::state::{AccountInfo, EvmState};
15use revm::{Context, ExecuteEvm, MainBuilder, MainContext};
16use revm_inspectors::tracing::{TracingInspector, TracingInspectorConfig};
17
18pub use revm_inspectors::tracing::CallTraceArena;
19
20use crate::error::{Error, Result};
21use crate::types::Operation;
22
23/// Result of a simulated transaction
24#[derive(Debug, Clone)]
25pub struct SimulationResult {
26    /// Whether the simulation succeeded
27    pub success: bool,
28    /// Gas used during simulation
29    pub gas_used: u64,
30    /// Return data from the call
31    pub return_data: Bytes,
32    /// Logs emitted during simulation
33    pub logs: Vec<Log>,
34    /// Revert reason if the call reverted
35    pub revert_reason: Option<String>,
36    /// State changes from simulation (pre/post state for touched accounts)
37    pub state_diff: DiffMode,
38    /// Call trace arena (if tracing was enabled)
39    pub traces: Option<CallTraceArena>,
40}
41
42impl SimulationResult {
43    /// Returns true if the simulation was successful
44    pub fn is_success(&self) -> bool {
45        self.success
46    }
47
48    /// Returns the revert reason if available
49    pub fn error_message(&self) -> Option<&str> {
50        self.revert_reason.as_deref()
51    }
52
53    /// Format traces as human-readable text (cast run style)
54    ///
55    /// Returns `None` if tracing was not enabled for this simulation.
56    pub fn format_traces(&self) -> Option<String> {
57        use revm_inspectors::tracing::TraceWriter;
58
59        let traces = self.traces.as_ref()?;
60        let mut writer = TraceWriter::new(Vec::<u8>::new());
61        writer.write_arena(traces).ok()?;
62        String::from_utf8(writer.into_writer()).ok()
63    }
64}
65
66/// Builds a state diff from REVM's execution state
67///
68/// REVM tracks original values in `Account.original_info` and `EvmStorageSlot.original_value`,
69/// so we can reconstruct both pre and post state from the final state.
70fn build_state_diff(state: &EvmState) -> DiffMode {
71    let mut pre = BTreeMap::new();
72    let mut post = BTreeMap::new();
73
74    for (address, account) in state.iter() {
75        // Skip if account wasn't touched
76        if !account.is_touched() {
77            continue;
78        }
79
80        // Build storage diffs - only include changed slots
81        let mut pre_storage = BTreeMap::new();
82        let mut post_storage = BTreeMap::new();
83
84        for (key, slot) in account.storage.iter() {
85            if slot.is_changed() {
86                pre_storage.insert(B256::from(*key), B256::from(slot.original_value));
87                post_storage.insert(B256::from(*key), B256::from(slot.present_value));
88            }
89        }
90
91        // Build pre-state from original_info
92        let pre_state = AccountState {
93            balance: Some(account.original_info.balance),
94            nonce: Some(account.original_info.nonce),
95            code: account
96                .original_info
97                .code
98                .as_ref()
99                .map(|c| Bytes::from(c.original_bytes().to_vec())),
100            storage: pre_storage,
101        };
102
103        // Build post-state from current info
104        let post_state = AccountState {
105            balance: Some(account.info.balance),
106            nonce: Some(account.info.nonce),
107            code: account
108                .info
109                .code
110                .as_ref()
111                .map(|c| Bytes::from(c.original_bytes().to_vec())),
112            storage: post_storage,
113        };
114
115        pre.insert(*address, pre_state);
116        post.insert(*address, post_state);
117    }
118
119    DiffMode { pre, post }
120}
121
122/// Fork simulator for executing transactions against a forked state
123pub struct ForkSimulator<P> {
124    provider: P,
125    chain_id: u64,
126    block_number: Option<u64>,
127    tracing: bool,
128    caller_balance: Option<U256>,
129}
130
131impl<P> ForkSimulator<P>
132where
133    P: Provider<AnyNetwork> + Clone + 'static,
134{
135    /// Creates a new fork simulator
136    pub fn new(provider: P, chain_id: u64) -> Self {
137        Self {
138            provider,
139            chain_id,
140            block_number: None,
141            tracing: false,
142            caller_balance: None,
143        }
144    }
145
146    /// Sets the block number to fork from
147    pub fn at_block(mut self, block: u64) -> Self {
148        self.block_number = Some(block);
149        self
150    }
151
152    /// Enables transaction tracing (cast run style)
153    ///
154    /// When enabled, `simulate_call()` will capture detailed call traces
155    /// showing the nested call hierarchy, gas per call, and call/return data.
156    /// Access traces via `SimulationResult::format_traces()`.
157    pub fn with_tracing(mut self, enable: bool) -> Self {
158        self.tracing = enable;
159        self
160    }
161
162    /// Sets a custom balance for the caller during simulation.
163    ///
164    /// If not set, the caller's on-chain balance is used.
165    pub fn with_caller_balance(mut self, balance: U256) -> Self {
166        self.caller_balance = Some(balance);
167        self
168    }
169
170    /// Creates a forked database from the current provider state
171    pub async fn create_fork_db(&self) -> Result<CacheDB<SharedBackend>> {
172        let block = match self.block_number {
173            Some(b) => b,
174            None => self
175                .provider
176                .get_block_number()
177                .await
178                .map_err(|e| Error::ForkDb(e.to_string()))?,
179        };
180
181        let meta = BlockchainDbMeta::new(
182            Default::default(), // empty known contracts
183            format!("fork-{}", self.chain_id),
184        );
185
186        let db = BlockchainDb::new(meta, None);
187        let backend = SharedBackend::spawn_backend_thread(
188            Arc::new(self.provider.clone()),
189            db,
190            Some(block.into()),
191        );
192
193        Ok(CacheDB::new(backend))
194    }
195
196    /// Simulates a call from the Safe
197    pub async fn simulate_call(
198        &self,
199        from: Address,
200        to: Address,
201        value: U256,
202        data: Bytes,
203        operation: Operation,
204    ) -> Result<SimulationResult> {
205        let mut db = self.create_fork_db().await?;
206
207        // Only override caller balance if configured
208        if let Some(balance) = self.caller_balance {
209            let caller_info = AccountInfo::default();
210            db.insert_account_info(from, caller_info);
211
212            if let Some(account) = db.cache.accounts.get_mut(&from) {
213                account.info.balance = balance;
214            }
215        }
216
217        // Determine the actual call target and calldata
218        let (call_to, call_data) = match operation {
219            Operation::Call => (to, data.to_vec()),
220            Operation::DelegateCall => {
221                // For delegatecall simulation, we execute directly from the Safe
222                // This is a simplification - in reality the Safe would delegatecall
223                (to, data.to_vec())
224            }
225        };
226
227        let tx = TxEnv {
228            caller: from,
229            gas_limit: 30_000_000,
230            gas_price: 0,
231            kind: TxKind::Call(call_to),
232            value,
233            data: call_data.into(),
234            nonce: 0,
235            chain_id: Some(self.chain_id),
236            ..Default::default()
237        };
238
239        // Build the EVM context
240        let ctx = Context::mainnet()
241            .with_db(db)
242            .modify_cfg_chained(|cfg| {
243                cfg.spec = SpecId::CANCUN;
244                cfg.chain_id = self.chain_id;
245            })
246            .modify_block_chained(|block| {
247                block.basefee = 0;
248            })
249            .with_tx(tx.clone());
250
251        if self.tracing {
252            // Create inspector for tracing
253            let config = TracingInspectorConfig::default_parity();
254            let mut inspector = TracingInspector::new(config);
255
256            // Build EVM with inspector attached and execute
257            let mut evm = ctx.build_mainnet_with_inspector(&mut inspector);
258            let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
259
260            // Extract traces from the inspector
261            let traces = Some(inspector.into_traces());
262
263            let mut sim_result = self.process_result(result);
264            sim_result.traces = traces;
265            Ok(sim_result)
266        } else {
267            // Create and run the EVM without tracing
268            let mut evm = ctx.build_mainnet();
269            let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
270
271            Ok(self.process_result(result))
272        }
273    }
274
275    /// Estimates gas for a Safe internal call
276    ///
277    /// Runs the simulation and returns gas used + 10% buffer
278    pub async fn estimate_safe_tx_gas(
279        &self,
280        from: Address,
281        to: Address,
282        value: U256,
283        data: Bytes,
284        operation: Operation,
285    ) -> Result<U256> {
286        let result = self.simulate_call(from, to, value, data, operation).await?;
287
288        if !result.success {
289            return Err(Error::GasEstimation(format!(
290                "Simulation failed: {}",
291                result.revert_reason.unwrap_or_else(|| "unknown".to_string())
292            )));
293        }
294
295        // Add 10% buffer to the gas used
296        let gas_with_buffer = result.gas_used + (result.gas_used / 10);
297        Ok(U256::from(gas_with_buffer))
298    }
299
300    fn process_result<H>(
301        &self,
302        result: revm::context::result::ExecResultAndState<revm::context::result::ExecutionResult<H>>,
303    ) -> SimulationResult
304    where
305        H: std::fmt::Debug,
306    {
307        use revm::context::result::{ExecutionResult, Output};
308
309        // Build state diff from the execution state
310        let state_diff = build_state_diff(&result.state);
311
312        match result.result {
313            ExecutionResult::Success {
314                gas_used,
315                output,
316                logs,
317                ..
318            } => {
319                let return_data = match output {
320                    Output::Call(data) => Bytes::from(data.to_vec()),
321                    Output::Create(_, _) => Bytes::new(),
322                };
323
324                let logs = logs
325                    .into_iter()
326                    .filter_map(|log| {
327                        Log::new(log.address, log.topics().to_vec(), log.data.data.clone())
328                    })
329                    .collect();
330
331                SimulationResult {
332                    success: true,
333                    gas_used,
334                    return_data,
335                    logs,
336                    revert_reason: None,
337                    state_diff,
338                    traces: None,
339                }
340            }
341            ExecutionResult::Revert { gas_used, output } => {
342                let revert_reason = Self::decode_revert_reason(&output);
343                SimulationResult {
344                    success: false,
345                    gas_used,
346                    return_data: Bytes::from(output.to_vec()),
347                    logs: vec![],
348                    revert_reason: Some(revert_reason),
349                    state_diff,
350                    traces: None,
351                }
352            }
353            ExecutionResult::Halt { gas_used, reason } => SimulationResult {
354                success: false,
355                gas_used,
356                return_data: Bytes::new(),
357                logs: vec![],
358                revert_reason: Some(format!("Halted: {:?}", reason)),
359                state_diff,
360                traces: None,
361            },
362        }
363    }
364
365    fn decode_revert_reason(output: &revm::primitives::Bytes) -> String {
366        if output.len() < 4 {
367            return "Unknown revert".to_string();
368        }
369
370        // Check for Error(string) selector: 0x08c379a0
371        if output[0..4] == [0x08, 0xc3, 0x79, 0xa0] && output.len() >= 68 {
372            // Skip selector (4) + offset (32) + length position
373            let offset = 4 + 32;
374            if output.len() > offset + 32 {
375                let len = u32::from_be_bytes([
376                    output[offset + 28],
377                    output[offset + 29],
378                    output[offset + 30],
379                    output[offset + 31],
380                ]) as usize;
381
382                let str_start = offset + 32;
383                if output.len() >= str_start + len {
384                    if let Ok(s) = String::from_utf8(output[str_start..str_start + len].to_vec()) {
385                        return s;
386                    }
387                }
388            }
389        }
390
391        // Check for Panic(uint256) selector: 0x4e487b71
392        if output[0..4] == [0x4e, 0x48, 0x7b, 0x71] && output.len() >= 36 {
393            let panic_code =
394                u32::from_be_bytes([output[32], output[33], output[34], output[35]]) as usize;
395            return match panic_code {
396                0x00 => "Panic: generic/compiler panic",
397                0x01 => "Panic: assertion failed",
398                0x11 => "Panic: arithmetic overflow/underflow",
399                0x12 => "Panic: division by zero",
400                0x21 => "Panic: invalid enum value",
401                0x22 => "Panic: access to incorrectly encoded storage",
402                0x31 => "Panic: pop on empty array",
403                0x32 => "Panic: array out of bounds",
404                0x41 => "Panic: memory overflow",
405                0x51 => "Panic: call to zero-initialized function",
406                _ => "Panic: unknown code",
407            }
408            .to_string();
409        }
410
411        format!("Revert: 0x{}", alloy::primitives::hex::encode(output))
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_simulation_result() {
421        let result = SimulationResult {
422            success: true,
423            gas_used: 21000,
424            return_data: Bytes::new(),
425            logs: vec![],
426            revert_reason: None,
427            state_diff: DiffMode::default(),
428            traces: None,
429        };
430
431        assert!(result.is_success());
432        assert!(result.error_message().is_none());
433        assert!(result.format_traces().is_none());
434    }
435
436    #[test]
437    fn test_simulation_result_revert() {
438        let result = SimulationResult {
439            success: false,
440            gas_used: 21000,
441            return_data: Bytes::new(),
442            logs: vec![],
443            revert_reason: Some("ERC20: insufficient balance".to_string()),
444            state_diff: DiffMode::default(),
445            traces: None,
446        };
447
448        assert!(!result.is_success());
449        assert_eq!(result.error_message(), Some("ERC20: insufficient balance"));
450    }
451
452    #[test]
453    fn test_state_diff_with_balance_change() {
454        let mut pre = BTreeMap::new();
455        let mut post = BTreeMap::new();
456
457        let addr = Address::ZERO;
458
459        pre.insert(
460            addr,
461            AccountState {
462                balance: Some(U256::from(1000)),
463                nonce: Some(0),
464                code: None,
465                storage: BTreeMap::new(),
466            },
467        );
468
469        post.insert(
470            addr,
471            AccountState {
472                balance: Some(U256::from(500)),
473                nonce: Some(1),
474                code: None,
475                storage: BTreeMap::new(),
476            },
477        );
478
479        let state_diff = DiffMode { pre, post };
480
481        let result = SimulationResult {
482            success: true,
483            gas_used: 21000,
484            return_data: Bytes::new(),
485            logs: vec![],
486            revert_reason: None,
487            state_diff,
488            traces: None,
489        };
490
491        assert!(result.is_success());
492        assert_eq!(result.state_diff.pre.len(), 1);
493        assert_eq!(result.state_diff.post.len(), 1);
494
495        let pre_account = result.state_diff.pre.get(&addr).unwrap();
496        let post_account = result.state_diff.post.get(&addr).unwrap();
497
498        assert_eq!(pre_account.balance, Some(U256::from(1000)));
499        assert_eq!(post_account.balance, Some(U256::from(500)));
500        assert_eq!(pre_account.nonce, Some(0));
501        assert_eq!(post_account.nonce, Some(1));
502    }
503
504    #[test]
505    fn test_state_diff_with_storage_change() {
506        let mut pre = BTreeMap::new();
507        let mut post = BTreeMap::new();
508
509        let addr = Address::ZERO;
510        let storage_key = B256::ZERO;
511
512        // Storage values in AccountState are B256, not U256
513        let pre_value = B256::from(U256::from(100));
514        let post_value = B256::from(U256::from(200));
515
516        let mut pre_storage = BTreeMap::new();
517        pre_storage.insert(storage_key, pre_value);
518
519        let mut post_storage = BTreeMap::new();
520        post_storage.insert(storage_key, post_value);
521
522        pre.insert(
523            addr,
524            AccountState {
525                balance: Some(U256::ZERO),
526                nonce: Some(0),
527                code: None,
528                storage: pre_storage,
529            },
530        );
531
532        post.insert(
533            addr,
534            AccountState {
535                balance: Some(U256::ZERO),
536                nonce: Some(0),
537                code: None,
538                storage: post_storage,
539            },
540        );
541
542        let state_diff = DiffMode { pre, post };
543
544        let result = SimulationResult {
545            success: true,
546            gas_used: 50000,
547            return_data: Bytes::new(),
548            logs: vec![],
549            revert_reason: None,
550            state_diff,
551            traces: None,
552        };
553
554        let pre_account = result.state_diff.pre.get(&addr).unwrap();
555        let post_account = result.state_diff.post.get(&addr).unwrap();
556
557        assert_eq!(pre_account.storage.get(&storage_key), Some(&pre_value));
558        assert_eq!(post_account.storage.get(&storage_key), Some(&post_value));
559    }
560}