tycho_simulation/evm/protocol/vm/
state.rs

1#![allow(deprecated)]
2use std::{
3    any::Any,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    str::FromStr,
7};
8
9use alloy::primitives::{Address, U256};
10use itertools::Itertools;
11use num_bigint::BigUint;
12use revm::DatabaseRef;
13use tycho_common::{
14    dto::ProtocolStateDelta,
15    models::token::Token,
16    simulation::{
17        errors::{SimulationError, TransitionError},
18        protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
19    },
20    Bytes,
21};
22
23use super::{
24    constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
25    erc20_token::{Overwrites, TokenProxyOverwriteFactory},
26    models::Capability,
27    tycho_simulation_contract::TychoSimulationContract,
28};
29use crate::evm::{
30    engine_db::{engine_db_interface::EngineDatabaseInterface, tycho_db::PreCachedDB},
31    protocol::{
32        u256_num::{u256_to_biguint, u256_to_f64},
33        utils::bytes_to_address,
34    },
35};
36
37#[derive(Clone)]
38pub struct EVMPoolState<D: EngineDatabaseInterface + Clone + Debug>
39where
40    <D as DatabaseRef>::Error: Debug,
41    <D as EngineDatabaseInterface>::Error: Debug,
42{
43    /// The pool's identifier
44    id: String,
45    /// The pool's token's addresses
46    pub tokens: Vec<Bytes>,
47    /// The pool's component balances.
48    balances: HashMap<Address, U256>,
49    /// The contract address for where protocol balances are stored (i.e. a vault contract).
50    /// If given, balances will be overwritten here instead of on the pool contract during
51    /// simulations. This has been deprecated in favor of `contract_balances`.
52    #[deprecated(note = "Use contract_balances instead")]
53    balance_owner: Option<Address>,
54    /// Spot prices of the pool by token pair
55    spot_prices: HashMap<(Address, Address), f64>,
56    /// The supported capabilities of this pool
57    capabilities: HashSet<Capability>,
58    /// Storage overwrites that will be applied to all simulations. They will be cleared
59    /// when ``update_pool_state`` is called, i.e. usually at each block. Hence, the name.
60    block_lasting_overwrites: HashMap<Address, Overwrites>,
61    /// A set of all contract addresses involved in the simulation of this pool.
62    involved_contracts: HashSet<Address>,
63    /// A map of contracts to their token balances.
64    contract_balances: HashMap<Address, HashMap<Address, U256>>,
65    /// Indicates if the protocol uses custom update rules and requires update
66    /// triggers to recalculate spot prices ect. Default is to update on all changes on
67    /// the pool.
68    manual_updates: bool,
69    /// The adapter contract. This is used to interact with the protocol when running simulations
70    adapter_contract: TychoSimulationContract<D>,
71    /// Tokens for which balance overwrites should be disabled.
72    disable_overwrite_tokens: HashSet<Address>,
73}
74
75impl<D> Debug for EVMPoolState<D>
76where
77    D: EngineDatabaseInterface + Clone + Debug,
78    <D as DatabaseRef>::Error: Debug,
79    <D as EngineDatabaseInterface>::Error: Debug,
80{
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.debug_struct("EVMPoolState")
83            .field("id", &self.id)
84            .field("tokens", &self.tokens)
85            .field("balances", &self.balances)
86            .field("involved_contracts", &self.involved_contracts)
87            .field("contract_balances", &self.contract_balances)
88            .finish_non_exhaustive()
89    }
90}
91
92impl<D> EVMPoolState<D>
93where
94    D: EngineDatabaseInterface + Clone + Debug + 'static,
95    <D as DatabaseRef>::Error: Debug,
96    <D as EngineDatabaseInterface>::Error: Debug,
97{
98    /// Creates a new instance of `EVMPoolState` with the given attributes, with the ability to
99    /// simulate a protocol-agnostic transaction.
100    ///
101    /// See struct definition of `EVMPoolState` for attribute explanations.
102    #[allow(clippy::too_many_arguments)]
103    pub fn new(
104        id: String,
105        tokens: Vec<Bytes>,
106        component_balances: HashMap<Address, U256>,
107        balance_owner: Option<Address>,
108        contract_balances: HashMap<Address, HashMap<Address, U256>>,
109        spot_prices: HashMap<(Address, Address), f64>,
110        capabilities: HashSet<Capability>,
111        block_lasting_overwrites: HashMap<Address, Overwrites>,
112        involved_contracts: HashSet<Address>,
113        manual_updates: bool,
114        adapter_contract: TychoSimulationContract<D>,
115        disable_overwrite_tokens: HashSet<Address>,
116    ) -> Self {
117        Self {
118            id,
119            tokens,
120            balances: component_balances,
121            balance_owner,
122            spot_prices,
123            capabilities,
124            block_lasting_overwrites,
125            involved_contracts,
126            contract_balances,
127            manual_updates,
128            adapter_contract,
129            disable_overwrite_tokens,
130        }
131    }
132
133    /// Ensures the pool supports the given capability
134    ///
135    /// # Arguments
136    ///
137    /// * `capability` - The capability that we would like to check for.
138    ///
139    /// # Returns
140    ///
141    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the capability is supported, or a
142    ///   `SimulationError` otherwise.
143    fn ensure_capability(&self, capability: Capability) -> Result<(), SimulationError> {
144        if !self.capabilities.contains(&capability) {
145            return Err(SimulationError::FatalError(format!(
146                "capability {:?} not supported",
147                capability.to_string()
148            )));
149        }
150        Ok(())
151    }
152    /// Sets the spot prices for a pool for all possible pairs of the given tokens.
153    ///
154    /// # Arguments
155    ///
156    /// * `tokens` - A hashmap of `Token` instances representing the tokens to calculate spot prices
157    ///   for.
158    ///
159    /// # Returns
160    ///
161    /// * `Result<(), SimulationError>` - Returns `Ok(())` if the spot prices are successfully set,
162    ///   or a `SimulationError` if an error occurs during the calculation or processing.
163    ///
164    /// # Behavior
165    ///
166    /// This function performs the following steps:
167    /// 1. Ensures the pool has the required capability to perform price calculations.
168    /// 2. Iterates over all permutations of token pairs (sell token and buy token). For each pair:
169    ///    - Retrieves all possible overwrites, considering the maximum balance limit.
170    ///    - Calculates the sell amount limit, considering the overwrites.
171    ///    - Invokes the adapter contract's `price` function to retrieve the calculated price for
172    ///      the token pair, considering the sell amount limit.
173    ///    - Processes the price based on whether the `ScaledPrice` capability is present:
174    ///       - If `ScaledPrice` is present, uses the price directly from the adapter contract.
175    ///       - If `ScaledPrice` is absent, scales the price by adjusting for token decimals.
176    ///    - Stores the calculated price in the `spot_prices` map with the token addresses as the
177    ///      key.
178    /// 3. Returns `Ok(())` upon successful completion or a `SimulationError` upon failure.
179    ///
180    /// # Usage
181    ///
182    /// Spot prices need to be set before attempting to retrieve prices using `spot_price`.
183    ///
184    /// Tip: Setting spot prices on the pool every time the pool actually changes will result in
185    /// faster price fetching than if prices are only set immediately before attempting to retrieve
186    /// prices.
187    pub fn set_spot_prices(
188        &mut self,
189        tokens: &HashMap<Bytes, Token>,
190    ) -> Result<(), SimulationError> {
191        match self.ensure_capability(Capability::PriceFunction) {
192            Ok(_) => {
193                for [sell_token_address, buy_token_address] in self
194                    .tokens
195                    .iter()
196                    .permutations(2)
197                    .map(|p| [p[0], p[1]])
198                {
199                    let sell_token_address = bytes_to_address(sell_token_address)?;
200                    let buy_token_address = bytes_to_address(buy_token_address)?;
201
202                    let overwrites = Some(self.get_overwrites(
203                        vec![sell_token_address, buy_token_address],
204                        *MAX_BALANCE / U256::from(100),
205                    )?);
206
207                    let (sell_amount_limit, _) = self.get_amount_limits(
208                        vec![sell_token_address, buy_token_address],
209                        overwrites.clone(),
210                    )?;
211                    let price_result = self.adapter_contract.price(
212                        &self.id,
213                        sell_token_address,
214                        buy_token_address,
215                        vec![sell_amount_limit / U256::from(100)],
216                        overwrites,
217                    )?;
218
219                    let price = if self
220                        .capabilities
221                        .contains(&Capability::ScaledPrice)
222                    {
223                        *price_result.first().ok_or_else(|| {
224                            SimulationError::FatalError(
225                                "Calculated price array is empty".to_string(),
226                            )
227                        })?
228                    } else {
229                        let unscaled_price = price_result.first().ok_or_else(|| {
230                            SimulationError::FatalError(
231                                "Calculated price array is empty".to_string(),
232                            )
233                        })?;
234                        let sell_token_decimals = self.get_decimals(tokens, &sell_token_address)?;
235                        let buy_token_decimals = self.get_decimals(tokens, &buy_token_address)?;
236                        *unscaled_price * 10f64.powi(sell_token_decimals as i32) /
237                            10f64.powi(buy_token_decimals as i32)
238                    };
239
240                    self.spot_prices
241                        .insert((sell_token_address, buy_token_address), price);
242                }
243            }
244            Err(SimulationError::FatalError(_)) => {
245                // If the pool does not support price function, we need to calculate spot prices by
246                // swapping two amounts and use the approximation to get the derivative.
247
248                for iter_tokens in self.tokens.iter().permutations(2) {
249                    let t0 = bytes_to_address(iter_tokens[0])?;
250                    let t1 = bytes_to_address(iter_tokens[1])?;
251
252                    let overwrites =
253                        Some(self.get_overwrites(vec![t0, t1], *MAX_BALANCE / U256::from(100))?);
254
255                    // Calculate the first sell amount (x1) as 1% of the maximum limit.
256                    let x1 = self
257                        .get_amount_limits(vec![t0, t1], overwrites.clone())?
258                        .0 /
259                        U256::from(100);
260
261                    // Calculate the second sell amount (x2) as x1 + 1% of x1. 1.01% of the max
262                    // limit
263                    let x2 = x1 + (x1 / U256::from(100));
264
265                    // Perform a swap for the first sell amount (x1) and retrieve the received
266                    // amount (y1).
267                    let y1 = self
268                        .adapter_contract
269                        .swap(&self.id, t0, t1, false, x1, overwrites.clone())?
270                        .0
271                        .received_amount;
272
273                    // Perform a swap for the second sell amount (x2) and retrieve the received
274                    // amount (y2).
275                    let y2 = self
276                        .adapter_contract
277                        .swap(&self.id, t0, t1, false, x2, overwrites)?
278                        .0
279                        .received_amount;
280
281                    let sell_token_decimals = self.get_decimals(tokens, &t0)?;
282                    let buy_token_decimals = self.get_decimals(tokens, &t1)?;
283
284                    let num = y2 - y1;
285                    let den = x2 - x1;
286
287                    // Calculate the marginal price, adjusting for token decimals.
288                    let token_correction =
289                        10f64.powi(sell_token_decimals as i32 - buy_token_decimals as i32);
290                    let num_f64 = u256_to_f64(num)?;
291                    let den_f64 = u256_to_f64(den)?;
292                    if den_f64 == 0.0 {
293                        return Err(SimulationError::FatalError(
294                            "Failed to compute marginal price: denominator converted to 0".into(),
295                        ));
296                    }
297                    let marginal_price = num_f64 / den_f64 * token_correction;
298
299                    self.spot_prices
300                        .insert((t0, t1), marginal_price);
301                }
302            }
303            Err(e) => return Err(e),
304        }
305
306        Ok(())
307    }
308
309    fn get_decimals(
310        &self,
311        tokens: &HashMap<Bytes, Token>,
312        sell_token_address: &Address,
313    ) -> Result<usize, SimulationError> {
314        tokens
315            .get(&Bytes::from(sell_token_address.as_slice()))
316            .map(|t| t.decimals as usize)
317            .ok_or_else(|| {
318                SimulationError::FatalError(format!(
319                    "Failed to scale spot prices! Pool: {} Token 0x{:x} is not available!",
320                    self.id, sell_token_address
321                ))
322            })
323    }
324
325    /// Retrieves the sell and buy amount limit for a given pair of tokens and the given overwrites.
326    ///
327    /// Attempting to swap an amount of the sell token that exceeds the sell amount limit is not
328    /// advised and in most cases will result in a revert.
329    ///
330    /// # Arguments
331    ///
332    /// * `tokens` - A vec of tokens, where the first token is the sell token and the second is the
333    ///   buy token. The order of tokens in the input vector is significant and determines the
334    ///   direction of the price query.
335    /// * `overwrites` - A hashmap of overwrites to apply to the simulation.
336    ///
337    /// # Returns
338    ///
339    /// * `Result<(U256,U256), SimulationError>` - Returns the sell and buy amount limit as a `U256`
340    ///   if successful, or a `SimulationError` on failure.
341    fn get_amount_limits(
342        &self,
343        tokens: Vec<Address>,
344        overwrites: Option<HashMap<Address, HashMap<U256, U256>>>,
345    ) -> Result<(U256, U256), SimulationError> {
346        let limits = self
347            .adapter_contract
348            .get_limits(&self.id, tokens[0], tokens[1], overwrites)?;
349
350        Ok(limits)
351    }
352
353    /// Updates the pool state.
354    ///
355    /// It is assumed this is called on a new block. Therefore, first the pool's overwrites cache is
356    /// cleared, then the balances are updated and the spot prices are recalculated.
357    ///
358    /// # Arguments
359    ///
360    /// * `tokens` - A hashmap of token addresses to `Token` instances. This is necessary for
361    ///   calculating new spot prices.
362    /// * `balances` - A `Balances` instance containing all balance updates on the current block.
363    fn update_pool_state(
364        &mut self,
365        tokens: &HashMap<Bytes, Token>,
366        balances: &Balances,
367    ) -> Result<(), SimulationError> {
368        // clear cache
369        self.adapter_contract
370            .engine
371            .clear_temp_storage()
372            .map_err(|err| {
373                SimulationError::FatalError(format!("Failed to clear temporary storage: {err:?}",))
374            })?;
375        self.block_lasting_overwrites.clear();
376
377        // set balances
378        if !self.balances.is_empty() {
379            // Pool uses component balances for overwrites
380            if let Some(bals) = balances
381                .component_balances
382                .get(&self.id)
383            {
384                // Merge delta balances with existing balances instead of replacing them
385                // Prevents errors when delta balance changes do not affect all the pool tokens.
386                for (token, bal) in bals {
387                    let addr = bytes_to_address(token).map_err(|_| {
388                        SimulationError::FatalError(format!(
389                            "Invalid token address in balance update: {token:?}"
390                        ))
391                    })?;
392                    self.balances
393                        .insert(addr, U256::from_be_slice(bal));
394                }
395            }
396        } else {
397            // Pool uses contract balances for overwrites
398            for contract in &self.involved_contracts {
399                if let Some(bals) = balances
400                    .account_balances
401                    .get(&Bytes::from(contract.as_slice()))
402                {
403                    let contract_entry = self
404                        .contract_balances
405                        .entry(*contract)
406                        .or_default();
407                    for (token, bal) in bals {
408                        let addr = bytes_to_address(token).map_err(|_| {
409                            SimulationError::FatalError(format!(
410                                "Invalid token address in balance update: {token:?}"
411                            ))
412                        })?;
413                        contract_entry.insert(addr, U256::from_be_slice(bal));
414                    }
415                }
416            }
417        }
418
419        // reset spot prices
420        self.set_spot_prices(tokens)?;
421        Ok(())
422    }
423
424    fn get_overwrites(
425        &self,
426        tokens: Vec<Address>,
427        max_amount: U256,
428    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
429        let token_overwrites = self.get_token_overwrites(tokens, max_amount)?;
430
431        // Merge `block_lasting_overwrites` with `token_overwrites`
432        let merged_overwrites =
433            self.merge(&self.block_lasting_overwrites.clone(), &token_overwrites);
434
435        Ok(merged_overwrites)
436    }
437
438    fn get_token_overwrites(
439        &self,
440        tokens: Vec<Address>,
441        max_amount: U256,
442    ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
443        let sell_token = &tokens[0].clone(); //TODO: need to make it clearer from the interface
444        let mut res: Vec<HashMap<Address, Overwrites>> = Vec::new();
445        if !self
446            .capabilities
447            .contains(&Capability::TokenBalanceIndependent)
448        {
449            res.push(self.get_balance_overwrites()?);
450        }
451
452        let mut overwrites = TokenProxyOverwriteFactory::new(*sell_token, None);
453
454        overwrites.set_balance(max_amount, Address::from_slice(&*EXTERNAL_ACCOUNT.0));
455
456        // Set allowance for adapter_address to max_amount
457        overwrites.set_allowance(max_amount, self.adapter_contract.address, *EXTERNAL_ACCOUNT);
458
459        res.push(overwrites.get_overwrites());
460
461        // Merge all overwrites into a single HashMap
462        Ok(res
463            .into_iter()
464            .fold(HashMap::new(), |acc, overwrite| self.merge(&acc, &overwrite)))
465    }
466
467    /// Gets all balance overwrites for the pool's tokens.
468    ///
469    /// If the pool uses component balances, the balances are set for the balance owner (if exists)
470    /// or for the pool itself. If the pool uses contract balances, the balances are set for the
471    /// contracts involved in the pool.
472    ///
473    /// # Returns
474    ///
475    /// * `Result<HashMap<Address, Overwrites>, SimulationError>` - Returns a hashmap of address to
476    ///   `Overwrites` if successful, or a `SimulationError` on failure.
477    fn get_balance_overwrites(&self) -> Result<HashMap<Address, Overwrites>, SimulationError> {
478        let mut balance_overwrites: HashMap<Address, Overwrites> = HashMap::new();
479
480        // Use component balances for overrides
481        let address = match self.balance_owner {
482            Some(owner) => Some(owner),
483            None if !self.contract_balances.is_empty() => None,
484            None => Some(self.id.parse().map_err(|_| {
485                SimulationError::FatalError(
486                    "Failed to get balance overwrites: Pool ID is not an address".into(),
487                )
488            })?),
489        };
490
491        if let Some(address) = address {
492            // Only override balances that are explicitly provided in self.balances
493            // This preserves existing balances for tokens not updated in delta transitions
494            for (token, bal) in &self.balances {
495                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
496                overwrites.set_balance(*bal, address);
497                balance_overwrites.extend(overwrites.get_overwrites());
498            }
499        }
500
501        // Use contract balances for overrides (will overwrite component balances if they were set
502        // for a contract we explicitly track balances for)
503        for (contract, balances) in &self.contract_balances {
504            for (token, balance) in balances {
505                let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
506                overwrites.set_balance(*balance, *contract);
507                balance_overwrites.extend(overwrites.get_overwrites());
508            }
509        }
510
511        // Apply disables for tokens that should not have any balance overrides
512        for token in &self.disable_overwrite_tokens {
513            balance_overwrites.remove(token);
514        }
515
516        Ok(balance_overwrites)
517    }
518
519    fn merge(
520        &self,
521        target: &HashMap<Address, Overwrites>,
522        source: &HashMap<Address, Overwrites>,
523    ) -> HashMap<Address, Overwrites> {
524        let mut merged = target.clone();
525
526        for (key, source_inner) in source {
527            merged
528                .entry(*key)
529                .or_default()
530                .extend(source_inner.clone());
531        }
532
533        merged
534    }
535
536    #[cfg(test)]
537    pub fn get_involved_contracts(&self) -> HashSet<Address> {
538        self.involved_contracts.clone()
539    }
540
541    #[cfg(test)]
542    pub fn get_manual_updates(&self) -> bool {
543        self.manual_updates
544    }
545
546    #[cfg(test)]
547    #[deprecated]
548    pub fn get_balance_owner(&self) -> Option<Address> {
549        self.balance_owner
550    }
551
552    /// Get the component balances for validation purposes
553    pub fn get_balances(&self) -> &HashMap<Address, U256> {
554        &self.balances
555    }
556}
557
558impl<D> ProtocolSim for EVMPoolState<D>
559where
560    D: EngineDatabaseInterface + Clone + Debug + 'static,
561    <D as DatabaseRef>::Error: Debug,
562    <D as EngineDatabaseInterface>::Error: Debug,
563{
564    fn fee(&self) -> f64 {
565        todo!()
566    }
567
568    fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
569        let base_address = bytes_to_address(&base.address)?;
570        let quote_address = bytes_to_address(&quote.address)?;
571        self.spot_prices
572            .get(&(base_address, quote_address))
573            .cloned()
574            .ok_or(SimulationError::FatalError(format!(
575                "Spot price not found for base token {base_address} and quote token {quote_address}"
576            )))
577    }
578
579    fn get_amount_out(
580        &self,
581        amount_in: BigUint,
582        token_in: &Token,
583        token_out: &Token,
584    ) -> Result<GetAmountOutResult, SimulationError> {
585        let sell_token_address = bytes_to_address(&token_in.address)?;
586        let buy_token_address = bytes_to_address(&token_out.address)?;
587        let sell_amount = U256::from_be_slice(&amount_in.to_bytes_be());
588        let overwrites = self.get_overwrites(
589            vec![sell_token_address, buy_token_address],
590            *MAX_BALANCE / U256::from(100),
591        )?;
592        let (sell_amount_limit, _) = self.get_amount_limits(
593            vec![sell_token_address, buy_token_address],
594            Some(overwrites.clone()),
595        )?;
596        let (sell_amount_respecting_limit, sell_amount_exceeds_limit) = if self
597            .capabilities
598            .contains(&Capability::HardLimits) &&
599            sell_amount_limit < sell_amount
600        {
601            (sell_amount_limit, true)
602        } else {
603            (sell_amount, false)
604        };
605
606        let overwrites_with_sell_limit =
607            self.get_overwrites(vec![sell_token_address, buy_token_address], sell_amount_limit)?;
608        let complete_overwrites = self.merge(&overwrites, &overwrites_with_sell_limit);
609
610        let (trade, state_changes) = self.adapter_contract.swap(
611            &self.id,
612            sell_token_address,
613            buy_token_address,
614            false,
615            sell_amount_respecting_limit,
616            Some(complete_overwrites),
617        )?;
618
619        let mut new_state = self.clone();
620
621        // Apply state changes to the new state
622        for (address, state_update) in state_changes {
623            if let Some(storage) = state_update.storage {
624                let block_overwrites = new_state
625                    .block_lasting_overwrites
626                    .entry(address)
627                    .or_default();
628                for (slot, value) in storage {
629                    let slot = U256::from_str(&slot.to_string()).map_err(|_| {
630                        SimulationError::FatalError("Failed to decode slot index".to_string())
631                    })?;
632                    let value = U256::from_str(&value.to_string()).map_err(|_| {
633                        SimulationError::FatalError("Failed to decode slot overwrite".to_string())
634                    })?;
635                    block_overwrites.insert(slot, value);
636                }
637            }
638        }
639
640        // Update spot prices
641        let tokens = HashMap::from([
642            (token_in.address.clone(), token_in.clone()),
643            (token_out.address.clone(), token_out.clone()),
644        ]);
645        let _ = new_state.set_spot_prices(&tokens);
646
647        let buy_amount = trade.received_amount;
648
649        if sell_amount_exceeds_limit {
650            return Err(SimulationError::InvalidInput(
651                format!("Sell amount exceeds limit {sell_amount_limit}"),
652                Some(GetAmountOutResult::new(
653                    u256_to_biguint(buy_amount),
654                    u256_to_biguint(trade.gas_used),
655                    Box::new(new_state.clone()),
656                )),
657            ));
658        }
659        Ok(GetAmountOutResult::new(
660            u256_to_biguint(buy_amount),
661            u256_to_biguint(trade.gas_used),
662            Box::new(new_state.clone()),
663        ))
664    }
665
666    fn get_limits(
667        &self,
668        sell_token: Bytes,
669        buy_token: Bytes,
670    ) -> Result<(BigUint, BigUint), SimulationError> {
671        let sell_token = bytes_to_address(&sell_token)?;
672        let buy_token = bytes_to_address(&buy_token)?;
673        let overwrites =
674            self.get_overwrites(vec![sell_token, buy_token], *MAX_BALANCE / U256::from(100))?;
675        let limits = self.get_amount_limits(vec![sell_token, buy_token], Some(overwrites))?;
676        Ok((u256_to_biguint(limits.0), u256_to_biguint(limits.1)))
677    }
678
679    fn delta_transition(
680        &mut self,
681        delta: ProtocolStateDelta,
682        tokens: &HashMap<Bytes, Token>,
683        balances: &Balances,
684    ) -> Result<(), TransitionError<String>> {
685        if self.manual_updates {
686            // Directly check for "update_marker" in `updated_attributes`
687            if let Some(marker) = delta
688                .updated_attributes
689                .get("update_marker")
690            {
691                // Assuming `marker` is of type `Bytes`, check its value for "truthiness"
692                if !marker.is_empty() && marker[0] != 0 {
693                    self.update_pool_state(tokens, balances)?;
694                }
695            }
696        } else {
697            self.update_pool_state(tokens, balances)?;
698        }
699
700        Ok(())
701    }
702
703    fn clone_box(&self) -> Box<dyn ProtocolSim> {
704        Box::new(self.clone())
705    }
706
707    fn as_any(&self) -> &dyn Any {
708        self
709    }
710
711    fn as_any_mut(&mut self) -> &mut dyn Any {
712        self
713    }
714
715    fn eq(&self, other: &dyn ProtocolSim) -> bool {
716        if let Some(other_state) = other
717            .as_any()
718            .downcast_ref::<EVMPoolState<PreCachedDB>>()
719        {
720            self.id == other_state.id
721        } else {
722            false
723        }
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use std::default::Default;
730
731    use num_traits::One;
732    use revm::{
733        primitives::KECCAK_EMPTY,
734        state::{AccountInfo, Bytecode},
735    };
736    use serde_json::Value;
737    use tycho_client::feed::BlockHeader;
738    use tycho_common::models::Chain;
739
740    use super::*;
741    use crate::evm::{
742        engine_db::{create_engine, SHARED_TYCHO_DB},
743        protocol::vm::{
744            constants::{BALANCER_V2, ERC20_PROXY_BYTECODE},
745            state_builder::EVMPoolStateBuilder,
746        },
747        simulation::SimulationEngine,
748        tycho_models::AccountUpdate,
749    };
750
751    fn dai() -> Token {
752        Token::new(
753            &Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
754            "DAI",
755            18,
756            0,
757            &[Some(10_000)],
758            Chain::Ethereum,
759            100,
760        )
761    }
762
763    fn bal() -> Token {
764        Token::new(
765            &Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(),
766            "BAL",
767            18,
768            0,
769            &[Some(10_000)],
770            Chain::Ethereum,
771            100,
772        )
773    }
774
775    fn dai_addr() -> Address {
776        bytes_to_address(&dai().address).unwrap()
777    }
778
779    fn bal_addr() -> Address {
780        bytes_to_address(&bal().address).unwrap()
781    }
782
783    async fn setup_pool_state() -> EVMPoolState<PreCachedDB> {
784        let data_str = include_str!("assets/balancer_contract_storage_block_20463609.json");
785        let data: Value = serde_json::from_str(data_str).expect("Failed to parse JSON");
786
787        let accounts: Vec<AccountUpdate> = serde_json::from_value(data["accounts"].clone())
788            .expect("Expected accounts to match AccountUpdate structure");
789
790        let db = SHARED_TYCHO_DB.clone();
791        let engine: SimulationEngine<_> = create_engine(db.clone(), false).unwrap();
792
793        let block = BlockHeader {
794            number: 20463609,
795            hash: Bytes::from_str(
796                "0x4315fd1afc25cc2ebc72029c543293f9fd833eeb305e2e30159459c827733b1b",
797            )
798            .unwrap(),
799            timestamp: 1722875891,
800            ..Default::default()
801        };
802
803        for account in accounts.clone() {
804            engine
805                .state
806                .init_account(
807                    account.address,
808                    AccountInfo {
809                        balance: account.balance.unwrap_or_default(),
810                        nonce: 0u64,
811                        code_hash: KECCAK_EMPTY,
812                        code: account
813                            .code
814                            .clone()
815                            .map(|arg0: Vec<u8>| Bytecode::new_raw(arg0.into())),
816                    },
817                    None,
818                    false,
819                )
820                .expect("Failed to initialize account");
821        }
822        db.update(accounts, Some(block))
823            .unwrap();
824
825        let tokens = vec![dai().address, bal().address];
826        for token in &tokens {
827            engine
828                .state
829                .init_account(
830                    bytes_to_address(token).unwrap(),
831                    AccountInfo {
832                        balance: U256::from(0),
833                        nonce: 0,
834                        code_hash: KECCAK_EMPTY,
835                        code: Some(Bytecode::new_raw(ERC20_PROXY_BYTECODE.into())),
836                    },
837                    None,
838                    true,
839                )
840                .expect("Failed to initialize account");
841        }
842
843        let block = BlockHeader {
844            number: 18485417,
845            hash: Bytes::from_str(
846                "0x28d41d40f2ac275a4f5f621a636b9016b527d11d37d610a45ac3a821346ebf8c",
847            )
848            .expect("Invalid block hash"),
849            timestamp: 0,
850            ..Default::default()
851        };
852        db.update(vec![], Some(block.clone()))
853            .unwrap();
854
855        let pool_id: String =
856            "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
857
858        let stateless_contracts = HashMap::from([(
859            String::from("0x3de27efa2f1aa663ae5d458857e731c129069f29"),
860            Some(Vec::new()),
861        )]);
862
863        let balances = HashMap::from([
864            (dai_addr(), U256::from_str("178754012737301807104").unwrap()),
865            (bal_addr(), U256::from_str("91082987763369885696").unwrap()),
866        ]);
867        let adapter_address =
868            Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
869
870        EVMPoolStateBuilder::new(pool_id, tokens, adapter_address)
871            .balances(balances)
872            .balance_owner(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap())
873            .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
874            .stateless_contracts(stateless_contracts)
875            .build(SHARED_TYCHO_DB.clone())
876            .await
877            .expect("Failed to build pool state")
878    }
879
880    #[tokio::test]
881    async fn test_init() {
882        // Clear DB from this test to prevent interference from other tests
883        SHARED_TYCHO_DB
884            .clear()
885            .expect("Failed to cleared SHARED TX");
886        let pool_state = setup_pool_state().await;
887
888        let expected_capabilities = vec![
889            Capability::SellSide,
890            Capability::BuySide,
891            Capability::PriceFunction,
892            Capability::HardLimits,
893        ]
894        .into_iter()
895        .collect::<HashSet<_>>();
896
897        let capabilities_adapter_contract = pool_state
898            .adapter_contract
899            .get_capabilities(
900                &pool_state.id,
901                bytes_to_address(&pool_state.tokens[0]).unwrap(),
902                bytes_to_address(&pool_state.tokens[1]).unwrap(),
903            )
904            .unwrap();
905
906        assert_eq!(capabilities_adapter_contract, expected_capabilities.clone());
907
908        let capabilities_state = pool_state.clone().capabilities;
909
910        assert_eq!(capabilities_state, expected_capabilities.clone());
911
912        for capability in expected_capabilities.clone() {
913            assert!(pool_state
914                .clone()
915                .ensure_capability(capability)
916                .is_ok());
917        }
918
919        assert!(pool_state
920            .clone()
921            .ensure_capability(Capability::MarginalPrice)
922            .is_err());
923
924        // Verify all tokens are initialized in the engine
925        let engine_accounts = pool_state
926            .adapter_contract
927            .engine
928            .state
929            .clone()
930            .get_account_storage()
931            .expect("Failed to get account storage");
932        for token in pool_state.tokens.clone() {
933            let account = engine_accounts
934                .get_account_info(&bytes_to_address(&token).unwrap())
935                .unwrap();
936            assert_eq!(account.balance, U256::from(0));
937            assert_eq!(account.nonce, 0u64);
938            assert_eq!(account.code_hash, KECCAK_EMPTY);
939            assert!(account.code.is_some());
940        }
941
942        // Verify external account is initialized in the engine
943        let external_account = engine_accounts
944            .get_account_info(&EXTERNAL_ACCOUNT)
945            .unwrap();
946        assert_eq!(external_account.balance, U256::from(*MAX_BALANCE));
947        assert_eq!(external_account.nonce, 0u64);
948        assert_eq!(external_account.code_hash, KECCAK_EMPTY);
949        assert!(external_account.code.is_none());
950    }
951
952    #[tokio::test]
953    async fn test_get_amount_out() -> Result<(), Box<dyn std::error::Error>> {
954        let pool_state = setup_pool_state().await;
955
956        let result = pool_state
957            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
958            .unwrap();
959        let new_state = result
960            .new_state
961            .as_any()
962            .downcast_ref::<EVMPoolState<PreCachedDB>>()
963            .unwrap();
964        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
965        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
966        assert!(pool_state
967            .block_lasting_overwrites
968            .is_empty());
969        Ok(())
970    }
971
972    #[tokio::test]
973    async fn test_sequential_get_amount_outs() {
974        let pool_state = setup_pool_state().await;
975
976        let result = pool_state
977            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
978            .unwrap();
979        let new_state = result
980            .new_state
981            .as_any()
982            .downcast_ref::<EVMPoolState<PreCachedDB>>()
983            .unwrap();
984        assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
985        assert_ne!(new_state.spot_prices, pool_state.spot_prices);
986
987        let new_result = new_state
988            .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
989            .unwrap();
990        let new_state_second_swap = new_result
991            .new_state
992            .as_any()
993            .downcast_ref::<EVMPoolState<PreCachedDB>>()
994            .unwrap();
995
996        assert_eq!(new_result.amount, BigUint::from_str("136964651490065626").unwrap());
997        assert_ne!(new_state_second_swap.spot_prices, new_state.spot_prices);
998    }
999
1000    #[tokio::test]
1001    async fn test_get_amount_out_dust() {
1002        let pool_state = setup_pool_state().await;
1003
1004        let result = pool_state
1005            .get_amount_out(BigUint::one(), &dai(), &bal())
1006            .unwrap();
1007
1008        let _ = result
1009            .new_state
1010            .as_any()
1011            .downcast_ref::<EVMPoolState<PreCachedDB>>()
1012            .unwrap();
1013        assert_eq!(result.amount, BigUint::ZERO);
1014    }
1015
1016    #[tokio::test]
1017    async fn test_get_amount_out_sell_limit() {
1018        let pool_state = setup_pool_state().await;
1019
1020        let result = pool_state.get_amount_out(
1021            // sell limit is 100279494253364362835
1022            BigUint::from_str("100379494253364362835").unwrap(),
1023            &dai(),
1024            &bal(),
1025        );
1026
1027        assert!(result.is_err());
1028
1029        match result {
1030            Err(SimulationError::InvalidInput(msg1, amount_out_result)) => {
1031                assert_eq!(msg1, "Sell amount exceeds limit 100279494253364362835");
1032                assert!(amount_out_result.is_some());
1033            }
1034            _ => panic!("Test failed: was expecting an Err(SimulationError::RetryDifferentInput(_, _)) value"),
1035        }
1036    }
1037
1038    #[tokio::test]
1039    async fn test_get_amount_limits() {
1040        let pool_state = setup_pool_state().await;
1041
1042        let overwrites = pool_state
1043            .get_overwrites(
1044                vec![
1045                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1046                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1047                ],
1048                *MAX_BALANCE / U256::from(100),
1049            )
1050            .unwrap();
1051        let (dai_limit, _) = pool_state
1052            .get_amount_limits(vec![dai_addr(), bal_addr()], Some(overwrites.clone()))
1053            .unwrap();
1054        assert_eq!(dai_limit, U256::from_str("100279494253364362835").unwrap());
1055
1056        let (bal_limit, _) = pool_state
1057            .get_amount_limits(
1058                vec![
1059                    bytes_to_address(&pool_state.tokens[1]).unwrap(),
1060                    bytes_to_address(&pool_state.tokens[0]).unwrap(),
1061                ],
1062                Some(overwrites),
1063            )
1064            .unwrap();
1065        assert_eq!(bal_limit, U256::from_str("13997408640689987484").unwrap());
1066    }
1067
1068    #[tokio::test]
1069    async fn test_set_spot_prices() {
1070        let mut pool_state = setup_pool_state().await;
1071
1072        pool_state
1073            .set_spot_prices(
1074                &vec![bal(), dai()]
1075                    .into_iter()
1076                    .map(|t| (t.address.clone(), t))
1077                    .collect(),
1078            )
1079            .unwrap();
1080
1081        let dai_bal_spot_price = pool_state
1082            .spot_prices
1083            .get(&(
1084                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1085                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1086            ))
1087            .unwrap();
1088        let bal_dai_spot_price = pool_state
1089            .spot_prices
1090            .get(&(
1091                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1092                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1093            ))
1094            .unwrap();
1095        assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9);
1096        assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);
1097    }
1098
1099    #[tokio::test]
1100    async fn test_set_spot_prices_without_capability() {
1101        // Tests set Spot Prices functions when the pool doesn't have PriceFunction capability
1102        let mut pool_state = setup_pool_state().await;
1103
1104        pool_state
1105            .capabilities
1106            .remove(&Capability::PriceFunction);
1107
1108        pool_state
1109            .set_spot_prices(
1110                &vec![bal(), dai()]
1111                    .into_iter()
1112                    .map(|t| (t.address.clone(), t))
1113                    .collect(),
1114            )
1115            .unwrap();
1116
1117        let dai_bal_spot_price = pool_state
1118            .spot_prices
1119            .get(&(
1120                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1121                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1122            ))
1123            .unwrap();
1124        let bal_dai_spot_price = pool_state
1125            .spot_prices
1126            .get(&(
1127                bytes_to_address(&pool_state.tokens[1]).unwrap(),
1128                bytes_to_address(&pool_state.tokens[0]).unwrap(),
1129            ))
1130            .unwrap();
1131        assert_eq!(dai_bal_spot_price, &0.13736685496467538);
1132        assert_eq!(bal_dai_spot_price, &7.050354297665408);
1133    }
1134
1135    #[tokio::test]
1136    async fn test_get_balance_overwrites_with_component_balances() {
1137        let pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1138
1139        let overwrites = pool_state
1140            .get_balance_overwrites()
1141            .unwrap();
1142
1143        let dai_address = dai_addr();
1144        let bal_address = bal_addr();
1145        assert!(overwrites.contains_key(&dai_address));
1146        assert!(overwrites.contains_key(&bal_address));
1147    }
1148
1149    #[tokio::test]
1150    async fn test_get_balance_overwrites_with_contract_balances() {
1151        let mut pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1152
1153        let contract_address =
1154            Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap();
1155
1156        // Ensure no component balances are used
1157        pool_state.balances.clear();
1158        pool_state.balance_owner = None;
1159
1160        // Set contract balances
1161        let dai_address = dai_addr();
1162        let bal_address = bal_addr();
1163        pool_state.contract_balances = HashMap::from([(
1164            contract_address,
1165            HashMap::from([
1166                (dai_address, U256::from_str("7500000000000000000000").unwrap()), // 7500 DAI
1167                (bal_address, U256::from_str("1500000000000000000000").unwrap()), // 1500 BAL
1168            ]),
1169        )]);
1170
1171        let overwrites = pool_state
1172            .get_balance_overwrites()
1173            .unwrap();
1174
1175        assert!(overwrites.contains_key(&dai_address));
1176        assert!(overwrites.contains_key(&bal_address));
1177    }
1178
1179    #[tokio::test]
1180    async fn test_balance_merging_during_delta_transition() {
1181        use std::str::FromStr;
1182
1183        let mut pool_state = setup_pool_state().await;
1184        let pool_id = pool_state.id.clone();
1185
1186        // Test the balance merging logic more directly
1187        // Setup initial balances including DAI and BAL (which the pool already knows about)
1188        let dai_addr = dai_addr();
1189        let bal_addr = bal_addr();
1190        let new_token = Address::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); // WETH
1191
1192        // Clear and setup clean initial state
1193        pool_state.balances.clear();
1194        pool_state
1195            .balances
1196            .insert(dai_addr, U256::from(1000000000u64));
1197        pool_state
1198            .balances
1199            .insert(bal_addr, U256::from(2000000000u64));
1200        pool_state
1201            .balances
1202            .insert(new_token, U256::from(3000000000u64));
1203
1204        // Create tokens mapping including the existing DAI and BAL
1205        let mut tokens = HashMap::new();
1206        tokens.insert(dai().address.clone(), dai());
1207        tokens.insert(bal().address.clone(), bal());
1208
1209        // Simulate a delta transition with only DAI balance update (missing BAL and new_token)
1210        let mut component_balances = HashMap::new();
1211        let mut delta_balances = HashMap::new();
1212        // Only update DAI balance, leave others unchanged in delta
1213        delta_balances.insert(dai().address.clone(), Bytes::from(vec![0x77, 0x35, 0x94, 0x00])); // 2000000000 (updated value)
1214        component_balances.insert(pool_id.clone(), delta_balances);
1215
1216        let balances = Balances { component_balances, account_balances: HashMap::new() };
1217
1218        // Record initial balance count
1219        let initial_balance_count = pool_state.balances.len();
1220        assert_eq!(initial_balance_count, 3);
1221
1222        // Apply delta transition
1223        pool_state
1224            .update_pool_state(&tokens, &balances)
1225            .unwrap();
1226
1227        // Verify that all 3 balances are preserved (BAL and new_token should still be there)
1228        assert_eq!(
1229            pool_state.balances.len(),
1230            3,
1231            "All balances should be preserved after delta transition"
1232        );
1233        assert!(
1234            pool_state
1235                .balances
1236                .contains_key(&dai_addr),
1237            "DAI balance should be present"
1238        );
1239        assert!(
1240            pool_state
1241                .balances
1242                .contains_key(&bal_addr),
1243            "BAL balance should be present"
1244        );
1245        assert!(
1246            pool_state
1247                .balances
1248                .contains_key(&new_token),
1249            "New token balance should be preserved from before delta"
1250        );
1251
1252        // Verify that updated token (DAI) has new value
1253        assert_eq!(
1254            pool_state.balances[&dai_addr],
1255            U256::from(2000000000u64),
1256            "DAI balance should be updated"
1257        );
1258
1259        // Verify that non-updated tokens retain their original values
1260        assert_eq!(
1261            pool_state.balances[&bal_addr],
1262            U256::from(2000000000u64),
1263            "BAL balance should be unchanged"
1264        );
1265        assert_eq!(
1266            pool_state.balances[&new_token],
1267            U256::from(3000000000u64),
1268            "New token balance should be unchanged"
1269        );
1270    }
1271}