Skip to main content

tycho_simulation/evm/protocol/vm/
state.rs

1#![allow(deprecated)]
2use std::{
3    any::Any,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    str::FromStr,
7};
8
9use alloy::primitives::{Address, U256};
10use itertools::Itertools;
11use num_bigint::BigUint;
12use revm::DatabaseRef;
13use serde::{Deserialize, Serialize};
14use tycho_common::{
15    dto::ProtocolStateDelta,
16    models::token::Token,
17    simulation::{
18        errors::{SimulationError, TransitionError},
19        protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
20    },
21    Bytes,
22};
23
24use super::{
25    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
26    erc20_token::{Overwrites, TokenProxyOverwriteFactory},
27    models::Capability,
28    tycho_simulation_contract::TychoSimulationContract,
29};
30use crate::evm::{
31    engine_db::{engine_db_interface::EngineDatabaseInterface, tycho_db::PreCachedDB},
32    protocol::{
33        u256_num::{u256_to_biguint, u256_to_f64},
34        utils::bytes_to_address,
35    },
36};
37
38#[derive(Clone)]
39pub struct EVMPoolState<D: EngineDatabaseInterface + Clone + Debug>
40where
41    <D as DatabaseRef>::Error: Debug,
42    <D as EngineDatabaseInterface>::Error: Debug,
43{
44    /// The pool's identifier
45    id: String,
46    /// The pool's token's addresses
47    pub tokens: Vec<Bytes>,
48    /// The pool's component balances.
49    balances: HashMap<Address, U256>,
50    /// The contract address for where protocol balances are stored (i.e. a vault contract).
51    /// If given, balances will be overwritten here instead of on the pool contract during
52    /// simulations. This has been deprecated in favor of `contract_balances`.
53    #[deprecated(note = "Use contract_balances instead")]
54    balance_owner: Option<Address>,
55    /// Spot prices of the pool by token pair
56    spot_prices: HashMap<(Address, Address), f64>,
57    /// The supported capabilities of this pool
58    capabilities: HashSet<Capability>,
59    /// Storage overwrites that will be applied to all simulations. They will be cleared
60    /// when ``update_pool_state`` is called, i.e. usually at each block. Hence, the name.
61    block_lasting_overwrites: HashMap<Address, Overwrites>,
62    /// A set of all contract addresses involved in the simulation of this pool.
63    involved_contracts: HashSet<Address>,
64    /// A map of contracts to their token balances.
65    contract_balances: HashMap<Address, HashMap<Address, U256>>,
66    /// Indicates if the protocol uses custom update rules and requires update
67    /// triggers to recalculate spot prices ect. Default is to update on all changes on
68    /// the pool.
69    manual_updates: bool,
70    /// The adapter contract. This is used to interact with the protocol when running simulations
71    adapter_contract: TychoSimulationContract<D>,
72    /// Tokens for which balance overwrites should be disabled.
73    disable_overwrite_tokens: HashSet<Address>,
74}
75
76impl<D> Debug for EVMPoolState<D>
77where
78    D: EngineDatabaseInterface + Clone + Debug,
79    <D as DatabaseRef>::Error: Debug,
80    <D as EngineDatabaseInterface>::Error: Debug,
81{
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("EVMPoolState")
84            .field("id", &self.id)
85            .field("tokens", &self.tokens)
86            .field("balances", &self.balances)
87            .field("involved_contracts", &self.involved_contracts)
88            .field("contract_balances", &self.contract_balances)
89            .finish_non_exhaustive()
90    }
91}
92
93impl<D> EVMPoolState<D>
94where
95    D: EngineDatabaseInterface + Clone + Debug + 'static,
96    <D as DatabaseRef>::Error: Debug,
97    <D as EngineDatabaseInterface>::Error: Debug,
98{
99    /// Creates a new instance of `EVMPoolState` with the given attributes, with the ability to
100    /// simulate a protocol-agnostic transaction.
101    ///
102    /// See struct definition of `EVMPoolState` for attribute explanations.
103    #[allow(clippy::too_many_arguments)]
104    pub fn new(
105        id: String,
106        tokens: Vec<Bytes>,
107        component_balances: HashMap<Address, U256>,
108        balance_owner: Option<Address>,
109        contract_balances: HashMap<Address, HashMap<Address, U256>>,
110        spot_prices: HashMap<(Address, Address), f64>,
111        capabilities: HashSet<Capability>,
112        block_lasting_overwrites: HashMap<Address, Overwrites>,
113        involved_contracts: HashSet<Address>,
114        manual_updates: bool,
115        adapter_contract: TychoSimulationContract<D>,
116        disable_overwrite_tokens: HashSet<Address>,
117    ) -> Self {
118        Self {
119            id,
120            tokens,
121            balances: component_balances,
122            balance_owner,
123            spot_prices,
124            capabilities,
125            block_lasting_overwrites,
126            involved_contracts,
127            contract_balances,
128            manual_updates,
129            adapter_contract,
130            disable_overwrite_tokens,
131        }
132    }
133
134    /// Ensures the pool supports the given capability
135    ///
136    /// # Arguments
137    ///
138    /// * `capability` - The capability that we would like to check for.
139    ///
140    /// # Returns
141    ///
142    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the capability is supported, or a
143    ///   `SimulationError` otherwise.
144    fn ensure_capability(&self, capability: Capability) -> Result<(), SimulationError> {
145        if !self.capabilities.contains(&capability) {
146            return Err(SimulationError::FatalError(format!(
147                "capability {:?} not supported",
148                capability.to_string()
149            )));
150        }
151        Ok(())
152    }
153    /// Sets the spot prices for a pool for all possible pairs of the given tokens.
154    ///
155    /// # Arguments
156    ///
157    /// * `tokens` - A hashmap of `Token` instances representing the tokens to calculate spot prices
158    ///   for.
159    ///
160    /// # Returns
161    ///
162    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the spot prices are successfully set,
163    ///   or a `SimulationError` if an error occurs during the calculation or processing.
164    ///
165    /// # Behavior
166    ///
167    /// This function performs the following steps:
168    /// 1. Ensures the pool has the required capability to perform price calculations.
169    /// 2. Iterates over all permutations of token pairs (sell token and buy token). For each pair:
170    ///    - Retrieves all possible overwrites, considering the maximum balance limit.
171    ///    - Calculates the sell amount limit, considering the overwrites.
172    ///    - Invokes the adapter contract's `price` function to retrieve the calculated price for
173    ///      the token pair, considering the sell amount limit.
174    ///    - Processes the price based on whether the `ScaledPrice` capability is present:
175    ///       - If `ScaledPrice` is present, uses the price directly from the adapter contract.
176    ///       - If `ScaledPrice` is absent, scales the price by adjusting for token decimals.
177    ///    - Stores the calculated price in the `spot_prices` map with the token addresses as the
178    ///      key.
179    /// 3. Returns `Ok(())` upon successful completion or a `SimulationError` upon failure.
180    ///
181    /// # Usage
182    ///
183    /// Spot prices need to be set before attempting to retrieve prices using `spot_price`.
184    ///
185    /// Tip: Setting spot prices on the pool every time the pool actually changes will result in
186    /// faster price fetching than if prices are only set immediately before attempting to retrieve
187    /// prices.
188    pub fn set_spot_prices(
189        &mut self,
190        tokens: &HashMap<Bytes, Token>,
191    ) -> Result<(), SimulationError> {
192        match self.ensure_capability(Capability::PriceFunction) {
193            Ok(_) => {
194                for [sell_token_address, buy_token_address] in self
195                    .tokens
196                    .iter()
197                    .permutations(2)
198                    .map(|p| [p[0], p[1]])
199                {
200                    let sell_token_address = bytes_to_address(sell_token_address)?;
201                    let buy_token_address = bytes_to_address(buy_token_address)?;
202
203                    let overwrites = Some(self.get_overwrites(
204                        vec![sell_token_address, buy_token_address],
205                        *MAX_BALANCE / U256::from(100),
206                    )?);
207
208                    let (sell_amount_limit, _) = self.get_amount_limits(
209                        vec![sell_token_address, buy_token_address],
210                        overwrites.clone(),
211                    )?;
212                    let price_result = self.adapter_contract.price(
213                        &self.id,
214                        sell_token_address,
215                        buy_token_address,
216                        vec![sell_amount_limit / U256::from(100)],
217                        overwrites,
218                    )?;
219
220                    let price = if self
221                        .capabilities
222                        .contains(&Capability::ScaledPrice)
223                    {
224                        *price_result.first().ok_or_else(|| {
225                            SimulationError::FatalError(
226                                "Calculated price array is empty".to_string(),
227                            )
228                        })?
229                    } else {
230                        let unscaled_price = price_result.first().ok_or_else(|| {
231                            SimulationError::FatalError(
232                                "Calculated price array is empty".to_string(),
233                            )
234                        })?;
235                        let sell_token_decimals = self.get_decimals(tokens, &sell_token_address)?;
236                        let buy_token_decimals = self.get_decimals(tokens, &buy_token_address)?;
237                        *unscaled_price * 10f64.powi(sell_token_decimals as i32) /
238                            10f64.powi(buy_token_decimals as i32)
239                    };
240
241                    self.spot_prices
242                        .insert((sell_token_address, buy_token_address), price);
243                }
244            }
245            Err(SimulationError::FatalError(_)) => {
246                // If the pool does not support price function, we need to calculate spot prices by
247                // swapping two amounts and use the approximation to get the derivative.
248
249                for iter_tokens in self.tokens.iter().permutations(2) {
250                    let t0 = bytes_to_address(iter_tokens[0])?;
251                    let t1 = bytes_to_address(iter_tokens[1])?;
252
253                    let overwrites =
254                        Some(self.get_overwrites(vec![t0, t1], *MAX_BALANCE / U256::from(100))?);
255
256                    // Calculate the first sell amount (x1) as 1% of the maximum limit.
257                    let x1 = self
258                        .get_amount_limits(vec![t0, t1], overwrites.clone())?
259                        .0 /
260                        U256::from(100);
261
262                    // Calculate the second sell amount (x2) as x1 + 1% of x1. 1.01% of the max
263                    // limit
264                    let x2 = x1 + (x1 / U256::from(100));
265
266                    // Perform a swap for the first sell amount (x1) and retrieve the received
267                    // amount (y1).
268                    let y1 = self
269                        .adapter_contract
270                        .swap(&self.id, t0, t1, false, x1, overwrites.clone())?
271                        .0
272                        .received_amount;
273
274                    // Perform a swap for the second sell amount (x2) and retrieve the received
275                    // amount (y2).
276                    let y2 = self
277                        .adapter_contract
278                        .swap(&self.id, t0, t1, false, x2, overwrites)?
279                        .0
280                        .received_amount;
281
282                    let sell_token_decimals = self.get_decimals(tokens, &t0)?;
283                    let buy_token_decimals = self.get_decimals(tokens, &t1)?;
284
285                    let num = y2 - y1;
286                    let den = x2 - x1;
287
288                    // Calculate the marginal price, adjusting for token decimals.
289                    let token_correction =
290                        10f64.powi(sell_token_decimals as i32 - buy_token_decimals as i32);
291                    let num_f64 = u256_to_f64(num)?;
292                    let den_f64 = u256_to_f64(den)?;
293                    if den_f64 == 0.0 {
294                        return Err(SimulationError::FatalError(
295                            "Failed to compute marginal price: denominator converted to 0".into(),
296                        ));
297                    }
298                    let marginal_price = num_f64 / den_f64 * token_correction;
299
300                    self.spot_prices
301                        .insert((t0, t1), marginal_price);
302                }
303            }
304            Err(e) => return Err(e),
305        }
306
307        Ok(())
308    }
309
310    fn get_decimals(
311        &self,
312        tokens: &HashMap<Bytes, Token>,
313        sell_token_address: &Address,
314    ) -> Result<usize, SimulationError> {
315        tokens
316            .get(&Bytes::from(sell_token_address.as_slice()))
317            .map(|t| t.decimals as usize)
318            .ok_or_else(|| {
319                SimulationError::FatalError(format!(
320                    "Failed to scale spot prices! Pool: {} Token 0x{:x} is not available!",
321                    self.id, sell_token_address
322                ))
323            })
324    }
325
326    /// Retrieves the sell and buy amount limit for a given pair of tokens and the given overwrites.
327    ///
328    /// Attempting to swap an amount of the sell token that exceeds the sell amount limit is not
329    /// advised and in most cases will result in a revert.
330    ///
331    /// # Arguments
332    ///
333    /// * `tokens` - A vec of tokens, where the first token is the sell token and the second is the
334    ///   buy token. The order of tokens in the input vector is significant and determines the
335    ///   direction of the price query.
336    /// * `overwrites` - A hashmap of overwrites to apply to the simulation.
337    ///
338    /// # Returns
339    ///
340    /// * `Result<(U256,U256), SimulationError>` - Returns the sell and buy amount limit as a `U256`
341    ///   if successful, or a `SimulationError` on failure.
342    fn get_amount_limits(
343        &self,
344        tokens: Vec<Address>,
345        overwrites: Option<HashMap<Address, HashMap<U256, U256>>>,
346    ) -> Result<(U256, U256), SimulationError> {
347        let limits = self
348            .adapter_contract
349            .get_limits(&self.id, tokens[0], tokens[1], overwrites)?;
350
351        Ok(limits)
352    }
353
354    /// Updates the pool state.
355    ///
356    /// It is assumed this is called on a new block. Therefore, first the pool's overwrites cache is
357    /// cleared, then the balances are updated and the spot prices are recalculated.
358    ///
359    /// # Arguments
360    ///
361    /// * `tokens` - A hashmap of token addresses to `Token` instances. This is necessary for
362    ///   calculating new spot prices.
363    /// * `balances` - A `Balances` instance containing all balance updates on the current block.
364    fn update_pool_state(
365        &mut self,
366        tokens: &HashMap<Bytes, Token>,
367        balances: &Balances,
368    ) -> Result<(), SimulationError> {
369        // clear cache
370        self.adapter_contract
371            .engine
372            .clear_temp_storage()
373            .map_err(|err| {
374                SimulationError::FatalError(format!("Failed to clear temporary storage: {err:?}",))
375            })?;
376        self.block_lasting_overwrites.clear();
377
378        // set balances
379        if !self.balances.is_empty() {
380            // Pool uses component balances for overwrites
381            if let Some(bals) = balances
382                .component_balances
383                .get(&self.id)
384            {
385                // Merge delta balances with existing balances instead of replacing them
386                // Prevents errors when delta balance changes do not affect all the pool tokens.
387                for (token, bal) in bals {
388                    let addr = bytes_to_address(token).map_err(|_| {
389                        SimulationError::FatalError(format!(
390                            "Invalid token address in balance update: {token:?}"
391                        ))
392                    })?;
393                    self.balances
394                        .insert(addr, U256::from_be_slice(bal));
395                }
396            }
397        } else {
398            // Pool uses contract balances for overwrites
399            for contract in &self.involved_contracts {
400                if let Some(bals) = balances
401                    .account_balances
402                    .get(&Bytes::from(contract.as_slice()))
403                {
404                    let contract_entry = self
405                        .contract_balances
406                        .entry(*contract)
407                        .or_default();
408                    for (token, bal) in bals {
409                        let addr = bytes_to_address(token).map_err(|_| {
410                            SimulationError::FatalError(format!(
411                                "Invalid token address in balance update: {token:?}"
412                            ))
413                        })?;
414                        contract_entry.insert(addr, U256::from_be_slice(bal));
415                    }
416                }
417            }
418        }
419
420        // reset spot prices
421        self.set_spot_prices(tokens)?;
422        Ok(())
423    }
424
425    fn get_overwrites(
426        &self,
427        tokens: Vec<Address>,
428        max_amount: U256,
429    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
430        let token_overwrites = self.get_token_overwrites(tokens, max_amount)?;
431
432        // Merge `block_lasting_overwrites` with `token_overwrites`
433        let merged_overwrites =
434            self.merge(&self.block_lasting_overwrites.clone(), &token_overwrites);
435
436        Ok(merged_overwrites)
437    }
438
439    fn get_token_overwrites(
440        &self,
441        tokens: Vec<Address>,
442        max_amount: U256,
443    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
444        let sell_token = &tokens[0].clone(); //TODO: need to make it clearer from the interface
445        let mut res: Vec<HashMap<Address, Overwrites>> = Vec::new();
446        if !self
447            .capabilities
448            .contains(&Capability::TokenBalanceIndependent)
449        {
450            res.push(self.get_balance_overwrites()?);
451        }
452
453        let mut overwrites = TokenProxyOverwriteFactory::new(*sell_token, None);
454
455        overwrites.set_balance(max_amount, Address::from_slice(&*EXTERNAL_ACCOUNT.0));
456
457        // Set allowance for adapter_address to max_amount
458        overwrites.set_allowance(max_amount, self.adapter_contract.address, *EXTERNAL_ACCOUNT);
459
460        res.push(overwrites.get_overwrites());
461
462        // Merge all overwrites into a single HashMap
463        Ok(res
464            .into_iter()
465            .fold(HashMap::new(), |acc, overwrite| self.merge(&acc, &overwrite)))
466    }
467
468    /// Gets all balance overwrites for the pool's tokens.
469    ///
470    /// If the pool uses component balances, the balances are set for the balance owner (if exists)
471    /// or for the pool itself. If the pool uses contract balances, the balances are set for the
472    /// contracts involved in the pool.
473    ///
474    /// # Returns
475    ///
476    /// * `Result<HashMap<Address, Overwrites>, SimulationError>` - Returns a hashmap of address to
477    ///   `Overwrites` if successful, or a `SimulationError` on failure.
478    fn get_balance_overwrites(&self) -> Result<HashMap<Address, Overwrites>, SimulationError> {
479        let mut balance_overwrites: HashMap<Address, Overwrites> = HashMap::new();
480
481        // Use component balances for overrides
482        let address = match self.balance_owner {
483            Some(owner) => Some(owner),
484            None if !self.contract_balances.is_empty() => None,
485            None => Some(self.id.parse().map_err(|_| {
486                SimulationError::FatalError(
487                    "Failed to get balance overwrites: Pool ID is not an address".into(),
488                )
489            })?),
490        };
491
492        if let Some(address) = address {
493            // Only override balances that are explicitly provided in self.balances
494            // This preserves existing balances for tokens not updated in delta transitions
495            for (token, bal) in &self.balances {
496                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
497                overwrites.set_balance(*bal, address);
498                balance_overwrites.extend(overwrites.get_overwrites());
499            }
500        }
501
502        // Use contract balances for overrides (will overwrite component balances if they were set
503        // for a contract we explicitly track balances for)
504        for (contract, balances) in &self.contract_balances {
505            for (token, balance) in balances {
506                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
507                overwrites.set_balance(*balance, *contract);
508                balance_overwrites.extend(overwrites.get_overwrites());
509            }
510        }
511
512        // Apply disables for tokens that should not have any balance overrides
513        for token in &self.disable_overwrite_tokens {
514            balance_overwrites.remove(token);
515        }
516
517        Ok(balance_overwrites)
518    }
519
520    fn merge(
521        &self,
522        target: &HashMap<Address, Overwrites>,
523        source: &HashMap<Address, Overwrites>,
524    ) -> HashMap<Address, Overwrites> {
525        let mut merged = target.clone();
526
527        for (key, source_inner) in source {
528            merged
529                .entry(*key)
530                .or_default()
531                .extend(source_inner.clone());
532        }
533
534        merged
535    }
536
537    #[cfg(test)]
538    pub fn get_involved_contracts(&self) -> HashSet<Address> {
539        self.involved_contracts.clone()
540    }
541
542    #[cfg(test)]
543    pub fn get_manual_updates(&self) -> bool {
544        self.manual_updates
545    }
546
547    #[cfg(test)]
548    #[deprecated]
549    pub fn get_balance_owner(&self) -> Option<Address> {
550        self.balance_owner
551    }
552
553    /// Get the component balances for validation purposes
554    pub fn get_balances(&self) -> &HashMap<Address, U256> {
555        &self.balances
556    }
557}
558
559impl<D> Serialize for EVMPoolState<D>
560where
561    D: EngineDatabaseInterface + Clone + Debug,
562    <D as DatabaseRef>::Error: Debug,
563    <D as EngineDatabaseInterface>::Error: Debug,
564{
565    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
566    where
567        S: serde::Serializer,
568    {
569        Err(serde::ser::Error::custom("not supported due vm state deps"))
570    }
571}
572
573impl<'de, D> Deserialize<'de> for EVMPoolState<D>
574where
575    D: EngineDatabaseInterface + Clone + Debug,
576    <D as DatabaseRef>::Error: Debug,
577    <D as EngineDatabaseInterface>::Error: Debug,
578{
579    fn deserialize<De>(_deserializer: De) -> Result<Self, De::Error>
580    where
581        De: serde::Deserializer<'de>,
582    {
583        Err(serde::de::Error::custom("not supported due vm state deps"))
584    }
585}
586
587#[typetag::serialize]
588impl<D> ProtocolSim for EVMPoolState<D>
589where
590    D: EngineDatabaseInterface + Clone + Debug + 'static,
591    <D as DatabaseRef>::Error: Debug,
592    <D as EngineDatabaseInterface>::Error: Debug,
593{
594    fn fee(&self) -> f64 {
595        todo!()
596    }
597
598    fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
599        let base_address = bytes_to_address(&base.address)?;
600        let quote_address = bytes_to_address(&quote.address)?;
601        self.spot_prices
602            .get(&(base_address, quote_address))
603            .cloned()
604            .ok_or(SimulationError::FatalError(format!(
605                "Spot price not found for base token {base_address} and quote token {quote_address}"
606            )))
607    }
608
609    fn get_amount_out(
610        &self,
611        amount_in: BigUint,
612        token_in: &Token,
613        token_out: &Token,
614    ) -> Result<GetAmountOutResult, SimulationError> {
615        let sell_token_address = bytes_to_address(&token_in.address)?;
616        let buy_token_address = bytes_to_address(&token_out.address)?;
617        let sell_amount = U256::from_be_slice(&amount_in.to_bytes_be());
618        let overwrites = self.get_overwrites(
619            vec![sell_token_address, buy_token_address],
620            *MAX_BALANCE / U256::from(100),
621        )?;
622        let (sell_amount_limit, _) = self.get_amount_limits(
623            vec![sell_token_address, buy_token_address],
624            Some(overwrites.clone()),
625        )?;
626        let (sell_amount_respecting_limit, sell_amount_exceeds_limit) = if self
627            .capabilities
628            .contains(&Capability::HardLimits) &&
629            sell_amount_limit < sell_amount
630        {
631            (sell_amount_limit, true)
632        } else {
633            (sell_amount, false)
634        };
635
636        let overwrites_with_sell_limit =
637            self.get_overwrites(vec![sell_token_address, buy_token_address], sell_amount_limit)?;
638        let complete_overwrites = self.merge(&overwrites, &overwrites_with_sell_limit);
639
640        let (trade, state_changes) = self.adapter_contract.swap(
641            &self.id,
642            sell_token_address,
643            buy_token_address,
644            false,
645            sell_amount_respecting_limit,
646            Some(complete_overwrites),
647        )?;
648
649        let mut new_state = self.clone();
650
651        // Apply state changes to the new state
652        for (address, state_update) in state_changes {
653            if let Some(storage) = state_update.storage {
654                let block_overwrites = new_state
655                    .block_lasting_overwrites
656                    .entry(address)
657                    .or_default();
658                for (slot, value) in storage {
659                    let slot = U256::from_str(&slot.to_string()).map_err(|_| {
660                        SimulationError::FatalError("Failed to decode slot index".to_string())
661                    })?;
662                    let value = U256::from_str(&value.to_string()).map_err(|_| {
663                        SimulationError::FatalError("Failed to decode slot overwrite".to_string())
664                    })?;
665                    block_overwrites.insert(slot, value);
666                }
667            }
668        }
669
670        // Update spot prices
671        let tokens = HashMap::from([
672            (token_in.address.clone(), token_in.clone()),
673            (token_out.address.clone(), token_out.clone()),
674        ]);
675        let _ = new_state.set_spot_prices(&tokens);
676
677        let buy_amount = trade.received_amount;
678
679        if sell_amount_exceeds_limit {
680            return Err(SimulationError::InvalidInput(
681                format!("Sell amount exceeds limit {sell_amount_limit}"),
682                Some(GetAmountOutResult::new(
683                    u256_to_biguint(buy_amount),
684                    u256_to_biguint(trade.gas_used),
685                    Box::new(new_state.clone()),
686                )),
687            ));
688        }
689        Ok(GetAmountOutResult::new(
690            u256_to_biguint(buy_amount),
691            u256_to_biguint(trade.gas_used),
692            Box::new(new_state.clone()),
693        ))
694    }
695
696    fn get_limits(
697        &self,
698        sell_token: Bytes,
699        buy_token: Bytes,
700    ) -> Result<(BigUint, BigUint), SimulationError> {
701        let sell_token = bytes_to_address(&sell_token)?;
702        let buy_token = bytes_to_address(&buy_token)?;
703        let overwrites =
704            self.get_overwrites(vec![sell_token, buy_token], *MAX_BALANCE / U256::from(100))?;
705        let limits = self.get_amount_limits(vec![sell_token, buy_token], Some(overwrites))?;
706        Ok((u256_to_biguint(limits.0), u256_to_biguint(limits.1)))
707    }
708
709    fn delta_transition(
710        &mut self,
711        delta: ProtocolStateDelta,
712        tokens: &HashMap<Bytes, Token>,
713        balances: &Balances,
714    ) -> Result<(), TransitionError<String>> {
715        if self.manual_updates {
716            // Directly check for "update_marker" in `updated_attributes`
717            if let Some(marker) = delta
718                .updated_attributes
719                .get("update_marker")
720            {
721                // Assuming `marker` is of type `Bytes`, check its value for "truthiness"
722                if !marker.is_empty() && marker[0] != 0 {
723                    self.update_pool_state(tokens, balances)?;
724                }
725            }
726        } else {
727            self.update_pool_state(tokens, balances)?;
728        }
729
730        Ok(())
731    }
732
733    fn clone_box(&self) -> Box<dyn ProtocolSim> {
734        Box::new(self.clone())
735    }
736
737    fn as_any(&self) -> &dyn Any {
738        self
739    }
740
741    fn as_any_mut(&mut self) -> &mut dyn Any {
742        self
743    }
744
745    fn eq(&self, other: &dyn ProtocolSim) -> bool {
746        if let Some(other_state) = other
747            .as_any()
748            .downcast_ref::<EVMPoolState<PreCachedDB>>()
749        {
750            self.id == other_state.id
751        } else {
752            false
753        }
754    }
755
756    /// Implemented manually because `typetag` macro not supports generics
757    fn typetag_deserialize(&self) {
758        // https://github.com/dtolnay/typetag/blob/21ae0d40c9f73443a20204ab4a134441355b52f7/impl/src/tagged_trait.rs#L140
759        unreachable!("Only to catch missing typetag attribute on impl blocks. Not called.")
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use std::default::Default;
766
767    use num_traits::One;
768    use revm::{
769        primitives::KECCAK_EMPTY,
770        state::{AccountInfo, Bytecode},
771    };
772    use serde_json::Value;
773    use tycho_client::feed::BlockHeader;
774    use tycho_common::models::Chain;
775
776    use super::*;
777    use crate::evm::{
778        engine_db::{create_engine, SHARED_TYCHO_DB},
779        protocol::vm::{
780            constants::{BALANCER_V2, ERC20_PROXY_BYTECODE},
781            state_builder::EVMPoolStateBuilder,
782        },
783        simulation::SimulationEngine,
784        tycho_models::AccountUpdate,
785    };
786
787    fn dai() -> Token {
788        Token::new(
789            &Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
790            "DAI",
791            18,
792            0,
793            &[Some(10_000)],
794            Chain::Ethereum,
795            100,
796        )
797    }
798
799    fn bal() -> Token {
800        Token::new(
801            &Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(),
802            "BAL",
803            18,
804            0,
805            &[Some(10_000)],
806            Chain::Ethereum,
807            100,
808        )
809    }
810
811    fn dai_addr() -> Address {
812        bytes_to_address(&dai().address).unwrap()
813    }
814
815    fn bal_addr() -> Address {
816        bytes_to_address(&bal().address).unwrap()
817    }
818
819    async fn setup_pool_state() -> EVMPoolState<PreCachedDB> {
820        let data_str = include_str!("assets/balancer_contract_storage_block_20463609.json");
821        let data: Value = serde_json::from_str(data_str).expect("Failed to parse JSON");
822
823        let accounts: Vec<AccountUpdate> = serde_json::from_value(data["accounts"].clone())
824            .expect("Expected accounts to match AccountUpdate structure");
825
826        let db = SHARED_TYCHO_DB.clone();
827        let engine: SimulationEngine<_> = create_engine(db.clone(), false).unwrap();
828
829        let block = BlockHeader {
830            number: 20463609,
831            hash: Bytes::from_str(
832                "0x4315fd1afc25cc2ebc72029c543293f9fd833eeb305e2e30159459c827733b1b",
833            )
834            .unwrap(),
835            timestamp: 1722875891,
836            ..Default::default()
837        };
838
839        for account in accounts.clone() {
840            engine
841                .state
842                .init_account(
843                    account.address,
844                    AccountInfo {
845                        balance: account.balance.unwrap_or_default(),
846                        nonce: 0u64,
847                        code_hash: KECCAK_EMPTY,
848                        code: account
849                            .code
850                            .clone()
851                            .map(|arg0: Vec<u8>| Bytecode::new_raw(arg0.into())),
852                    },
853                    None,
854                    false,
855                )
856                .expect("Failed to initialize account");
857        }
858        db.update(accounts, Some(block))
859            .unwrap();
860
861        let tokens = vec![dai().address, bal().address];
862        for token in &tokens {
863            engine
864                .state
865                .init_account(
866                    bytes_to_address(token).unwrap(),
867                    AccountInfo {
868                        balance: U256::from(0),
869                        nonce: 0,
870                        code_hash: KECCAK_EMPTY,
871                        code: Some(Bytecode::new_raw(ERC20_PROXY_BYTECODE.into())),
872                    },
873                    None,
874                    true,
875                )
876                .expect("Failed to initialize account");
877        }
878
879        let block = BlockHeader {
880            number: 18485417,
881            hash: Bytes::from_str(
882                "0x28d41d40f2ac275a4f5f621a636b9016b527d11d37d610a45ac3a821346ebf8c",
883            )
884            .expect("Invalid block hash"),
885            timestamp: 0,
886            ..Default::default()
887        };
888        db.update(vec![], Some(block.clone()))
889            .unwrap();
890
891        let pool_id: String =
892            "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
893
894        let stateless_contracts = HashMap::from([(
895            String::from("0x3de27efa2f1aa663ae5d458857e731c129069f29"),
896            Some(Vec::new()),
897        )]);
898
899        let balances = HashMap::from([
900            (dai_addr(), U256::from_str("178754012737301807104").unwrap()),
901            (bal_addr(), U256::from_str("91082987763369885696").unwrap()),
902        ]);
903        let adapter_address =
904            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
905
906        EVMPoolStateBuilder::new(pool_id, tokens, adapter_address)
907            .balances(balances)
908            .balance_owner(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap())
909            .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
910            .stateless_contracts(stateless_contracts)
911            .build(SHARED_TYCHO_DB.clone())
912            .await
913            .expect("Failed to build pool state")
914    }
915
916    #[tokio::test]
917    async fn test_init() {
918        // Clear DB from this test to prevent interference from other tests
919        SHARED_TYCHO_DB
920            .clear()
921            .expect("Failed to cleared SHARED TX");
922        let pool_state = setup_pool_state().await;
923
924        let expected_capabilities = vec![
925            Capability::SellSide,
926            Capability::BuySide,
927            Capability::PriceFunction,
928            Capability::HardLimits,
929        ]
930        .into_iter()
931        .collect::<HashSet<_>>();
932
933        let capabilities_adapter_contract = pool_state
934            .adapter_contract
935            .get_capabilities(
936                &pool_state.id,
937                bytes_to_address(&pool_state.tokens[0]).unwrap(),
938                bytes_to_address(&pool_state.tokens[1]).unwrap(),
939            )
940            .unwrap();
941
942        assert_eq!(capabilities_adapter_contract, expected_capabilities.clone());
943
944        let capabilities_state = pool_state.clone().capabilities;
945
946        assert_eq!(capabilities_state, expected_capabilities.clone());
947
948        for capability in expected_capabilities.clone() {
949            assert!(pool_state
950                .clone()
951                .ensure_capability(capability)
952                .is_ok());
953        }
954
955        assert!(pool_state
956            .clone()
957            .ensure_capability(Capability::MarginalPrice)
958            .is_err());
959
960        // Verify all tokens are initialized in the engine
961        let engine_accounts = pool_state
962            .adapter_contract
963            .engine
964            .state
965            .clone()
966            .get_account_storage()
967            .expect("Failed to get account storage");
968        for token in pool_state.tokens.clone() {
969            let account = engine_accounts
970                .get_account_info(&bytes_to_address(&token).unwrap())
971                .unwrap();
972            assert_eq!(account.balance, U256::from(0));
973            assert_eq!(account.nonce, 0u64);
974            assert_eq!(account.code_hash, KECCAK_EMPTY);
975            assert!(account.code.is_some());
976        }
977
978        // Verify external account is initialized in the engine
979        let external_account = engine_accounts
980            .get_account_info(&EXTERNAL_ACCOUNT)
981            .unwrap();
982        assert_eq!(external_account.balance, U256::from(*MAX_BALANCE));
983        assert_eq!(external_account.nonce, 0u64);
984        assert_eq!(external_account.code_hash, KECCAK_EMPTY);
985        assert!(external_account.code.is_none());
986    }
987
988    #[tokio::test]
989    async fn test_get_amount_out() -> Result<(), Box<dyn std::error::Error>> {
990        let pool_state = setup_pool_state().await;
991
992        let result = pool_state
993            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
994            .unwrap();
995        let new_state = result
996            .new_state
997            .as_any()
998            .downcast_ref::<EVMPoolState<PreCachedDB>>()
999            .unwrap();
1000        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1001        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1002        assert!(pool_state
1003            .block_lasting_overwrites
1004            .is_empty());
1005        Ok(())
1006    }
1007
1008    #[tokio::test]
1009    async fn test_sequential_get_amount_outs() {
1010        let pool_state = setup_pool_state().await;
1011
1012        let result = pool_state
1013            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1014            .unwrap();
1015        let new_state = result
1016            .new_state
1017            .as_any()
1018            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1019            .unwrap();
1020        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1021        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1022
1023        let new_result = new_state
1024            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1025            .unwrap();
1026        let new_state_second_swap = new_result
1027            .new_state
1028            .as_any()
1029            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1030            .unwrap();
1031
1032        assert_eq!(new_result.amount, BigUint::from_str("136964651490065626").unwrap());
1033        assert_ne!(new_state_second_swap.spot_prices, new_state.spot_prices);
1034    }
1035
1036    #[tokio::test]
1037    async fn test_get_amount_out_dust() {
1038        let pool_state = setup_pool_state().await;
1039
1040        let result = pool_state
1041            .get_amount_out(BigUint::one(), &dai(), &bal())
1042            .unwrap();
1043
1044        let _ = result
1045            .new_state
1046            .as_any()
1047            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1048            .unwrap();
1049        assert_eq!(result.amount, BigUint::ZERO);
1050    }
1051
1052    #[tokio::test]
1053    async fn test_get_amount_out_sell_limit() {
1054        let pool_state = setup_pool_state().await;
1055
1056        let result = pool_state.get_amount_out(
1057            // sell limit is 100279494253364362835
1058            BigUint::from_str("100379494253364362835").unwrap(),
1059            &dai(),
1060            &bal(),
1061        );
1062
1063        assert!(result.is_err());
1064
1065        match result {
1066            Err(SimulationError::InvalidInput(msg1, amount_out_result)) => {
1067                assert_eq!(msg1, "Sell amount exceeds limit 100279494253364362835");
1068                assert!(amount_out_result.is_some());
1069            }
1070            _ => panic!("Test failed: was expecting an Err(SimulationError::RetryDifferentInput(_, _)) value"),
1071        }
1072    }
1073
1074    #[tokio::test]
1075    async fn test_get_amount_limits() {
1076        let pool_state = setup_pool_state().await;
1077
1078        let overwrites = pool_state
1079            .get_overwrites(
1080                vec![
1081                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1082                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1083                ],
1084                *MAX_BALANCE / U256::from(100),
1085            )
1086            .unwrap();
1087        let (dai_limit, _) = pool_state
1088            .get_amount_limits(vec![dai_addr(), bal_addr()], Some(overwrites.clone()))
1089            .unwrap();
1090        assert_eq!(dai_limit, U256::from_str("100279494253364362835").unwrap());
1091
1092        let (bal_limit, _) = pool_state
1093            .get_amount_limits(
1094                vec![
1095                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1096                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1097                ],
1098                Some(overwrites),
1099            )
1100            .unwrap();
1101        assert_eq!(bal_limit, U256::from_str("13997408640689987484").unwrap());
1102    }
1103
1104    #[tokio::test]
1105    async fn test_set_spot_prices() {
1106        let mut pool_state = setup_pool_state().await;
1107
1108        pool_state
1109            .set_spot_prices(
1110                &vec![bal(), dai()]
1111                    .into_iter()
1112                    .map(|t| (t.address.clone(), t))
1113                    .collect(),
1114            )
1115            .unwrap();
1116
1117        let dai_bal_spot_price = pool_state
1118            .spot_prices
1119            .get(&(
1120                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1121                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1122            ))
1123            .unwrap();
1124        let bal_dai_spot_price = pool_state
1125            .spot_prices
1126            .get(&(
1127                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1128                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1129            ))
1130            .unwrap();
1131        assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9);
1132        assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);
1133    }
1134
1135    #[tokio::test]
1136    async fn test_set_spot_prices_without_capability() {
1137        // Tests set Spot Prices functions when the pool doesn't have PriceFunction capability
1138        let mut pool_state = setup_pool_state().await;
1139
1140        pool_state
1141            .capabilities
1142            .remove(&Capability::PriceFunction);
1143
1144        pool_state
1145            .set_spot_prices(
1146                &vec![bal(), dai()]
1147                    .into_iter()
1148                    .map(|t| (t.address.clone(), t))
1149                    .collect(),
1150            )
1151            .unwrap();
1152
1153        let dai_bal_spot_price = pool_state
1154            .spot_prices
1155            .get(&(
1156                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1157                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1158            ))
1159            .unwrap();
1160        let bal_dai_spot_price = pool_state
1161            .spot_prices
1162            .get(&(
1163                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1164                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1165            ))
1166            .unwrap();
1167        assert_eq!(dai_bal_spot_price, &0.13736685496467538);
1168        assert_eq!(bal_dai_spot_price, &7.050354297665408);
1169    }
1170
1171    #[tokio::test]
1172    async fn test_get_balance_overwrites_with_component_balances() {
1173        let pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1174
1175        let overwrites = pool_state
1176            .get_balance_overwrites()
1177            .unwrap();
1178
1179        let dai_address = dai_addr();
1180        let bal_address = bal_addr();
1181        assert!(overwrites.contains_key(&dai_address));
1182        assert!(overwrites.contains_key(&bal_address));
1183    }
1184
1185    #[tokio::test]
1186    async fn test_get_balance_overwrites_with_contract_balances() {
1187        let mut pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1188
1189        let contract_address =
1190            Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap();
1191
1192        // Ensure no component balances are used
1193        pool_state.balances.clear();
1194        pool_state.balance_owner = None;
1195
1196        // Set contract balances
1197        let dai_address = dai_addr();
1198        let bal_address = bal_addr();
1199        pool_state.contract_balances = HashMap::from([(
1200            contract_address,
1201            HashMap::from([
1202                (dai_address, U256::from_str("7500000000000000000000").unwrap()), // 7500 DAI
1203                (bal_address, U256::from_str("1500000000000000000000").unwrap()), // 1500 BAL
1204            ]),
1205        )]);
1206
1207        let overwrites = pool_state
1208            .get_balance_overwrites()
1209            .unwrap();
1210
1211        assert!(overwrites.contains_key(&dai_address));
1212        assert!(overwrites.contains_key(&bal_address));
1213    }
1214
1215    #[tokio::test]
1216    async fn test_balance_merging_during_delta_transition() {
1217        use std::str::FromStr;
1218
1219        let mut pool_state = setup_pool_state().await;
1220        let pool_id = pool_state.id.clone();
1221
1222        // Test the balance merging logic more directly
1223        // Setup initial balances including DAI and BAL (which the pool already knows about)
1224        let dai_addr = dai_addr();
1225        let bal_addr = bal_addr();
1226        let new_token = Address::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); // WETH
1227
1228        // Clear and setup clean initial state
1229        pool_state.balances.clear();
1230        pool_state
1231            .balances
1232            .insert(dai_addr, U256::from(1000000000u64));
1233        pool_state
1234            .balances
1235            .insert(bal_addr, U256::from(2000000000u64));
1236        pool_state
1237            .balances
1238            .insert(new_token, U256::from(3000000000u64));
1239
1240        // Create tokens mapping including the existing DAI and BAL
1241        let mut tokens = HashMap::new();
1242        tokens.insert(dai().address.clone(), dai());
1243        tokens.insert(bal().address.clone(), bal());
1244
1245        // Simulate a delta transition with only DAI balance update (missing BAL and new_token)
1246        let mut component_balances = HashMap::new();
1247        let mut delta_balances = HashMap::new();
1248        // Only update DAI balance, leave others unchanged in delta
1249        delta_balances.insert(dai().address.clone(), Bytes::from(vec![0x77, 0x35, 0x94, 0x00])); // 2000000000 (updated value)
1250        component_balances.insert(pool_id.clone(), delta_balances);
1251
1252        let balances = Balances { component_balances, account_balances: HashMap::new() };
1253
1254        // Record initial balance count
1255        let initial_balance_count = pool_state.balances.len();
1256        assert_eq!(initial_balance_count, 3);
1257
1258        // Apply delta transition
1259        pool_state
1260            .update_pool_state(&tokens, &balances)
1261            .unwrap();
1262
1263        // Verify that all 3 balances are preserved (BAL and new_token should still be there)
1264        assert_eq!(
1265            pool_state.balances.len(),
1266            3,
1267            "All balances should be preserved after delta transition"
1268        );
1269        assert!(
1270            pool_state
1271                .balances
1272                .contains_key(&dai_addr),
1273            "DAI balance should be present"
1274        );
1275        assert!(
1276            pool_state
1277                .balances
1278                .contains_key(&bal_addr),
1279            "BAL balance should be present"
1280        );
1281        assert!(
1282            pool_state
1283                .balances
1284                .contains_key(&new_token),
1285            "New token balance should be preserved from before delta"
1286        );
1287
1288        // Verify that updated token (DAI) has new value
1289        assert_eq!(
1290            pool_state.balances[&dai_addr],
1291            U256::from(2000000000u64),
1292            "DAI balance should be updated"
1293        );
1294
1295        // Verify that non-updated tokens retain their original values
1296        assert_eq!(
1297            pool_state.balances[&bal_addr],
1298            U256::from(2000000000u64),
1299            "BAL balance should be unchanged"
1300        );
1301        assert_eq!(
1302            pool_state.balances[&new_token],
1303            U256::from(3000000000u64),
1304            "New token balance should be unchanged"
1305        );
1306    }
1307
1308    #[test]
1309    fn should_not_panic_at_typetag_deserialize() {
1310        let deserialized: Result<Box<dyn ProtocolSim>, _> = serde_json::from_str(
1311            r#"{"protocol":"EVMPoolState","state":{"reserve_0":1,"reserve_1":2}}"#,
1312        );
1313
1314        assert!(deserialized.is_err());
1315    }
1316}