Skip to main content

tycho_simulation/evm/protocol/vm/
state.rs

1#![allow(deprecated)]
2use std::{
3    any::Any,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    str::FromStr,
7};
8
9use alloy::primitives::{Address, U256};
10use itertools::Itertools;
11use num_bigint::BigUint;
12use revm::DatabaseRef;
13use serde::{Deserialize, Serialize};
14use tycho_common::{
15    dto::ProtocolStateDelta,
16    models::token::Token,
17    simulation::{
18        errors::{SimulationError, TransitionError},
19        protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
20    },
21    Bytes,
22};
23
24use super::{
25    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
26    erc20_token::{Overwrites, TokenProxyOverwriteFactory},
27    models::Capability,
28    tycho_simulation_contract::TychoSimulationContract,
29};
30use crate::evm::{
31    engine_db::{engine_db_interface::EngineDatabaseInterface, tycho_db::PreCachedDB},
32    protocol::{
33        u256_num::{u256_to_biguint, u256_to_f64},
34        utils::bytes_to_address,
35    },
36};
37
38#[derive(Clone)]
39pub struct EVMPoolState<D: EngineDatabaseInterface + Clone + Debug>
40where
41    <D as DatabaseRef>::Error: Debug,
42    <D as EngineDatabaseInterface>::Error: Debug,
43{
44    /// The pool's identifier
45    id: String,
46    /// The pool's token's addresses
47    pub tokens: Vec<Bytes>,
48    /// The pool's component balances.
49    balances: HashMap<Address, U256>,
50    /// The contract address for where protocol balances are stored (i.e. a vault contract).
51    /// If given, balances will be overwritten here instead of on the pool contract during
52    /// simulations. This has been deprecated in favor of `contract_balances`.
53    #[deprecated(note = "Use contract_balances instead")]
54    balance_owner: Option<Address>,
55    /// Spot prices of the pool by token pair
56    spot_prices: HashMap<(Address, Address), f64>,
57    /// The supported capabilities of this pool
58    capabilities: HashSet<Capability>,
59    /// Storage overwrites that will be applied to all simulations. They will be cleared
60    /// when ``update_pool_state`` is called, i.e. usually at each block. Hence, the name.
61    block_lasting_overwrites: HashMap<Address, Overwrites>,
62    /// A set of all contract addresses involved in the simulation of this pool.
63    involved_contracts: HashSet<Address>,
64    /// A map of contracts to their token balances.
65    contract_balances: HashMap<Address, HashMap<Address, U256>>,
66    /// Indicates if the protocol uses custom update rules and requires update
67    /// triggers to recalculate spot prices ect. Default is to update on all changes on
68    /// the pool.
69    manual_updates: bool,
70    /// The adapter contract. This is used to interact with the protocol when running simulations
71    adapter_contract: TychoSimulationContract<D>,
72    /// Tokens for which balance overwrites should be disabled.
73    disable_overwrite_tokens: HashSet<Address>,
74}
75
76impl<D> Debug for EVMPoolState<D>
77where
78    D: EngineDatabaseInterface + Clone + Debug,
79    <D as DatabaseRef>::Error: Debug,
80    <D as EngineDatabaseInterface>::Error: Debug,
81{
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("EVMPoolState")
84            .field("id", &self.id)
85            .field("tokens", &self.tokens)
86            .field("balances", &self.balances)
87            .field("involved_contracts", &self.involved_contracts)
88            .field("contract_balances", &self.contract_balances)
89            .finish_non_exhaustive()
90    }
91}
92
93impl<D> EVMPoolState<D>
94where
95    D: EngineDatabaseInterface + Clone + Debug + 'static,
96    <D as DatabaseRef>::Error: Debug,
97    <D as EngineDatabaseInterface>::Error: Debug,
98{
99    /// Creates a new instance of `EVMPoolState` with the given attributes, with the ability to
100    /// simulate a protocol-agnostic transaction.
101    ///
102    /// See struct definition of `EVMPoolState` for attribute explanations.
103    #[allow(clippy::too_many_arguments)]
104    pub fn new(
105        id: String,
106        tokens: Vec<Bytes>,
107        component_balances: HashMap<Address, U256>,
108        balance_owner: Option<Address>,
109        contract_balances: HashMap<Address, HashMap<Address, U256>>,
110        spot_prices: HashMap<(Address, Address), f64>,
111        capabilities: HashSet<Capability>,
112        block_lasting_overwrites: HashMap<Address, Overwrites>,
113        involved_contracts: HashSet<Address>,
114        manual_updates: bool,
115        adapter_contract: TychoSimulationContract<D>,
116        disable_overwrite_tokens: HashSet<Address>,
117    ) -> Self {
118        Self {
119            id,
120            tokens,
121            balances: component_balances,
122            balance_owner,
123            spot_prices,
124            capabilities,
125            block_lasting_overwrites,
126            involved_contracts,
127            contract_balances,
128            manual_updates,
129            adapter_contract,
130            disable_overwrite_tokens,
131        }
132    }
133
134    /// Ensures the pool supports the given capability
135    ///
136    /// # Arguments
137    ///
138    /// * `capability` - The capability that we would like to check for.
139    ///
140    /// # Returns
141    ///
142    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the capability is supported, or a
143    ///   `SimulationError` otherwise.
144    fn ensure_capability(&self, capability: Capability) -> Result<(), SimulationError> {
145        if !self.capabilities.contains(&capability) {
146            return Err(SimulationError::FatalError(format!(
147                "capability {:?} not supported",
148                capability.to_string()
149            )));
150        }
151        Ok(())
152    }
153    /// Sets the spot prices for a pool for all possible pairs of the given tokens.
154    ///
155    /// # Arguments
156    ///
157    /// * `tokens` - A hashmap of `Token` instances representing the tokens to calculate spot prices
158    ///   for.
159    ///
160    /// # Returns
161    ///
162    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the spot prices are successfully set,
163    ///   or a `SimulationError` if an error occurs during the calculation or processing.
164    ///
165    /// # Behavior
166    ///
167    /// This function performs the following steps:
168    /// 1. Ensures the pool has the required capability to perform price calculations.
169    /// 2. Iterates over all permutations of token pairs (sell token and buy token). For each pair:
170    ///    - Retrieves all possible overwrites, considering the maximum balance limit.
171    ///    - Calculates the sell amount limit, considering the overwrites.
172    ///    - Invokes the adapter contract's `price` function to retrieve the calculated price for
173    ///      the token pair, considering the sell amount limit.
174    ///    - Processes the price based on whether the `ScaledPrice` capability is present:
175    ///       - If `ScaledPrice` is present, uses the price directly from the adapter contract.
176    ///       - If `ScaledPrice` is absent, scales the price by adjusting for token decimals.
177    ///    - Stores the calculated price in the `spot_prices` map with the token addresses as the
178    ///      key.
179    /// 3. Returns `Ok(())` upon successful completion or a `SimulationError` upon failure.
180    ///
181    /// # Usage
182    ///
183    /// Spot prices need to be set before attempting to retrieve prices using `spot_price`.
184    ///
185    /// Tip: Setting spot prices on the pool every time the pool actually changes will result in
186    /// faster price fetching than if prices are only set immediately before attempting to retrieve
187    /// prices.
188    pub fn set_spot_prices(
189        &mut self,
190        tokens: &HashMap<Bytes, Token>,
191    ) -> Result<(), SimulationError> {
192        match self.ensure_capability(Capability::PriceFunction) {
193            Ok(_) => {
194                for [sell_token_address, buy_token_address] in self
195                    .tokens
196                    .iter()
197                    .permutations(2)
198                    .map(|p| [p[0], p[1]])
199                {
200                    let sell_token_address = bytes_to_address(sell_token_address)?;
201                    let buy_token_address = bytes_to_address(buy_token_address)?;
202
203                    let overwrites = Some(self.get_overwrites(
204                        vec![sell_token_address, buy_token_address],
205                        *MAX_BALANCE / U256::from(100),
206                    )?);
207
208                    let (sell_amount_limit, _) = self.get_amount_limits(
209                        vec![sell_token_address, buy_token_address],
210                        overwrites.clone(),
211                    )?;
212                    let price_result = self.adapter_contract.price(
213                        &self.id,
214                        sell_token_address,
215                        buy_token_address,
216                        vec![sell_amount_limit / U256::from(100)],
217                        overwrites,
218                    )?;
219
220                    let price = if self
221                        .capabilities
222                        .contains(&Capability::ScaledPrice)
223                    {
224                        *price_result.first().ok_or_else(|| {
225                            SimulationError::FatalError(
226                                "Calculated price array is empty".to_string(),
227                            )
228                        })?
229                    } else {
230                        let unscaled_price = price_result.first().ok_or_else(|| {
231                            SimulationError::FatalError(
232                                "Calculated price array is empty".to_string(),
233                            )
234                        })?;
235                        let sell_token_decimals = self.get_decimals(tokens, &sell_token_address)?;
236                        let buy_token_decimals = self.get_decimals(tokens, &buy_token_address)?;
237                        *unscaled_price * 10f64.powi(sell_token_decimals as i32) /
238                            10f64.powi(buy_token_decimals as i32)
239                    };
240
241                    self.spot_prices
242                        .insert((sell_token_address, buy_token_address), price);
243                }
244            }
245            Err(SimulationError::FatalError(_)) => {
246                // If the pool does not support price function, we need to calculate spot prices by
247                // swapping two amounts and use the approximation to get the derivative.
248
249                for iter_tokens in self.tokens.iter().permutations(2) {
250                    let t0 = bytes_to_address(iter_tokens[0])?;
251                    let t1 = bytes_to_address(iter_tokens[1])?;
252
253                    let overwrites =
254                        Some(self.get_overwrites(vec![t0, t1], *MAX_BALANCE / U256::from(100))?);
255
256                    // Calculate the first sell amount (x1) as 1% of the maximum limit.
257                    let x1 = self
258                        .get_amount_limits(vec![t0, t1], overwrites.clone())?
259                        .0 /
260                        U256::from(100);
261
262                    // Calculate the second sell amount (x2) as x1 + 1% of x1. 1.01% of the max
263                    // limit
264                    let x2 = x1 + (x1 / U256::from(100));
265
266                    // Perform a swap for the first sell amount (x1) and retrieve the received
267                    // amount (y1).
268                    let y1 = self
269                        .adapter_contract
270                        .swap(&self.id, t0, t1, false, x1, overwrites.clone())?
271                        .0
272                        .received_amount;
273
274                    // Perform a swap for the second sell amount (x2) and retrieve the received
275                    // amount (y2).
276                    let y2 = self
277                        .adapter_contract
278                        .swap(&self.id, t0, t1, false, x2, overwrites)?
279                        .0
280                        .received_amount;
281
282                    let sell_token_decimals = self.get_decimals(tokens, &t0)?;
283                    let buy_token_decimals = self.get_decimals(tokens, &t1)?;
284
285                    let num = y2 - y1;
286                    let den = x2 - x1;
287
288                    // Calculate the marginal price, adjusting for token decimals.
289                    let token_correction =
290                        10f64.powi(sell_token_decimals as i32 - buy_token_decimals as i32);
291                    let num_f64 = u256_to_f64(num)?;
292                    let den_f64 = u256_to_f64(den)?;
293                    if den_f64 == 0.0 {
294                        return Err(SimulationError::FatalError(
295                            "Failed to compute marginal price: denominator converted to 0".into(),
296                        ));
297                    }
298                    let marginal_price = num_f64 / den_f64 * token_correction;
299
300                    self.spot_prices
301                        .insert((t0, t1), marginal_price);
302                }
303            }
304            Err(e) => return Err(e),
305        }
306
307        Ok(())
308    }
309
310    fn get_decimals(
311        &self,
312        tokens: &HashMap<Bytes, Token>,
313        sell_token_address: &Address,
314    ) -> Result<usize, SimulationError> {
315        tokens
316            .get(&Bytes::from(sell_token_address.as_slice()))
317            .map(|t| t.decimals as usize)
318            .ok_or_else(|| {
319                SimulationError::FatalError(format!(
320                    "Failed to scale spot prices! Pool: {} Token 0x{:x} is not available!",
321                    self.id, sell_token_address
322                ))
323            })
324    }
325
326    /// Retrieves the sell and buy amount limit for a given pair of tokens and the given overwrites.
327    ///
328    /// Attempting to swap an amount of the sell token that exceeds the sell amount limit is not
329    /// advised and in most cases will result in a revert.
330    ///
331    /// # Arguments
332    ///
333    /// * `tokens` - A vec of tokens, where the first token is the sell token and the second is the
334    ///   buy token. The order of tokens in the input vector is significant and determines the
335    ///   direction of the price query.
336    /// * `overwrites` - A hashmap of overwrites to apply to the simulation.
337    ///
338    /// # Returns
339    ///
340    /// * `Result<(U256,U256), SimulationError>` - Returns the sell and buy amount limit as a `U256`
341    ///   if successful, or a `SimulationError` on failure.
342    fn get_amount_limits(
343        &self,
344        tokens: Vec<Address>,
345        overwrites: Option<HashMap<Address, HashMap<U256, U256>>>,
346    ) -> Result<(U256, U256), SimulationError> {
347        let limits = self
348            .adapter_contract
349            .get_limits(&self.id, tokens[0], tokens[1], overwrites)?;
350
351        Ok(limits)
352    }
353
354    /// Updates the pool state.
355    ///
356    /// It is assumed this is called on a new block. Therefore, first the pool's overwrites cache is
357    /// cleared, then the balances are updated and the spot prices are recalculated.
358    ///
359    /// # Arguments
360    ///
361    /// * `tokens` - A hashmap of token addresses to `Token` instances. This is necessary for
362    ///   calculating new spot prices.
363    /// * `balances` - A `Balances` instance containing all balance updates on the current block.
364    fn update_pool_state(
365        &mut self,
366        tokens: &HashMap<Bytes, Token>,
367        balances: &Balances,
368    ) -> Result<(), SimulationError> {
369        // clear cache
370        self.adapter_contract
371            .engine
372            .clear_temp_storage()
373            .map_err(|err| {
374                SimulationError::FatalError(format!("Failed to clear temporary storage: {err:?}",))
375            })?;
376        self.block_lasting_overwrites.clear();
377
378        // set balances
379        if !self.balances.is_empty() {
380            // Pool uses component balances for overwrites
381            if let Some(bals) = balances
382                .component_balances
383                .get(&self.id)
384            {
385                // Merge delta balances with existing balances instead of replacing them
386                // Prevents errors when delta balance changes do not affect all the pool tokens.
387                for (token, bal) in bals {
388                    let addr = bytes_to_address(token).map_err(|_| {
389                        SimulationError::FatalError(format!(
390                            "Invalid token address in balance update: {token:?}"
391                        ))
392                    })?;
393                    self.balances
394                        .insert(addr, U256::from_be_slice(bal));
395                }
396            }
397        } else {
398            // Pool uses contract balances for overwrites
399            for contract in &self.involved_contracts {
400                if let Some(bals) = balances
401                    .account_balances
402                    .get(&Bytes::from(contract.as_slice()))
403                {
404                    let contract_entry = self
405                        .contract_balances
406                        .entry(*contract)
407                        .or_default();
408                    for (token, bal) in bals {
409                        let addr = bytes_to_address(token).map_err(|_| {
410                            SimulationError::FatalError(format!(
411                                "Invalid token address in balance update: {token:?}"
412                            ))
413                        })?;
414                        contract_entry.insert(addr, U256::from_be_slice(bal));
415                    }
416                }
417            }
418        }
419
420        // reset spot prices
421        self.set_spot_prices(tokens)?;
422        Ok(())
423    }
424
425    fn get_overwrites(
426        &self,
427        tokens: Vec<Address>,
428        max_amount: U256,
429    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
430        let token_overwrites = self.get_token_overwrites(tokens, max_amount)?;
431
432        // Merge `block_lasting_overwrites` with `token_overwrites`
433        let merged_overwrites =
434            self.merge(&self.block_lasting_overwrites.clone(), &token_overwrites);
435
436        Ok(merged_overwrites)
437    }
438
439    fn get_token_overwrites(
440        &self,
441        tokens: Vec<Address>,
442        max_amount: U256,
443    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
444        let sell_token = &tokens[0].clone(); //TODO: need to make it clearer from the interface
445        let mut res: Vec<HashMap<Address, Overwrites>> = Vec::new();
446        if !self
447            .capabilities
448            .contains(&Capability::TokenBalanceIndependent)
449        {
450            res.push(self.get_balance_overwrites()?);
451        }
452
453        let mut overwrites = TokenProxyOverwriteFactory::new(*sell_token, None);
454
455        overwrites.set_balance(max_amount, Address::from_slice(&*EXTERNAL_ACCOUNT.0));
456
457        // Set allowance for adapter_address to max_amount
458        overwrites.set_allowance(max_amount, self.adapter_contract.address, *EXTERNAL_ACCOUNT);
459
460        res.push(overwrites.get_overwrites());
461
462        // Merge all overwrites into a single HashMap
463        Ok(res
464            .into_iter()
465            .fold(HashMap::new(), |acc, overwrite| self.merge(&acc, &overwrite)))
466    }
467
468    /// Gets all balance overwrites for the pool's tokens.
469    ///
470    /// If the pool uses component balances, the balances are set for the balance owner (if exists)
471    /// or for the pool itself. If the pool uses contract balances, the balances are set for the
472    /// contracts involved in the pool.
473    ///
474    /// # Returns
475    ///
476    /// * `Result<HashMap<Address, Overwrites>, SimulationError>` - Returns a hashmap of address to
477    ///   `Overwrites` if successful, or a `SimulationError` on failure.
478    fn get_balance_overwrites(&self) -> Result<HashMap<Address, Overwrites>, SimulationError> {
479        let mut balance_overwrites: HashMap<Address, Overwrites> = HashMap::new();
480
481        // Use component balances for overrides
482        let address = match self.balance_owner {
483            Some(owner) => Some(owner),
484            None if !self.contract_balances.is_empty() => None,
485            None => Some(self.id.parse().map_err(|_| {
486                SimulationError::FatalError(
487                    "Failed to get balance overwrites: Pool ID is not an address".into(),
488                )
489            })?),
490        };
491
492        if let Some(address) = address {
493            // Only override balances that are explicitly provided in self.balances
494            // This preserves existing balances for tokens not updated in delta transitions
495            for (token, bal) in &self.balances {
496                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
497                overwrites.set_balance(*bal, address);
498                balance_overwrites.extend(overwrites.get_overwrites());
499            }
500        }
501
502        // Use contract balances for overrides (will overwrite component balances if they were set
503        // for a contract we explicitly track balances for)
504        for (contract, balances) in &self.contract_balances {
505            for (token, balance) in balances {
506                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
507                overwrites.set_balance(*balance, *contract);
508                balance_overwrites.extend(overwrites.get_overwrites());
509            }
510        }
511
512        // Apply disables for tokens that should not have any balance overrides
513        for token in &self.disable_overwrite_tokens {
514            balance_overwrites.remove(token);
515        }
516
517        Ok(balance_overwrites)
518    }
519
520    fn merge(
521        &self,
522        target: &HashMap<Address, Overwrites>,
523        source: &HashMap<Address, Overwrites>,
524    ) -> HashMap<Address, Overwrites> {
525        let mut merged = target.clone();
526
527        for (key, source_inner) in source {
528            merged
529                .entry(*key)
530                .or_default()
531                .extend(source_inner.clone());
532        }
533
534        merged
535    }
536
537    #[cfg(test)]
538    pub fn get_involved_contracts(&self) -> HashSet<Address> {
539        self.involved_contracts.clone()
540    }
541
542    #[cfg(test)]
543    pub fn get_manual_updates(&self) -> bool {
544        self.manual_updates
545    }
546
547    #[cfg(test)]
548    #[deprecated]
549    pub fn get_balance_owner(&self) -> Option<Address> {
550        self.balance_owner
551    }
552
553    /// Get the component balances for validation purposes
554    pub fn get_balances(&self) -> &HashMap<Address, U256> {
555        &self.balances
556    }
557}
558
559impl<D> Serialize for EVMPoolState<D>
560where
561    D: EngineDatabaseInterface + Clone + Debug,
562    <D as DatabaseRef>::Error: Debug,
563    <D as EngineDatabaseInterface>::Error: Debug,
564{
565    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
566    where
567        S: serde::Serializer,
568    {
569        Err(serde::ser::Error::custom("not supported due vm state deps"))
570    }
571}
572
573impl<'de, D> Deserialize<'de> for EVMPoolState<D>
574where
575    D: EngineDatabaseInterface + Clone + Debug,
576    <D as DatabaseRef>::Error: Debug,
577    <D as EngineDatabaseInterface>::Error: Debug,
578{
579    fn deserialize<De>(_deserializer: De) -> Result<Self, De::Error>
580    where
581        De: serde::Deserializer<'de>,
582    {
583        Err(serde::de::Error::custom("not supported due vm state deps"))
584    }
585}
586
587#[typetag::serialize]
588impl<D> ProtocolSim for EVMPoolState<D>
589where
590    D: EngineDatabaseInterface + Clone + Debug + 'static,
591    <D as DatabaseRef>::Error: Debug,
592    <D as EngineDatabaseInterface>::Error: Debug,
593{
594    fn fee(&self) -> f64 {
595        todo!()
596    }
597
598    fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
599        let base_address = bytes_to_address(&base.address)?;
600        let quote_address = bytes_to_address(&quote.address)?;
601        self.spot_prices
602            .get(&(base_address, quote_address))
603            .cloned()
604            .ok_or(SimulationError::FatalError(format!(
605                "Spot price not found for base token {base_address} and quote token {quote_address}"
606            )))
607    }
608
609    fn get_amount_out(
610        &self,
611        amount_in: BigUint,
612        token_in: &Token,
613        token_out: &Token,
614    ) -> Result<GetAmountOutResult, SimulationError> {
615        let sell_token_address = bytes_to_address(&token_in.address)?;
616        let buy_token_address = bytes_to_address(&token_out.address)?;
617        let sell_amount = U256::from_be_slice(&amount_in.to_bytes_be());
618        let overwrites = self.get_overwrites(
619            vec![sell_token_address, buy_token_address],
620            *MAX_BALANCE / U256::from(100),
621        )?;
622        let (sell_amount_limit, _) = self.get_amount_limits(
623            vec![sell_token_address, buy_token_address],
624            Some(overwrites.clone()),
625        )?;
626        let (sell_amount_respecting_limit, sell_amount_exceeds_limit) = if self
627            .capabilities
628            .contains(&Capability::HardLimits) &&
629            sell_amount_limit < sell_amount
630        {
631            (sell_amount_limit, true)
632        } else {
633            (sell_amount, false)
634        };
635
636        let overwrites_with_sell_limit =
637            self.get_overwrites(vec![sell_token_address, buy_token_address], sell_amount_limit)?;
638        let complete_overwrites = self.merge(&overwrites, &overwrites_with_sell_limit);
639
640        let (trade, state_changes) = self.adapter_contract.swap(
641            &self.id,
642            sell_token_address,
643            buy_token_address,
644            false,
645            sell_amount_respecting_limit,
646            Some(complete_overwrites),
647        )?;
648
649        let mut new_state = self.clone();
650
651        // Apply state changes to the new state
652        for (address, state_update) in state_changes {
653            if let Some(storage) = state_update.storage {
654                let block_overwrites = new_state
655                    .block_lasting_overwrites
656                    .entry(address)
657                    .or_default();
658                for (slot, value) in storage {
659                    let slot = U256::from_str(&slot.to_string()).map_err(|_| {
660                        SimulationError::FatalError("Failed to decode slot index".to_string())
661                    })?;
662                    let value = U256::from_str(&value.to_string()).map_err(|_| {
663                        SimulationError::FatalError("Failed to decode slot overwrite".to_string())
664                    })?;
665                    block_overwrites.insert(slot, value);
666                }
667            }
668        }
669
670        // Update spot prices
671        let tokens = HashMap::from([
672            (token_in.address.clone(), token_in.clone()),
673            (token_out.address.clone(), token_out.clone()),
674        ]);
675        let _ = new_state.set_spot_prices(&tokens);
676
677        let buy_amount = trade.received_amount;
678
679        if sell_amount_exceeds_limit {
680            return Err(SimulationError::InvalidInput(
681                format!("Sell amount exceeds limit {sell_amount_limit}"),
682                Some(GetAmountOutResult::new(
683                    u256_to_biguint(buy_amount),
684                    u256_to_biguint(trade.gas_used),
685                    Box::new(new_state.clone()),
686                )),
687            ));
688        }
689        Ok(GetAmountOutResult::new(
690            u256_to_biguint(buy_amount),
691            u256_to_biguint(trade.gas_used),
692            Box::new(new_state.clone()),
693        ))
694    }
695
696    fn get_limits(
697        &self,
698        sell_token: Bytes,
699        buy_token: Bytes,
700    ) -> Result<(BigUint, BigUint), SimulationError> {
701        let sell_token = bytes_to_address(&sell_token)?;
702        let buy_token = bytes_to_address(&buy_token)?;
703        let overwrites =
704            self.get_overwrites(vec![sell_token, buy_token], *MAX_BALANCE / U256::from(100))?;
705        let limits = self.get_amount_limits(vec![sell_token, buy_token], Some(overwrites))?;
706        Ok((u256_to_biguint(limits.0), u256_to_biguint(limits.1)))
707    }
708
709    fn delta_transition(
710        &mut self,
711        delta: ProtocolStateDelta,
712        tokens: &HashMap<Bytes, Token>,
713        balances: &Balances,
714    ) -> Result<(), TransitionError> {
715        if self.manual_updates {
716            // Directly check for "update_marker" in `updated_attributes`
717            if let Some(marker) = delta
718                .updated_attributes
719                .get("update_marker")
720            {
721                // Assuming `marker` is of type `Bytes`, check its value for "truthiness"
722                if !marker.is_empty() && marker[0] != 0 {
723                    self.update_pool_state(tokens, balances)?;
724                }
725            }
726        } else {
727            self.update_pool_state(tokens, balances)?;
728        }
729
730        Ok(())
731    }
732
733    fn query_pool_swap(
734        &self,
735        params: &tycho_common::simulation::protocol_sim::QueryPoolSwapParams,
736    ) -> Result<tycho_common::simulation::protocol_sim::PoolSwap, SimulationError> {
737        crate::evm::query_pool_swap::query_pool_swap(self, params)
738    }
739
740    fn clone_box(&self) -> Box<dyn ProtocolSim> {
741        Box::new(self.clone())
742    }
743
744    fn as_any(&self) -> &dyn Any {
745        self
746    }
747
748    fn as_any_mut(&mut self) -> &mut dyn Any {
749        self
750    }
751
752    fn eq(&self, other: &dyn ProtocolSim) -> bool {
753        if let Some(other_state) = other
754            .as_any()
755            .downcast_ref::<EVMPoolState<PreCachedDB>>()
756        {
757            self.id == other_state.id
758        } else {
759            false
760        }
761    }
762
763    /// Implemented manually because `typetag` macro not supports generics
764    fn typetag_deserialize(&self) {
765        // https://github.com/dtolnay/typetag/blob/21ae0d40c9f73443a20204ab4a134441355b52f7/impl/src/tagged_trait.rs#L140
766        unreachable!("Only to catch missing typetag attribute on impl blocks. Not called.")
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use std::default::Default;
773
774    use num_traits::One;
775    use revm::{
776        primitives::KECCAK_EMPTY,
777        state::{AccountInfo, Bytecode},
778    };
779    use serde_json::Value;
780    use tycho_client::feed::BlockHeader;
781    use tycho_common::models::Chain;
782
783    use super::*;
784    use crate::evm::{
785        engine_db::{create_engine, SHARED_TYCHO_DB},
786        protocol::vm::{
787            constants::{BALANCER_V2, ERC20_PROXY_BYTECODE},
788            state_builder::EVMPoolStateBuilder,
789        },
790        simulation::SimulationEngine,
791        tycho_models::AccountUpdate,
792    };
793
794    fn dai() -> Token {
795        Token::new(
796            &Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
797            "DAI",
798            18,
799            0,
800            &[Some(10_000)],
801            Chain::Ethereum,
802            100,
803        )
804    }
805
806    fn bal() -> Token {
807        Token::new(
808            &Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(),
809            "BAL",
810            18,
811            0,
812            &[Some(10_000)],
813            Chain::Ethereum,
814            100,
815        )
816    }
817
818    fn dai_addr() -> Address {
819        bytes_to_address(&dai().address).unwrap()
820    }
821
822    fn bal_addr() -> Address {
823        bytes_to_address(&bal().address).unwrap()
824    }
825
826    async fn setup_pool_state() -> EVMPoolState<PreCachedDB> {
827        let data_str = include_str!("assets/balancer_contract_storage_block_20463609.json");
828        let data: Value = serde_json::from_str(data_str).expect("Failed to parse JSON");
829
830        let accounts: Vec<AccountUpdate> = serde_json::from_value(data["accounts"].clone())
831            .expect("Expected accounts to match AccountUpdate structure");
832
833        let db = SHARED_TYCHO_DB.clone();
834        let engine: SimulationEngine<_> = create_engine(db.clone(), false).unwrap();
835
836        let block = BlockHeader {
837            number: 20463609,
838            hash: Bytes::from_str(
839                "0x4315fd1afc25cc2ebc72029c543293f9fd833eeb305e2e30159459c827733b1b",
840            )
841            .unwrap(),
842            timestamp: 1722875891,
843            ..Default::default()
844        };
845
846        for account in accounts.clone() {
847            engine
848                .state
849                .init_account(
850                    account.address,
851                    AccountInfo {
852                        balance: account.balance.unwrap_or_default(),
853                        nonce: 0u64,
854                        code_hash: KECCAK_EMPTY,
855                        code: account
856                            .code
857                            .clone()
858                            .map(|arg0: Vec<u8>| Bytecode::new_raw(arg0.into())),
859                    },
860                    None,
861                    false,
862                )
863                .expect("Failed to initialize account");
864        }
865        db.update(accounts, Some(block))
866            .unwrap();
867
868        let tokens = vec![dai().address, bal().address];
869        for token in &tokens {
870            engine
871                .state
872                .init_account(
873                    bytes_to_address(token).unwrap(),
874                    AccountInfo {
875                        balance: U256::from(0),
876                        nonce: 0,
877                        code_hash: KECCAK_EMPTY,
878                        code: Some(Bytecode::new_raw(ERC20_PROXY_BYTECODE.into())),
879                    },
880                    None,
881                    true,
882                )
883                .expect("Failed to initialize account");
884        }
885
886        let block = BlockHeader {
887            number: 18485417,
888            hash: Bytes::from_str(
889                "0x28d41d40f2ac275a4f5f621a636b9016b527d11d37d610a45ac3a821346ebf8c",
890            )
891            .expect("Invalid block hash"),
892            timestamp: 0,
893            ..Default::default()
894        };
895        db.update(vec![], Some(block.clone()))
896            .unwrap();
897
898        let pool_id: String =
899            "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
900
901        let stateless_contracts = HashMap::from([(
902            String::from("0x3de27efa2f1aa663ae5d458857e731c129069f29"),
903            Some(Vec::new()),
904        )]);
905
906        let balances = HashMap::from([
907            (dai_addr(), U256::from_str("178754012737301807104").unwrap()),
908            (bal_addr(), U256::from_str("91082987763369885696").unwrap()),
909        ]);
910        let adapter_address =
911            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
912
913        EVMPoolStateBuilder::new(pool_id, tokens, adapter_address)
914            .balances(balances)
915            .balance_owner(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap())
916            .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
917            .stateless_contracts(stateless_contracts)
918            .build(SHARED_TYCHO_DB.clone())
919            .await
920            .expect("Failed to build pool state")
921    }
922
923    #[tokio::test]
924    async fn test_init() {
925        // Clear DB from this test to prevent interference from other tests
926        SHARED_TYCHO_DB
927            .clear()
928            .expect("Failed to cleared SHARED TX");
929        let pool_state = setup_pool_state().await;
930
931        let expected_capabilities = vec![
932            Capability::SellSide,
933            Capability::BuySide,
934            Capability::PriceFunction,
935            Capability::HardLimits,
936        ]
937        .into_iter()
938        .collect::<HashSet<_>>();
939
940        let capabilities_adapter_contract = pool_state
941            .adapter_contract
942            .get_capabilities(
943                &pool_state.id,
944                bytes_to_address(&pool_state.tokens[0]).unwrap(),
945                bytes_to_address(&pool_state.tokens[1]).unwrap(),
946            )
947            .unwrap();
948
949        assert_eq!(capabilities_adapter_contract, expected_capabilities.clone());
950
951        let capabilities_state = pool_state.clone().capabilities;
952
953        assert_eq!(capabilities_state, expected_capabilities.clone());
954
955        for capability in expected_capabilities.clone() {
956            assert!(pool_state
957                .clone()
958                .ensure_capability(capability)
959                .is_ok());
960        }
961
962        assert!(pool_state
963            .clone()
964            .ensure_capability(Capability::MarginalPrice)
965            .is_err());
966
967        // Verify all tokens are initialized in the engine
968        let engine_accounts = pool_state
969            .adapter_contract
970            .engine
971            .state
972            .clone()
973            .get_account_storage()
974            .expect("Failed to get account storage");
975        for token in pool_state.tokens.clone() {
976            let account = engine_accounts
977                .get_account_info(&bytes_to_address(&token).unwrap())
978                .unwrap();
979            assert_eq!(account.balance, U256::from(0));
980            assert_eq!(account.nonce, 0u64);
981            assert_eq!(account.code_hash, KECCAK_EMPTY);
982            assert!(account.code.is_some());
983        }
984
985        // Verify external account is initialized in the engine
986        let external_account = engine_accounts
987            .get_account_info(&EXTERNAL_ACCOUNT)
988            .unwrap();
989        assert_eq!(external_account.balance, U256::from(*MAX_BALANCE));
990        assert_eq!(external_account.nonce, 0u64);
991        assert_eq!(external_account.code_hash, KECCAK_EMPTY);
992        assert!(external_account.code.is_none());
993    }
994
995    #[tokio::test]
996    async fn test_get_amount_out() -> Result<(), Box<dyn std::error::Error>> {
997        let pool_state = setup_pool_state().await;
998
999        let result = pool_state
1000            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1001            .unwrap();
1002        let new_state = result
1003            .new_state
1004            .as_any()
1005            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1006            .unwrap();
1007        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1008        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1009        assert!(pool_state
1010            .block_lasting_overwrites
1011            .is_empty());
1012        Ok(())
1013    }
1014
1015    #[tokio::test]
1016    async fn test_sequential_get_amount_outs() {
1017        let pool_state = setup_pool_state().await;
1018
1019        let result = pool_state
1020            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1021            .unwrap();
1022        let new_state = result
1023            .new_state
1024            .as_any()
1025            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1026            .unwrap();
1027        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1028        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1029
1030        let new_result = new_state
1031            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1032            .unwrap();
1033        let new_state_second_swap = new_result
1034            .new_state
1035            .as_any()
1036            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1037            .unwrap();
1038
1039        assert_eq!(new_result.amount, BigUint::from_str("136964651490065626").unwrap());
1040        assert_ne!(new_state_second_swap.spot_prices, new_state.spot_prices);
1041    }
1042
1043    #[tokio::test]
1044    async fn test_get_amount_out_dust() {
1045        let pool_state = setup_pool_state().await;
1046
1047        let result = pool_state
1048            .get_amount_out(BigUint::one(), &dai(), &bal())
1049            .unwrap();
1050
1051        let _ = result
1052            .new_state
1053            .as_any()
1054            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1055            .unwrap();
1056        assert_eq!(result.amount, BigUint::ZERO);
1057    }
1058
1059    #[tokio::test]
1060    async fn test_get_amount_out_sell_limit() {
1061        let pool_state = setup_pool_state().await;
1062
1063        let result = pool_state.get_amount_out(
1064            // sell limit is 100279494253364362835
1065            BigUint::from_str("100379494253364362835").unwrap(),
1066            &dai(),
1067            &bal(),
1068        );
1069
1070        assert!(result.is_err());
1071
1072        match result {
1073            Err(SimulationError::InvalidInput(msg1, amount_out_result)) => {
1074                assert_eq!(msg1, "Sell amount exceeds limit 100279494253364362835");
1075                assert!(amount_out_result.is_some());
1076            }
1077            _ => panic!("Test failed: was expecting an Err(SimulationError::RetryDifferentInput(_, _)) value"),
1078        }
1079    }
1080
1081    #[tokio::test]
1082    async fn test_get_amount_limits() {
1083        let pool_state = setup_pool_state().await;
1084
1085        let overwrites = pool_state
1086            .get_overwrites(
1087                vec![
1088                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1089                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1090                ],
1091                *MAX_BALANCE / U256::from(100),
1092            )
1093            .unwrap();
1094        let (dai_limit, _) = pool_state
1095            .get_amount_limits(vec![dai_addr(), bal_addr()], Some(overwrites.clone()))
1096            .unwrap();
1097        assert_eq!(dai_limit, U256::from_str("100279494253364362835").unwrap());
1098
1099        let (bal_limit, _) = pool_state
1100            .get_amount_limits(
1101                vec![
1102                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1103                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1104                ],
1105                Some(overwrites),
1106            )
1107            .unwrap();
1108        assert_eq!(bal_limit, U256::from_str("13997408640689987484").unwrap());
1109    }
1110
1111    #[tokio::test]
1112    async fn test_set_spot_prices() {
1113        let mut pool_state = setup_pool_state().await;
1114
1115        pool_state
1116            .set_spot_prices(
1117                &vec![bal(), dai()]
1118                    .into_iter()
1119                    .map(|t| (t.address.clone(), t))
1120                    .collect(),
1121            )
1122            .unwrap();
1123
1124        let dai_bal_spot_price = pool_state
1125            .spot_prices
1126            .get(&(
1127                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1128                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1129            ))
1130            .unwrap();
1131        let bal_dai_spot_price = pool_state
1132            .spot_prices
1133            .get(&(
1134                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1135                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1136            ))
1137            .unwrap();
1138        assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9);
1139        assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);
1140    }
1141
1142    #[tokio::test]
1143    async fn test_set_spot_prices_without_capability() {
1144        // Tests set Spot Prices functions when the pool doesn't have PriceFunction capability
1145        let mut pool_state = setup_pool_state().await;
1146
1147        pool_state
1148            .capabilities
1149            .remove(&Capability::PriceFunction);
1150
1151        pool_state
1152            .set_spot_prices(
1153                &vec![bal(), dai()]
1154                    .into_iter()
1155                    .map(|t| (t.address.clone(), t))
1156                    .collect(),
1157            )
1158            .unwrap();
1159
1160        let dai_bal_spot_price = pool_state
1161            .spot_prices
1162            .get(&(
1163                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1164                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1165            ))
1166            .unwrap();
1167        let bal_dai_spot_price = pool_state
1168            .spot_prices
1169            .get(&(
1170                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1171                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1172            ))
1173            .unwrap();
1174        assert_eq!(dai_bal_spot_price, &0.13736685496467538);
1175        assert_eq!(bal_dai_spot_price, &7.050354297665408);
1176    }
1177
1178    #[tokio::test]
1179    async fn test_get_balance_overwrites_with_component_balances() {
1180        let pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1181
1182        let overwrites = pool_state
1183            .get_balance_overwrites()
1184            .unwrap();
1185
1186        let dai_address = dai_addr();
1187        let bal_address = bal_addr();
1188        assert!(overwrites.contains_key(&dai_address));
1189        assert!(overwrites.contains_key(&bal_address));
1190    }
1191
1192    #[tokio::test]
1193    async fn test_get_balance_overwrites_with_contract_balances() {
1194        let mut pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1195
1196        let contract_address =
1197            Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap();
1198
1199        // Ensure no component balances are used
1200        pool_state.balances.clear();
1201        pool_state.balance_owner = None;
1202
1203        // Set contract balances
1204        let dai_address = dai_addr();
1205        let bal_address = bal_addr();
1206        pool_state.contract_balances = HashMap::from([(
1207            contract_address,
1208            HashMap::from([
1209                (dai_address, U256::from_str("7500000000000000000000").unwrap()), // 7500 DAI
1210                (bal_address, U256::from_str("1500000000000000000000").unwrap()), // 1500 BAL
1211            ]),
1212        )]);
1213
1214        let overwrites = pool_state
1215            .get_balance_overwrites()
1216            .unwrap();
1217
1218        assert!(overwrites.contains_key(&dai_address));
1219        assert!(overwrites.contains_key(&bal_address));
1220    }
1221
1222    #[tokio::test]
1223    async fn test_balance_merging_during_delta_transition() {
1224        use std::str::FromStr;
1225
1226        let mut pool_state = setup_pool_state().await;
1227        let pool_id = pool_state.id.clone();
1228
1229        // Test the balance merging logic more directly
1230        // Setup initial balances including DAI and BAL (which the pool already knows about)
1231        let dai_addr = dai_addr();
1232        let bal_addr = bal_addr();
1233        let new_token = Address::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); // WETH
1234
1235        // Clear and setup clean initial state
1236        pool_state.balances.clear();
1237        pool_state
1238            .balances
1239            .insert(dai_addr, U256::from(1000000000u64));
1240        pool_state
1241            .balances
1242            .insert(bal_addr, U256::from(2000000000u64));
1243        pool_state
1244            .balances
1245            .insert(new_token, U256::from(3000000000u64));
1246
1247        // Create tokens mapping including the existing DAI and BAL
1248        let mut tokens = HashMap::new();
1249        tokens.insert(dai().address.clone(), dai());
1250        tokens.insert(bal().address.clone(), bal());
1251
1252        // Simulate a delta transition with only DAI balance update (missing BAL and new_token)
1253        let mut component_balances = HashMap::new();
1254        let mut delta_balances = HashMap::new();
1255        // Only update DAI balance, leave others unchanged in delta
1256        delta_balances.insert(dai().address.clone(), Bytes::from(vec![0x77, 0x35, 0x94, 0x00])); // 2000000000 (updated value)
1257        component_balances.insert(pool_id.clone(), delta_balances);
1258
1259        let balances = Balances { component_balances, account_balances: HashMap::new() };
1260
1261        // Record initial balance count
1262        let initial_balance_count = pool_state.balances.len();
1263        assert_eq!(initial_balance_count, 3);
1264
1265        // Apply delta transition
1266        pool_state
1267            .update_pool_state(&tokens, &balances)
1268            .unwrap();
1269
1270        // Verify that all 3 balances are preserved (BAL and new_token should still be there)
1271        assert_eq!(
1272            pool_state.balances.len(),
1273            3,
1274            "All balances should be preserved after delta transition"
1275        );
1276        assert!(
1277            pool_state
1278                .balances
1279                .contains_key(&dai_addr),
1280            "DAI balance should be present"
1281        );
1282        assert!(
1283            pool_state
1284                .balances
1285                .contains_key(&bal_addr),
1286            "BAL balance should be present"
1287        );
1288        assert!(
1289            pool_state
1290                .balances
1291                .contains_key(&new_token),
1292            "New token balance should be preserved from before delta"
1293        );
1294
1295        // Verify that updated token (DAI) has new value
1296        assert_eq!(
1297            pool_state.balances[&dai_addr],
1298            U256::from(2000000000u64),
1299            "DAI balance should be updated"
1300        );
1301
1302        // Verify that non-updated tokens retain their original values
1303        assert_eq!(
1304            pool_state.balances[&bal_addr],
1305            U256::from(2000000000u64),
1306            "BAL balance should be unchanged"
1307        );
1308        assert_eq!(
1309            pool_state.balances[&new_token],
1310            U256::from(3000000000u64),
1311            "New token balance should be unchanged"
1312        );
1313    }
1314
1315    #[test]
1316    fn should_not_panic_at_typetag_deserialize() {
1317        let deserialized: Result<Box<dyn ProtocolSim>, _> = serde_json::from_str(
1318            r#"{"protocol":"EVMPoolState","state":{"reserve_0":1,"reserve_1":2}}"#,
1319        );
1320
1321        assert!(deserialized.is_err());
1322    }
1323}