1#![allow(deprecated)]
2use std::{
3 any::Any,
4 collections::{HashMap, HashSet},
5 fmt::{self, Debug},
6 str::FromStr,
7};
8
9use alloy::primitives::{Address, U256};
10use itertools::Itertools;
11use num_bigint::BigUint;
12use revm::DatabaseRef;
13use tycho_common::{
14 dto::ProtocolStateDelta,
15 models::token::Token,
16 simulation::{
17 errors::{SimulationError, TransitionError},
18 protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
19 },
20 Bytes,
21};
22
23use super::{
24 constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
25 erc20_token::{Overwrites, TokenProxyOverwriteFactory},
26 models::Capability,
27 tycho_simulation_contract::TychoSimulationContract,
28};
29use crate::evm::{
30 engine_db::{engine_db_interface::EngineDatabaseInterface, tycho_db::PreCachedDB},
31 protocol::{
32 u256_num::{u256_to_biguint, u256_to_f64},
33 utils::bytes_to_address,
34 },
35};
36
37#[derive(Clone)]
38pub struct EVMPoolState<D: EngineDatabaseInterface + Clone + Debug>
39where
40 <D as DatabaseRef>::Error: Debug,
41 <D as EngineDatabaseInterface>::Error: Debug,
42{
43 id: String,
45 pub tokens: Vec<Bytes>,
47 balances: HashMap<Address, U256>,
49 #[deprecated(note = "Use contract_balances instead")]
53 balance_owner: Option<Address>,
54 spot_prices: HashMap<(Address, Address), f64>,
56 capabilities: HashSet<Capability>,
58 block_lasting_overwrites: HashMap<Address, Overwrites>,
61 involved_contracts: HashSet<Address>,
63 contract_balances: HashMap<Address, HashMap<Address, U256>>,
65 manual_updates: bool,
69 adapter_contract: TychoSimulationContract<D>,
71 disable_overwrite_tokens: HashSet<Address>,
73}
74
75impl<D> Debug for EVMPoolState<D>
76where
77 D: EngineDatabaseInterface + Clone + Debug,
78 <D as DatabaseRef>::Error: Debug,
79 <D as EngineDatabaseInterface>::Error: Debug,
80{
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.debug_struct("EVMPoolState")
83 .field("id", &self.id)
84 .field("tokens", &self.tokens)
85 .field("balances", &self.balances)
86 .field("involved_contracts", &self.involved_contracts)
87 .field("contract_balances", &self.contract_balances)
88 .finish_non_exhaustive()
89 }
90}
91
92impl<D> EVMPoolState<D>
93where
94 D: EngineDatabaseInterface + Clone + Debug + 'static,
95 <D as DatabaseRef>::Error: Debug,
96 <D as EngineDatabaseInterface>::Error: Debug,
97{
98 #[allow(clippy::too_many_arguments)]
103 pub fn new(
104 id: String,
105 tokens: Vec<Bytes>,
106 component_balances: HashMap<Address, U256>,
107 balance_owner: Option<Address>,
108 contract_balances: HashMap<Address, HashMap<Address, U256>>,
109 spot_prices: HashMap<(Address, Address), f64>,
110 capabilities: HashSet<Capability>,
111 block_lasting_overwrites: HashMap<Address, Overwrites>,
112 involved_contracts: HashSet<Address>,
113 manual_updates: bool,
114 adapter_contract: TychoSimulationContract<D>,
115 disable_overwrite_tokens: HashSet<Address>,
116 ) -> Self {
117 Self {
118 id,
119 tokens,
120 balances: component_balances,
121 balance_owner,
122 spot_prices,
123 capabilities,
124 block_lasting_overwrites,
125 involved_contracts,
126 contract_balances,
127 manual_updates,
128 adapter_contract,
129 disable_overwrite_tokens,
130 }
131 }
132
133 fn ensure_capability(&self, capability: Capability) -> Result<(), SimulationError> {
144 if !self.capabilities.contains(&capability) {
145 return Err(SimulationError::FatalError(format!(
146 "capability {:?} not supported",
147 capability.to_string()
148 )));
149 }
150 Ok(())
151 }
152 pub fn set_spot_prices(
188 &mut self,
189 tokens: &HashMap<Bytes, Token>,
190 ) -> Result<(), SimulationError> {
191 match self.ensure_capability(Capability::PriceFunction) {
192 Ok(_) => {
193 for [sell_token_address, buy_token_address] in self
194 .tokens
195 .iter()
196 .permutations(2)
197 .map(|p| [p[0], p[1]])
198 {
199 let sell_token_address = bytes_to_address(sell_token_address)?;
200 let buy_token_address = bytes_to_address(buy_token_address)?;
201
202 let overwrites = Some(self.get_overwrites(
203 vec![sell_token_address, buy_token_address],
204 *MAX_BALANCE / U256::from(100),
205 )?);
206
207 let (sell_amount_limit, _) = self.get_amount_limits(
208 vec![sell_token_address, buy_token_address],
209 overwrites.clone(),
210 )?;
211 let price_result = self.adapter_contract.price(
212 &self.id,
213 sell_token_address,
214 buy_token_address,
215 vec![sell_amount_limit / U256::from(100)],
216 overwrites,
217 )?;
218
219 let price = if self
220 .capabilities
221 .contains(&Capability::ScaledPrice)
222 {
223 *price_result.first().ok_or_else(|| {
224 SimulationError::FatalError(
225 "Calculated price array is empty".to_string(),
226 )
227 })?
228 } else {
229 let unscaled_price = price_result.first().ok_or_else(|| {
230 SimulationError::FatalError(
231 "Calculated price array is empty".to_string(),
232 )
233 })?;
234 let sell_token_decimals = self.get_decimals(tokens, &sell_token_address)?;
235 let buy_token_decimals = self.get_decimals(tokens, &buy_token_address)?;
236 *unscaled_price * 10f64.powi(sell_token_decimals as i32) /
237 10f64.powi(buy_token_decimals as i32)
238 };
239
240 self.spot_prices
241 .insert((sell_token_address, buy_token_address), price);
242 }
243 }
244 Err(SimulationError::FatalError(_)) => {
245 for iter_tokens in self.tokens.iter().permutations(2) {
249 let t0 = bytes_to_address(iter_tokens[0])?;
250 let t1 = bytes_to_address(iter_tokens[1])?;
251
252 let overwrites =
253 Some(self.get_overwrites(vec![t0, t1], *MAX_BALANCE / U256::from(100))?);
254
255 let x1 = self
257 .get_amount_limits(vec![t0, t1], overwrites.clone())?
258 .0 /
259 U256::from(100);
260
261 let x2 = x1 + (x1 / U256::from(100));
264
265 let y1 = self
268 .adapter_contract
269 .swap(&self.id, t0, t1, false, x1, overwrites.clone())?
270 .0
271 .received_amount;
272
273 let y2 = self
276 .adapter_contract
277 .swap(&self.id, t0, t1, false, x2, overwrites)?
278 .0
279 .received_amount;
280
281 let sell_token_decimals = self.get_decimals(tokens, &t0)?;
282 let buy_token_decimals = self.get_decimals(tokens, &t1)?;
283
284 let num = y2 - y1;
285 let den = x2 - x1;
286
287 let token_correction =
289 10f64.powi(sell_token_decimals as i32 - buy_token_decimals as i32);
290 let num_f64 = u256_to_f64(num)?;
291 let den_f64 = u256_to_f64(den)?;
292 if den_f64 == 0.0 {
293 return Err(SimulationError::FatalError(
294 "Failed to compute marginal price: denominator converted to 0".into(),
295 ));
296 }
297 let marginal_price = num_f64 / den_f64 * token_correction;
298
299 self.spot_prices
300 .insert((t0, t1), marginal_price);
301 }
302 }
303 Err(e) => return Err(e),
304 }
305
306 Ok(())
307 }
308
309 fn get_decimals(
310 &self,
311 tokens: &HashMap<Bytes, Token>,
312 sell_token_address: &Address,
313 ) -> Result<usize, SimulationError> {
314 tokens
315 .get(&Bytes::from(sell_token_address.as_slice()))
316 .map(|t| t.decimals as usize)
317 .ok_or_else(|| {
318 SimulationError::FatalError(format!(
319 "Failed to scale spot prices! Pool: {} Token 0x{:x} is not available!",
320 self.id, sell_token_address
321 ))
322 })
323 }
324
325 fn get_amount_limits(
342 &self,
343 tokens: Vec<Address>,
344 overwrites: Option<HashMap<Address, HashMap<U256, U256>>>,
345 ) -> Result<(U256, U256), SimulationError> {
346 let limits = self
347 .adapter_contract
348 .get_limits(&self.id, tokens[0], tokens[1], overwrites)?;
349
350 Ok(limits)
351 }
352
353 fn update_pool_state(
364 &mut self,
365 tokens: &HashMap<Bytes, Token>,
366 balances: &Balances,
367 ) -> Result<(), SimulationError> {
368 self.adapter_contract
370 .engine
371 .clear_temp_storage()
372 .map_err(|err| {
373 SimulationError::FatalError(format!("Failed to clear temporary storage: {err:?}",))
374 })?;
375 self.block_lasting_overwrites.clear();
376
377 if !self.balances.is_empty() {
379 if let Some(bals) = balances
381 .component_balances
382 .get(&self.id)
383 {
384 for (token, bal) in bals {
387 let addr = bytes_to_address(token).map_err(|_| {
388 SimulationError::FatalError(format!(
389 "Invalid token address in balance update: {token:?}"
390 ))
391 })?;
392 self.balances
393 .insert(addr, U256::from_be_slice(bal));
394 }
395 }
396 } else {
397 for contract in &self.involved_contracts {
399 if let Some(bals) = balances
400 .account_balances
401 .get(&Bytes::from(contract.as_slice()))
402 {
403 let contract_entry = self
404 .contract_balances
405 .entry(*contract)
406 .or_default();
407 for (token, bal) in bals {
408 let addr = bytes_to_address(token).map_err(|_| {
409 SimulationError::FatalError(format!(
410 "Invalid token address in balance update: {token:?}"
411 ))
412 })?;
413 contract_entry.insert(addr, U256::from_be_slice(bal));
414 }
415 }
416 }
417 }
418
419 self.set_spot_prices(tokens)?;
421 Ok(())
422 }
423
424 fn get_overwrites(
425 &self,
426 tokens: Vec<Address>,
427 max_amount: U256,
428 ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
429 let token_overwrites = self.get_token_overwrites(tokens, max_amount)?;
430
431 let merged_overwrites =
433 self.merge(&self.block_lasting_overwrites.clone(), &token_overwrites);
434
435 Ok(merged_overwrites)
436 }
437
438 fn get_token_overwrites(
439 &self,
440 tokens: Vec<Address>,
441 max_amount: U256,
442 ) -> Result<HashMap<Address, Overwrites>, SimulationError> {
443 let sell_token = &tokens[0].clone(); let mut res: Vec<HashMap<Address, Overwrites>> = Vec::new();
445 if !self
446 .capabilities
447 .contains(&Capability::TokenBalanceIndependent)
448 {
449 res.push(self.get_balance_overwrites()?);
450 }
451
452 let mut overwrites = TokenProxyOverwriteFactory::new(*sell_token, None);
453
454 overwrites.set_balance(max_amount, Address::from_slice(&*EXTERNAL_ACCOUNT.0));
455
456 overwrites.set_allowance(max_amount, self.adapter_contract.address, *EXTERNAL_ACCOUNT);
458
459 res.push(overwrites.get_overwrites());
460
461 Ok(res
463 .into_iter()
464 .fold(HashMap::new(), |acc, overwrite| self.merge(&acc, &overwrite)))
465 }
466
467 fn get_balance_overwrites(&self) -> Result<HashMap<Address, Overwrites>, SimulationError> {
478 let mut balance_overwrites: HashMap<Address, Overwrites> = HashMap::new();
479
480 let address = match self.balance_owner {
482 Some(owner) => Some(owner),
483 None if !self.contract_balances.is_empty() => None,
484 None => Some(self.id.parse().map_err(|_| {
485 SimulationError::FatalError(
486 "Failed to get balance overwrites: Pool ID is not an address".into(),
487 )
488 })?),
489 };
490
491 if let Some(address) = address {
492 for (token, bal) in &self.balances {
495 let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
496 overwrites.set_balance(*bal, address);
497 balance_overwrites.extend(overwrites.get_overwrites());
498 }
499 }
500
501 for (contract, balances) in &self.contract_balances {
504 for (token, balance) in balances {
505 let mut overwrites = TokenProxyOverwriteFactory::new(*token, None);
506 overwrites.set_balance(*balance, *contract);
507 balance_overwrites.extend(overwrites.get_overwrites());
508 }
509 }
510
511 for token in &self.disable_overwrite_tokens {
513 balance_overwrites.remove(token);
514 }
515
516 Ok(balance_overwrites)
517 }
518
519 fn merge(
520 &self,
521 target: &HashMap<Address, Overwrites>,
522 source: &HashMap<Address, Overwrites>,
523 ) -> HashMap<Address, Overwrites> {
524 let mut merged = target.clone();
525
526 for (key, source_inner) in source {
527 merged
528 .entry(*key)
529 .or_default()
530 .extend(source_inner.clone());
531 }
532
533 merged
534 }
535
536 #[cfg(test)]
537 pub fn get_involved_contracts(&self) -> HashSet<Address> {
538 self.involved_contracts.clone()
539 }
540
541 #[cfg(test)]
542 pub fn get_manual_updates(&self) -> bool {
543 self.manual_updates
544 }
545
546 #[cfg(test)]
547 #[deprecated]
548 pub fn get_balance_owner(&self) -> Option<Address> {
549 self.balance_owner
550 }
551
552 pub fn get_balances(&self) -> &HashMap<Address, U256> {
554 &self.balances
555 }
556}
557
558impl<D> ProtocolSim for EVMPoolState<D>
559where
560 D: EngineDatabaseInterface + Clone + Debug + 'static,
561 <D as DatabaseRef>::Error: Debug,
562 <D as EngineDatabaseInterface>::Error: Debug,
563{
564 fn fee(&self) -> f64 {
565 todo!()
566 }
567
568 fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
569 let base_address = bytes_to_address(&base.address)?;
570 let quote_address = bytes_to_address("e.address)?;
571 self.spot_prices
572 .get(&(base_address, quote_address))
573 .cloned()
574 .ok_or(SimulationError::FatalError(format!(
575 "Spot price not found for base token {base_address} and quote token {quote_address}"
576 )))
577 }
578
579 fn get_amount_out(
580 &self,
581 amount_in: BigUint,
582 token_in: &Token,
583 token_out: &Token,
584 ) -> Result<GetAmountOutResult, SimulationError> {
585 let sell_token_address = bytes_to_address(&token_in.address)?;
586 let buy_token_address = bytes_to_address(&token_out.address)?;
587 let sell_amount = U256::from_be_slice(&amount_in.to_bytes_be());
588 let overwrites = self.get_overwrites(
589 vec![sell_token_address, buy_token_address],
590 *MAX_BALANCE / U256::from(100),
591 )?;
592 let (sell_amount_limit, _) = self.get_amount_limits(
593 vec![sell_token_address, buy_token_address],
594 Some(overwrites.clone()),
595 )?;
596 let (sell_amount_respecting_limit, sell_amount_exceeds_limit) = if self
597 .capabilities
598 .contains(&Capability::HardLimits) &&
599 sell_amount_limit < sell_amount
600 {
601 (sell_amount_limit, true)
602 } else {
603 (sell_amount, false)
604 };
605
606 let overwrites_with_sell_limit =
607 self.get_overwrites(vec![sell_token_address, buy_token_address], sell_amount_limit)?;
608 let complete_overwrites = self.merge(&overwrites, &overwrites_with_sell_limit);
609
610 let (trade, state_changes) = self.adapter_contract.swap(
611 &self.id,
612 sell_token_address,
613 buy_token_address,
614 false,
615 sell_amount_respecting_limit,
616 Some(complete_overwrites),
617 )?;
618
619 let mut new_state = self.clone();
620
621 for (address, state_update) in state_changes {
623 if let Some(storage) = state_update.storage {
624 let block_overwrites = new_state
625 .block_lasting_overwrites
626 .entry(address)
627 .or_default();
628 for (slot, value) in storage {
629 let slot = U256::from_str(&slot.to_string()).map_err(|_| {
630 SimulationError::FatalError("Failed to decode slot index".to_string())
631 })?;
632 let value = U256::from_str(&value.to_string()).map_err(|_| {
633 SimulationError::FatalError("Failed to decode slot overwrite".to_string())
634 })?;
635 block_overwrites.insert(slot, value);
636 }
637 }
638 }
639
640 let tokens = HashMap::from([
642 (token_in.address.clone(), token_in.clone()),
643 (token_out.address.clone(), token_out.clone()),
644 ]);
645 let _ = new_state.set_spot_prices(&tokens);
646
647 let buy_amount = trade.received_amount;
648
649 if sell_amount_exceeds_limit {
650 return Err(SimulationError::InvalidInput(
651 format!("Sell amount exceeds limit {sell_amount_limit}"),
652 Some(GetAmountOutResult::new(
653 u256_to_biguint(buy_amount),
654 u256_to_biguint(trade.gas_used),
655 Box::new(new_state.clone()),
656 )),
657 ));
658 }
659 Ok(GetAmountOutResult::new(
660 u256_to_biguint(buy_amount),
661 u256_to_biguint(trade.gas_used),
662 Box::new(new_state.clone()),
663 ))
664 }
665
666 fn get_limits(
667 &self,
668 sell_token: Bytes,
669 buy_token: Bytes,
670 ) -> Result<(BigUint, BigUint), SimulationError> {
671 let sell_token = bytes_to_address(&sell_token)?;
672 let buy_token = bytes_to_address(&buy_token)?;
673 let overwrites =
674 self.get_overwrites(vec![sell_token, buy_token], *MAX_BALANCE / U256::from(100))?;
675 let limits = self.get_amount_limits(vec![sell_token, buy_token], Some(overwrites))?;
676 Ok((u256_to_biguint(limits.0), u256_to_biguint(limits.1)))
677 }
678
679 fn delta_transition(
680 &mut self,
681 delta: ProtocolStateDelta,
682 tokens: &HashMap<Bytes, Token>,
683 balances: &Balances,
684 ) -> Result<(), TransitionError<String>> {
685 if self.manual_updates {
686 if let Some(marker) = delta
688 .updated_attributes
689 .get("update_marker")
690 {
691 if !marker.is_empty() && marker[0] != 0 {
693 self.update_pool_state(tokens, balances)?;
694 }
695 }
696 } else {
697 self.update_pool_state(tokens, balances)?;
698 }
699
700 Ok(())
701 }
702
703 fn clone_box(&self) -> Box<dyn ProtocolSim> {
704 Box::new(self.clone())
705 }
706
707 fn as_any(&self) -> &dyn Any {
708 self
709 }
710
711 fn as_any_mut(&mut self) -> &mut dyn Any {
712 self
713 }
714
715 fn eq(&self, other: &dyn ProtocolSim) -> bool {
716 if let Some(other_state) = other
717 .as_any()
718 .downcast_ref::<EVMPoolState<PreCachedDB>>()
719 {
720 self.id == other_state.id
721 } else {
722 false
723 }
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use std::default::Default;
730
731 use num_traits::One;
732 use revm::{
733 primitives::KECCAK_EMPTY,
734 state::{AccountInfo, Bytecode},
735 };
736 use serde_json::Value;
737 use tycho_client::feed::BlockHeader;
738 use tycho_common::models::Chain;
739
740 use super::*;
741 use crate::evm::{
742 engine_db::{create_engine, SHARED_TYCHO_DB},
743 protocol::vm::{
744 constants::{BALANCER_V2, ERC20_PROXY_BYTECODE},
745 state_builder::EVMPoolStateBuilder,
746 },
747 simulation::SimulationEngine,
748 tycho_models::AccountUpdate,
749 };
750
751 fn dai() -> Token {
752 Token::new(
753 &Bytes::from_str("0x6b175474e89094c44da98b954eedeac495271d0f").unwrap(),
754 "DAI",
755 18,
756 0,
757 &[Some(10_000)],
758 Chain::Ethereum,
759 100,
760 )
761 }
762
763 fn bal() -> Token {
764 Token::new(
765 &Bytes::from_str("0xba100000625a3754423978a60c9317c58a424e3d").unwrap(),
766 "BAL",
767 18,
768 0,
769 &[Some(10_000)],
770 Chain::Ethereum,
771 100,
772 )
773 }
774
775 fn dai_addr() -> Address {
776 bytes_to_address(&dai().address).unwrap()
777 }
778
779 fn bal_addr() -> Address {
780 bytes_to_address(&bal().address).unwrap()
781 }
782
783 async fn setup_pool_state() -> EVMPoolState<PreCachedDB> {
784 let data_str = include_str!("assets/balancer_contract_storage_block_20463609.json");
785 let data: Value = serde_json::from_str(data_str).expect("Failed to parse JSON");
786
787 let accounts: Vec<AccountUpdate> = serde_json::from_value(data["accounts"].clone())
788 .expect("Expected accounts to match AccountUpdate structure");
789
790 let db = SHARED_TYCHO_DB.clone();
791 let engine: SimulationEngine<_> = create_engine(db.clone(), false).unwrap();
792
793 let block = BlockHeader {
794 number: 20463609,
795 hash: Bytes::from_str(
796 "0x4315fd1afc25cc2ebc72029c543293f9fd833eeb305e2e30159459c827733b1b",
797 )
798 .unwrap(),
799 timestamp: 1722875891,
800 ..Default::default()
801 };
802
803 for account in accounts.clone() {
804 engine
805 .state
806 .init_account(
807 account.address,
808 AccountInfo {
809 balance: account.balance.unwrap_or_default(),
810 nonce: 0u64,
811 code_hash: KECCAK_EMPTY,
812 code: account
813 .code
814 .clone()
815 .map(|arg0: Vec<u8>| Bytecode::new_raw(arg0.into())),
816 },
817 None,
818 false,
819 )
820 .expect("Failed to initialize account");
821 }
822 db.update(accounts, Some(block))
823 .unwrap();
824
825 let tokens = vec![dai().address, bal().address];
826 for token in &tokens {
827 engine
828 .state
829 .init_account(
830 bytes_to_address(token).unwrap(),
831 AccountInfo {
832 balance: U256::from(0),
833 nonce: 0,
834 code_hash: KECCAK_EMPTY,
835 code: Some(Bytecode::new_raw(ERC20_PROXY_BYTECODE.into())),
836 },
837 None,
838 true,
839 )
840 .expect("Failed to initialize account");
841 }
842
843 let block = BlockHeader {
844 number: 18485417,
845 hash: Bytes::from_str(
846 "0x28d41d40f2ac275a4f5f621a636b9016b527d11d37d610a45ac3a821346ebf8c",
847 )
848 .expect("Invalid block hash"),
849 timestamp: 0,
850 ..Default::default()
851 };
852 db.update(vec![], Some(block.clone()))
853 .unwrap();
854
855 let pool_id: String =
856 "0x4626d81b3a1711beb79f4cecff2413886d461677000200000000000000000011".into();
857
858 let stateless_contracts = HashMap::from([(
859 String::from("0x3de27efa2f1aa663ae5d458857e731c129069f29"),
860 Some(Vec::new()),
861 )]);
862
863 let balances = HashMap::from([
864 (dai_addr(), U256::from_str("178754012737301807104").unwrap()),
865 (bal_addr(), U256::from_str("91082987763369885696").unwrap()),
866 ]);
867 let adapter_address =
868 Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
869
870 EVMPoolStateBuilder::new(pool_id, tokens, adapter_address)
871 .balances(balances)
872 .balance_owner(Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap())
873 .adapter_contract_bytecode(Bytecode::new_raw(BALANCER_V2.into()))
874 .stateless_contracts(stateless_contracts)
875 .build(SHARED_TYCHO_DB.clone())
876 .await
877 .expect("Failed to build pool state")
878 }
879
880 #[tokio::test]
881 async fn test_init() {
882 SHARED_TYCHO_DB
884 .clear()
885 .expect("Failed to cleared SHARED TX");
886 let pool_state = setup_pool_state().await;
887
888 let expected_capabilities = vec![
889 Capability::SellSide,
890 Capability::BuySide,
891 Capability::PriceFunction,
892 Capability::HardLimits,
893 ]
894 .into_iter()
895 .collect::<HashSet<_>>();
896
897 let capabilities_adapter_contract = pool_state
898 .adapter_contract
899 .get_capabilities(
900 &pool_state.id,
901 bytes_to_address(&pool_state.tokens[0]).unwrap(),
902 bytes_to_address(&pool_state.tokens[1]).unwrap(),
903 )
904 .unwrap();
905
906 assert_eq!(capabilities_adapter_contract, expected_capabilities.clone());
907
908 let capabilities_state = pool_state.clone().capabilities;
909
910 assert_eq!(capabilities_state, expected_capabilities.clone());
911
912 for capability in expected_capabilities.clone() {
913 assert!(pool_state
914 .clone()
915 .ensure_capability(capability)
916 .is_ok());
917 }
918
919 assert!(pool_state
920 .clone()
921 .ensure_capability(Capability::MarginalPrice)
922 .is_err());
923
924 let engine_accounts = pool_state
926 .adapter_contract
927 .engine
928 .state
929 .clone()
930 .get_account_storage()
931 .expect("Failed to get account storage");
932 for token in pool_state.tokens.clone() {
933 let account = engine_accounts
934 .get_account_info(&bytes_to_address(&token).unwrap())
935 .unwrap();
936 assert_eq!(account.balance, U256::from(0));
937 assert_eq!(account.nonce, 0u64);
938 assert_eq!(account.code_hash, KECCAK_EMPTY);
939 assert!(account.code.is_some());
940 }
941
942 let external_account = engine_accounts
944 .get_account_info(&EXTERNAL_ACCOUNT)
945 .unwrap();
946 assert_eq!(external_account.balance, U256::from(*MAX_BALANCE));
947 assert_eq!(external_account.nonce, 0u64);
948 assert_eq!(external_account.code_hash, KECCAK_EMPTY);
949 assert!(external_account.code.is_none());
950 }
951
952 #[tokio::test]
953 async fn test_get_amount_out() -> Result<(), Box<dyn std::error::Error>> {
954 let pool_state = setup_pool_state().await;
955
956 let result = pool_state
957 .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
958 .unwrap();
959 let new_state = result
960 .new_state
961 .as_any()
962 .downcast_ref::<EVMPoolState<PreCachedDB>>()
963 .unwrap();
964 assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
965 assert_ne!(new_state.spot_prices, pool_state.spot_prices);
966 assert!(pool_state
967 .block_lasting_overwrites
968 .is_empty());
969 Ok(())
970 }
971
972 #[tokio::test]
973 async fn test_sequential_get_amount_outs() {
974 let pool_state = setup_pool_state().await;
975
976 let result = pool_state
977 .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
978 .unwrap();
979 let new_state = result
980 .new_state
981 .as_any()
982 .downcast_ref::<EVMPoolState<PreCachedDB>>()
983 .unwrap();
984 assert_eq!(result.amount, BigUint::from_str("137780051463393923").unwrap());
985 assert_ne!(new_state.spot_prices, pool_state.spot_prices);
986
987 let new_result = new_state
988 .get_amount_out(BigUint::from_str("1000000000000000000").unwrap(), &dai(), &bal())
989 .unwrap();
990 let new_state_second_swap = new_result
991 .new_state
992 .as_any()
993 .downcast_ref::<EVMPoolState<PreCachedDB>>()
994 .unwrap();
995
996 assert_eq!(new_result.amount, BigUint::from_str("136964651490065626").unwrap());
997 assert_ne!(new_state_second_swap.spot_prices, new_state.spot_prices);
998 }
999
1000 #[tokio::test]
1001 async fn test_get_amount_out_dust() {
1002 let pool_state = setup_pool_state().await;
1003
1004 let result = pool_state
1005 .get_amount_out(BigUint::one(), &dai(), &bal())
1006 .unwrap();
1007
1008 let _ = result
1009 .new_state
1010 .as_any()
1011 .downcast_ref::<EVMPoolState<PreCachedDB>>()
1012 .unwrap();
1013 assert_eq!(result.amount, BigUint::ZERO);
1014 }
1015
1016 #[tokio::test]
1017 async fn test_get_amount_out_sell_limit() {
1018 let pool_state = setup_pool_state().await;
1019
1020 let result = pool_state.get_amount_out(
1021 BigUint::from_str("100379494253364362835").unwrap(),
1023 &dai(),
1024 &bal(),
1025 );
1026
1027 assert!(result.is_err());
1028
1029 match result {
1030 Err(SimulationError::InvalidInput(msg1, amount_out_result)) => {
1031 assert_eq!(msg1, "Sell amount exceeds limit 100279494253364362835");
1032 assert!(amount_out_result.is_some());
1033 }
1034 _ => panic!("Test failed: was expecting an Err(SimulationError::RetryDifferentInput(_, _)) value"),
1035 }
1036 }
1037
1038 #[tokio::test]
1039 async fn test_get_amount_limits() {
1040 let pool_state = setup_pool_state().await;
1041
1042 let overwrites = pool_state
1043 .get_overwrites(
1044 vec![
1045 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1046 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1047 ],
1048 *MAX_BALANCE / U256::from(100),
1049 )
1050 .unwrap();
1051 let (dai_limit, _) = pool_state
1052 .get_amount_limits(vec![dai_addr(), bal_addr()], Some(overwrites.clone()))
1053 .unwrap();
1054 assert_eq!(dai_limit, U256::from_str("100279494253364362835").unwrap());
1055
1056 let (bal_limit, _) = pool_state
1057 .get_amount_limits(
1058 vec![
1059 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1060 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1061 ],
1062 Some(overwrites),
1063 )
1064 .unwrap();
1065 assert_eq!(bal_limit, U256::from_str("13997408640689987484").unwrap());
1066 }
1067
1068 #[tokio::test]
1069 async fn test_set_spot_prices() {
1070 let mut pool_state = setup_pool_state().await;
1071
1072 pool_state
1073 .set_spot_prices(
1074 &vec![bal(), dai()]
1075 .into_iter()
1076 .map(|t| (t.address.clone(), t))
1077 .collect(),
1078 )
1079 .unwrap();
1080
1081 let dai_bal_spot_price = pool_state
1082 .spot_prices
1083 .get(&(
1084 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1085 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1086 ))
1087 .unwrap();
1088 let bal_dai_spot_price = pool_state
1089 .spot_prices
1090 .get(&(
1091 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1092 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1093 ))
1094 .unwrap();
1095 assert_eq!(dai_bal_spot_price, &0.137_778_914_319_047_9);
1096 assert_eq!(bal_dai_spot_price, &7.071_503_245_428_246);
1097 }
1098
1099 #[tokio::test]
1100 async fn test_set_spot_prices_without_capability() {
1101 let mut pool_state = setup_pool_state().await;
1103
1104 pool_state
1105 .capabilities
1106 .remove(&Capability::PriceFunction);
1107
1108 pool_state
1109 .set_spot_prices(
1110 &vec![bal(), dai()]
1111 .into_iter()
1112 .map(|t| (t.address.clone(), t))
1113 .collect(),
1114 )
1115 .unwrap();
1116
1117 let dai_bal_spot_price = pool_state
1118 .spot_prices
1119 .get(&(
1120 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1121 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1122 ))
1123 .unwrap();
1124 let bal_dai_spot_price = pool_state
1125 .spot_prices
1126 .get(&(
1127 bytes_to_address(&pool_state.tokens[1]).unwrap(),
1128 bytes_to_address(&pool_state.tokens[0]).unwrap(),
1129 ))
1130 .unwrap();
1131 assert_eq!(dai_bal_spot_price, &0.13736685496467538);
1132 assert_eq!(bal_dai_spot_price, &7.050354297665408);
1133 }
1134
1135 #[tokio::test]
1136 async fn test_get_balance_overwrites_with_component_balances() {
1137 let pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1138
1139 let overwrites = pool_state
1140 .get_balance_overwrites()
1141 .unwrap();
1142
1143 let dai_address = dai_addr();
1144 let bal_address = bal_addr();
1145 assert!(overwrites.contains_key(&dai_address));
1146 assert!(overwrites.contains_key(&bal_address));
1147 }
1148
1149 #[tokio::test]
1150 async fn test_get_balance_overwrites_with_contract_balances() {
1151 let mut pool_state: EVMPoolState<PreCachedDB> = setup_pool_state().await;
1152
1153 let contract_address =
1154 Address::from_str("0xBA12222222228d8Ba445958a75a0704d566BF2C8").unwrap();
1155
1156 pool_state.balances.clear();
1158 pool_state.balance_owner = None;
1159
1160 let dai_address = dai_addr();
1162 let bal_address = bal_addr();
1163 pool_state.contract_balances = HashMap::from([(
1164 contract_address,
1165 HashMap::from([
1166 (dai_address, U256::from_str("7500000000000000000000").unwrap()), (bal_address, U256::from_str("1500000000000000000000").unwrap()), ]),
1169 )]);
1170
1171 let overwrites = pool_state
1172 .get_balance_overwrites()
1173 .unwrap();
1174
1175 assert!(overwrites.contains_key(&dai_address));
1176 assert!(overwrites.contains_key(&bal_address));
1177 }
1178
1179 #[tokio::test]
1180 async fn test_balance_merging_during_delta_transition() {
1181 use std::str::FromStr;
1182
1183 let mut pool_state = setup_pool_state().await;
1184 let pool_id = pool_state.id.clone();
1185
1186 let dai_addr = dai_addr();
1189 let bal_addr = bal_addr();
1190 let new_token = Address::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(); pool_state.balances.clear();
1194 pool_state
1195 .balances
1196 .insert(dai_addr, U256::from(1000000000u64));
1197 pool_state
1198 .balances
1199 .insert(bal_addr, U256::from(2000000000u64));
1200 pool_state
1201 .balances
1202 .insert(new_token, U256::from(3000000000u64));
1203
1204 let mut tokens = HashMap::new();
1206 tokens.insert(dai().address.clone(), dai());
1207 tokens.insert(bal().address.clone(), bal());
1208
1209 let mut component_balances = HashMap::new();
1211 let mut delta_balances = HashMap::new();
1212 delta_balances.insert(dai().address.clone(), Bytes::from(vec![0x77, 0x35, 0x94, 0x00])); component_balances.insert(pool_id.clone(), delta_balances);
1215
1216 let balances = Balances { component_balances, account_balances: HashMap::new() };
1217
1218 let initial_balance_count = pool_state.balances.len();
1220 assert_eq!(initial_balance_count, 3);
1221
1222 pool_state
1224 .update_pool_state(&tokens, &balances)
1225 .unwrap();
1226
1227 assert_eq!(
1229 pool_state.balances.len(),
1230 3,
1231 "All balances should be preserved after delta transition"
1232 );
1233 assert!(
1234 pool_state
1235 .balances
1236 .contains_key(&dai_addr),
1237 "DAI balance should be present"
1238 );
1239 assert!(
1240 pool_state
1241 .balances
1242 .contains_key(&bal_addr),
1243 "BAL balance should be present"
1244 );
1245 assert!(
1246 pool_state
1247 .balances
1248 .contains_key(&new_token),
1249 "New token balance should be preserved from before delta"
1250 );
1251
1252 assert_eq!(
1254 pool_state.balances[&dai_addr],
1255 U256::from(2000000000u64),
1256 "DAI balance should be updated"
1257 );
1258
1259 assert_eq!(
1261 pool_state.balances[&bal_addr],
1262 U256::from(2000000000u64),
1263 "BAL balance should be unchanged"
1264 );
1265 assert_eq!(
1266 pool_state.balances[&new_token],
1267 U256::from(3000000000u64),
1268 "New token balance should be unchanged"
1269 );
1270 }
1271}