1use std::{any::Any, collections::HashMap, fmt};
2
3use async_trait::async_trait;
4use num_bigint::BigUint;
5use num_traits::{FromPrimitive, Pow, ToPrimitive};
6use serde::{Deserialize, Serialize};
7use tycho_common::{
8 dto::ProtocolStateDelta,
9 models::{protocol::GetAmountOutParams, token::Token},
10 simulation::{
11 errors::{SimulationError, TransitionError},
12 indicatively_priced::{IndicativelyPriced, SignedQuote},
13 protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
14 },
15 Bytes,
16};
17
18use crate::rfq::{
19 client::RFQClient,
20 protocols::hashflow::{client::HashflowClient, models::HashflowMarketMakerLevels},
21};
22
23#[derive(Clone, Serialize, Deserialize)]
24pub struct HashflowState {
25 pub base_token: Token,
26 pub quote_token: Token,
27 pub levels: HashflowMarketMakerLevels,
28 pub market_maker: String,
29 pub client: HashflowClient,
30}
31
32impl fmt::Debug for HashflowState {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 f.debug_struct("HashflowState")
35 .field("base_token", &self.base_token)
36 .field("quote_token", &self.quote_token)
37 .field("market_maker", &self.market_maker)
38 .finish_non_exhaustive()
39 }
40}
41
42impl HashflowState {
43 pub fn new(
44 base_token: Token,
45 quote_token: Token,
46 levels: HashflowMarketMakerLevels,
47 market_maker: String,
48 client: HashflowClient,
49 ) -> Self {
50 Self { base_token, quote_token, levels, market_maker, client }
51 }
52
53 fn valid_direction_guard(
54 &self,
55 token_address_in: &Bytes,
56 token_address_out: &Bytes,
57 ) -> Result<(), SimulationError> {
58 if !(token_address_in == &self.base_token.address &&
60 token_address_out == &self.quote_token.address)
61 {
62 Err(SimulationError::InvalidInput(
63 format!("Invalid token addresses. Got in={token_address_in}, out={token_address_out}, expected in={}, out={}", self.base_token.address, self.quote_token.address),
64 None,
65 ))
66 } else {
67 Ok(())
68 }
69 }
70
71 fn valid_levels_guard(&self) -> Result<(), SimulationError> {
72 if self.levels.levels.is_empty() {
73 return Err(SimulationError::RecoverableError("No liquidity".into()));
74 }
75 Ok(())
76 }
77}
78
79#[typetag::serde]
80impl ProtocolSim for HashflowState {
81 fn fee(&self) -> f64 {
82 todo!()
83 }
84
85 fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
86 self.valid_direction_guard(&base.address, "e.address)?;
87
88 self.levels
90 .levels
91 .first()
92 .ok_or(SimulationError::RecoverableError("No liquidity".into()))
93 .map(|level| level.price)
94 }
95
96 fn get_amount_out(
97 &self,
98 amount_in: BigUint,
99 token_in: &Token,
100 token_out: &Token,
101 ) -> Result<GetAmountOutResult, SimulationError> {
102 self.valid_direction_guard(&token_in.address, &token_out.address)?;
103 self.valid_levels_guard()?;
104
105 let amount_in = amount_in.to_f64().ok_or_else(|| {
106 SimulationError::RecoverableError("Can't convert amount in to f64".into())
107 })? / 10f64.powi(token_in.decimals as i32);
108
109 let min_amount = self.levels.levels[0].quantity;
111 if amount_in < min_amount {
112 return Err(SimulationError::RecoverableError(format!(
113 "Amount below minimum. Input amount: {amount_in}, min amount: {min_amount}"
114 )));
115 }
116
117 let (amount_out, remaining_amount_in) = self
119 .levels
120 .get_amount_out_from_levels(amount_in);
121
122 let res = GetAmountOutResult {
123 amount: BigUint::from_f64(amount_out * 10f64.powi(token_out.decimals as i32))
124 .ok_or_else(|| {
125 SimulationError::RecoverableError("Can't convert amount out to BigUInt".into())
126 })?,
127 gas: BigUint::from(134_000u64), new_state: self.clone_box(), };
130
131 if remaining_amount_in > 0.0 {
132 return Err(SimulationError::InvalidInput(
133 format!("Pool has not enough liquidity to support complete swap. Input amount: {amount_in}, consumed amount: {}", amount_in-remaining_amount_in),
134 Some(res)));
135 }
136
137 Ok(res)
138 }
139
140 fn get_limits(
141 &self,
142 sell_token: Bytes,
143 buy_token: Bytes,
144 ) -> Result<(BigUint, BigUint), SimulationError> {
145 self.valid_direction_guard(&sell_token, &buy_token)?;
146 self.valid_levels_guard()?;
147
148 let sell_decimals = self.base_token.decimals;
149 let buy_decimals = self.quote_token.decimals;
150 let (total_sell_amount, total_buy_amount) =
151 self.levels
152 .levels
153 .iter()
154 .fold((0.0, 0.0), |(sell_sum, buy_sum), level| {
155 (sell_sum + level.quantity, buy_sum + level.quantity * level.price)
156 });
157
158 let sell_limit =
159 BigUint::from((total_sell_amount * 10_f64.pow(sell_decimals as f64)) as u128);
160 let buy_limit = BigUint::from((total_buy_amount * 10_f64.pow(buy_decimals as f64)) as u128);
161
162 Ok((sell_limit, buy_limit))
163 }
164
165 fn as_indicatively_priced(&self) -> Result<&dyn IndicativelyPriced, SimulationError> {
166 Ok(self)
167 }
168
169 fn delta_transition(
170 &mut self,
171 _delta: ProtocolStateDelta,
172 _tokens: &HashMap<Bytes, Token>,
173 _balances: &Balances,
174 ) -> Result<(), TransitionError> {
175 todo!()
176 }
177
178 fn clone_box(&self) -> Box<dyn ProtocolSim> {
179 Box::new(self.clone())
180 }
181
182 fn as_any(&self) -> &dyn Any {
183 self
184 }
185
186 fn as_any_mut(&mut self) -> &mut dyn Any {
187 self
188 }
189
190 fn eq(&self, other: &dyn ProtocolSim) -> bool {
191 if let Some(other_state) = other
192 .as_any()
193 .downcast_ref::<HashflowState>()
194 {
195 self.base_token == other_state.base_token &&
196 self.quote_token == other_state.quote_token &&
197 self.levels == other_state.levels
198 } else {
199 false
200 }
201 }
202}
203
204#[async_trait]
205impl IndicativelyPriced for HashflowState {
206 async fn request_signed_quote(
207 &self,
208 params: GetAmountOutParams,
209 ) -> Result<SignedQuote, SimulationError> {
210 Ok(self
211 .client
212 .request_binding_quote(¶ms)
213 .await?)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::{collections::HashSet, str::FromStr};
220
221 use tokio::time::Duration;
222 use tycho_common::models::Chain;
223
224 use super::*;
225 use crate::rfq::protocols::hashflow::models::{HashflowPair, HashflowPriceLevel};
226
227 fn wbtc() -> Token {
228 Token::new(
229 &hex::decode("2260fac5e5542a773aa44fbcfedf7c193bc2c599")
230 .unwrap()
231 .into(),
232 "WBTC",
233 8,
234 0,
235 &[Some(10_000)],
236 Chain::Ethereum,
237 100,
238 )
239 }
240
241 fn usdc() -> Token {
242 Token::new(
243 &hex::decode("a0b86991c6218a76c1d19d4a2e9eb0ce3606eb48")
244 .unwrap()
245 .into(),
246 "USDC",
247 6,
248 0,
249 &[Some(10_000)],
250 Chain::Ethereum,
251 100,
252 )
253 }
254
255 fn weth() -> Token {
256 Token::new(
257 &Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
258 "WETH",
259 18,
260 0,
261 &[],
262 Default::default(),
263 100,
264 )
265 }
266
267 fn empty_hashflow_client() -> HashflowClient {
268 HashflowClient::new(
269 Chain::Ethereum,
270 HashSet::new(),
271 0.0,
272 HashSet::new(),
273 "".to_string(),
274 "".to_string(),
275 Duration::from_secs(0),
276 Duration::from_secs(30),
277 )
278 .unwrap()
279 }
280
281 fn create_test_hashflow_state() -> HashflowState {
282 HashflowState {
283 base_token: weth(),
284 quote_token: usdc(),
285 levels: HashflowMarketMakerLevels {
286 pair: HashflowPair {
287 base_token: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2")
288 .unwrap(),
289 quote_token: Bytes::from_str("0xa0b86991c6218a76c1d19d4a2e9eb0ce3606eb48")
290 .unwrap(),
291 },
292 levels: vec![
293 HashflowPriceLevel { quantity: 0.5, price: 3000.0 },
294 HashflowPriceLevel { quantity: 1.5, price: 3000.0 },
295 HashflowPriceLevel { quantity: 5.0, price: 2999.0 },
296 ],
297 },
298 market_maker: "test_mm".to_string(),
299 client: empty_hashflow_client(),
300 }
301 }
302
303 mod spot_price {
304 use super::*;
305
306 #[test]
307 fn returns_best_price() {
308 let state = create_test_hashflow_state();
309 let price = state
310 .spot_price(&state.base_token, &state.quote_token)
311 .unwrap();
312 assert_eq!(price, 3000.0);
314 }
315
316 #[test]
317 fn returns_invalid_input_error() {
318 let state = create_test_hashflow_state();
319 let result = state.spot_price(&wbtc(), &usdc());
320 assert!(result.is_err());
321 if let Err(SimulationError::InvalidInput(msg, _)) = result {
322 assert!(msg.contains("Invalid token addresses"));
323 } else {
324 panic!("Expected InvalidInput");
325 }
326 }
327
328 #[test]
329 fn returns_no_liquidity_error() {
330 let mut state = create_test_hashflow_state();
331 state.levels.levels.clear();
332 let result = state.spot_price(&state.base_token, &state.quote_token);
333 assert!(result.is_err());
334 if let Err(SimulationError::RecoverableError(msg)) = result {
335 assert_eq!(msg, "No liquidity");
336 } else {
337 panic!("Expected RecoverableError");
338 }
339 }
340 }
341
342 mod get_amount_out {
343 use super::*;
344
345 #[test]
346 fn wbtc_to_usdc() {
347 let state = create_test_hashflow_state();
348
349 let amount_out_result = state
353 .get_amount_out(
354 BigUint::from_str("1500000000000000000").unwrap(), &weth(),
356 &usdc(),
357 )
358 .unwrap();
359
360 assert_eq!(amount_out_result.amount, BigUint::from_str("4500000000").unwrap()); assert_eq!(amount_out_result.gas, BigUint::from(134_000u64));
363 }
364
365 #[test]
366 fn usdc_to_wbtc() {
367 let state = create_test_hashflow_state();
368
369 let result = state.get_amount_out(
374 BigUint::from_str("10000000000").unwrap(), &usdc(),
376 &weth(),
377 );
378
379 assert!(result.is_err());
380 if let Err(SimulationError::InvalidInput(msg, ..)) = result {
381 assert!(msg.contains("Invalid token addresses"));
382 } else {
383 panic!("Expected InvalidInput");
384 }
385 }
386
387 #[test]
388 fn below_minimum() {
389 let state = create_test_hashflow_state();
390
391 let result = state.get_amount_out(
393 BigUint::from_str("250000000000000000").unwrap(), &weth(),
395 &usdc(),
396 );
397
398 assert!(result.is_err());
399 if let Err(SimulationError::RecoverableError(msg)) = result {
400 assert!(msg.contains("Amount below minimum"));
401 } else {
402 panic!("Expected RecoverableError");
403 }
404 }
405
406 #[test]
407 fn insufficient_liquidity() {
408 let state = create_test_hashflow_state();
409
410 let result = state.get_amount_out(
412 BigUint::from_str("8000000000000000000").unwrap(), &weth(),
414 &usdc(),
415 );
416
417 assert!(result.is_err());
418 if let Err(SimulationError::InvalidInput(msg, _)) = result {
419 assert!(msg.contains("Pool has not enough liquidity"));
420 } else {
421 panic!("Expected InvalidInput");
422 }
423 }
424
425 #[test]
426 fn invalid_token_pair() {
427 let state = create_test_hashflow_state();
428
429 let result = state.get_amount_out(
431 BigUint::from_str("100000000").unwrap(), &wbtc(),
433 &usdc(),
434 );
435
436 assert!(result.is_err());
437 if let Err(SimulationError::InvalidInput(msg, ..)) = result {
438 assert!(msg.contains("Invalid token addresses"));
439 } else {
440 panic!("Expected InvalidInput");
441 }
442 }
443
444 #[test]
445 fn no_liquidity() {
446 let mut state = create_test_hashflow_state();
447 state.levels.levels = vec![]; let result = state.get_amount_out(
450 BigUint::from_str("1000000000000000000").unwrap(), &weth(),
452 &usdc(),
453 );
454
455 assert!(result.is_err());
456 if let Err(SimulationError::RecoverableError(msg)) = result {
457 assert_eq!(msg, "No liquidity");
458 } else {
459 panic!("Expected RecoverableError");
460 }
461 }
462 }
463
464 mod get_limits {
465 use super::*;
466
467 #[test]
468 fn valid_limits() {
469 let state = create_test_hashflow_state();
470 let (sell_limit, buy_limit) = state
471 .get_limits(state.base_token.address.clone(), state.quote_token.address.clone())
472 .unwrap();
473
474 assert_eq!(sell_limit, BigUint::from((7.0 * 10f64.powi(18)) as u128));
477 assert_eq!(buy_limit, BigUint::from((20995.0 * 10f64.powi(6)) as u128));
478 }
479
480 #[test]
481 fn invalid_token_pair() {
482 let state = create_test_hashflow_state();
483 let result =
484 state.get_limits(wbtc().address.clone(), state.quote_token.address.clone());
485 assert!(result.is_err());
486 if let Err(SimulationError::InvalidInput(msg, _)) = result {
487 assert!(msg.contains("Invalid token addresses"));
488 } else {
489 panic!("Expected InvalidInput");
490 }
491 }
492
493 #[test]
494 fn no_liquidity() {
495 let mut state = create_test_hashflow_state();
496 state.levels.levels = vec![];
497 let result = state
498 .get_limits(state.base_token.address.clone(), state.quote_token.address.clone());
499 assert!(result.is_err());
500 if let Err(SimulationError::RecoverableError(msg)) = result {
501 assert_eq!(msg, "No liquidity");
502 } else {
503 panic!("Expected RecoverableError");
504 }
505 }
506 }
507}