tycho_simulation/rfq/
stream.rs

1use std::collections::HashMap;
2
3use futures::{stream::select_all, StreamExt};
4use tycho_client::feed::{synchronizer::ComponentWithState, FeedMessage};
5use tycho_common::{
6    models::token::Token,
7    simulation::{errors::SimulationError, protocol_sim::ProtocolSim},
8    Bytes,
9};
10
11use crate::{
12    evm::decoder::TychoStreamDecoder,
13    protocol::{
14        errors::InvalidSnapshotError,
15        models::{TryFromWithBlock, Update},
16    },
17    rfq::{client::RFQClient, models::TimestampHeader},
18};
19
20/// `RFQStreamBuilder` is a utility for constructing and managing a merged stream of RFQ (Request
21/// For Quote) providers in Tycho.
22///
23/// It allows you to:
24/// - Register multiple `RFQClient` implementations, each providing its own stream of RFQ price
25///   updates.
26/// - Dynamically decode incoming updates into `Update` objects using `TychoStreamDecoder`.
27///
28/// The `build` method consumes the builder and runs the event loop, sending decoded `Update`s
29/// through the provided `mpsc::Sender`. It returns an error if decoding an update or forwarding
30/// it to the channel fails.
31///
32/// ### Error Handling:
33/// - Each `RFQClient`'s stream is expected to yield `Result<(String, StateSyncMessage), RFQError>`.
34/// - If a client's stream returns an `Err` (e.g., `RFQError::FatalError`), the client is
35///   **removed** from the merged stream, and the system continues running without it.
36#[derive(Default)]
37pub struct RFQStreamBuilder {
38    clients: Vec<Box<dyn RFQClient>>,
39    decoder: TychoStreamDecoder<TimestampHeader>,
40}
41
42impl RFQStreamBuilder {
43    pub fn new() -> Self {
44        Self { clients: Vec::new(), decoder: TychoStreamDecoder::new() }
45    }
46
47    pub fn add_client<T>(mut self, name: &str, provider: Box<dyn RFQClient>) -> Self
48    where
49        T: ProtocolSim
50            + TryFromWithBlock<ComponentWithState, TimestampHeader, Error = InvalidSnapshotError>
51            + Send
52            + 'static,
53    {
54        self.clients.push(provider);
55        self.decoder.register_decoder::<T>(name);
56        self
57    }
58
59    pub async fn build(self, tx: tokio::sync::mpsc::Sender<Update>) -> Result<(), SimulationError> {
60        let streams: Vec<_> = self
61            .clients
62            .into_iter()
63            .map(|provider| provider.stream())
64            .collect();
65
66        let mut merged = select_all(streams);
67
68        while let Some(next) = merged.next().await {
69            match next {
70                Ok((provider, msg)) => {
71                    let update = self
72                        .decoder
73                        .decode(&FeedMessage {
74                            state_msgs: HashMap::from([(provider.clone(), msg)]),
75                            sync_states: HashMap::new(),
76                        })
77                        .await
78                        .map_err(|e| {
79                            SimulationError::RecoverableError(format!("Decoding error: {e}"))
80                        })?;
81                    tx.send(update).await.map_err(|e| {
82                        SimulationError::RecoverableError(format!(
83                            "Failed to send update through channel: {e}"
84                        ))
85                    })?;
86                }
87                Err(e) => {
88                    tracing::error!(
89                        "RFQ stream fatal error: {e}. Assuming this stream will not emit more messages."
90                    );
91                }
92            }
93        }
94
95        Ok(())
96    }
97
98    /// Sets the currently known tokens which to be considered during decoding.
99    ///
100    /// Protocol components containing tokens which are not included in this initial list, or
101    /// added when applying deltas, will not be decoded.
102    pub async fn set_tokens(self, tokens: HashMap<Bytes, Token>) -> Self {
103        self.decoder.set_tokens(tokens).await;
104        self
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::{any::Any, time::Duration};
111
112    use async_trait::async_trait;
113    use futures::stream::BoxStream;
114    use num_bigint::BigUint;
115    use tokio::sync::mpsc;
116    use tokio_stream::wrappers::IntervalStream;
117    use tycho_client::feed::synchronizer::{Snapshot, StateSyncMessage};
118    use tycho_common::{
119        dto::{ProtocolComponent, ProtocolStateDelta, ResponseProtocolState},
120        models::{protocol::GetAmountOutParams, token::Token},
121        simulation::{
122            errors::{SimulationError, TransitionError},
123            indicatively_priced::SignedQuote,
124            protocol_sim::{Balances, GetAmountOutResult},
125        },
126        Bytes,
127    };
128
129    use super::*;
130    use crate::{protocol::models::DecoderContext, rfq::errors::RFQError};
131
132    #[derive(Clone, Debug)]
133    pub struct DummyProtocol;
134
135    impl ProtocolSim for DummyProtocol {
136        fn fee(&self) -> f64 {
137            unimplemented!("Not needed for this test")
138        }
139
140        fn spot_price(&self, _base: &Token, _quote: &Token) -> Result<f64, SimulationError> {
141            unimplemented!("Not needed for this test")
142        }
143
144        fn get_amount_out(
145            &self,
146            _amount_in: BigUint,
147            _token_in: &Token,
148            _token_out: &Token,
149        ) -> Result<GetAmountOutResult, SimulationError> {
150            unimplemented!("Not needed for this test")
151        }
152
153        fn get_limits(
154            &self,
155            _sell_token: Bytes,
156            _buy_token: Bytes,
157        ) -> Result<(BigUint, BigUint), SimulationError> {
158            unimplemented!("Not needed for this test")
159        }
160
161        fn delta_transition(
162            &mut self,
163            _delta: ProtocolStateDelta,
164            _tokens: &HashMap<Bytes, Token>,
165            _balances: &Balances,
166        ) -> Result<(), TransitionError<String>> {
167            unimplemented!("Not needed for this test")
168        }
169
170        fn clone_box(&self) -> Box<dyn ProtocolSim> {
171            Box::new(self.clone())
172        }
173
174        fn as_any(&self) -> &dyn Any {
175            self
176        }
177
178        fn as_any_mut(&mut self) -> &mut dyn Any {
179            self
180        }
181        fn eq(&self, _other: &dyn ProtocolSim) -> bool {
182            unimplemented!("Not needed for this test")
183        }
184    }
185
186    impl TryFromWithBlock<ComponentWithState, TimestampHeader> for DummyProtocol {
187        type Error = InvalidSnapshotError;
188        async fn try_from_with_header(
189            _value: ComponentWithState,
190            _header: TimestampHeader,
191            _account_balances: &HashMap<Bytes, HashMap<Bytes, Bytes>>,
192            _all_tokens: &HashMap<Bytes, Token>,
193            _decoder_context: &DecoderContext,
194        ) -> Result<Self, Self::Error> {
195            Ok(DummyProtocol)
196        }
197    }
198
199    pub struct MockRFQClient {
200        name: String,
201        interval: Duration,
202        error_at_time: Option<u128>,
203    }
204
205    impl MockRFQClient {
206        pub fn new(name: &str, interval: Duration, error_at_time: Option<u128>) -> Self {
207            Self { name: name.to_string(), interval, error_at_time }
208        }
209    }
210
211    #[async_trait]
212    impl RFQClient for MockRFQClient {
213        fn stream(
214            &self,
215        ) -> BoxStream<'static, Result<(String, StateSyncMessage<TimestampHeader>), RFQError>>
216        {
217            let name = self.name.clone();
218            let error_at_time = self.error_at_time;
219            let mut current_time: u128 = 0;
220            let interval = self.interval;
221            let interval =
222                IntervalStream::new(tokio::time::interval(self.interval)).map(move |_| {
223                    if let Some(error_at_time) = error_at_time {
224                        if error_at_time == current_time {
225                            return Err(RFQError::FatalError(format!(
226                                "{name} stream is dying and can't go on"
227                            )))
228                        };
229                    };
230                    let protocol_component =
231                        ProtocolComponent { protocol_system: name.clone(), ..Default::default() };
232
233                    let snapshot = Snapshot {
234                        states: HashMap::from([(
235                            name.clone(),
236                            ComponentWithState {
237                                state: ResponseProtocolState {
238                                    component_id: name.clone(),
239                                    attributes: HashMap::new(),
240                                    balances: HashMap::new(),
241                                },
242                                component: protocol_component,
243                                component_tvl: None,
244                                entrypoints: vec![],
245                            },
246                        )]),
247                        vm_storage: HashMap::new(),
248                    };
249
250                    let msg = StateSyncMessage {
251                        header: TimestampHeader { timestamp: current_time as u64 },
252                        snapshots: snapshot,
253                        ..Default::default()
254                    };
255
256                    current_time += interval.as_millis();
257                    Ok((name.clone(), msg))
258                });
259            Box::pin(interval)
260        }
261
262        async fn request_binding_quote(
263            &self,
264            _params: &GetAmountOutParams,
265        ) -> Result<SignedQuote, RFQError> {
266            unimplemented!("Not needed for this test")
267        }
268    }
269
270    #[tokio::test]
271    async fn test_rfq_stream_builder() {
272        // This test has two mocked RFQ clients
273        // 1. Bebop client that emits a message every 100ms
274        // 2. Hashflow client that emits a message every 200m
275        let (tx, mut rx) = mpsc::channel::<Update>(10);
276
277        let builder = RFQStreamBuilder::new()
278            .add_client::<DummyProtocol>(
279                "bebop",
280                Box::new(MockRFQClient::new("bebop", Duration::from_millis(100), Some(300))),
281            )
282            .add_client::<DummyProtocol>(
283                "hashflow",
284                Box::new(MockRFQClient::new("hashflow", Duration::from_millis(200), None)),
285            );
286
287        tokio::spawn(builder.build(tx));
288
289        // Collect only the first 10 messages
290        let mut updates = Vec::new();
291        for _ in 0..6 {
292            let update = rx.recv().await.unwrap();
293            updates.push(update);
294        }
295
296        // Collect all timestamps per provider
297        let bebop_updates: Vec<_> = updates
298            .iter()
299            .filter(|u| u.new_pairs.contains_key("bebop"))
300            .collect();
301        let hashflow_updates: Vec<_> = updates
302            .iter()
303            .filter(|u| u.new_pairs.contains_key("hashflow"))
304            .collect();
305
306        assert_eq!(bebop_updates[0].block_number_or_timestamp, 0,);
307        assert_eq!(hashflow_updates[0].block_number_or_timestamp, 0,);
308        assert_eq!(bebop_updates[1].block_number_or_timestamp, 100);
309        assert_eq!(bebop_updates[2].block_number_or_timestamp, 200);
310        assert_eq!(hashflow_updates[1].block_number_or_timestamp, 200);
311        // At this point the bebop stream dies, and we shouldn't have any more bebop updates, only
312        // hashflow
313        assert_eq!(bebop_updates.len(), 3);
314        assert_eq!(hashflow_updates[2].block_number_or_timestamp, 400);
315    }
316}