1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
5};
6
7use alloy::{
8 primitives::{Address, Bytes as AlloyBytes, StorageValue, B256, U256},
9 providers::{
10 fillers::{BlobGasFiller, ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller},
11 Provider, RootProvider,
12 },
13 transports::{RpcError, TransportErrorKind},
14};
15use revm::{
16 context::DBErrorMarker,
17 state::{AccountInfo, Bytecode},
18 DatabaseRef,
19};
20use thiserror::Error;
21use tracing::{debug, info};
22use tycho_client::feed::BlockHeader;
23
24use super::{
25 super::account_storage::{AccountStorage, StateUpdate},
26 engine_db_interface::EngineDatabaseInterface,
27};
28
29pub struct OverriddenSimulationDB<'a, DB: DatabaseRef> {
31 pub inner_db: &'a DB,
33 pub overrides: &'a HashMap<Address, HashMap<U256, U256>>,
36}
37
38impl<'a, DB: DatabaseRef> OverriddenSimulationDB<'a, DB> {
39 pub fn new(inner_db: &'a DB, overrides: &'a HashMap<Address, HashMap<U256, U256>>) -> Self {
50 OverriddenSimulationDB { inner_db, overrides }
51 }
52}
53
54impl<DB: DatabaseRef> DatabaseRef for OverriddenSimulationDB<'_, DB> {
55 type Error = DB::Error;
56
57 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
58 self.inner_db.basic_ref(address)
59 }
60
61 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
62 self.inner_db
63 .code_by_hash_ref(code_hash)
64 }
65
66 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
67 match self.overrides.get(&address) {
68 None => self
69 .inner_db
70 .storage_ref(address, index),
71 Some(slot_overrides) => match slot_overrides.get(&index) {
72 Some(value) => {
73 debug!(%address, %index, %value, "Requested storage of account {:x?} slot {}", address, index);
74 Ok(*value)
75 }
76 None => self
77 .inner_db
78 .storage_ref(address, index),
79 },
80 }
81 }
82
83 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
84 self.inner_db.block_hash_ref(number)
85 }
86}
87
88#[derive(Clone, Debug)]
90pub struct SimulationDB<P: Provider + Debug> {
91 client: Arc<P>,
93 account_storage: Arc<RwLock<AccountStorage>>,
95 block: Option<BlockHeader>,
97 pub runtime: Option<Arc<tokio::runtime::Runtime>>,
99}
100
101pub type EVMProvider = FillProvider<
102 JoinFill<
103 alloy::providers::Identity,
104 JoinFill<GasFiller, JoinFill<BlobGasFiller, JoinFill<NonceFiller, ChainIdFiller>>>,
105 >,
106 RootProvider,
107>;
108
109impl<P: Provider + Debug + 'static> SimulationDB<P> {
110 pub fn new(
111 client: Arc<P>,
112 runtime: Option<Arc<tokio::runtime::Runtime>>,
113 block: Option<BlockHeader>,
114 ) -> Self {
115 Self {
116 client,
117 account_storage: Arc::new(RwLock::new(AccountStorage::new())),
118 block,
119 runtime,
120 }
121 }
122
123 pub fn set_block(&mut self, block: Option<BlockHeader>) {
125 self.block = block;
126 }
127
128 pub fn update_state(
140 &mut self,
141 updates: &HashMap<Address, StateUpdate>,
142 block: BlockHeader,
143 ) -> Result<HashMap<Address, StateUpdate>, SimulationDBError> {
144 info!("Received account state update.");
145 let mut revert_updates = HashMap::new();
146 self.block = Some(block);
147 for (address, update_info) in updates.iter() {
148 let mut revert_entry = StateUpdate::default();
149 if let Some(current_account) = self
150 .read_account_storage()?
151 .get_account_info(address)
152 {
153 revert_entry.balance = Some(current_account.balance);
154 }
155 if let Some(storage_updates) = update_info.storage.as_ref() {
156 let mut revert_storage = HashMap::default();
157 for index in storage_updates.keys() {
158 if let Some(s) = self
159 .read_account_storage()?
160 .get_permanent_storage(address, index)
161 {
162 revert_storage.insert(*index, s);
163 }
164 }
165 revert_entry.storage = Some(revert_storage);
166 }
167 revert_updates.insert(*address, revert_entry);
168
169 self.write_account_storage()?
170 .update_account(address, update_info);
171 }
172 Ok(revert_updates)
173 }
174
175 fn query_account_info(
187 &self,
188 address: Address,
189 ) -> Result<AccountInfo, <SimulationDB<P> as DatabaseRef>::Error> {
190 debug!("Querying account info of {:x?} at block {:?}", address, self.block);
191
192 let (balance, nonce, code) = self.block_on(async {
193 let mut balance_request = self.client.get_balance(address);
194 let mut nonce_request = self
195 .client
196 .get_transaction_count(address);
197 let mut code_request = self.client.get_code_at(address);
198
199 if let Some(block) = &self.block {
200 balance_request = balance_request.number(block.number);
201 nonce_request = nonce_request.number(block.number);
202 code_request = code_request.number(block.number);
203 }
204
205 tokio::join!(balance_request, nonce_request, code_request,)
206 });
207 let code = Bytecode::new_raw(AlloyBytes::copy_from_slice(&code?));
208
209 Ok(AccountInfo::new(balance?, nonce?, code.hash_slow(), code))
210 }
211
212 pub fn query_storage(
224 &self,
225 address: Address,
226 index: U256,
227 ) -> Result<StorageValue, <SimulationDB<P> as DatabaseRef>::Error> {
228 let mut request = self
229 .client
230 .get_storage_at(address, index);
231 if let Some(block) = &self.block {
232 request = request.number(block.number);
233 }
234
235 let storage_future = async move {
236 request.await.map_err(|err| {
237 SimulationDBError::SimulationError(format!(
238 "Failed to fetch storage for {address:?} slot {index}: {err}"
239 ))
240 })
241 };
242
243 self.block_on(storage_future)
244 }
245
246 fn read_account_storage(
247 &self,
248 ) -> Result<RwLockReadGuard<'_, AccountStorage>, SimulationDBError> {
249 self.account_storage
250 .read()
251 .map_err(|_| SimulationDBError::Internal("Account storage read lock poisoned".into()))
252 }
253
254 fn write_account_storage(
255 &self,
256 ) -> Result<RwLockWriteGuard<'_, AccountStorage>, SimulationDBError> {
257 self.account_storage
258 .write()
259 .map_err(|_| SimulationDBError::Internal("Account storage write lock poisoned".into()))
260 }
261
262 fn block_on<F: core::future::Future>(&self, f: F) -> F::Output {
263 match &self.runtime {
267 Some(runtime) => runtime.block_on(f),
268 None => futures::executor::block_on(f),
269 }
270 }
271}
272
273impl<P: Provider + Debug> EngineDatabaseInterface for SimulationDB<P>
274where
275 P: Provider + Send + Sync + 'static,
276{
277 type Error = SimulationDBError;
278
279 fn init_account(
293 &self,
294 address: Address,
295 mut account: AccountInfo,
296 permanent_storage: Option<HashMap<U256, U256>>,
297 mocked: bool,
298 ) -> Result<(), <Self as EngineDatabaseInterface>::Error> {
299 if let Some(code) = account.code.clone() {
300 account.code = Some(code);
301 }
302
303 self.write_account_storage()?
304 .init_account(address, account, permanent_storage, mocked);
305
306 Ok(())
307 }
308
309 fn clear_temp_storage(&mut self) -> Result<(), <Self as EngineDatabaseInterface>::Error> {
314 self.write_account_storage()?
315 .clear_temp_storage();
316
317 Ok(())
318 }
319
320 fn get_current_block(&self) -> Option<BlockHeader> {
321 self.block.clone()
322 }
323}
324
325#[derive(Error, Debug)]
326pub enum SimulationDBError {
327 #[error("Simulation error: {0} ")]
328 SimulationError(String),
329 #[error("Not implemented error: {0}")]
330 NotImplementedError(String),
331 #[error("Simulation DB internal error: {0}")]
332 Internal(String),
333}
334
335impl DBErrorMarker for SimulationDBError {}
336
337impl From<RpcError<TransportErrorKind>> for SimulationDBError {
338 fn from(err: RpcError<TransportErrorKind>) -> Self {
339 SimulationDBError::SimulationError(err.to_string())
340 }
341}
342
343impl<P: Provider> DatabaseRef for SimulationDB<P>
344where
345 P: Provider + Debug + Send + Sync + 'static,
346{
347 type Error = SimulationDBError;
348
349 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
380 if let Some(account) = {
381 self.read_account_storage()?
382 .get_account_info(&address)
383 .cloned()
384 } {
385 return Ok(Some(account));
386 }
387 let account_info = self.query_account_info(address)?;
388 self.init_account(address, account_info.clone(), None, false)?;
389 Ok(Some(account_info))
390 }
391
392 fn code_by_hash_ref(&self, _code_hash: B256) -> Result<Bytecode, Self::Error> {
393 Err(SimulationDBError::NotImplementedError(
394 "Code by hash is not implemented in SimulationDB".to_string(),
395 ))
396 }
397
398 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
433 debug!("Requested storage of account {:x?} slot {}", address, index);
434 let (is_mocked, local_value) = {
435 let account_storage = self.read_account_storage()?;
436 (
437 account_storage.is_mocked_account(&address),
438 account_storage.get_storage(&address, &index),
439 )
440 };
441
442 if let Some(storage_value) = local_value {
443 debug!(
444 "Got value locally. This is a {} account. Value: {}",
445 if is_mocked.unwrap_or(false) { "mocked" } else { "non-mocked" },
446 storage_value
447 );
448 return Ok(storage_value);
449 }
450
451 match is_mocked {
453 Some(true) => {
454 debug!("This is a mocked account for which we don't have data. Returning zero.");
455 Ok(U256::ZERO)
456 }
457 Some(false) => {
458 let storage_value = self.query_storage(address, index)?;
459 self.write_account_storage()?
460 .set_temp_storage(address, index, storage_value);
461 debug!(
462 "This is a non-mocked account for which we didn't have data. Fetched value: {}",
463 storage_value
464 );
465 Ok(storage_value)
466 }
467 None => {
468 let account_info = self.query_account_info(address)?;
469 let storage_value = self.query_storage(address, index)?;
470 self.init_account(address, account_info, None, false)?;
471 self.write_account_storage()?
472 .set_temp_storage(address, index, storage_value);
473 debug!("This is non-initialised account. Fetched value: {}", storage_value);
474 Ok(storage_value)
475 }
476 }
477 }
478
479 fn block_hash_ref(&self, _number: u64) -> Result<B256, Self::Error> {
482 match &self.block {
483 Some(header) => Ok(B256::from_slice(&header.hash)),
484 None => Ok(B256::ZERO),
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use std::{error::Error, str::FromStr};
492
493 use alloy::primitives::U160;
494 use rstest::rstest;
495 use tycho_common::Bytes;
496
497 use super::*;
498 use crate::evm::engine_db::utils::{get_client, get_runtime};
499
500 #[rstest]
501 fn test_query_storage_latest_block() -> Result<(), Box<dyn Error>> {
502 let db = SimulationDB::new(
503 get_client(None).expect("Failed to create test client"),
504 get_runtime().expect("Failed to create test runtime"),
505 None,
506 );
507 let address = Address::from_str("0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc")?;
508 let index = U256::from_limbs_slice(&[8]);
509 db.init_account(address, AccountInfo::default(), None, false)
510 .expect("Failed to init account");
511
512 db.query_storage(address, index)
513 .unwrap();
514
515 Ok(())
519 }
520
521 #[rstest]
522 fn test_query_account_info() {
523 let mut db = SimulationDB::new(
524 get_client(None).expect("Failed to create test client"),
525 get_runtime().expect("Failed to create test runtime"),
526 None,
527 );
528 let block = BlockHeader {
529 number: 20308186,
530 hash: Bytes::from_str(
531 "0x61c51e3640b02ae58a03201be0271e84e02dac8a4826501995cbe4da24174b52",
532 )
533 .unwrap(),
534 timestamp: 234,
535 ..Default::default()
536 };
537 db.set_block(Some(block));
538 let address = Address::from_str("0x168b93113fe5902c87afaecE348581A1481d0f93").unwrap();
539 db.init_account(address, AccountInfo::default(), None, false)
540 .expect("Failed to init account");
541
542 let account_info = db.query_account_info(address).unwrap();
543
544 assert_eq!(account_info.balance, U256::from_str("6246978663692389").unwrap());
545 assert_eq!(account_info.nonce, 17);
546 }
547
548 #[rstest]
549 fn test_mock_account_get_acc_info() {
550 let db = SimulationDB::new(
551 get_client(None).expect("Failed to create test client"),
552 get_runtime().expect("Failed to create test runtime"),
553 None,
554 );
555 let mock_acc_address =
556 Address::from_str("0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc").unwrap();
557 db.init_account(mock_acc_address, AccountInfo::default(), None, true)
558 .expect("Failed to init account");
559
560 let acc_info = db
561 .basic_ref(mock_acc_address)
562 .unwrap()
563 .unwrap();
564
565 assert_eq!(
566 db.account_storage
567 .read()
568 .unwrap()
569 .get_account_info(&mock_acc_address)
570 .unwrap(),
571 &acc_info
572 );
573 }
574
575 #[rstest]
576 fn test_mock_account_get_storage() {
577 let db = SimulationDB::new(
578 get_client(None).expect("Failed to create test client"),
579 get_runtime().expect("Failed to create test runtime"),
580 None,
581 );
582 let mock_acc_address =
583 Address::from_str("0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc").unwrap();
584 let storage_address = U256::ZERO;
585 db.init_account(mock_acc_address, AccountInfo::default(), None, true)
586 .expect("Failed to init account");
587
588 let storage = db
589 .storage_ref(mock_acc_address, storage_address)
590 .unwrap();
591
592 assert_eq!(storage, U256::ZERO);
593 }
594
595 #[rstest]
596 fn test_update_state() {
597 let mut db = SimulationDB::new(
598 get_client(None).expect("Failed to create test client"),
599 get_runtime().expect("Failed to create test runtime"),
600 None,
601 );
602 let address = Address::from_str("0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc").unwrap();
603 db.init_account(address, AccountInfo::default(), None, false)
604 .expect("Failed to init account");
605
606 let mut new_storage = HashMap::default();
607 let new_storage_value_index = U256::from_limbs_slice(&[123]);
608 new_storage.insert(new_storage_value_index, new_storage_value_index);
609 let new_balance = U256::from_limbs_slice(&[500]);
610 let update = StateUpdate { storage: Some(new_storage), balance: Some(new_balance) };
611 let mut updates = HashMap::default();
612 updates.insert(address, update);
613 let new_block = BlockHeader { number: 1, timestamp: 234, ..Default::default() };
614
615 let reverse_update = db
616 .update_state(&updates, new_block)
617 .expect("State update should succeed");
618
619 assert_eq!(
620 db.account_storage
621 .read()
622 .expect("Storage entry should exist")
623 .get_storage(&address, &new_storage_value_index)
624 .unwrap(),
625 new_storage_value_index
626 );
627 assert_eq!(
628 db.account_storage
629 .read()
630 .unwrap()
631 .get_account_info(&address)
632 .unwrap()
633 .balance,
634 new_balance
635 );
636 assert_eq!(db.block.unwrap().number, 1);
637
638 assert_eq!(
639 reverse_update
640 .get(&address)
641 .unwrap()
642 .balance
643 .unwrap(),
644 AccountInfo::default().balance
645 );
646 assert_eq!(
647 reverse_update
648 .get(&address)
649 .unwrap()
650 .storage,
651 Some(HashMap::default())
652 );
653 }
654
655 #[rstest]
656 fn test_overridden_db() {
657 let db = SimulationDB::new(
658 get_client(None).expect("Failed to create test client"),
659 get_runtime().expect("Failed to create test runtime"),
660 None,
661 );
662 let slot1 = U256::from_limbs_slice(&[1]);
663 let slot2 = U256::from_limbs_slice(&[2]);
664 let orig_value1 = U256::from_limbs_slice(&[100]);
665 let orig_value2 = U256::from_limbs_slice(&[200]);
666 let original_storage: HashMap<U256, U256> = [(slot1, orig_value1), (slot2, orig_value2)]
667 .iter()
668 .cloned()
669 .collect();
670
671 let address1 = Address::from(U160::from(1));
672 let address2 = Address::from(U160::from(2));
673 let address3 = Address::from(U160::from(3));
674
675 db.init_account(address1, AccountInfo::default(), Some(original_storage.clone()), false)
678 .expect("Failed to init account");
679 db.init_account(address2, AccountInfo::default(), Some(original_storage), false)
680 .expect("Failed to init account");
681
682 let overridden_value1 = U256::from_limbs_slice(&[101]);
683 let mut overrides: HashMap<Address, HashMap<U256, U256>> = HashMap::new();
684 overrides.insert(
685 address2,
686 [(slot1, overridden_value1)]
687 .iter()
688 .cloned()
689 .collect(),
690 );
691 overrides.insert(
692 address3,
693 [(slot1, overridden_value1)]
694 .iter()
695 .cloned()
696 .collect(),
697 );
698
699 let overriden_db = OverriddenSimulationDB::new(&db, &overrides);
700
701 assert_eq!(
702 overriden_db
703 .storage_ref(address1, slot1)
704 .expect("Value should be available"),
705 orig_value1,
706 "Slots of non-overridden account should hold original values."
707 );
708
709 assert_eq!(
710 overriden_db
711 .storage_ref(address1, slot2)
712 .expect("Value should be available"),
713 orig_value2,
714 "Slots of non-overridden account should hold original values."
715 );
716
717 assert_eq!(
718 overriden_db
719 .storage_ref(address2, slot1)
720 .expect("Value should be available"),
721 overridden_value1,
722 "Overridden slot of overridden account should hold an overridden value."
723 );
724
725 assert_eq!(
726 overriden_db
727 .storage_ref(address2, slot2)
728 .expect("Value should be available"),
729 orig_value2,
730 "Non-overridden slot of an account with other slots overridden \
731 should hold an original value."
732 );
733
734 assert_eq!(
735 overriden_db
736 .storage_ref(address3, slot1)
737 .expect("Value should be available"),
738 overridden_value1,
739 "Overridden slot of an overridden non-existent account should hold an overriden value."
740 );
741 }
742}