Skip to main content

tycho_simulation/evm/protocol/vm/
state.rs

1#![allow(deprecated)]
2use std::{
3    any::Any,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    str::FromStr,
7};
8
9use alloy::primitives::{Address, U256};
10use itertools::Itertools;
11use num_bigint::BigUint;
12use revm::DatabaseRef;
13use serde::{Deserialize, Serialize};
14use tycho_common::{
15    dto::ProtocolStateDelta,
16    models::token::Token,
17    simulation::{
18        errors::{SimulationError, TransitionError},
19        protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
20    },
21    Bytes,
22};
23
24use super::{
25    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
26    erc20_token::{Overwrites, TokenProxyOverwriteFactory},
27    models::Capability,
28    tycho_simulation_contract::TychoSimulationContract,
29};
30use crate::evm::{
31    engine_db::{engine_db_interface::EngineDatabaseInterface, tycho_db::PreCachedDB},
32    protocol::{
33        u256_num::{u256_to_biguint, u256_to_f64},
34        utils::bytes_to_address,
35    },
36};
37
38#[derive(Clone)]
39pub struct EVMPoolState<D: EngineDatabaseInterface + Clone + Debug>
40where
41    <D as DatabaseRef>::Error: Debug,
42    <D as EngineDatabaseInterface>::Error: Debug,
43{
44    /// The pool's identifier
45    id: String,
46    /// The pool's token's addresses
47    pub tokens: Vec<Bytes>,
48    /// The pool's component balances.
49    balances: HashMap<Address, U256>,
50    /// The contract address for where protocol balances are stored (i.e. a vault contract).
51    /// If given, balances will be overwritten here instead of on the pool contract during
52    /// simulations. This has been deprecated in favor of `contract_balances`.
53    #[deprecated(note = "Use contract_balances instead")]
54    balance_owner: Option<Address>,
55    /// Spot prices of the pool by token pair
56    spot_prices: HashMap<(Address, Address), f64>,
57    /// The supported capabilities of this pool
58    capabilities: HashSet<Capability>,
59    /// Storage overwrites that will be applied to all simulations. They will be cleared
60    /// when ``update_pool_state`` is called, i.e. usually at each block. Hence, the name.
61    block_lasting_overwrites: HashMap<Address, Overwrites>,
62    /// A set of all contract addresses involved in the simulation of this pool.
63    involved_contracts: HashSet<Address>,
64    /// A map of contracts to their token balances.
65    contract_balances: HashMap<Address, HashMap<Address, U256>>,
66    /// Indicates if the protocol uses custom update rules and requires update
67    /// triggers to recalculate spot prices ect. Default is to update on all changes on
68    /// the pool.
69    manual_updates: bool,
70    /// The adapter contract. This is used to interact with the protocol when running simulations
71    adapter_contract: TychoSimulationContract<D>,
72    /// Tokens for which balance overwrites should be disabled.
73    disable_overwrite_tokens: HashSet<Address>,
74}
75
76impl<D> Debug for EVMPoolState<D>
77where
78    D: EngineDatabaseInterface + Clone + Debug,
79    <D as DatabaseRef>::Error: Debug,
80    <D as EngineDatabaseInterface>::Error: Debug,
81{
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("EVMPoolState")
84            .field("id", &self.id)
85            .field("tokens", &self.tokens)
86            .field("balances", &self.balances)
87            .field("involved_contracts", &self.involved_contracts)
88            .field("contract_balances", &self.contract_balances)
89            .finish_non_exhaustive()
90    }
91}
92
93impl<D> EVMPoolState<D>
94where
95    D: EngineDatabaseInterface + Clone + Debug + 'static,
96    <D as DatabaseRef>::Error: Debug,
97    <D as EngineDatabaseInterface>::Error: Debug,
98{
99    /// Creates a new instance of `EVMPoolState` with the given attributes, with the ability to
100    /// simulate a protocol-agnostic transaction.
101    ///
102    /// See struct definition of `EVMPoolState` for attribute explanations.
103    #[allow(clippy::too_many_arguments)]
104    pub fn new(
105        id: String,
106        tokens: Vec<Bytes>,
107        component_balances: HashMap<Address, U256>,
108        balance_owner: Option<Address>,
109        contract_balances: HashMap<Address, HashMap<Address, U256>>,
110        spot_prices: HashMap<(Address, Address), f64>,
111        capabilities: HashSet<Capability>,
112        block_lasting_overwrites: HashMap<Address, Overwrites>,
113        involved_contracts: HashSet<Address>,
114        manual_updates: bool,
115        adapter_contract: TychoSimulationContract<D>,
116        disable_overwrite_tokens: HashSet<Address>,
117    ) -> Self {
118        Self {
119            id,
120            tokens,
121            balances: component_balances,
122            balance_owner,
123            spot_prices,
124            capabilities,
125            block_lasting_overwrites,
126            involved_contracts,
127            contract_balances,
128            manual_updates,
129            adapter_contract,
130            disable_overwrite_tokens,
131        }
132    }
133
134    /// Ensures the pool supports the given capability
135    ///
136    /// # Arguments
137    ///
138    /// * `capability` - The capability that we would like to check for.
139    ///
140    /// # Returns
141    ///
142    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the capability is supported, or a
143    ///   `SimulationError` otherwise.
144    fn ensure_capability(&self, capability: Capability) -> Result<(), SimulationError> {
145        if !self.capabilities.contains(&capability) {
146            return Err(SimulationError::FatalError(format!(
147                "capability {:?} not supported",
148                capability.to_string()
149            )));
150        }
151        Ok(())
152    }
153    /// Sets the spot prices for a pool for all possible pairs of the given tokens.
154    ///
155    /// # Arguments
156    ///
157    /// * `tokens` - A hashmap of `Token` instances representing the tokens to calculate spot prices
158    ///   for.
159    ///
160    /// # Returns
161    ///
162    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the spot prices are successfully set,
163    ///   or a `SimulationError` if an error occurs during the calculation or processing.
164    ///
165    /// # Behavior
166    ///
167    /// This function performs the following steps:
168    /// 1. Ensures the pool has the required capability to perform price calculations.
169    /// 2. Iterates over all permutations of token pairs (sell token and buy token). For each pair:
170    ///    - Retrieves all possible overwrites, considering the maximum balance limit.
171    ///    - Calculates the sell amount limit, considering the overwrites.
172    ///    - Invokes the adapter contract's `price` function to retrieve the calculated price for
173    ///      the token pair, considering the sell amount limit.
174    ///    - Processes the price based on whether the `ScaledPrice` capability is present:
175    ///       - If `ScaledPrice` is present, uses the price directly from the adapter contract.
176    ///       - If `ScaledPrice` is absent, scales the price by adjusting for token decimals.
177    ///    - Stores the calculated price in the `spot_prices` map with the token addresses as the
178    ///      key.
179    /// 3. Returns `Ok(())` upon successful completion or a `SimulationError` upon failure.
180    ///
181    /// # Usage
182    ///
183    /// Spot prices need to be set before attempting to retrieve prices using `spot_price`.
184    ///
185    /// Tip: Setting spot prices on the pool every time the pool actually changes will result in
186    /// faster price fetching than if prices are only set immediately before attempting to retrieve
187    /// prices.
188    pub fn set_spot_prices(
189        &mut self,
190        tokens: &HashMap<Bytes, Token>,
191    ) -> Result<(), SimulationError> {
192        match self.ensure_capability(Capability::PriceFunction) {
193            Ok(_) => {
194                for [sell_token_address, buy_token_address] in self
195                    .tokens
196                    .iter()
197                    .permutations(2)
198                    .map(|p| [p[0], p[1]])
199                {
200                    let sell_token_address = bytes_to_address(sell_token_address)?;
201                    let buy_token_address = bytes_to_address(buy_token_address)?;
202
203                    let overwrites = Some(self.get_overwrites(
204                        vec![sell_token_address, buy_token_address],
205                        *MAX_BALANCE / U256::from(100),
206                    )?);
207
208                    let (sell_amount_limit, _) = self.get_amount_limits(
209                        vec![sell_token_address, buy_token_address],
210                        overwrites.clone(),
211                    )?;
212                    let price_result = self.adapter_contract.price(
213                        &self.id,
214                        sell_token_address,
215                        buy_token_address,
216                        vec![sell_amount_limit / U256::from(100)],
217                        overwrites,
218                    )?;
219
220                    let price = if self
221                        .capabilities
222                        .contains(&Capability::ScaledPrice)
223                    {
224                        *price_result.first().ok_or_else(|| {
225                            SimulationError::FatalError(
226                                "Calculated price array is empty".to_string(),
227                            )
228                        })?
229                    } else {
230                        let unscaled_price = price_result.first().ok_or_else(|| {
231                            SimulationError::FatalError(
232                                "Calculated price array is empty".to_string(),
233                            )
234                        })?;
235                        let sell_token_decimals = self.get_decimals(tokens, &sell_token_address)?;
236                        let buy_token_decimals = self.get_decimals(tokens, &buy_token_address)?;
237                        *unscaled_price * 10f64.powi(sell_token_decimals as i32) /
238                            10f64.powi(buy_token_decimals as i32)
239                    };
240
241                    self.spot_prices
242                        .insert((sell_token_address, buy_token_address), price);
243                }
244            }
245            Err(SimulationError::FatalError(_)) => {
246                // If the pool does not support price function, we need to calculate spot prices by
247                // swapping two amounts and use the approximation to get the derivative.
248
249                for iter_tokens in self.tokens.iter().permutations(2) {
250                    let t0 = bytes_to_address(iter_tokens[0])?;
251                    let t1 = bytes_to_address(iter_tokens[1])?;
252
253                    let overwrites =
254                        Some(self.get_overwrites(vec![t0, t1], *MAX_BALANCE / U256::from(100))?);
255
256                    // Calculate the first sell amount (x1) as 1% of the maximum limit.
257                    let x1 = self
258                        .get_amount_limits(vec![t0, t1], overwrites.clone())?
259                        .0 /
260                        U256::from(100);
261
262                    // Calculate the second sell amount (x2) as x1 + 1% of x1. 1.01% of the max
263                    // limit
264                    let x2 = x1 + (x1 / U256::from(100));
265
266                    // Perform a swap for the first sell amount (x1) and retrieve the received
267                    // amount (y1).
268                    let y1 = self
269                        .adapter_contract
270                        .swap(&self.id, t0, t1, false, x1, overwrites.clone())?
271                        .0
272                        .received_amount;
273
274                    // Perform a swap for the second sell amount (x2) and retrieve the received
275                    // amount (y2).
276                    let y2 = self
277                        .adapter_contract
278                        .swap(&self.id, t0, t1, false, x2, overwrites)?
279                        .0
280                        .received_amount;
281
282                    let sell_token_decimals = self.get_decimals(tokens, &t0)?;
283                    let buy_token_decimals = self.get_decimals(tokens, &t1)?;
284
285                    let num = y2 - y1;
286                    let den = x2 - x1;
287
288                    // Calculate the marginal price, adjusting for token decimals.
289                    let token_correction =
290                        10f64.powi(sell_token_decimals as i32 - buy_token_decimals as i32);
291                    let num_f64 = u256_to_f64(num)?;
292                    let den_f64 = u256_to_f64(den)?;
293                    if den_f64 == 0.0 {
294                        return Err(SimulationError::FatalError(
295                            "Failed to compute marginal price: denominator converted to 0".into(),
296                        ));
297                    }
298                    let marginal_price = num_f64 / den_f64 * token_correction;
299
300                    self.spot_prices
301                        .insert((t0, t1), marginal_price);
302                }
303            }
304            Err(e) => return Err(e),
305        }
306
307        Ok(())
308    }
309
310    fn get_decimals(
311        &self,
312        tokens: &HashMap<Bytes, Token>,
313        sell_token_address: &Address,
314    ) -> Result<usize, SimulationError> {
315        tokens
316            .get(&Bytes::from(sell_token_address.as_slice()))
317            .map(|t| t.decimals as usize)
318            .ok_or_else(|| {
319                SimulationError::FatalError(format!(
320                    "Failed to scale spot prices! Pool: {} Token 0x{:x} is not available!",
321                    self.id, sell_token_address
322                ))
323            })
324    }
325
326    /// Retrieves the sell and buy amount limit for a given pair of tokens and the given overwrites.
327    ///
328    /// Attempting to swap an amount of the sell token that exceeds the sell amount limit is not
329    /// advised and in most cases will result in a revert.
330    ///
331    /// # Arguments
332    ///
333    /// * `tokens` - A vec of tokens, where the first token is the sell token and the second is the
334    ///   buy token. The order of tokens in the input vector is significant and determines the
335    ///   direction of the price query.
336    /// * `overwrites` - A hashmap of overwrites to apply to the simulation.
337    ///
338    /// # Returns
339    ///
340    /// * `Result<(U256,U256), SimulationError>` - Returns the sell and buy amount limit as a `U256`
341    ///   if successful, or a `SimulationError` on failure.
342    fn get_amount_limits(
343        &self,
344        tokens: Vec<Address>,
345        overwrites: Option<HashMap<Address, HashMap<U256, U256>>>,
346    ) -> Result<(U256, U256), SimulationError> {
347        let limits = self
348            .adapter_contract
349            .get_limits(&self.id, tokens[0], tokens[1], overwrites)?;
350
351        Ok(limits)
352    }
353
354    /// Updates the pool state.
355    ///
356    /// It is assumed this is called on a new block. Therefore, first the pool's overwrites cache is
357    /// cleared, then the balances are updated and the spot prices are recalculated.
358    ///
359    /// # Arguments
360    ///
361    /// * `tokens` - A hashmap of token addresses to `Token` instances. This is necessary for
362    ///   calculating new spot prices.
363    /// * `balances` - A `Balances` instance containing all balance updates on the current block.
364    fn update_pool_state(
365        &mut self,
366        tokens: &HashMap<Bytes, Token>,
367        balances: &Balances,
368    ) -> Result<(), SimulationError> {
369        // clear cache
370        self.adapter_contract
371            .engine
372            .clear_temp_storage()
373            .map_err(|err| {
374                SimulationError::FatalError(format!("Failed to clear temporary storage: {err:?}",))
375            })?;
376        self.block_lasting_overwrites.clear();
377
378        // set balances
379        if !self.balances.is_empty() {
380            // Pool uses component balances for overwrites
381            if let Some(bals) = balances
382                .component_balances
383                .get(&self.id)
384            {
385                // Merge delta balances with existing balances instead of replacing them
386                // Prevents errors when delta balance changes do not affect all the pool tokens.
387                for (token, bal) in bals {
388                    let addr = bytes_to_address(token).map_err(|_| {
389                        SimulationError::FatalError(format!(
390                            "Invalid token address in balance update: {token:?}"
391                        ))
392                    })?;
393                    self.balances
394                        .insert(addr, U256::from_be_slice(bal));
395                }
396            }
397        } else {
398            // Pool uses contract balances for overwrites
399            for contract in &self.involved_contracts {
400                if let Some(bals) = balances
401                    .account_balances
402                    .get(&Bytes::from(contract.as_slice()))
403                {
404                    let contract_entry = self
405                        .contract_balances
406                        .entry(*contract)
407                        .or_default();
408                    for (token, bal) in bals {
409                        let addr = bytes_to_address(token).map_err(|_| {
410                            SimulationError::FatalError(format!(
411                                "Invalid token address in balance update: {token:?}"
412                            ))
413                        })?;
414                        contract_entry.insert(addr, U256::from_be_slice(bal));
415                    }
416                }
417            }
418        }
419
420        // reset spot prices
421        self.set_spot_prices(tokens)?;
422        Ok(())
423    }
424
425    fn get_overwrites(
426        &self,
427        tokens: Vec<Address>,
428        max_amount: U256,
429    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
430        let token_overwrites = self.get_token_overwrites(tokens, max_amount)?;
431
432        // Merge `block_lasting_overwrites` with `token_overwrites`
433        let merged_overwrites =
434            self.merge(&self.block_lasting_overwrites.clone(), &token_overwrites);
435
436        Ok(merged_overwrites)
437    }
438
439    fn get_token_overwrites(
440        &self,
441        tokens: Vec<Address>,
442        max_amount: U256,
443    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
444        let sell_token = &tokens[0].clone(); //TODO: need to make it clearer from the interface
445        let mut res: Vec<HashMap<Address, Overwrites>> = Vec::new();
446        if !self
447            .capabilities
448            .contains(&Capability::TokenBalanceIndependent)
449        {
450            res.push(self.get_balance_overwrites()?);
451        }
452
453        let mut overwrites = TokenProxyOverwriteFactory::new(*sell_token, None);
454
455        overwrites.set_balance(max_amount, Address::from_slice(&*EXTERNAL_ACCOUNT.0));
456
457        // Set allowance for adapter_address to max_amount
458        overwrites.set_allowance(max_amount, self.adapter_contract.address, *EXTERNAL_ACCOUNT);
459
460        res.push(overwrites.get_overwrites());
461
462        // Merge all overwrites into a single HashMap
463        Ok(res
464            .into_iter()
465            .fold(HashMap::new(), |acc, overwrite| self.merge(&acc, &overwrite)))
466    }
467
468    /// Gets all balance overwrites for the pool's tokens.
469    ///
470    /// If the pool uses component balances, the balances are set for the balance owner (if exists)
471    /// or for the pool itself. If the pool uses contract balances, the balances are set for the
472    /// contracts involved in the pool.
473    ///
474    /// # Returns
475    ///
476    /// * `Result<HashMap<Address, Overwrites>, SimulationError>` - Returns a hashmap of address to
477    ///   `Overwrites` if successful, or a `SimulationError` on failure.
478    fn get_balance_overwrites(&self) -> Result<HashMap<Address, Overwrites>, SimulationError> {
479        let mut balance_overwrites: HashMap<Address, Overwrites> = HashMap::new();
480
481        // Use component balances for overrides
482        let address = match self.balance_owner {
483            Some(owner) => Some(owner),
484            None if !self.contract_balances.is_empty() => None,
485            None => Some(self.id.parse().map_err(|_| {
486                SimulationError::FatalError(
487                    "Failed to get balance overwrites: Pool ID is not an address".into(),
488                )
489            })?),
490        };
491
492        if let Some(address) = address {
493            // Only override balances that are explicitly provided in self.balances
494            // This preserves existing balances for tokens not updated in delta transitions
495            for (token, bal) in &self.balances {
496                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
497                overwrites.set_balance(*bal, address);
498                balance_overwrites.extend(overwrites.get_overwrites());
499            }
500        }
501
502        // Use contract balances for overrides (will overwrite component balances if they were set
503        // for a contract we explicitly track balances for)
504        for (contract, balances) in &self.contract_balances {
505            for (token, balance) in balances {
506                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
507                overwrites.set_balance(*balance, *contract);
508                balance_overwrites.extend(overwrites.get_overwrites());
509            }
510        }
511
512        // Apply disables for tokens that should not have any balance overrides
513        for token in &self.disable_overwrite_tokens {
514            balance_overwrites.remove(token);
515        }
516
517        Ok(balance_overwrites)
518    }
519
520    fn merge(
521        &self,
522        target: &HashMap<Address, Overwrites>,
523        source: &HashMap<Address, Overwrites>,
524    ) -> HashMap<Address, Overwrites> {
525        let mut merged = target.clone();
526
527        for (key, source_inner) in source {
528            merged
529                .entry(*key)
530                .or_default()
531                .extend(source_inner.clone());
532        }
533
534        merged
535    }
536
537    #[cfg(test)]
538    pub fn get_involved_contracts(&self) -> HashSet<Address> {
539        self.involved_contracts.clone()
540    }
541
542    #[cfg(test)]
543    pub fn get_manual_updates(&self) -> bool {
544        self.manual_updates
545    }
546
547    #[cfg(test)]
548    pub fn get_balance_owner(&self) -> Option<Address> {
549        self.balance_owner
550    }
551
552    /// Get the component balances for validation purposes
553    pub fn get_balances(&self) -> &HashMap<Address, U256> {
554        &self.balances
555    }
556}
557
558impl<D> Serialize for EVMPoolState<D>
559where
560    D: EngineDatabaseInterface + Clone + Debug,
561    <D as DatabaseRef>::Error: Debug,
562    <D as EngineDatabaseInterface>::Error: Debug,
563{
564    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
565    where
566        S: serde::Serializer,
567    {
568        Err(serde::ser::Error::custom("not supported due vm state deps"))
569    }
570}
571
572impl<'de, D> Deserialize<'de> for EVMPoolState<D>
573where
574    D: EngineDatabaseInterface + Clone + Debug,
575    <D as DatabaseRef>::Error: Debug,
576    <D as EngineDatabaseInterface>::Error: Debug,
577{
578    fn deserialize<De>(_deserializer: De) -> Result<Self, De::Error>
579    where
580        De: serde::Deserializer<'de>,
581    {
582        Err(serde::de::Error::custom("not supported due vm state deps"))
583    }
584}
585
586#[typetag::serialize]
587impl<D> ProtocolSim for EVMPoolState<D>
588where
589    D: EngineDatabaseInterface + Clone + Debug + 'static,
590    <D as DatabaseRef>::Error: Debug,
591    <D as EngineDatabaseInterface>::Error: Debug,
592{
593    fn fee(&self) -> f64 {
594        todo!()
595    }
596
597    fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
598        let base_address = bytes_to_address(&base.address)?;
599        let quote_address = bytes_to_address(&quote.address)?;
600        self.spot_prices
601            .get(&(base_address, quote_address))
602            .cloned()
603            .ok_or(SimulationError::FatalError(format!(
604                "Spot price not found for base token {base_address} and quote token {quote_address}"
605            )))
606    }
607
608    fn get_amount_out(
609        &self,
610        amount_in: BigUint,
611        token_in: &Token,
612        token_out: &Token,
613    ) -> Result<GetAmountOutResult, SimulationError> {
614        let sell_token_address = bytes_to_address(&token_in.address)?;
615        let buy_token_address = bytes_to_address(&token_out.address)?;
616        let sell_amount = U256::from_be_slice(&amount_in.to_bytes_be());
617        let overwrites = self.get_overwrites(
618            vec![sell_token_address, buy_token_address],
619            *MAX_BALANCE / U256::from(100),
620        )?;
621        let (sell_amount_limit, _) = self.get_amount_limits(
622            vec![sell_token_address, buy_token_address],
623            Some(overwrites.clone()),
624        )?;
625        let (sell_amount_respecting_limit, sell_amount_exceeds_limit) = if self
626            .capabilities
627            .contains(&Capability::HardLimits) &&
628            sell_amount_limit < sell_amount
629        {
630            (sell_amount_limit, true)
631        } else {
632            (sell_amount, false)
633        };
634
635        let overwrites_with_sell_limit =
636            self.get_overwrites(vec![sell_token_address, buy_token_address], sell_amount_limit)?;
637        let complete_overwrites = self.merge(&overwrites, &overwrites_with_sell_limit);
638
639        let (trade, state_changes) = self.adapter_contract.swap(
640            &self.id,
641            sell_token_address,
642            buy_token_address,
643            false,
644            sell_amount_respecting_limit,
645            Some(complete_overwrites),
646        )?;
647
648        let mut new_state = self.clone();
649
650        // Apply state changes to the new state
651        for (address, state_update) in state_changes {
652            if let Some(storage) = state_update.storage {
653                let block_overwrites = new_state
654                    .block_lasting_overwrites
655                    .entry(address)
656                    .or_default();
657                for (slot, value) in storage {
658                    let slot = U256::from_str(&slot.to_string()).map_err(|_| {
659                        SimulationError::FatalError("Failed to decode slot index".to_string())
660                    })?;
661                    let value = U256::from_str(&value.to_string()).map_err(|_| {
662                        SimulationError::FatalError("Failed to decode slot overwrite".to_string())
663                    })?;
664                    block_overwrites.insert(slot, value);
665                }
666            }
667        }
668
669        // Update spot prices
670        let tokens = HashMap::from([
671            (token_in.address.clone(), token_in.clone()),
672            (token_out.address.clone(), token_out.clone()),
673        ]);
674        let _ = new_state.set_spot_prices(&tokens);
675
676        let buy_amount = trade.received_amount;
677
678        if sell_amount_exceeds_limit {
679            return Err(SimulationError::InvalidInput(
680                format!("Sell amount exceeds limit {sell_amount_limit}"),
681                Some(GetAmountOutResult::new(
682                    u256_to_biguint(buy_amount),
683                    u256_to_biguint(trade.gas_used),
684                    Box::new(new_state.clone()),
685                )),
686            ));
687        }
688        Ok(GetAmountOutResult::new(
689            u256_to_biguint(buy_amount),
690            u256_to_biguint(trade.gas_used),
691            Box::new(new_state.clone()),
692        ))
693    }
694
695    fn get_limits(
696        &self,
697        sell_token: Bytes,
698        buy_token: Bytes,
699    ) -> Result<(BigUint, BigUint), SimulationError> {
700        let sell_token = bytes_to_address(&sell_token)?;
701        let buy_token = bytes_to_address(&buy_token)?;
702        let overwrites =
703            self.get_overwrites(vec![sell_token, buy_token], *MAX_BALANCE / U256::from(100))?;
704        let limits = self.get_amount_limits(vec![sell_token, buy_token], Some(overwrites))?;
705        Ok((u256_to_biguint(limits.0), u256_to_biguint(limits.1)))
706    }
707
708    fn delta_transition(
709        &mut self,
710        delta: ProtocolStateDelta,
711        tokens: &HashMap<Bytes, Token>,
712        balances: &Balances,
713    ) -> Result<(), TransitionError> {
714        if self.manual_updates {
715            // Directly check for "update_marker" in `updated_attributes`
716            if let Some(marker) = delta
717                .updated_attributes
718                .get("update_marker")
719            {
720                // Assuming `marker` is of type `Bytes`, check its value for "truthiness"
721                if !marker.is_empty() && marker[0] != 0 {
722                    self.update_pool_state(tokens, balances)?;
723                }
724            }
725        } else {
726            self.update_pool_state(tokens, balances)?;
727        }
728
729        Ok(())
730    }
731
732    fn query_pool_swap(
733        &self,
734        params: &tycho_common::simulation::protocol_sim::QueryPoolSwapParams,
735    ) -> Result<tycho_common::simulation::protocol_sim::PoolSwap, SimulationError> {
736        crate::evm::query_pool_swap::query_pool_swap(self, params)
737    }
738
739    fn clone_box(&self) -> Box<dyn ProtocolSim> {
740        Box::new(self.clone())
741    }
742
743    fn as_any(&self) -> &dyn Any {
744        self
745    }
746
747    fn as_any_mut(&mut self) -> &mut dyn Any {
748        self
749    }
750
751    fn eq(&self, other: &dyn ProtocolSim) -> bool {
752        if let Some(other_state) = other
753            .as_any()
754            .downcast_ref::<EVMPoolState<PreCachedDB>>()
755        {
756            self.id == other_state.id
757        } else {
758            false
759        }
760    }
761
762    /// Implemented manually because `typetag` macro not supports generics
763    fn typetag_deserialize(&self) {
764        // https://github.com/dtolnay/typetag/blob/21ae0d40c9f73443a20204ab4a134441355b52f7/impl/src/tagged_trait.rs#L140
765        unreachable!("Only to catch missing typetag attribute on impl blocks. Not called.")
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use std::default::Default;
772
773    use num_traits::One;
774    use revm::{
775        primitives::KECCAK_EMPTY,
776        state::{AccountInfo, Bytecode},
777    };
778    use serde_json::Value;
779    use tycho_client::feed::BlockHeader;
780    use tycho_common::models::Chain;
781
782    use super::*;
783    use crate::evm::{
784        engine_db::{create_engine, SHARED_TYCHO_DB},
785        protocol::vm::{
786            constants::{BALANCER_V2, ERC20_PROXY_BYTECODE},
787            state_builder::EVMPoolStateBuilder,
788        },
789        simulation::SimulationEngine,
790        tycho_models::AccountUpdate,
791    };
792
793    fn dai() -> Token {
794        Token::new(
795            &Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
796            "DAI",
797            18,
798            0,
799            &[Some(10_000)],
800            Chain::Ethereum,
801            100,
802        )
803    }
804
805    fn bal() -> Token {
806        Token::new(
807            &Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(),
808            "BAL",
809            18,
810            0,
811            &[Some(10_000)],
812            Chain::Ethereum,
813            100,
814        )
815    }
816
817    fn dai_addr() -> Address {
818        bytes_to_address(&dai().address).unwrap()
819    }
820
821    fn bal_addr() -> Address {
822        bytes_to_address(&bal().address).unwrap()
823    }
824
825    async fn setup_pool_state() -> EVMPoolState<PreCachedDB> {
826        let data_str = include_str!("assets/balancer_contract_storage_block_20463609.json");
827        let data: Value = serde_json::from_str(data_str).expect("Failed to parse JSON");
828
829        let accounts: Vec<AccountUpdate> = serde_json::from_value(data["accounts"].clone())
830            .expect("Expected accounts to match AccountUpdate structure");
831
832        let db = SHARED_TYCHO_DB.clone();
833        let engine: SimulationEngine<_> = create_engine(db.clone(), false).unwrap();
834
835        let block = BlockHeader {
836            number: 20463609,
837            hash: Bytes::from_str(
838                "0x4315fd1afc25cc2ebc72029c543293f9fd833eeb305e2e30159459c827733b1b",
839            )
840            .unwrap(),
841            timestamp: 1722875891,
842            ..Default::default()
843        };
844
845        for account in accounts.clone() {
846            engine
847                .state
848                .init_account(
849                    account.address,
850                    AccountInfo {
851                        balance: account.balance.unwrap_or_default(),
852                        nonce: 0u64,
853                        code_hash: KECCAK_EMPTY,
854                        code: account
855                            .code
856                            .clone()
857                            .map(|arg0: Vec<u8>| Bytecode::new_raw(arg0.into())),
858                    },
859                    None,
860                    false,
861                )
862                .expect("Failed to initialize account");
863        }
864        db.update(accounts, Some(block))
865            .unwrap();
866
867        let tokens = vec![dai().address, bal().address];
868        for token in &tokens {
869            engine
870                .state
871                .init_account(
872                    bytes_to_address(token).unwrap(),
873                    AccountInfo {
874                        balance: U256::from(0),
875                        nonce: 0,
876                        code_hash: KECCAK_EMPTY,
877                        code: Some(Bytecode::new_raw(ERC20_PROXY_BYTECODE.into())),
878                    },
879                    None,
880                    true,
881                )
882                .expect("Failed to initialize account");
883        }
884
885        let block = BlockHeader {
886            number: 18485417,
887            hash: Bytes::from_str(
888                "0x28d41d40f2ac275a4f5f621a636b9016b527d11d37d610a45ac3a821346ebf8c",
889            )
890            .expect("Invalid block hash"),
891            timestamp: 0,
892            ..Default::default()
893        };
894        db.update(vec![], Some(block.clone()))
895            .unwrap();
896
897        let pool_id: String =
898            "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
899
900        let stateless_contracts = HashMap::from([(
901            String::from("0x3de27efa2f1aa663ae5d458857e731c129069f29"),
902            Some(Vec::new()),
903        )]);
904
905        let balances = HashMap::from([
906            (dai_addr(), U256::from_str("178754012737301807104").unwrap()),
907            (bal_addr(), U256::from_str("91082987763369885696").unwrap()),
908        ]);
909        let adapter_address =
910            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
911
912        EVMPoolStateBuilder::new(pool_id, tokens, adapter_address)
913            .balances(balances)
914            .balance_owner(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap())
915            .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
916            .stateless_contracts(stateless_contracts)
917            .build(SHARED_TYCHO_DB.clone())
918            .await
919            .expect("Failed to build pool state")
920    }
921
922    #[tokio::test]
923    async fn test_init() {
924        // Clear DB from this test to prevent interference from other tests
925        SHARED_TYCHO_DB
926            .clear()
927            .expect("Failed to cleared SHARED TX");
928        let pool_state = setup_pool_state().await;
929
930        let expected_capabilities = vec![
931            Capability::SellSide,
932            Capability::BuySide,
933            Capability::PriceFunction,
934            Capability::HardLimits,
935        ]
936        .into_iter()
937        .collect::<HashSet<_>>();
938
939        let capabilities_adapter_contract = pool_state
940            .adapter_contract
941            .get_capabilities(
942                &pool_state.id,
943                bytes_to_address(&pool_state.tokens[0]).unwrap(),
944                bytes_to_address(&pool_state.tokens[1]).unwrap(),
945            )
946            .unwrap();
947
948        assert_eq!(capabilities_adapter_contract, expected_capabilities.clone());
949
950        let capabilities_state = pool_state.clone().capabilities;
951
952        assert_eq!(capabilities_state, expected_capabilities.clone());
953
954        for capability in expected_capabilities.clone() {
955            assert!(pool_state
956                .clone()
957                .ensure_capability(capability)
958                .is_ok());
959        }
960
961        assert!(pool_state
962            .clone()
963            .ensure_capability(Capability::MarginalPrice)
964            .is_err());
965
966        // Verify all tokens are initialized in the engine
967        let engine_accounts = pool_state
968            .adapter_contract
969            .engine
970            .state
971            .clone()
972            .get_account_storage()
973            .expect("Failed to get account storage");
974        for token in pool_state.tokens.clone() {
975            let account = engine_accounts
976                .get_account_info(&bytes_to_address(&token).unwrap())
977                .unwrap();
978            assert_eq!(account.balance, U256::from(0));
979            assert_eq!(account.nonce, 0u64);
980            assert_eq!(account.code_hash, KECCAK_EMPTY);
981            assert!(account.code.is_some());
982        }
983
984        // Verify external account is initialized in the engine
985        let external_account = engine_accounts
986            .get_account_info(&EXTERNAL_ACCOUNT)
987            .unwrap();
988        assert_eq!(external_account.balance, U256::from(*MAX_BALANCE));
989        assert_eq!(external_account.nonce, 0u64);
990        assert_eq!(external_account.code_hash, KECCAK_EMPTY);
991        assert!(external_account.code.is_none());
992    }
993
994    #[tokio::test]
995    async fn test_get_amount_out() -> Result<(), Box<dyn std::error::Error>> {
996        let pool_state = setup_pool_state().await;
997
998        let result = pool_state
999            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1000            .unwrap();
1001        let new_state = result
1002            .new_state
1003            .as_any()
1004            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1005            .unwrap();
1006        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1007        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1008        assert!(pool_state
1009            .block_lasting_overwrites
1010            .is_empty());
1011        Ok(())
1012    }
1013
1014    #[tokio::test]
1015    async fn test_sequential_get_amount_outs() {
1016        let pool_state = setup_pool_state().await;
1017
1018        let result = pool_state
1019            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1020            .unwrap();
1021        let new_state = result
1022            .new_state
1023            .as_any()
1024            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1025            .unwrap();
1026        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
1027        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
1028
1029        let new_result = new_state
1030            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
1031            .unwrap();
1032        let new_state_second_swap = new_result
1033            .new_state
1034            .as_any()
1035            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1036            .unwrap();
1037
1038        assert_eq!(new_result.amount, BigUint::from_str("136964651490065626").unwrap());
1039        assert_ne!(new_state_second_swap.spot_prices, new_state.spot_prices);
1040    }
1041
1042    #[tokio::test]
1043    async fn test_get_amount_out_dust() {
1044        let pool_state = setup_pool_state().await;
1045
1046        let result = pool_state
1047            .get_amount_out(BigUint::one(), &dai(), &bal())
1048            .unwrap();
1049
1050        let _ = result
1051            .new_state
1052            .as_any()
1053            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1054            .unwrap();
1055        assert_eq!(result.amount, BigUint::ZERO);
1056    }
1057
1058    #[tokio::test]
1059    async fn test_get_amount_out_sell_limit() {
1060        let pool_state = setup_pool_state().await;
1061
1062        let result = pool_state.get_amount_out(
1063            // sell limit is 100279494253364362835
1064            BigUint::from_str("100379494253364362835").unwrap(),
1065            &dai(),
1066            &bal(),
1067        );
1068
1069        assert!(result.is_err());
1070
1071        match result {
1072            Err(SimulationError::InvalidInput(msg1, amount_out_result)) => {
1073                assert_eq!(msg1, "Sell amount exceeds limit 100279494253364362835");
1074                assert!(amount_out_result.is_some());
1075            }
1076            _ => panic!("Test failed: was expecting an Err(SimulationError::RetryDifferentInput(_, _)) value"),
1077        }
1078    }
1079
1080    #[tokio::test]
1081    async fn test_get_amount_limits() {
1082        let pool_state = setup_pool_state().await;
1083
1084        let overwrites = pool_state
1085            .get_overwrites(
1086                vec![
1087                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1088                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1089                ],
1090                *MAX_BALANCE / U256::from(100),
1091            )
1092            .unwrap();
1093        let (dai_limit, _) = pool_state
1094            .get_amount_limits(vec![dai_addr(), bal_addr()], Some(overwrites.clone()))
1095            .unwrap();
1096        assert_eq!(dai_limit, U256::from_str("100279494253364362835").unwrap());
1097
1098        let (bal_limit, _) = pool_state
1099            .get_amount_limits(
1100                vec![
1101                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1102                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1103                ],
1104                Some(overwrites),
1105            )
1106            .unwrap();
1107        assert_eq!(bal_limit, U256::from_str("13997408640689987484").unwrap());
1108    }
1109
1110    #[tokio::test]
1111    async fn test_set_spot_prices() {
1112        let mut pool_state = setup_pool_state().await;
1113
1114        pool_state
1115            .set_spot_prices(
1116                &vec![bal(), dai()]
1117                    .into_iter()
1118                    .map(|t| (t.address.clone(), t))
1119                    .collect(),
1120            )
1121            .unwrap();
1122
1123        let dai_bal_spot_price = pool_state
1124            .spot_prices
1125            .get(&(
1126                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1127                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1128            ))
1129            .unwrap();
1130        let bal_dai_spot_price = pool_state
1131            .spot_prices
1132            .get(&(
1133                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1134                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1135            ))
1136            .unwrap();
1137        assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9);
1138        assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);
1139    }
1140
1141    #[tokio::test]
1142    async fn test_set_spot_prices_without_capability() {
1143        // Tests set Spot Prices functions when the pool doesn't have PriceFunction capability
1144        let mut pool_state = setup_pool_state().await;
1145
1146        pool_state
1147            .capabilities
1148            .remove(&Capability::PriceFunction);
1149
1150        pool_state
1151            .set_spot_prices(
1152                &vec![bal(), dai()]
1153                    .into_iter()
1154                    .map(|t| (t.address.clone(), t))
1155                    .collect(),
1156            )
1157            .unwrap();
1158
1159        let dai_bal_spot_price = pool_state
1160            .spot_prices
1161            .get(&(
1162                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1163                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1164            ))
1165            .unwrap();
1166        let bal_dai_spot_price = pool_state
1167            .spot_prices
1168            .get(&(
1169                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1170                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1171            ))
1172            .unwrap();
1173        assert_eq!(dai_bal_spot_price, &0.13736685496467538);
1174        assert_eq!(bal_dai_spot_price, &7.050354297665408);
1175    }
1176
1177    #[tokio::test]
1178    async fn test_get_balance_overwrites_with_component_balances() {
1179        let pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1180
1181        let overwrites = pool_state
1182            .get_balance_overwrites()
1183            .unwrap();
1184
1185        let dai_address = dai_addr();
1186        let bal_address = bal_addr();
1187        assert!(overwrites.contains_key(&dai_address));
1188        assert!(overwrites.contains_key(&bal_address));
1189    }
1190
1191    #[tokio::test]
1192    async fn test_get_balance_overwrites_with_contract_balances() {
1193        let mut pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1194
1195        let contract_address =
1196            Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap();
1197
1198        // Ensure no component balances are used
1199        pool_state.balances.clear();
1200        pool_state.balance_owner = None;
1201
1202        // Set contract balances
1203        let dai_address = dai_addr();
1204        let bal_address = bal_addr();
1205        pool_state.contract_balances = HashMap::from([(
1206            contract_address,
1207            HashMap::from([
1208                (dai_address, U256::from_str("7500000000000000000000").unwrap()), // 7500 DAI
1209                (bal_address, U256::from_str("1500000000000000000000").unwrap()), // 1500 BAL
1210            ]),
1211        )]);
1212
1213        let overwrites = pool_state
1214            .get_balance_overwrites()
1215            .unwrap();
1216
1217        assert!(overwrites.contains_key(&dai_address));
1218        assert!(overwrites.contains_key(&bal_address));
1219    }
1220
1221    #[tokio::test]
1222    async fn test_balance_merging_during_delta_transition() {
1223        use std::str::FromStr;
1224
1225        let mut pool_state = setup_pool_state().await;
1226        let pool_id = pool_state.id.clone();
1227
1228        // Test the balance merging logic more directly
1229        // Setup initial balances including DAI and BAL (which the pool already knows about)
1230        let dai_addr = dai_addr();
1231        let bal_addr = bal_addr();
1232        let new_token = Address::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); // WETH
1233
1234        // Clear and setup clean initial state
1235        pool_state.balances.clear();
1236        pool_state
1237            .balances
1238            .insert(dai_addr, U256::from(1000000000u64));
1239        pool_state
1240            .balances
1241            .insert(bal_addr, U256::from(2000000000u64));
1242        pool_state
1243            .balances
1244            .insert(new_token, U256::from(3000000000u64));
1245
1246        // Create tokens mapping including the existing DAI and BAL
1247        let mut tokens = HashMap::new();
1248        tokens.insert(dai().address.clone(), dai());
1249        tokens.insert(bal().address.clone(), bal());
1250
1251        // Simulate a delta transition with only DAI balance update (missing BAL and new_token)
1252        let mut component_balances = HashMap::new();
1253        let mut delta_balances = HashMap::new();
1254        // Only update DAI balance, leave others unchanged in delta
1255        delta_balances.insert(dai().address.clone(), Bytes::from(vec![0x77, 0x35, 0x94, 0x00])); // 2000000000 (updated value)
1256        component_balances.insert(pool_id.clone(), delta_balances);
1257
1258        let balances = Balances { component_balances, account_balances: HashMap::new() };
1259
1260        // Record initial balance count
1261        let initial_balance_count = pool_state.balances.len();
1262        assert_eq!(initial_balance_count, 3);
1263
1264        // Apply delta transition
1265        pool_state
1266            .update_pool_state(&tokens, &balances)
1267            .unwrap();
1268
1269        // Verify that all 3 balances are preserved (BAL and new_token should still be there)
1270        assert_eq!(
1271            pool_state.balances.len(),
1272            3,
1273            "All balances should be preserved after delta transition"
1274        );
1275        assert!(
1276            pool_state
1277                .balances
1278                .contains_key(&dai_addr),
1279            "DAI balance should be present"
1280        );
1281        assert!(
1282            pool_state
1283                .balances
1284                .contains_key(&bal_addr),
1285            "BAL balance should be present"
1286        );
1287        assert!(
1288            pool_state
1289                .balances
1290                .contains_key(&new_token),
1291            "New token balance should be preserved from before delta"
1292        );
1293
1294        // Verify that updated token (DAI) has new value
1295        assert_eq!(
1296            pool_state.balances[&dai_addr],
1297            U256::from(2000000000u64),
1298            "DAI balance should be updated"
1299        );
1300
1301        // Verify that non-updated tokens retain their original values
1302        assert_eq!(
1303            pool_state.balances[&bal_addr],
1304            U256::from(2000000000u64),
1305            "BAL balance should be unchanged"
1306        );
1307        assert_eq!(
1308            pool_state.balances[&new_token],
1309            U256::from(3000000000u64),
1310            "New token balance should be unchanged"
1311        );
1312    }
1313
1314    #[test]
1315    fn should_not_panic_at_typetag_deserialize() {
1316        let deserialized: Result<Box<dyn ProtocolSim>, _> = serde_json::from_str(
1317            r#"{"protocol":"EVMPoolState","state":{"reserve_0":1,"reserve_1":2}}"#,
1318        );
1319
1320        assert!(deserialized.is_err());
1321    }
1322}