Skip to main content

tycho_simulation/evm/protocol/vm/
state_builder.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::Debug,
4};
5
6use alloy::{
7    primitives::{Address, Bytes, Keccak256, U256},
8    sol_types::SolValue,
9};
10use itertools::Itertools;
11use revm::{
12    primitives::KECCAK_EMPTY,
13    state::{AccountInfo, Bytecode},
14    DatabaseRef,
15};
16use tracing::warn;
17use tycho_common::{simulation::errors::SimulationError, Bytes as TychoBytes};
18
19use super::{
20    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
21    models::Capability,
22    state::EVMPoolState,
23    tycho_simulation_contract::TychoSimulationContract,
24    utils::get_code_for_contract,
25};
26use crate::evm::{
27    engine_db::{create_engine, engine_db_interface::EngineDatabaseInterface},
28    protocol::utils::bytes_to_address,
29    simulation::{BlockEnvOverrides, SimulationEngine, SimulationParameters},
30};
31
32#[derive(Debug)]
33/// `EVMPoolStateBuilder` is a builder pattern implementation for creating instances of
34/// `EVMPoolState`.
35///
36/// This struct provides a flexible way to construct `EVMPoolState` objects with
37/// multiple optional parameters. It handles the validation of required fields and applies default
38/// values for optional parameters where necessary.
39/// # Example
40/// Constructing a `EVMPoolState` with only the required parameters:
41/// ```rust
42/// use alloy::primitives::Address;
43/// use std::path::PathBuf;
44/// use tycho_common::Bytes;
45/// use tycho_simulation::evm::engine_db::SHARED_TYCHO_DB;
46/// use tycho_simulation::evm::protocol::vm::state_builder::EVMPoolStateBuilder;
47/// use tycho_simulation::evm::protocol::vm::constants::BALANCER_V2;
48/// /// use tycho_common::simulation::errors::SimulationError;
49/// use revm::state::Bytecode;
50///
51/// #[tokio::main]
52/// async fn main() -> Result<(), tycho_common::simulation::errors::SimulationError> {
53///     use tycho_client::feed::BlockHeader;
54///
55///     let pool_id: String = "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
56///
57///     let tokens = vec![
58///         Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f"),
59///         Bytes::from("0xba100000625a3754423978a60c9317c58a424e3d"),
60///     ];
61///
62///     // Set up the block for the database
63///     let block = BlockHeader {
64///         number: 1,
65///         hash: Default::default(),
66///         timestamp: 1632456789,
67///         ..Default::default()
68///     };
69///     SHARED_TYCHO_DB.update(vec![], Some(block)).unwrap();
70///
71///     // Build the EVMPoolState
72///     let pool_state = EVMPoolStateBuilder::new(pool_id, tokens, Address::random())
73///         .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
74///         .build(SHARED_TYCHO_DB.clone())
75///         .await?;
76///     Ok(())
77/// }
78/// ```
79pub struct EVMPoolStateBuilder<D: EngineDatabaseInterface + Clone + Debug>
80where
81    <D as DatabaseRef>::Error: Debug,
82    <D as EngineDatabaseInterface>::Error: Debug,
83{
84    id: String,
85    tokens: Vec<TychoBytes>,
86    balances: HashMap<Address, U256>,
87    adapter_address: Address,
88    balance_owner: Option<Address>,
89    capabilities: Option<HashSet<Capability>>,
90    involved_contracts: Option<HashSet<Address>>,
91    contract_balances: HashMap<Address, HashMap<Address, U256>>,
92    stateless_contracts: Option<HashMap<String, Option<Vec<u8>>>>,
93    manual_updates: Option<bool>,
94    trace: Option<bool>,
95    engine: Option<SimulationEngine<D>>,
96    adapter_contract: Option<TychoSimulationContract<D>>,
97    adapter_contract_bytecode: Option<Bytecode>,
98    disable_overwrite_tokens: HashSet<Address>,
99    block_overrides: Option<BlockEnvOverrides>,
100}
101
102impl<D> EVMPoolStateBuilder<D>
103where
104    D: EngineDatabaseInterface + Clone + Debug + 'static,
105    <D as DatabaseRef>::Error: Debug,
106    <D as EngineDatabaseInterface>::Error: Debug,
107{
108    pub fn new(id: String, tokens: Vec<TychoBytes>, adapter_address: Address) -> Self {
109        Self {
110            id,
111            tokens,
112            balances: HashMap::new(),
113            adapter_address,
114            balance_owner: None,
115            capabilities: None,
116            involved_contracts: None,
117            contract_balances: HashMap::new(),
118            stateless_contracts: None,
119            manual_updates: None,
120            trace: None,
121            engine: None,
122            adapter_contract: None,
123            adapter_contract_bytecode: None,
124            disable_overwrite_tokens: HashSet::new(),
125            block_overrides: None,
126        }
127    }
128
129    #[deprecated(note = "Use account balances instead")]
130    pub fn balance_owner(mut self, balance_owner: Address) -> Self {
131        self.balance_owner = Some(balance_owner);
132        self
133    }
134
135    /// Set component balances. This balance belongs to the 'balance_owner' if one is set,
136    /// otherwise it belongs to the pool itself.
137    pub fn balances(mut self, balances: HashMap<Address, U256>) -> Self {
138        self.balances = balances;
139        self
140    }
141
142    /// Set contract balances
143    pub fn account_balances(
144        mut self,
145        account_balances: HashMap<Address, HashMap<Address, U256>>,
146    ) -> Self {
147        self.contract_balances = account_balances;
148        self
149    }
150
151    pub fn capabilities(mut self, capabilities: HashSet<Capability>) -> Self {
152        self.capabilities = Some(capabilities);
153        self
154    }
155
156    pub fn involved_contracts(mut self, involved_contracts: HashSet<Address>) -> Self {
157        self.involved_contracts = Some(involved_contracts);
158        self
159    }
160
161    pub fn stateless_contracts(
162        mut self,
163        stateless_contracts: HashMap<String, Option<Vec<u8>>>,
164    ) -> Self {
165        self.stateless_contracts = Some(stateless_contracts);
166        self
167    }
168    pub fn manual_updates(mut self, manual_updates: bool) -> Self {
169        self.manual_updates = Some(manual_updates);
170        self
171    }
172
173    pub fn trace(mut self, trace: bool) -> Self {
174        self.trace = Some(trace);
175        self
176    }
177
178    pub fn engine(mut self, engine: SimulationEngine<D>) -> Self {
179        self.engine = Some(engine);
180        self
181    }
182
183    pub fn adapter_contract(mut self, adapter_contract: TychoSimulationContract<D>) -> Self {
184        self.adapter_contract = Some(adapter_contract);
185        self
186    }
187
188    pub fn adapter_contract_bytecode(mut self, adapter_contract_bytecode: Bytecode) -> Self {
189        self.adapter_contract_bytecode = Some(adapter_contract_bytecode);
190        self
191    }
192
193    pub fn disable_overwrite_tokens(mut self, disable_overwrite_tokens: HashSet<Address>) -> Self {
194        self.disable_overwrite_tokens = disable_overwrite_tokens;
195        self
196    }
197
198    pub fn block_overrides(mut self, block_overrides: Option<BlockEnvOverrides>) -> Self {
199        self.block_overrides = block_overrides;
200        self
201    }
202
203    /// Build the final EVMPoolState object
204    pub async fn build(mut self, db: D) -> Result<EVMPoolState<D>, SimulationError> {
205        let engine = if let Some(engine) = &self.engine {
206            engine.clone()
207        } else {
208            self.engine = Some(self.get_default_engine(db).await?);
209            self.engine.clone().ok_or_else(|| {
210                SimulationError::FatalError(
211                    "Failed to get build engine: Engine not initialized".to_string(),
212                )
213            })?
214        };
215
216        if self.adapter_contract.is_none() {
217            self.adapter_contract = Some(TychoSimulationContract::new_contract(
218                self.adapter_address,
219                self.adapter_contract_bytecode
220                    .clone()
221                    .ok_or_else(|| {
222                        SimulationError::FatalError("Adapter contract bytecode not set".to_string())
223                    })?,
224                engine.clone(),
225            )?)
226        };
227
228        let capabilities = if let Some(capabilities) = &self.capabilities {
229            capabilities.clone()
230        } else {
231            self.get_default_capabilities()?
232        };
233
234        let adapter_contract = self.adapter_contract.ok_or_else(|| {
235            SimulationError::FatalError(
236                "Failed to get build engine: Adapter contract not initialized".to_string(),
237            )
238        })?;
239
240        Ok(EVMPoolState::new(
241            self.id,
242            self.tokens,
243            self.balances,
244            self.balance_owner,
245            self.contract_balances,
246            HashMap::new(),
247            capabilities,
248            HashMap::new(),
249            self.involved_contracts
250                .unwrap_or_default(),
251            self.manual_updates.unwrap_or(false),
252            adapter_contract,
253            self.disable_overwrite_tokens,
254            self.block_overrides,
255        ))
256    }
257
258    async fn get_default_engine(&self, db: D) -> Result<SimulationEngine<D>, SimulationError> {
259        let engine = create_engine(db, self.trace.unwrap_or(false))?;
260
261        engine
262            .state
263            .init_account(
264                *EXTERNAL_ACCOUNT,
265                AccountInfo {
266                    balance: *MAX_BALANCE,
267                    nonce: 0,
268                    code_hash: KECCAK_EMPTY,
269                    code: None,
270                },
271                None,
272                false,
273            )
274            .map_err(|err| {
275                SimulationError::FatalError(format!(
276                    "Failed to get default engine: Failed to init external account: {err:?}"
277                ))
278            })?;
279
280        if let Some(stateless_contracts) = &self.stateless_contracts {
281            for (address, bytecode) in stateless_contracts.iter() {
282                let mut addr_str = address.clone();
283                let (code, code_hash) = if bytecode.is_none() {
284                    if addr_str.starts_with("call") {
285                        addr_str = self
286                            .get_address_from_call(&engine, &addr_str)?
287                            .to_string();
288                    }
289                    let code = get_code_for_contract(&addr_str, None).await?;
290                    (Some(code.clone()), code.hash_slow())
291                } else {
292                    let code =
293                        Bytecode::new_raw(Bytes::from(bytecode.clone().ok_or_else(|| {
294                            SimulationError::FatalError(
295                                "Failed to get default engine: Byte code from stateless contracts is None".into(),
296                            )
297                        })?));
298                    (Some(code.clone()), code.hash_slow())
299                };
300                let account_address: Address = addr_str.parse().map_err(|_| {
301                    SimulationError::FatalError(format!(
302                        "Failed to get default engine: Couldn't parse address string {address}"
303                    ))
304                })?;
305                engine.state.init_account(
306                    Address(*account_address),
307                    AccountInfo { balance: Default::default(), nonce: 0, code_hash, code },
308                    None,
309                    false,
310                ).map_err(|err| {
311                    SimulationError::FatalError(format!(
312                        "Failed to get default engine: Failed to init stateless contract account: {err:?}"
313                    ))
314                })?;
315            }
316        }
317        Ok(engine)
318    }
319
320    fn get_default_capabilities(&mut self) -> Result<HashSet<Capability>, SimulationError> {
321        let mut capabilities = Vec::new();
322
323        // Generate all permutations of tokens and retrieve capabilities
324        for tokens_pair in self.tokens.iter().permutations(2) {
325            // Manually unpack the inner vector
326            if let [t0, t1] = tokens_pair[..] {
327                let caps = self
328                    .adapter_contract
329                    .clone()
330                    .ok_or_else(|| {
331                        SimulationError::FatalError(
332                            "Failed to get default capabilities: Adapter contract not initialized"
333                                .to_string(),
334                        )
335                    })?
336                    .get_capabilities(&self.id, bytes_to_address(t0)?, bytes_to_address(t1)?)?;
337                capabilities.push(caps);
338            }
339        }
340
341        // Find the maximum capabilities length
342        let max_capabilities = capabilities
343            .iter()
344            .map(|c| c.len())
345            .max()
346            .unwrap_or(0);
347
348        // Intersect all capability sets
349        let common_capabilities: HashSet<_> = capabilities
350            .iter()
351            .fold(capabilities[0].clone(), |acc, cap| acc.intersection(cap).cloned().collect());
352
353        // Check for mismatches in capabilities
354        if common_capabilities.len() < max_capabilities {
355            warn!(
356                "Warning: Pool {} has different capabilities depending on the token pair!",
357                self.id
358            );
359        }
360        Ok(common_capabilities)
361    }
362
363    /// Gets the address of the code - mostly used for dynamic proxy implementations. For example,
364    /// some protocols have some dynamic math implementation that is given by the factory. When
365    /// we swap on the pools for such protocols, it will call the factory to get the implementation
366    /// and use it for the swap.
367    /// This method simulates the call to the pool, which gives us the address of the
368    /// implementation.
369    ///
370    /// # See Also
371    /// [Dynamic Address Resolution Example](https://github.com/propeller-heads/propeller-protocol-lib/blob/main/docs/indexing/reserved-attributes.md#description-2)
372    fn get_address_from_call(
373        &self,
374        engine: &SimulationEngine<D>,
375        decoded: &str,
376    ) -> Result<Address, SimulationError> {
377        let method_name = decoded
378            .split(':')
379            .next_back()
380            .ok_or_else(|| {
381                SimulationError::FatalError(
382                    "Failed to get address from call: Could not decode method name from call"
383                        .into(),
384                )
385            })?;
386
387        let selector = {
388            let mut hasher = Keccak256::new();
389            hasher.update(method_name.as_bytes());
390            let result = hasher.finalize();
391            result[..4].to_vec()
392        };
393
394        let to_address = decoded
395            .split(':')
396            .nth(1)
397            .ok_or_else(|| {
398                SimulationError::FatalError(
399                    "Failed to get address from call: Could not decode to_address from call".into(),
400                )
401            })?;
402
403        let parsed_address: Address = to_address.parse().map_err(|_| {
404            SimulationError::FatalError(format!(
405                "Failed to get address from call: Invalid address format: {to_address}"
406            ))
407        })?;
408
409        let sim_params = SimulationParameters {
410            data: selector.to_vec(),
411            to: parsed_address,
412            overrides: Some(HashMap::new()),
413            caller: *EXTERNAL_ACCOUNT,
414            value: U256::from(0u64),
415            gas_limit: None,
416            transient_storage: None,
417            block_overrides: None,
418        };
419
420        let sim_result = engine
421            .simulate(&sim_params)
422            .map_err(|err| SimulationError::FatalError(err.to_string()))?;
423
424        let address: Address = Address::abi_decode(&sim_result.result).map_err(|e| {
425            SimulationError::FatalError(format!("Failed to get address from call: Failed to decode address list from simulation result {e:?}"))
426        })?;
427
428        Ok(address)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use std::str::FromStr;
435
436    use super::*;
437    use crate::evm::engine_db::{tycho_db::PreCachedDB, SHARED_TYCHO_DB};
438
439    #[test]
440    fn test_build_without_required_fields() {
441        let id = "pool_1".to_string();
442        let tokens =
443            vec![TychoBytes::from_str("0000000000000000000000000000000000000000").unwrap()];
444        let balances = HashMap::new();
445        let adapter_address =
446            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
447        let result = tokio_test::block_on(
448            EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address)
449                .balances(balances)
450                .build(SHARED_TYCHO_DB.clone()),
451        );
452
453        assert!(result.is_err());
454        match result.unwrap_err() {
455            SimulationError::FatalError(field) => {
456                assert_eq!(field, "Adapter contract bytecode not set")
457            }
458            _ => panic!("Unexpected error type"),
459        }
460    }
461
462    #[test]
463    fn test_engine_setup() {
464        let id = "pool_1".to_string();
465        let token2 = TychoBytes::from_str("0000000000000000000000000000000000000002").unwrap();
466        let token3 = TychoBytes::from_str("0000000000000000000000000000000000000003").unwrap();
467        let tokens = vec![token2.clone(), token3.clone()];
468        let balances = HashMap::new();
469        let adapter_address =
470            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
471        let builder =
472            EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address).balances(balances);
473
474        let engine = tokio_test::block_on(builder.get_default_engine(SHARED_TYCHO_DB.clone()));
475
476        assert!(engine.is_ok());
477        let engine = engine.unwrap();
478        assert!(engine
479            .state
480            .get_account_storage()
481            .expect("Failed to get account storage")
482            .account_present(&EXTERNAL_ACCOUNT));
483    }
484}