1use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use alloy::network::AnyNetwork;
7use alloy::primitives::{Address, Bytes, Log, TxKind, B256, U256};
8use alloy::providers::Provider;
9use alloy::rpc::types::trace::geth::pre_state::{AccountState, DiffMode};
10use foundry_fork_db::{cache::BlockchainDbMeta, BlockchainDb, SharedBackend};
11use revm::context::TxEnv;
12use revm::database::CacheDB;
13use revm::primitives::hardfork::SpecId;
14use revm::state::{AccountInfo, EvmState};
15use revm::{Context, ExecuteEvm, MainBuilder, MainContext};
16
17use crate::error::{Error, Result};
18use crate::types::Operation;
19
20#[derive(Debug, Clone)]
22pub struct SimulationResult {
23 pub success: bool,
25 pub gas_used: u64,
27 pub return_data: Bytes,
29 pub logs: Vec<Log>,
31 pub revert_reason: Option<String>,
33 pub state_diff: DiffMode,
35}
36
37impl SimulationResult {
38 pub fn is_success(&self) -> bool {
40 self.success
41 }
42
43 pub fn error_message(&self) -> Option<&str> {
45 self.revert_reason.as_deref()
46 }
47}
48
49fn build_state_diff(state: &EvmState) -> DiffMode {
54 let mut pre = BTreeMap::new();
55 let mut post = BTreeMap::new();
56
57 for (address, account) in state.iter() {
58 if !account.is_touched() {
60 continue;
61 }
62
63 let mut pre_storage = BTreeMap::new();
65 let mut post_storage = BTreeMap::new();
66
67 for (key, slot) in account.storage.iter() {
68 if slot.is_changed() {
69 pre_storage.insert(B256::from(*key), B256::from(slot.original_value));
70 post_storage.insert(B256::from(*key), B256::from(slot.present_value));
71 }
72 }
73
74 let pre_state = AccountState {
76 balance: Some(account.original_info.balance),
77 nonce: Some(account.original_info.nonce),
78 code: account
79 .original_info
80 .code
81 .as_ref()
82 .map(|c| Bytes::from(c.original_bytes().to_vec())),
83 storage: pre_storage,
84 };
85
86 let post_state = AccountState {
88 balance: Some(account.info.balance),
89 nonce: Some(account.info.nonce),
90 code: account
91 .info
92 .code
93 .as_ref()
94 .map(|c| Bytes::from(c.original_bytes().to_vec())),
95 storage: post_storage,
96 };
97
98 pre.insert(*address, pre_state);
99 post.insert(*address, post_state);
100 }
101
102 DiffMode { pre, post }
103}
104
105pub struct ForkSimulator<P> {
107 provider: P,
108 chain_id: u64,
109 block_number: Option<u64>,
110}
111
112impl<P> ForkSimulator<P>
113where
114 P: Provider<AnyNetwork> + Clone + 'static,
115{
116 pub fn new(provider: P, chain_id: u64) -> Self {
118 Self {
119 provider,
120 chain_id,
121 block_number: None,
122 }
123 }
124
125 pub fn at_block(mut self, block: u64) -> Self {
127 self.block_number = Some(block);
128 self
129 }
130
131 pub async fn create_fork_db(&self) -> Result<CacheDB<SharedBackend>> {
133 let block = match self.block_number {
134 Some(b) => b,
135 None => self
136 .provider
137 .get_block_number()
138 .await
139 .map_err(|e| Error::ForkDb(e.to_string()))?,
140 };
141
142 let meta = BlockchainDbMeta::new(
143 Default::default(), format!("fork-{}", self.chain_id),
145 );
146
147 let db = BlockchainDb::new(meta, None);
148 let backend = SharedBackend::spawn_backend_thread(
149 Arc::new(self.provider.clone()),
150 db,
151 Some(block.into()),
152 );
153
154 Ok(CacheDB::new(backend))
155 }
156
157 pub async fn simulate_call(
159 &self,
160 from: Address,
161 to: Address,
162 value: U256,
163 data: Bytes,
164 operation: Operation,
165 ) -> Result<SimulationResult> {
166 let mut db = self.create_fork_db().await?;
167
168 let caller_info = AccountInfo::default();
170 db.insert_account_info(from, caller_info);
171
172 if let Some(account) = db.cache.accounts.get_mut(&from) {
174 account.info.balance = U256::from(1_000_000_000_000_000_000_000u128); }
176
177 let (call_to, call_data) = match operation {
179 Operation::Call => (to, data.to_vec()),
180 Operation::DelegateCall => {
181 (to, data.to_vec())
184 }
185 };
186
187 let tx = TxEnv {
188 caller: from,
189 gas_limit: 30_000_000,
190 gas_price: 0,
191 kind: TxKind::Call(call_to),
192 value,
193 data: call_data.into(),
194 nonce: 0,
195 chain_id: Some(self.chain_id),
196 ..Default::default()
197 };
198
199 let ctx = Context::mainnet()
201 .with_db(db)
202 .modify_cfg_chained(|cfg| {
203 cfg.spec = SpecId::CANCUN;
204 cfg.chain_id = self.chain_id;
205 })
206 .modify_block_chained(|block| {
207 block.basefee = 0;
208 })
209 .with_tx(tx.clone());
210
211 let mut evm = ctx.build_mainnet();
213 let result = evm.transact(tx).map_err(|e| Error::Revm(format!("{:?}", e)))?;
214
215 Ok(self.process_result(result))
216 }
217
218 pub async fn estimate_safe_tx_gas(
222 &self,
223 from: Address,
224 to: Address,
225 value: U256,
226 data: Bytes,
227 operation: Operation,
228 ) -> Result<U256> {
229 let result = self.simulate_call(from, to, value, data, operation).await?;
230
231 if !result.success {
232 return Err(Error::GasEstimation(format!(
233 "Simulation failed: {}",
234 result.revert_reason.unwrap_or_else(|| "unknown".to_string())
235 )));
236 }
237
238 let gas_with_buffer = result.gas_used + (result.gas_used / 10);
240 Ok(U256::from(gas_with_buffer))
241 }
242
243 fn process_result<H>(
244 &self,
245 result: revm::context::result::ExecResultAndState<revm::context::result::ExecutionResult<H>>,
246 ) -> SimulationResult
247 where
248 H: std::fmt::Debug,
249 {
250 use revm::context::result::{ExecutionResult, Output};
251
252 let state_diff = build_state_diff(&result.state);
254
255 match result.result {
256 ExecutionResult::Success {
257 gas_used,
258 output,
259 logs,
260 ..
261 } => {
262 let return_data = match output {
263 Output::Call(data) => Bytes::from(data.to_vec()),
264 Output::Create(_, _) => Bytes::new(),
265 };
266
267 let logs = logs
268 .into_iter()
269 .filter_map(|log| {
270 Log::new(log.address, log.topics().to_vec(), log.data.data.clone())
271 })
272 .collect();
273
274 SimulationResult {
275 success: true,
276 gas_used,
277 return_data,
278 logs,
279 revert_reason: None,
280 state_diff,
281 }
282 }
283 ExecutionResult::Revert { gas_used, output } => {
284 let revert_reason = Self::decode_revert_reason(&output);
285 SimulationResult {
286 success: false,
287 gas_used,
288 return_data: Bytes::from(output.to_vec()),
289 logs: vec![],
290 revert_reason: Some(revert_reason),
291 state_diff,
292 }
293 }
294 ExecutionResult::Halt { gas_used, reason } => SimulationResult {
295 success: false,
296 gas_used,
297 return_data: Bytes::new(),
298 logs: vec![],
299 revert_reason: Some(format!("Halted: {:?}", reason)),
300 state_diff,
301 },
302 }
303 }
304
305 fn decode_revert_reason(output: &revm::primitives::Bytes) -> String {
306 if output.len() < 4 {
307 return "Unknown revert".to_string();
308 }
309
310 if output[0..4] == [0x08, 0xc3, 0x79, 0xa0] && output.len() >= 68 {
312 let offset = 4 + 32;
314 if output.len() > offset + 32 {
315 let len = u32::from_be_bytes([
316 output[offset + 28],
317 output[offset + 29],
318 output[offset + 30],
319 output[offset + 31],
320 ]) as usize;
321
322 let str_start = offset + 32;
323 if output.len() >= str_start + len {
324 if let Ok(s) = String::from_utf8(output[str_start..str_start + len].to_vec()) {
325 return s;
326 }
327 }
328 }
329 }
330
331 if output[0..4] == [0x4e, 0x48, 0x7b, 0x71] && output.len() >= 36 {
333 let panic_code =
334 u32::from_be_bytes([output[32], output[33], output[34], output[35]]) as usize;
335 return match panic_code {
336 0x00 => "Panic: generic/compiler panic",
337 0x01 => "Panic: assertion failed",
338 0x11 => "Panic: arithmetic overflow/underflow",
339 0x12 => "Panic: division by zero",
340 0x21 => "Panic: invalid enum value",
341 0x22 => "Panic: access to incorrectly encoded storage",
342 0x31 => "Panic: pop on empty array",
343 0x32 => "Panic: array out of bounds",
344 0x41 => "Panic: memory overflow",
345 0x51 => "Panic: call to zero-initialized function",
346 _ => "Panic: unknown code",
347 }
348 .to_string();
349 }
350
351 format!("Revert: 0x{}", alloy::primitives::hex::encode(output))
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_simulation_result() {
361 let result = SimulationResult {
362 success: true,
363 gas_used: 21000,
364 return_data: Bytes::new(),
365 logs: vec![],
366 revert_reason: None,
367 state_diff: DiffMode::default(),
368 };
369
370 assert!(result.is_success());
371 assert!(result.error_message().is_none());
372 }
373
374 #[test]
375 fn test_simulation_result_revert() {
376 let result = SimulationResult {
377 success: false,
378 gas_used: 21000,
379 return_data: Bytes::new(),
380 logs: vec![],
381 revert_reason: Some("ERC20: insufficient balance".to_string()),
382 state_diff: DiffMode::default(),
383 };
384
385 assert!(!result.is_success());
386 assert_eq!(result.error_message(), Some("ERC20: insufficient balance"));
387 }
388
389 #[test]
390 fn test_state_diff_with_balance_change() {
391 let mut pre = BTreeMap::new();
392 let mut post = BTreeMap::new();
393
394 let addr = Address::ZERO;
395
396 pre.insert(
397 addr,
398 AccountState {
399 balance: Some(U256::from(1000)),
400 nonce: Some(0),
401 code: None,
402 storage: BTreeMap::new(),
403 },
404 );
405
406 post.insert(
407 addr,
408 AccountState {
409 balance: Some(U256::from(500)),
410 nonce: Some(1),
411 code: None,
412 storage: BTreeMap::new(),
413 },
414 );
415
416 let state_diff = DiffMode { pre, post };
417
418 let result = SimulationResult {
419 success: true,
420 gas_used: 21000,
421 return_data: Bytes::new(),
422 logs: vec![],
423 revert_reason: None,
424 state_diff,
425 };
426
427 assert!(result.is_success());
428 assert_eq!(result.state_diff.pre.len(), 1);
429 assert_eq!(result.state_diff.post.len(), 1);
430
431 let pre_account = result.state_diff.pre.get(&addr).unwrap();
432 let post_account = result.state_diff.post.get(&addr).unwrap();
433
434 assert_eq!(pre_account.balance, Some(U256::from(1000)));
435 assert_eq!(post_account.balance, Some(U256::from(500)));
436 assert_eq!(pre_account.nonce, Some(0));
437 assert_eq!(post_account.nonce, Some(1));
438 }
439
440 #[test]
441 fn test_state_diff_with_storage_change() {
442 let mut pre = BTreeMap::new();
443 let mut post = BTreeMap::new();
444
445 let addr = Address::ZERO;
446 let storage_key = B256::ZERO;
447
448 let pre_value = B256::from(U256::from(100));
450 let post_value = B256::from(U256::from(200));
451
452 let mut pre_storage = BTreeMap::new();
453 pre_storage.insert(storage_key, pre_value);
454
455 let mut post_storage = BTreeMap::new();
456 post_storage.insert(storage_key, post_value);
457
458 pre.insert(
459 addr,
460 AccountState {
461 balance: Some(U256::ZERO),
462 nonce: Some(0),
463 code: None,
464 storage: pre_storage,
465 },
466 );
467
468 post.insert(
469 addr,
470 AccountState {
471 balance: Some(U256::ZERO),
472 nonce: Some(0),
473 code: None,
474 storage: post_storage,
475 },
476 );
477
478 let state_diff = DiffMode { pre, post };
479
480 let result = SimulationResult {
481 success: true,
482 gas_used: 50000,
483 return_data: Bytes::new(),
484 logs: vec![],
485 revert_reason: None,
486 state_diff,
487 };
488
489 let pre_account = result.state_diff.pre.get(&addr).unwrap();
490 let post_account = result.state_diff.post.get(&addr).unwrap();
491
492 assert_eq!(pre_account.storage.get(&storage_key), Some(&pre_value));
493 assert_eq!(post_account.storage.get(&storage_key), Some(&post_value));
494 }
495}