1use std::{
2 collections::{HashMap, HashSet},
3 fmt::Debug,
4};
5
6use alloy::{
7 primitives::{Address, Bytes, Keccak256, U256},
8 sol_types::SolValue,
9};
10use itertools::Itertools;
11use revm::{
12 primitives::KECCAK_EMPTY,
13 state::{AccountInfo, Bytecode},
14 DatabaseRef,
15};
16use tracing::warn;
17use tycho_common::{simulation::errors::SimulationError, Bytes as TychoBytes};
18
19use super::{
20 constants::{EXTERNAL_ACCOUNT, MAX_BALANCE},
21 models::Capability,
22 state::EVMPoolState,
23 tycho_simulation_contract::TychoSimulationContract,
24 utils::get_code_for_contract,
25};
26use crate::evm::{
27 engine_db::{create_engine, engine_db_interface::EngineDatabaseInterface},
28 protocol::utils::bytes_to_address,
29 simulation::{BlockEnvOverrides, SimulationEngine, SimulationParameters},
30};
31
32#[derive(Debug)]
33pub struct EVMPoolStateBuilder<D: EngineDatabaseInterface + Clone + Debug>
80where
81 <D as DatabaseRef>::Error: Debug,
82 <D as EngineDatabaseInterface>::Error: Debug,
83{
84 id: String,
85 tokens: Vec<TychoBytes>,
86 balances: HashMap<Address, U256>,
87 adapter_address: Address,
88 balance_owner: Option<Address>,
89 capabilities: Option<HashSet<Capability>>,
90 involved_contracts: Option<HashSet<Address>>,
91 contract_balances: HashMap<Address, HashMap<Address, U256>>,
92 stateless_contracts: Option<HashMap<String, Option<Vec<u8>>>>,
93 manual_updates: Option<bool>,
94 trace: Option<bool>,
95 engine: Option<SimulationEngine<D>>,
96 adapter_contract: Option<TychoSimulationContract<D>>,
97 adapter_contract_bytecode: Option<Bytecode>,
98 disable_overwrite_tokens: HashSet<Address>,
99 block_overrides: Option<BlockEnvOverrides>,
100}
101
102impl<D> EVMPoolStateBuilder<D>
103where
104 D: EngineDatabaseInterface + Clone + Debug + 'static,
105 <D as DatabaseRef>::Error: Debug,
106 <D as EngineDatabaseInterface>::Error: Debug,
107{
108 pub fn new(id: String, tokens: Vec<TychoBytes>, adapter_address: Address) -> Self {
109 Self {
110 id,
111 tokens,
112 balances: HashMap::new(),
113 adapter_address,
114 balance_owner: None,
115 capabilities: None,
116 involved_contracts: None,
117 contract_balances: HashMap::new(),
118 stateless_contracts: None,
119 manual_updates: None,
120 trace: None,
121 engine: None,
122 adapter_contract: None,
123 adapter_contract_bytecode: None,
124 disable_overwrite_tokens: HashSet::new(),
125 block_overrides: None,
126 }
127 }
128
129 #[deprecated(note = "Use account balances instead")]
130 pub fn balance_owner(mut self, balance_owner: Address) -> Self {
131 self.balance_owner = Some(balance_owner);
132 self
133 }
134
135 pub fn balances(mut self, balances: HashMap<Address, U256>) -> Self {
138 self.balances = balances;
139 self
140 }
141
142 pub fn account_balances(
144 mut self,
145 account_balances: HashMap<Address, HashMap<Address, U256>>,
146 ) -> Self {
147 self.contract_balances = account_balances;
148 self
149 }
150
151 pub fn capabilities(mut self, capabilities: HashSet<Capability>) -> Self {
152 self.capabilities = Some(capabilities);
153 self
154 }
155
156 pub fn involved_contracts(mut self, involved_contracts: HashSet<Address>) -> Self {
157 self.involved_contracts = Some(involved_contracts);
158 self
159 }
160
161 pub fn stateless_contracts(
162 mut self,
163 stateless_contracts: HashMap<String, Option<Vec<u8>>>,
164 ) -> Self {
165 self.stateless_contracts = Some(stateless_contracts);
166 self
167 }
168 pub fn manual_updates(mut self, manual_updates: bool) -> Self {
169 self.manual_updates = Some(manual_updates);
170 self
171 }
172
173 pub fn trace(mut self, trace: bool) -> Self {
174 self.trace = Some(trace);
175 self
176 }
177
178 pub fn engine(mut self, engine: SimulationEngine<D>) -> Self {
179 self.engine = Some(engine);
180 self
181 }
182
183 pub fn adapter_contract(mut self, adapter_contract: TychoSimulationContract<D>) -> Self {
184 self.adapter_contract = Some(adapter_contract);
185 self
186 }
187
188 pub fn adapter_contract_bytecode(mut self, adapter_contract_bytecode: Bytecode) -> Self {
189 self.adapter_contract_bytecode = Some(adapter_contract_bytecode);
190 self
191 }
192
193 pub fn disable_overwrite_tokens(mut self, disable_overwrite_tokens: HashSet<Address>) -> Self {
194 self.disable_overwrite_tokens = disable_overwrite_tokens;
195 self
196 }
197
198 pub fn block_overrides(mut self, block_overrides: Option<BlockEnvOverrides>) -> Self {
199 self.block_overrides = block_overrides;
200 self
201 }
202
203 pub async fn build(mut self, db: D) -> Result<EVMPoolState<D>, SimulationError> {
205 let engine = if let Some(engine) = &self.engine {
206 engine.clone()
207 } else {
208 self.engine = Some(self.get_default_engine(db).await?);
209 self.engine.clone().ok_or_else(|| {
210 SimulationError::FatalError(
211 "Failed to get build engine: Engine not initialized".to_string(),
212 )
213 })?
214 };
215
216 if self.adapter_contract.is_none() {
217 self.adapter_contract = Some(TychoSimulationContract::new_contract(
218 self.adapter_address,
219 self.adapter_contract_bytecode
220 .clone()
221 .ok_or_else(|| {
222 SimulationError::FatalError("Adapter contract bytecode not set".to_string())
223 })?,
224 engine.clone(),
225 )?)
226 };
227
228 let capabilities = if let Some(capabilities) = &self.capabilities {
229 capabilities.clone()
230 } else {
231 self.get_default_capabilities()?
232 };
233
234 let adapter_contract = self.adapter_contract.ok_or_else(|| {
235 SimulationError::FatalError(
236 "Failed to get build engine: Adapter contract not initialized".to_string(),
237 )
238 })?;
239
240 Ok(EVMPoolState::new(
241 self.id,
242 self.tokens,
243 self.balances,
244 self.balance_owner,
245 self.contract_balances,
246 HashMap::new(),
247 capabilities,
248 HashMap::new(),
249 self.involved_contracts
250 .unwrap_or_default(),
251 self.manual_updates.unwrap_or(false),
252 adapter_contract,
253 self.disable_overwrite_tokens,
254 self.block_overrides,
255 ))
256 }
257
258 async fn get_default_engine(&self, db: D) -> Result<SimulationEngine<D>, SimulationError> {
259 let engine = create_engine(db, self.trace.unwrap_or(false))?;
260
261 engine
262 .state
263 .init_account(
264 *EXTERNAL_ACCOUNT,
265 AccountInfo {
266 balance: *MAX_BALANCE,
267 nonce: 0,
268 code_hash: KECCAK_EMPTY,
269 code: None,
270 },
271 None,
272 false,
273 )
274 .map_err(|err| {
275 SimulationError::FatalError(format!(
276 "Failed to get default engine: Failed to init external account: {err:?}"
277 ))
278 })?;
279
280 if let Some(stateless_contracts) = &self.stateless_contracts {
281 for (address, bytecode) in stateless_contracts.iter() {
282 let mut addr_str = address.clone();
283 let (code, code_hash) = if bytecode.is_none() {
284 if addr_str.starts_with("call") {
285 addr_str = self
286 .get_address_from_call(&engine, &addr_str)?
287 .to_string();
288 }
289 let code = get_code_for_contract(&addr_str, None).await?;
290 (Some(code.clone()), code.hash_slow())
291 } else {
292 let code =
293 Bytecode::new_raw(Bytes::from(bytecode.clone().ok_or_else(|| {
294 SimulationError::FatalError(
295 "Failed to get default engine: Byte code from stateless contracts is None".into(),
296 )
297 })?));
298 (Some(code.clone()), code.hash_slow())
299 };
300 let account_address: Address = addr_str.parse().map_err(|_| {
301 SimulationError::FatalError(format!(
302 "Failed to get default engine: Couldn't parse address string {address}"
303 ))
304 })?;
305 engine.state.init_account(
306 Address(*account_address),
307 AccountInfo { balance: Default::default(), nonce: 0, code_hash, code },
308 None,
309 false,
310 ).map_err(|err| {
311 SimulationError::FatalError(format!(
312 "Failed to get default engine: Failed to init stateless contract account: {err:?}"
313 ))
314 })?;
315 }
316 }
317 Ok(engine)
318 }
319
320 fn get_default_capabilities(&mut self) -> Result<HashSet<Capability>, SimulationError> {
321 let mut capabilities = Vec::new();
322
323 for tokens_pair in self.tokens.iter().permutations(2) {
325 if let [t0, t1] = tokens_pair[..] {
327 let caps = self
328 .adapter_contract
329 .clone()
330 .ok_or_else(|| {
331 SimulationError::FatalError(
332 "Failed to get default capabilities: Adapter contract not initialized"
333 .to_string(),
334 )
335 })?
336 .get_capabilities(&self.id, bytes_to_address(t0)?, bytes_to_address(t1)?)?;
337 capabilities.push(caps);
338 }
339 }
340
341 let max_capabilities = capabilities
343 .iter()
344 .map(|c| c.len())
345 .max()
346 .unwrap_or(0);
347
348 let common_capabilities: HashSet<_> = capabilities
350 .iter()
351 .fold(capabilities[0].clone(), |acc, cap| acc.intersection(cap).cloned().collect());
352
353 if common_capabilities.len() < max_capabilities {
355 warn!(
356 "Warning: Pool {} has different capabilities depending on the token pair!",
357 self.id
358 );
359 }
360 Ok(common_capabilities)
361 }
362
363 fn get_address_from_call(
373 &self,
374 engine: &SimulationEngine<D>,
375 decoded: &str,
376 ) -> Result<Address, SimulationError> {
377 let method_name = decoded
378 .split(':')
379 .next_back()
380 .ok_or_else(|| {
381 SimulationError::FatalError(
382 "Failed to get address from call: Could not decode method name from call"
383 .into(),
384 )
385 })?;
386
387 let selector = {
388 let mut hasher = Keccak256::new();
389 hasher.update(method_name.as_bytes());
390 let result = hasher.finalize();
391 result[..4].to_vec()
392 };
393
394 let to_address = decoded
395 .split(':')
396 .nth(1)
397 .ok_or_else(|| {
398 SimulationError::FatalError(
399 "Failed to get address from call: Could not decode to_address from call".into(),
400 )
401 })?;
402
403 let parsed_address: Address = to_address.parse().map_err(|_| {
404 SimulationError::FatalError(format!(
405 "Failed to get address from call: Invalid address format: {to_address}"
406 ))
407 })?;
408
409 let sim_params = SimulationParameters {
410 data: selector.to_vec(),
411 to: parsed_address,
412 overrides: Some(HashMap::new()),
413 caller: *EXTERNAL_ACCOUNT,
414 value: U256::from(0u64),
415 gas_limit: None,
416 transient_storage: None,
417 block_overrides: None,
418 };
419
420 let sim_result = engine
421 .simulate(&sim_params)
422 .map_err(|err| SimulationError::FatalError(err.to_string()))?;
423
424 let address: Address = Address::abi_decode(&sim_result.result).map_err(|e| {
425 SimulationError::FatalError(format!("Failed to get address from call: Failed to decode address list from simulation result {e:?}"))
426 })?;
427
428 Ok(address)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use std::str::FromStr;
435
436 use super::*;
437 use crate::evm::engine_db::{tycho_db::PreCachedDB, SHARED_TYCHO_DB};
438
439 #[test]
440 fn test_build_without_required_fields() {
441 let id = "pool_1".to_string();
442 let tokens =
443 vec![TychoBytes::from_str("0000000000000000000000000000000000000000").unwrap()];
444 let balances = HashMap::new();
445 let adapter_address =
446 Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
447 let result = tokio_test::block_on(
448 EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address)
449 .balances(balances)
450 .build(SHARED_TYCHO_DB.clone()),
451 );
452
453 assert!(result.is_err());
454 match result.unwrap_err() {
455 SimulationError::FatalError(field) => {
456 assert_eq!(field, "Adapter contract bytecode not set")
457 }
458 _ => panic!("Unexpected error type"),
459 }
460 }
461
462 #[test]
463 fn test_engine_setup() {
464 let id = "pool_1".to_string();
465 let token2 = TychoBytes::from_str("0000000000000000000000000000000000000002").unwrap();
466 let token3 = TychoBytes::from_str("0000000000000000000000000000000000000003").unwrap();
467 let tokens = vec![token2.clone(), token3.clone()];
468 let balances = HashMap::new();
469 let adapter_address =
470 Address::from_str("0xA2C5C98A892fD6656a7F39A2f63228C0Bc846270").unwrap();
471 let builder =
472 EVMPoolStateBuilder::<PreCachedDB>::new(id, tokens, adapter_address).balances(balances);
473
474 let engine = tokio_test::block_on(builder.get_default_engine(SHARED_TYCHO_DB.clone()));
475
476 assert!(engine.is_ok());
477 let engine = engine.unwrap();
478 assert!(engine
479 .state
480 .get_account_storage()
481 .expect("Failed to get account storage")
482 .account_present(&EXTERNAL_ACCOUNT));
483 }
484}