1use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use alloy::network::AnyNetwork;
7use alloy::primitives::{Address, Bytes, Log, TxKind, B256, U256};
8use alloy::providers::Provider;
9use alloy::rpc::types::trace::geth::pre_state::{AccountState, DiffMode};
10use foundry_fork_db::{cache::BlockchainDbMeta, BlockchainDb, SharedBackend};
11use revm::context::TxEnv;
12use revm::database::CacheDB;
13use revm::primitives::hardfork::SpecId;
14use revm::state::{AccountInfo, EvmState};
15use revm::{Context, ExecuteEvm, MainBuilder, MainContext};
16use revm_inspectors::tracing::{TracingInspector, TracingInspectorConfig};
17
18pub use revm_inspectors::tracing::CallTraceArena;
19
20use crate::error::{Error, Result};
21use crate::types::Operation;
22
23#[derive(Debug, Clone)]
25pub struct SimulationResult {
26 pub success: bool,
28 pub gas_used: u64,
30 pub return_data: Bytes,
32 pub logs: Vec<Log>,
34 pub revert_reason: Option<String>,
36 pub state_diff: DiffMode,
38 pub traces: Option<CallTraceArena>,
40}
41
42impl SimulationResult {
43 pub fn is_success(&self) -> bool {
45 self.success
46 }
47
48 pub fn error_message(&self) -> Option<&str> {
50 self.revert_reason.as_deref()
51 }
52
53 pub fn format_traces(&self) -> Option<String> {
57 use revm_inspectors::tracing::TraceWriter;
58
59 let traces = self.traces.as_ref()?;
60 let mut writer = TraceWriter::new(Vec::<u8>::new());
61 writer.write_arena(traces).ok()?;
62 String::from_utf8(writer.into_writer()).ok()
63 }
64}
65
66fn build_state_diff(state: &EvmState) -> DiffMode {
71 let mut pre = BTreeMap::new();
72 let mut post = BTreeMap::new();
73
74 for (address, account) in state.iter() {
75 if !account.is_touched() {
77 continue;
78 }
79
80 let mut pre_storage = BTreeMap::new();
82 let mut post_storage = BTreeMap::new();
83
84 for (key, slot) in account.storage.iter() {
85 if slot.is_changed() {
86 pre_storage.insert(B256::from(*key), B256::from(slot.original_value));
87 post_storage.insert(B256::from(*key), B256::from(slot.present_value));
88 }
89 }
90
91 let pre_state = AccountState {
93 balance: Some(account.original_info.balance),
94 nonce: Some(account.original_info.nonce),
95 code: account
96 .original_info
97 .code
98 .as_ref()
99 .map(|c| Bytes::from(c.original_bytes().to_vec())),
100 storage: pre_storage,
101 };
102
103 let post_state = AccountState {
105 balance: Some(account.info.balance),
106 nonce: Some(account.info.nonce),
107 code: account
108 .info
109 .code
110 .as_ref()
111 .map(|c| Bytes::from(c.original_bytes().to_vec())),
112 storage: post_storage,
113 };
114
115 pre.insert(*address, pre_state);
116 post.insert(*address, post_state);
117 }
118
119 DiffMode { pre, post }
120}
121
122pub struct ForkSimulator<P> {
124 provider: P,
125 chain_id: u64,
126 block_number: Option<u64>,
127 tracing: bool,
128 caller_balance: Option<U256>,
129}
130
131impl<P> ForkSimulator<P>
132where
133 P: Provider<AnyNetwork> + Clone + 'static,
134{
135 pub fn new(provider: P, chain_id: u64) -> Self {
137 Self {
138 provider,
139 chain_id,
140 block_number: None,
141 tracing: false,
142 caller_balance: None,
143 }
144 }
145
146 pub fn at_block(mut self, block: u64) -> Self {
148 self.block_number = Some(block);
149 self
150 }
151
152 pub fn with_tracing(mut self, enable: bool) -> Self {
158 self.tracing = enable;
159 self
160 }
161
162 pub fn with_caller_balance(mut self, balance: U256) -> Self {
166 self.caller_balance = Some(balance);
167 self
168 }
169
170 pub async fn create_fork_db(&self) -> Result<CacheDB<SharedBackend>> {
172 let block = match self.block_number {
173 Some(b) => b,
174 None => self
175 .provider
176 .get_block_number()
177 .await
178 .map_err(|e| Error::ForkDb(e.to_string()))?,
179 };
180
181 let meta = BlockchainDbMeta::new(
182 Default::default(), format!("fork-{}", self.chain_id),
184 );
185
186 let db = BlockchainDb::new(meta, None);
187 let backend = SharedBackend::spawn_backend_thread(
188 Arc::new(self.provider.clone()),
189 db,
190 Some(block.into()),
191 );
192
193 Ok(CacheDB::new(backend))
194 }
195
196 pub async fn simulate_call(
198 &self,
199 from: Address,
200 to: Address,
201 value: U256,
202 data: Bytes,
203 operation: Operation,
204 ) -> Result<SimulationResult> {
205 let mut db = self.create_fork_db().await?;
206
207 if let Some(balance) = self.caller_balance {
209 let caller_info = AccountInfo::default();
210 db.insert_account_info(from, caller_info);
211
212 if let Some(account) = db.cache.accounts.get_mut(&from) {
213 account.info.balance = balance;
214 }
215 }
216
217 let (call_to, call_data) = match operation {
219 Operation::Call => (to, data.to_vec()),
220 Operation::DelegateCall => {
221 (to, data.to_vec())
224 }
225 };
226
227 let tx = TxEnv {
228 caller: from,
229 gas_limit: 30_000_000,
230 gas_price: 0,
231 kind: TxKind::Call(call_to),
232 value,
233 data: call_data.into(),
234 nonce: 0,
235 chain_id: Some(self.chain_id),
236 ..Default::default()
237 };
238
239 let ctx = Context::mainnet()
241 .with_db(db)
242 .modify_cfg_chained(|cfg| {
243 cfg.spec = SpecId::CANCUN;
244 cfg.chain_id = self.chain_id;
245 })
246 .modify_block_chained(|block| {
247 block.basefee = 0;
248 })
249 .with_tx(tx.clone());
250
251 if self.tracing {
252 let config = TracingInspectorConfig::default_parity();
254 let mut inspector = TracingInspector::new(config);
255
256 let mut evm = ctx.build_mainnet_with_inspector(&mut inspector);
258 let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
259
260 let traces = Some(inspector.into_traces());
262
263 let mut sim_result = self.process_result(result);
264 sim_result.traces = traces;
265 Ok(sim_result)
266 } else {
267 let mut evm = ctx.build_mainnet();
269 let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
270
271 Ok(self.process_result(result))
272 }
273 }
274
275 pub async fn estimate_safe_tx_gas(
279 &self,
280 from: Address,
281 to: Address,
282 value: U256,
283 data: Bytes,
284 operation: Operation,
285 ) -> Result<U256> {
286 let result = self.simulate_call(from, to, value, data, operation).await?;
287
288 if !result.success {
289 return Err(Error::GasEstimation(format!(
290 "Simulation failed: {}",
291 result.revert_reason.unwrap_or_else(|| "unknown".to_string())
292 )));
293 }
294
295 let gas_with_buffer = result.gas_used + (result.gas_used / 10);
297 Ok(U256::from(gas_with_buffer))
298 }
299
300 fn process_result<H>(
301 &self,
302 result: revm::context::result::ExecResultAndState<revm::context::result::ExecutionResult<H>>,
303 ) -> SimulationResult
304 where
305 H: std::fmt::Debug,
306 {
307 use revm::context::result::{ExecutionResult, Output};
308
309 let state_diff = build_state_diff(&result.state);
311
312 match result.result {
313 ExecutionResult::Success {
314 gas_used,
315 output,
316 logs,
317 ..
318 } => {
319 let return_data = match output {
320 Output::Call(data) => Bytes::from(data.to_vec()),
321 Output::Create(_, _) => Bytes::new(),
322 };
323
324 let logs = logs
325 .into_iter()
326 .filter_map(|log| {
327 Log::new(log.address, log.topics().to_vec(), log.data.data.clone())
328 })
329 .collect();
330
331 SimulationResult {
332 success: true,
333 gas_used,
334 return_data,
335 logs,
336 revert_reason: None,
337 state_diff,
338 traces: None,
339 }
340 }
341 ExecutionResult::Revert { gas_used, output } => {
342 let revert_reason = Self::decode_revert_reason(&output);
343 SimulationResult {
344 success: false,
345 gas_used,
346 return_data: Bytes::from(output.to_vec()),
347 logs: vec![],
348 revert_reason: Some(revert_reason),
349 state_diff,
350 traces: None,
351 }
352 }
353 ExecutionResult::Halt { gas_used, reason } => SimulationResult {
354 success: false,
355 gas_used,
356 return_data: Bytes::new(),
357 logs: vec![],
358 revert_reason: Some(format!("Halted: {:?}", reason)),
359 state_diff,
360 traces: None,
361 },
362 }
363 }
364
365 fn decode_revert_reason(output: &revm::primitives::Bytes) -> String {
366 if output.len() < 4 {
367 return "Unknown revert".to_string();
368 }
369
370 if output[0..4] == [0x08, 0xc3, 0x79, 0xa0] && output.len() >= 68 {
372 let offset = 4 + 32;
374 if output.len() > offset + 32 {
375 let len = u32::from_be_bytes([
376 output[offset + 28],
377 output[offset + 29],
378 output[offset + 30],
379 output[offset + 31],
380 ]) as usize;
381
382 let str_start = offset + 32;
383 if output.len() >= str_start + len {
384 if let Ok(s) = String::from_utf8(output[str_start..str_start + len].to_vec()) {
385 return s;
386 }
387 }
388 }
389 }
390
391 if output[0..4] == [0x4e, 0x48, 0x7b, 0x71] && output.len() >= 36 {
393 let panic_code =
394 u32::from_be_bytes([output[32], output[33], output[34], output[35]]) as usize;
395 return match panic_code {
396 0x00 => "Panic: generic/compiler panic",
397 0x01 => "Panic: assertion failed",
398 0x11 => "Panic: arithmetic overflow/underflow",
399 0x12 => "Panic: division by zero",
400 0x21 => "Panic: invalid enum value",
401 0x22 => "Panic: access to incorrectly encoded storage",
402 0x31 => "Panic: pop on empty array",
403 0x32 => "Panic: array out of bounds",
404 0x41 => "Panic: memory overflow",
405 0x51 => "Panic: call to zero-initialized function",
406 _ => "Panic: unknown code",
407 }
408 .to_string();
409 }
410
411 format!("Revert: 0x{}", alloy::primitives::hex::encode(output))
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_simulation_result() {
421 let result = SimulationResult {
422 success: true,
423 gas_used: 21000,
424 return_data: Bytes::new(),
425 logs: vec![],
426 revert_reason: None,
427 state_diff: DiffMode::default(),
428 traces: None,
429 };
430
431 assert!(result.is_success());
432 assert!(result.error_message().is_none());
433 assert!(result.format_traces().is_none());
434 }
435
436 #[test]
437 fn test_simulation_result_revert() {
438 let result = SimulationResult {
439 success: false,
440 gas_used: 21000,
441 return_data: Bytes::new(),
442 logs: vec![],
443 revert_reason: Some("ERC20: insufficient balance".to_string()),
444 state_diff: DiffMode::default(),
445 traces: None,
446 };
447
448 assert!(!result.is_success());
449 assert_eq!(result.error_message(), Some("ERC20: insufficient balance"));
450 }
451
452 #[test]
453 fn test_state_diff_with_balance_change() {
454 let mut pre = BTreeMap::new();
455 let mut post = BTreeMap::new();
456
457 let addr = Address::ZERO;
458
459 pre.insert(
460 addr,
461 AccountState {
462 balance: Some(U256::from(1000)),
463 nonce: Some(0),
464 code: None,
465 storage: BTreeMap::new(),
466 },
467 );
468
469 post.insert(
470 addr,
471 AccountState {
472 balance: Some(U256::from(500)),
473 nonce: Some(1),
474 code: None,
475 storage: BTreeMap::new(),
476 },
477 );
478
479 let state_diff = DiffMode { pre, post };
480
481 let result = SimulationResult {
482 success: true,
483 gas_used: 21000,
484 return_data: Bytes::new(),
485 logs: vec![],
486 revert_reason: None,
487 state_diff,
488 traces: None,
489 };
490
491 assert!(result.is_success());
492 assert_eq!(result.state_diff.pre.len(), 1);
493 assert_eq!(result.state_diff.post.len(), 1);
494
495 let pre_account = result.state_diff.pre.get(&addr).unwrap();
496 let post_account = result.state_diff.post.get(&addr).unwrap();
497
498 assert_eq!(pre_account.balance, Some(U256::from(1000)));
499 assert_eq!(post_account.balance, Some(U256::from(500)));
500 assert_eq!(pre_account.nonce, Some(0));
501 assert_eq!(post_account.nonce, Some(1));
502 }
503
504 #[test]
505 fn test_state_diff_with_storage_change() {
506 let mut pre = BTreeMap::new();
507 let mut post = BTreeMap::new();
508
509 let addr = Address::ZERO;
510 let storage_key = B256::ZERO;
511
512 let pre_value = B256::from(U256::from(100));
514 let post_value = B256::from(U256::from(200));
515
516 let mut pre_storage = BTreeMap::new();
517 pre_storage.insert(storage_key, pre_value);
518
519 let mut post_storage = BTreeMap::new();
520 post_storage.insert(storage_key, post_value);
521
522 pre.insert(
523 addr,
524 AccountState {
525 balance: Some(U256::ZERO),
526 nonce: Some(0),
527 code: None,
528 storage: pre_storage,
529 },
530 );
531
532 post.insert(
533 addr,
534 AccountState {
535 balance: Some(U256::ZERO),
536 nonce: Some(0),
537 code: None,
538 storage: post_storage,
539 },
540 );
541
542 let state_diff = DiffMode { pre, post };
543
544 let result = SimulationResult {
545 success: true,
546 gas_used: 50000,
547 return_data: Bytes::new(),
548 logs: vec![],
549 revert_reason: None,
550 state_diff,
551 traces: None,
552 };
553
554 let pre_account = result.state_diff.pre.get(&addr).unwrap();
555 let post_account = result.state_diff.post.get(&addr).unwrap();
556
557 assert_eq!(pre_account.storage.get(&storage_key), Some(&pre_value));
558 assert_eq!(post_account.storage.get(&storage_key), Some(&post_value));
559 }
560}