Skip to main content

tycho_simulation/rfq/protocols/hashflow/
state.rs

1use std::{any::Any, collections::HashMap, fmt};
2
3use async_trait::async_trait;
4use num_bigint::BigUint;
5use num_traits::{FromPrimitive, Pow, ToPrimitive};
6use serde::{Deserialize, Serialize};
7use tycho_common::{
8    dto::ProtocolStateDelta,
9    models::{protocol::GetAmountOutParams, token::Token},
10    simulation::{
11        errors::{SimulationError, TransitionError},
12        indicatively_priced::{IndicativelyPriced, SignedQuote},
13        protocol_sim::{Balances, GetAmountOutResult, ProtocolSim},
14    },
15    Bytes,
16};
17
18use crate::rfq::{
19    client::RFQClient,
20    protocols::hashflow::{client::HashflowClient, models::HashflowMarketMakerLevels},
21};
22
23#[derive(Clone, Serialize, Deserialize)]
24pub struct HashflowState {
25    pub base_token: Token,
26    pub quote_token: Token,
27    pub levels: HashflowMarketMakerLevels,
28    pub market_maker: String,
29    pub client: HashflowClient,
30}
31
32impl fmt::Debug for HashflowState {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.debug_struct("HashflowState")
35            .field("base_token", &self.base_token)
36            .field("quote_token", &self.quote_token)
37            .field("market_maker", &self.market_maker)
38            .finish_non_exhaustive()
39    }
40}
41
42impl HashflowState {
43    pub fn new(
44        base_token: Token,
45        quote_token: Token,
46        levels: HashflowMarketMakerLevels,
47        market_maker: String,
48        client: HashflowClient,
49    ) -> Self {
50        Self { base_token, quote_token, levels, market_maker, client }
51    }
52
53    fn valid_direction_guard(
54        &self,
55        token_address_in: &Bytes,
56        token_address_out: &Bytes,
57    ) -> Result<(), SimulationError> {
58        // The current levels are only valid for the base/quote pair.
59        if !(token_address_in == &self.base_token.address &&
60            token_address_out == &self.quote_token.address)
61        {
62            Err(SimulationError::InvalidInput(
63                format!("Invalid token addresses. Got in={token_address_in}, out={token_address_out}, expected in={}, out={}", self.base_token.address, self.quote_token.address),
64                None,
65            ))
66        } else {
67            Ok(())
68        }
69    }
70
71    fn valid_levels_guard(&self) -> Result<(), SimulationError> {
72        if self.levels.levels.is_empty() {
73            return Err(SimulationError::RecoverableError("No liquidity".into()));
74        }
75        Ok(())
76    }
77}
78
79#[typetag::serde]
80impl ProtocolSim for HashflowState {
81    fn fee(&self) -> f64 {
82        todo!()
83    }
84
85    fn spot_price(&self, base: &Token, quote: &Token) -> Result<f64, SimulationError> {
86        self.valid_direction_guard(&base.address, &quote.address)?;
87
88        // Hashflow's levels are sorted by price, so the first level represents the best price.
89        self.levels
90            .levels
91            .first()
92            .ok_or(SimulationError::RecoverableError("No liquidity".into()))
93            .map(|level| level.price)
94    }
95
96    fn get_amount_out(
97        &self,
98        amount_in: BigUint,
99        token_in: &Token,
100        token_out: &Token,
101    ) -> Result<GetAmountOutResult, SimulationError> {
102        self.valid_direction_guard(&token_in.address, &token_out.address)?;
103        self.valid_levels_guard()?;
104
105        let amount_in = amount_in.to_f64().ok_or_else(|| {
106            SimulationError::RecoverableError("Can't convert amount in to f64".into())
107        })? / 10f64.powi(token_in.decimals as i32);
108
109        // First level represents the minimum amount that can be traded
110        let min_amount = self.levels.levels[0].quantity;
111        if amount_in < min_amount {
112            return Err(SimulationError::RecoverableError(format!(
113                "Amount below minimum. Input amount: {amount_in}, min amount: {min_amount}"
114            )));
115        }
116
117        // Calculate amount out
118        let (amount_out, remaining_amount_in) = self
119            .levels
120            .get_amount_out_from_levels(amount_in);
121
122        let res = GetAmountOutResult {
123            amount: BigUint::from_f64(amount_out * 10f64.powi(token_out.decimals as i32))
124                .ok_or_else(|| {
125                    SimulationError::RecoverableError("Can't convert amount out to BigUInt".into())
126                })?,
127            gas: BigUint::from(134_000u64), // Rough gas estimation
128            new_state: self.clone_box(),    // The state doesn't change after a swap
129        };
130
131        if remaining_amount_in > 0.0 {
132            return Err(SimulationError::InvalidInput(
133                format!("Pool has not enough liquidity to support complete swap. Input amount: {amount_in}, consumed amount: {}", amount_in-remaining_amount_in),
134                Some(res)));
135        }
136
137        Ok(res)
138    }
139
140    fn get_limits(
141        &self,
142        sell_token: Bytes,
143        buy_token: Bytes,
144    ) -> Result<(BigUint, BigUint), SimulationError> {
145        self.valid_direction_guard(&sell_token, &buy_token)?;
146        self.valid_levels_guard()?;
147
148        let sell_decimals = self.base_token.decimals;
149        let buy_decimals = self.quote_token.decimals;
150        let (total_sell_amount, total_buy_amount) =
151            self.levels
152                .levels
153                .iter()
154                .fold((0.0, 0.0), |(sell_sum, buy_sum), level| {
155                    (sell_sum + level.quantity, buy_sum + level.quantity * level.price)
156                });
157
158        let sell_limit =
159            BigUint::from((total_sell_amount * 10_f64.pow(sell_decimals as f64)) as u128);
160        let buy_limit = BigUint::from((total_buy_amount * 10_f64.pow(buy_decimals as f64)) as u128);
161
162        Ok((sell_limit, buy_limit))
163    }
164
165    fn as_indicatively_priced(&self) -> Result<&dyn IndicativelyPriced, SimulationError> {
166        Ok(self)
167    }
168
169    fn delta_transition(
170        &mut self,
171        _delta: ProtocolStateDelta,
172        _tokens: &HashMap<Bytes, Token>,
173        _balances: &Balances,
174    ) -> Result<(), TransitionError> {
175        todo!()
176    }
177
178    fn clone_box(&self) -> Box<dyn ProtocolSim> {
179        Box::new(self.clone())
180    }
181
182    fn as_any(&self) -> &dyn Any {
183        self
184    }
185
186    fn as_any_mut(&mut self) -> &mut dyn Any {
187        self
188    }
189
190    fn eq(&self, other: &dyn ProtocolSim) -> bool {
191        if let Some(other_state) = other
192            .as_any()
193            .downcast_ref::<HashflowState>()
194        {
195            self.base_token == other_state.base_token &&
196                self.quote_token == other_state.quote_token &&
197                self.levels == other_state.levels
198        } else {
199            false
200        }
201    }
202}
203
204#[async_trait]
205impl IndicativelyPriced for HashflowState {
206    async fn request_signed_quote(
207        &self,
208        params: GetAmountOutParams,
209    ) -> Result<SignedQuote, SimulationError> {
210        Ok(self
211            .client
212            .request_binding_quote(&params)
213            .await?)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::{collections::HashSet, str::FromStr};
220
221    use tokio::time::Duration;
222    use tycho_common::models::Chain;
223
224    use super::*;
225    use crate::rfq::protocols::hashflow::models::{HashflowPair, HashflowPriceLevel};
226
227    fn wbtc() -> Token {
228        Token::new(
229            &hex::decode("2260fac5e5542a773aa44fbcfedf7c193bc2c599")
230                .unwrap()
231                .into(),
232            "WBTC",
233            8,
234            0,
235            &[Some(10_000)],
236            Chain::Ethereum,
237            100,
238        )
239    }
240
241    fn usdc() -> Token {
242        Token::new(
243            &hex::decode("a0b86991c6218a76c1d19d4a2e9eb0ce3606eb48")
244                .unwrap()
245                .into(),
246            "USDC",
247            6,
248            0,
249            &[Some(10_000)],
250            Chain::Ethereum,
251            100,
252        )
253    }
254
255    fn weth() -> Token {
256        Token::new(
257            &Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
258            "WETH",
259            18,
260            0,
261            &[],
262            Default::default(),
263            100,
264        )
265    }
266
267    fn empty_hashflow_client() -> HashflowClient {
268        HashflowClient::new(
269            Chain::Ethereum,
270            HashSet::new(),
271            0.0,
272            HashSet::new(),
273            "".to_string(),
274            "".to_string(),
275            Duration::from_secs(0),
276            Duration::from_secs(30),
277        )
278        .unwrap()
279    }
280
281    fn create_test_hashflow_state() -> HashflowState {
282        HashflowState {
283            base_token: weth(),
284            quote_token: usdc(),
285            levels: HashflowMarketMakerLevels {
286                pair: HashflowPair {
287                    base_token: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2")
288                        .unwrap(),
289                    quote_token: Bytes::from_str("0xa0b86991c6218a76c1d19d4a2e9eb0ce3606eb48")
290                        .unwrap(),
291                },
292                levels: vec![
293                    HashflowPriceLevel { quantity: 0.5, price: 3000.0 },
294                    HashflowPriceLevel { quantity: 1.5, price: 3000.0 },
295                    HashflowPriceLevel { quantity: 5.0, price: 2999.0 },
296                ],
297            },
298            market_maker: "test_mm".to_string(),
299            client: empty_hashflow_client(),
300        }
301    }
302
303    mod spot_price {
304        use super::*;
305
306        #[test]
307        fn returns_best_price() {
308            let state = create_test_hashflow_state();
309            let price = state
310                .spot_price(&state.base_token, &state.quote_token)
311                .unwrap();
312            // The best price is the first level's price (3000.0)
313            assert_eq!(price, 3000.0);
314        }
315
316        #[test]
317        fn returns_invalid_input_error() {
318            let state = create_test_hashflow_state();
319            let result = state.spot_price(&wbtc(), &usdc());
320            assert!(result.is_err());
321            if let Err(SimulationError::InvalidInput(msg, _)) = result {
322                assert!(msg.contains("Invalid token addresses"));
323            } else {
324                panic!("Expected InvalidInput");
325            }
326        }
327
328        #[test]
329        fn returns_no_liquidity_error() {
330            let mut state = create_test_hashflow_state();
331            state.levels.levels.clear();
332            let result = state.spot_price(&state.base_token, &state.quote_token);
333            assert!(result.is_err());
334            if let Err(SimulationError::RecoverableError(msg)) = result {
335                assert_eq!(msg, "No liquidity");
336            } else {
337                panic!("Expected RecoverableError");
338            }
339        }
340    }
341
342    mod get_amount_out {
343        use super::*;
344
345        #[test]
346        fn wbtc_to_usdc() {
347            let state = create_test_hashflow_state();
348
349            // Test swapping 1.5 WETH -> USDC
350            // Should consume first level (0.5 WETH at 3000) + partial second level (1.0 WETH at
351            // 3000)
352            let amount_out_result = state
353                .get_amount_out(
354                    BigUint::from_str("1500000000000000000").unwrap(), // 1.5 WETH (18 decimals)
355                    &weth(),
356                    &usdc(),
357                )
358                .unwrap();
359
360            // Expected: (0.5 * 3000) + (1.0 * 3000) = 1500 + 3000 = 4500 USDC
361            assert_eq!(amount_out_result.amount, BigUint::from_str("4500000000").unwrap()); // 6 decimals
362            assert_eq!(amount_out_result.gas, BigUint::from(134_000u64));
363        }
364
365        #[test]
366        fn usdc_to_wbtc() {
367            let state = create_test_hashflow_state();
368
369            // Test swapping 10000 USDC -> WETH
370            // The price levels returned by Hashflow are only valid for the requested pair,
371            // and they can't be inverted to derive the reverse swap.
372            // In that case, we should return an error.
373            let result = state.get_amount_out(
374                BigUint::from_str("10000000000").unwrap(), // 10000 USDC (6 decimals)
375                &usdc(),
376                &weth(),
377            );
378
379            assert!(result.is_err());
380            if let Err(SimulationError::InvalidInput(msg, ..)) = result {
381                assert!(msg.contains("Invalid token addresses"));
382            } else {
383                panic!("Expected InvalidInput");
384            }
385        }
386
387        #[test]
388        fn below_minimum() {
389            let state = create_test_hashflow_state();
390
391            // Test with amount below minimum (first level quantity is 0.5 WETH)
392            let result = state.get_amount_out(
393                BigUint::from_str("250000000000000000").unwrap(), // 0.25 WETH (18 decimals)
394                &weth(),
395                &usdc(),
396            );
397
398            assert!(result.is_err());
399            if let Err(SimulationError::RecoverableError(msg)) = result {
400                assert!(msg.contains("Amount below minimum"));
401            } else {
402                panic!("Expected RecoverableError");
403            }
404        }
405
406        #[test]
407        fn insufficient_liquidity() {
408            let state = create_test_hashflow_state();
409
410            // Test with amount exceeding total liquidity (total is 7.0 WETH)
411            let result = state.get_amount_out(
412                BigUint::from_str("8000000000000000000").unwrap(), // 8.0 WETH (18 decimals)
413                &weth(),
414                &usdc(),
415            );
416
417            assert!(result.is_err());
418            if let Err(SimulationError::InvalidInput(msg, _)) = result {
419                assert!(msg.contains("Pool has not enough liquidity"));
420            } else {
421                panic!("Expected InvalidInput");
422            }
423        }
424
425        #[test]
426        fn invalid_token_pair() {
427            let state = create_test_hashflow_state();
428
429            // Test with invalid token pair (WBTC not in WETH/USDC pool)
430            let result = state.get_amount_out(
431                BigUint::from_str("100000000").unwrap(), // 1 WBTC
432                &wbtc(),
433                &usdc(),
434            );
435
436            assert!(result.is_err());
437            if let Err(SimulationError::InvalidInput(msg, ..)) = result {
438                assert!(msg.contains("Invalid token addresses"));
439            } else {
440                panic!("Expected InvalidInput");
441            }
442        }
443
444        #[test]
445        fn no_liquidity() {
446            let mut state = create_test_hashflow_state();
447            state.levels.levels = vec![]; // Remove all levels
448
449            let result = state.get_amount_out(
450                BigUint::from_str("1000000000000000000").unwrap(), // 1.0 WETH
451                &weth(),
452                &usdc(),
453            );
454
455            assert!(result.is_err());
456            if let Err(SimulationError::RecoverableError(msg)) = result {
457                assert_eq!(msg, "No liquidity");
458            } else {
459                panic!("Expected RecoverableError");
460            }
461        }
462    }
463
464    mod get_limits {
465        use super::*;
466
467        #[test]
468        fn valid_limits() {
469            let state = create_test_hashflow_state();
470            let (sell_limit, buy_limit) = state
471                .get_limits(state.base_token.address.clone(), state.quote_token.address.clone())
472                .unwrap();
473
474            // Total sell: 0.5 + 1.5 + 5.0 = 7.0 WETH (18 decimals)
475            // Total buy: (0.5+1.5)*3000 + 5.0*2999 = 20995 USDC (6 decimals)
476            assert_eq!(sell_limit, BigUint::from((7.0 * 10f64.powi(18)) as u128));
477            assert_eq!(buy_limit, BigUint::from((20995.0 * 10f64.powi(6)) as u128));
478        }
479
480        #[test]
481        fn invalid_token_pair() {
482            let state = create_test_hashflow_state();
483            let result =
484                state.get_limits(wbtc().address.clone(), state.quote_token.address.clone());
485            assert!(result.is_err());
486            if let Err(SimulationError::InvalidInput(msg, _)) = result {
487                assert!(msg.contains("Invalid token addresses"));
488            } else {
489                panic!("Expected InvalidInput");
490            }
491        }
492
493        #[test]
494        fn no_liquidity() {
495            let mut state = create_test_hashflow_state();
496            state.levels.levels = vec![];
497            let result = state
498                .get_limits(state.base_token.address.clone(), state.quote_token.address.clone());
499            assert!(result.is_err());
500            if let Err(SimulationError::RecoverableError(msg)) = result {
501                assert_eq!(msg, "No liquidity");
502            } else {
503                panic!("Expected RecoverableError");
504            }
505        }
506    }
507}