tycho_simulation/evm/protocol/vm/
state_builder.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::Debug,
4};
5
6use alloy::{
7    primitives::{Address, Bytes, Keccak256, U256},
8    sol_types::SolValue,
9};
10use itertools::Itertools;
11use revm::{
12    primitives::KECCAK_EMPTY,
13    state::{AccountInfo, Bytecode},
14    DatabaseRef,
15};
16use tracing::warn;
17use tycho_common::{simulation::errors::SimulationError, Bytes as TychoBytes};
18
19use super::{
20    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
21    models::Capability,
22    state::EVMPoolState,
23    tycho_simulation_contract::TychoSimulationContract,
24    utils::get_code_for_contract,
25};
26use crate::evm::{
27    engine_db::{create_engine, engine_db_interface::EngineDatabaseInterface},
28    protocol::utils::bytes_to_address,
29    simulation::{SimulationEngine, SimulationParameters},
30};
31
32#[derive(Debug)]
33/// `EVMPoolStateBuilder` is a builder pattern implementation for creating instances of
34/// `EVMPoolState`.
35///
36/// This struct provides a flexible way to construct `EVMPoolState` objects with
37/// multiple optional parameters. It handles the validation of required fields and applies default
38/// values for optional parameters where necessary.
39/// # Example
40/// Constructing a `EVMPoolState` with only the required parameters:
41/// ```rust
42/// use alloy::primitives::Address;
43/// use std::path::PathBuf;
44/// use tycho_common::Bytes;
45/// use tycho_simulation::evm::engine_db::SHARED_TYCHO_DB;
46/// use tycho_simulation::evm::protocol::vm::state_builder::EVMPoolStateBuilder;
47/// use tycho_simulation::evm::protocol::vm::constants::BALANCER_V2;
48/// /// use tycho_common::simulation::errors::SimulationError;
49/// use revm::state::Bytecode;
50///
51/// #[tokio::main]
52/// async fn main() -> Result<(), tycho_common::simulation::errors::SimulationError> {
53///     use tycho_client::feed::BlockHeader;
54///
55///     let pool_id: String = "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
56///
57///     let tokens = vec![
58///         Bytes::from("0x6b175474e89094c44da98b954eedeac495271d0f"),
59///         Bytes::from("0xba100000625a3754423978a60c9317c58a424e3d"),
60///     ];
61///
62///     // Set up the block for the database
63///     let block = BlockHeader {
64///         number: 1,
65///         hash: Default::default(),
66///         timestamp: 1632456789,
67///         ..Default::default()
68///     };
69///     SHARED_TYCHO_DB.update(vec![], Some(block)).unwrap();
70///
71///     // Build the EVMPoolState
72///     let pool_state = EVMPoolStateBuilder::new(pool_id, tokens, Address::random())
73///         .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
74///         .build(SHARED_TYCHO_DB.clone())
75///         .await?;
76///     Ok(())
77/// }
78/// ```
79pub struct EVMPoolStateBuilder<D: EngineDatabaseInterface + Clone + Debug>
80where
81    <D as DatabaseRef>::Error: Debug,
82    <D as EngineDatabaseInterface>::Error: Debug,
83{
84    id: String,
85    tokens: Vec<TychoBytes>,
86    balances: HashMap<Address, U256>,
87    adapter_address: Address,
88    balance_owner: Option<Address>,
89    capabilities: Option<HashSet<Capability>>,
90    involved_contracts: Option<HashSet<Address>>,
91    contract_balances: HashMap<Address, HashMap<Address, U256>>,
92    stateless_contracts: Option<HashMap<String, Option<Vec<u8>>>>,
93    manual_updates: Option<bool>,
94    trace: Option<bool>,
95    engine: Option<SimulationEngine<D>>,
96    adapter_contract: Option<TychoSimulationContract<D>>,
97    adapter_contract_bytecode: Option<Bytecode>,
98    disable_overwrite_tokens: HashSet<Address>,
99}
100
101impl<D> EVMPoolStateBuilder<D>
102where
103    D: EngineDatabaseInterface + Clone + Debug + 'static,
104    <D as DatabaseRef>::Error: Debug,
105    <D as EngineDatabaseInterface>::Error: Debug,
106{
107    pub fn new(id: String, tokens: Vec<TychoBytes>, adapter_address: Address) -> Self {
108        Self {
109            id,
110            tokens,
111            balances: HashMap::new(),
112            adapter_address,
113            balance_owner: None,
114            capabilities: None,
115            involved_contracts: None,
116            contract_balances: HashMap::new(),
117            stateless_contracts: None,
118            manual_updates: None,
119            trace: None,
120            engine: None,
121            adapter_contract: None,
122            adapter_contract_bytecode: None,
123            disable_overwrite_tokens: HashSet::new(),
124        }
125    }
126
127    #[deprecated(note = "Use account balances instead")]
128    pub fn balance_owner(mut self, balance_owner: Address) -> Self {
129        self.balance_owner = Some(balance_owner);
130        self
131    }
132
133    /// Set component balances. This balance belongs to the 'balance_owner' if one is set,
134    /// otherwise it belongs to the pool itself.
135    pub fn balances(mut self, balances: HashMap<Address, U256>) -> Self {
136        self.balances = balances;
137        self
138    }
139
140    /// Set contract balances
141    pub fn account_balances(
142        mut self,
143        account_balances: HashMap<Address, HashMap<Address, U256>>,
144    ) -> Self {
145        self.contract_balances = account_balances;
146        self
147    }
148
149    pub fn capabilities(mut self, capabilities: HashSet<Capability>) -> Self {
150        self.capabilities = Some(capabilities);
151        self
152    }
153
154    pub fn involved_contracts(mut self, involved_contracts: HashSet<Address>) -> Self {
155        self.involved_contracts = Some(involved_contracts);
156        self
157    }
158
159    pub fn stateless_contracts(
160        mut self,
161        stateless_contracts: HashMap<String, Option<Vec<u8>>>,
162    ) -> Self {
163        self.stateless_contracts = Some(stateless_contracts);
164        self
165    }
166    pub fn manual_updates(mut self, manual_updates: bool) -> Self {
167        self.manual_updates = Some(manual_updates);
168        self
169    }
170
171    pub fn trace(mut self, trace: bool) -> Self {
172        self.trace = Some(trace);
173        self
174    }
175
176    pub fn engine(mut self, engine: SimulationEngine<D>) -> Self {
177        self.engine = Some(engine);
178        self
179    }
180
181    pub fn adapter_contract(mut self, adapter_contract: TychoSimulationContract<D>) -> Self {
182        self.adapter_contract = Some(adapter_contract);
183        self
184    }
185
186    pub fn adapter_contract_bytecode(mut self, adapter_contract_bytecode: Bytecode) -> Self {
187        self.adapter_contract_bytecode = Some(adapter_contract_bytecode);
188        self
189    }
190
191    pub fn disable_overwrite_tokens(mut self, disable_overwrite_tokens: HashSet<Address>) -> Self {
192        self.disable_overwrite_tokens = disable_overwrite_tokens;
193        self
194    }
195
196    /// Build the final EVMPoolState object
197    pub async fn build(mut self, db: D) -> Result<EVMPoolState<D>, SimulationError> {
198        let engine = if let Some(engine) = &self.engine {
199            engine.clone()
200        } else {
201            self.engine = Some(self.get_default_engine(db).await?);
202            self.engine.clone().ok_or_else(|| {
203                SimulationError::FatalError(
204                    "Failed to get build engine: Engine not initialized".to_string(),
205                )
206            })?
207        };
208
209        if self.adapter_contract.is_none() {
210            self.adapter_contract = Some(TychoSimulationContract::new_contract(
211                self.adapter_address,
212                self.adapter_contract_bytecode
213                    .clone()
214                    .ok_or_else(|| {
215                        SimulationError::FatalError("Adapter contract bytecode not set".to_string())
216                    })?,
217                engine.clone(),
218            )?)
219        };
220
221        let capabilities = if let Some(capabilities) = &self.capabilities {
222            capabilities.clone()
223        } else {
224            self.get_default_capabilities()?
225        };
226
227        let adapter_contract = self.adapter_contract.ok_or_else(|| {
228            SimulationError::FatalError(
229                "Failed to get build engine: Adapter contract not initialized".to_string(),
230            )
231        })?;
232
233        Ok(EVMPoolState::new(
234            self.id,
235            self.tokens,
236            self.balances,
237            self.balance_owner,
238            self.contract_balances,
239            HashMap::new(),
240            capabilities,
241            HashMap::new(),
242            self.involved_contracts
243                .unwrap_or_default(),
244            self.manual_updates.unwrap_or(false),
245            adapter_contract,
246            self.disable_overwrite_tokens,
247        ))
248    }
249
250    async fn get_default_engine(&self, db: D) -> Result<SimulationEngine<D>, SimulationError> {
251        let engine = create_engine(db, self.trace.unwrap_or(false))?;
252
253        engine
254            .state
255            .init_account(
256                *EXTERNAL_ACCOUNT,
257                AccountInfo {
258                    balance: *MAX_BALANCE,
259                    nonce: 0,
260                    code_hash: KECCAK_EMPTY,
261                    code: None,
262                },
263                None,
264                false,
265            )
266            .map_err(|err| {
267                SimulationError::FatalError(format!(
268                    "Failed to get default engine: Failed to init external account: {err:?}"
269                ))
270            })?;
271
272        if let Some(stateless_contracts) = &self.stateless_contracts {
273            for (address, bytecode) in stateless_contracts.iter() {
274                let mut addr_str = address.clone();
275                let (code, code_hash) = if bytecode.is_none() {
276                    if addr_str.starts_with("call") {
277                        addr_str = self
278                            .get_address_from_call(&engine, &addr_str)?
279                            .to_string();
280                    }
281                    let code = get_code_for_contract(&addr_str, None).await?;
282                    (Some(code.clone()), code.hash_slow())
283                } else {
284                    let code =
285                        Bytecode::new_raw(Bytes::from(bytecode.clone().ok_or_else(|| {
286                            SimulationError::FatalError(
287                                "Failed to get default engine: Byte code from stateless contracts is None".into(),
288                            )
289                        })?));
290                    (Some(code.clone()), code.hash_slow())
291                };
292                let account_address: Address = addr_str.parse().map_err(|_| {
293                    SimulationError::FatalError(format!(
294                        "Failed to get default engine: Couldn't parse address string {address}"
295                    ))
296                })?;
297                engine.state.init_account(
298                    Address(*account_address),
299                    AccountInfo { balance: Default::default(), nonce: 0, code_hash, code },
300                    None,
301                    false,
302                ).map_err(|err| {
303                    SimulationError::FatalError(format!(
304                        "Failed to get default engine: Failed to init stateless contract account: {err:?}"
305                    ))
306                })?;
307            }
308        }
309        Ok(engine)
310    }
311
312    fn get_default_capabilities(&mut self) -> Result<HashSet<Capability>, SimulationError> {
313        let mut capabilities = Vec::new();
314
315        // Generate all permutations of tokens and retrieve capabilities
316        for tokens_pair in self.tokens.iter().permutations(2) {
317            // Manually unpack the inner vector
318            if let [t0, t1] = tokens_pair[..] {
319                let caps = self
320                    .adapter_contract
321                    .clone()
322                    .ok_or_else(|| {
323                        SimulationError::FatalError(
324                            "Failed to get default capabilities: Adapter contract not initialized"
325                                .to_string(),
326                        )
327                    })?
328                    .get_capabilities(&self.id, bytes_to_address(t0)?, bytes_to_address(t1)?)?;
329                capabilities.push(caps);
330            }
331        }
332
333        // Find the maximum capabilities length
334        let max_capabilities = capabilities
335            .iter()
336            .map(|c| c.len())
337            .max()
338            .unwrap_or(0);
339
340        // Intersect all capability sets
341        let common_capabilities: HashSet<_> = capabilities
342            .iter()
343            .fold(capabilities[0].clone(), |acc, cap| acc.intersection(cap).cloned().collect());
344
345        // Check for mismatches in capabilities
346        if common_capabilities.len() < max_capabilities {
347            warn!(
348                "Warning: Pool {} has different capabilities depending on the token pair!",
349                self.id
350            );
351        }
352        Ok(common_capabilities)
353    }
354
355    /// Gets the address of the code - mostly used for dynamic proxy implementations. For example,
356    /// some protocols have some dynamic math implementation that is given by the factory. When
357    /// we swap on the pools for such protocols, it will call the factory to get the implementation
358    /// and use it for the swap.
359    /// This method simulates the call to the pool, which gives us the address of the
360    /// implementation.
361    ///
362    /// # See Also
363    /// [Dynamic Address Resolution Example](https://github.com/propeller-heads/propeller-protocol-lib/blob/main/docs/indexing/reserved-attributes.md#description-2)
364    fn get_address_from_call(
365        &self,
366        engine: &SimulationEngine<D>,
367        decoded: &str,
368    ) -> Result<Address, SimulationError> {
369        let method_name = decoded
370            .split(':')
371            .next_back()
372            .ok_or_else(|| {
373                SimulationError::FatalError(
374                    "Failed to get address from call: Could not decode method name from call"
375                        .into(),
376                )
377            })?;
378
379        let selector = {
380            let mut hasher = Keccak256::new();
381            hasher.update(method_name.as_bytes());
382            let result = hasher.finalize();
383            result[..4].to_vec()
384        };
385
386        let to_address = decoded
387            .split(':')
388            .nth(1)
389            .ok_or_else(|| {
390                SimulationError::FatalError(
391                    "Failed to get address from call: Could not decode to_address from call".into(),
392                )
393            })?;
394
395        let parsed_address: Address = to_address.parse().map_err(|_| {
396            SimulationError::FatalError(format!(
397                "Failed to get address from call: Invalid address format: {to_address}"
398            ))
399        })?;
400
401        let sim_params = SimulationParameters {
402            data: selector.to_vec(),
403            to: parsed_address,
404            overrides: Some(HashMap::new()),
405            caller: *EXTERNAL_ACCOUNT,
406            value: U256::from(0u64),
407            gas_limit: None,
408            transient_storage: None,
409        };
410
411        let sim_result = engine
412            .simulate(&sim_params)
413            .map_err(|err| SimulationError::FatalError(err.to_string()))?;
414
415        let address: Address = Address::abi_decode(&sim_result.result).map_err(|e| {
416            SimulationError::FatalError(format!("Failed to get address from call: Failed to decode address list from simulation result {e:?}"))
417        })?;
418
419        Ok(address)
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::str::FromStr;
426
427    use super::*;
428    use crate::evm::engine_db::{tycho_db::PreCachedDB, SHARED_TYCHO_DB};
429
430    #[test]
431    fn test_build_without_required_fields() {
432        let id = "pool_1".to_string();
433        let tokens =
434            vec![TychoBytes::from_str("0000000000000000000000000000000000000000").unwrap()];
435        let balances = HashMap::new();
436        let adapter_address =
437            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
438        let result = tokio_test::block_on(
439            EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address)
440                .balances(balances)
441                .build(SHARED_TYCHO_DB.clone()),
442        );
443
444        assert!(result.is_err());
445        match result.unwrap_err() {
446            SimulationError::FatalError(field) => {
447                assert_eq!(field, "Adapter contract bytecode not set")
448            }
449            _ => panic!("Unexpected error type"),
450        }
451    }
452
453    #[test]
454    fn test_engine_setup() {
455        let id = "pool_1".to_string();
456        let token2 = TychoBytes::from_str("0000000000000000000000000000000000000002").unwrap();
457        let token3 = TychoBytes::from_str("0000000000000000000000000000000000000003").unwrap();
458        let tokens = vec![token2.clone(), token3.clone()];
459        let balances = HashMap::new();
460        let adapter_address =
461            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
462        let builder =
463            EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address).balances(balances);
464
465        let engine = tokio_test::block_on(builder.get_default_engine(SHARED_TYCHO_DB.clone()));
466
467        assert!(engine.is_ok());
468        let engine = engine.unwrap();
469        assert!(engine
470            .state
471            .get_account_storage()
472            .expect("Failed to get account storage")
473            .account_present(&EXTERNAL_ACCOUNT));
474    }
475}