tycho_client/feed/
synchronizer.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4    time::Duration,
5};
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use tokio::{
10    select,
11    sync::{
12        mpsc::{channel, Receiver, Sender},
13        oneshot, Mutex,
14    },
15    task::JoinHandle,
16    time::timeout,
17};
18use tracing::{debug, error, info, instrument, trace, warn};
19use tycho_common::{
20    dto::{
21        BlockChanges, BlockParam, ExtractorIdentity, ProtocolComponent, ResponseAccount,
22        ResponseProtocolState, VersionParam,
23    },
24    Bytes,
25};
26
27use crate::{
28    deltas::{DeltasClient, SubscriptionOptions},
29    feed::{
30        component_tracker::{ComponentFilter, ComponentTracker},
31        Header,
32    },
33    rpc::RPCClient,
34};
35
36pub type SyncResult<T> = anyhow::Result<T>;
37
38#[derive(Clone)]
39pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
40    extractor_id: ExtractorIdentity,
41    retrieve_balances: bool,
42    rpc_client: R,
43    deltas_client: D,
44    max_retries: u64,
45    include_snapshots: bool,
46    component_tracker: Arc<Mutex<ComponentTracker<R>>>,
47    shared: Arc<Mutex<SharedState>>,
48    end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
49    timeout: u64,
50}
51
52#[derive(Debug, Default)]
53struct SharedState {
54    last_synced_block: Option<Header>,
55}
56
57#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
58pub struct ComponentWithState {
59    pub state: ResponseProtocolState,
60    pub component: ProtocolComponent,
61}
62
63#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
64pub struct Snapshot {
65    pub states: HashMap<String, ComponentWithState>,
66    pub vm_storage: HashMap<Bytes, ResponseAccount>,
67}
68
69impl Snapshot {
70    fn extend(&mut self, other: Snapshot) {
71        self.states.extend(other.states);
72        self.vm_storage.extend(other.vm_storage);
73    }
74
75    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
76        &self.states
77    }
78
79    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
80        &self.vm_storage
81    }
82}
83
84#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
85pub struct StateSyncMessage {
86    /// The block information for this update.
87    pub header: Header,
88    /// Snapshot for new components.
89    pub snapshots: Snapshot,
90    /// A single delta contains state updates for all tracked components, as well as additional
91    /// information about the system components e.g. newly added components (even below tvl), tvl
92    /// updates, balance updates.
93    pub deltas: Option<BlockChanges>,
94    /// Components that stopped being tracked.
95    pub removed_components: HashMap<String, ProtocolComponent>,
96}
97
98impl StateSyncMessage {
99    pub fn merge(mut self, other: Self) -> Self {
100        // be careful with removed and snapshots attributes here, these can be ambiguous.
101        self.removed_components
102            .retain(|k, _| !other.snapshots.states.contains_key(k));
103        self.snapshots
104            .states
105            .retain(|k, _| !other.removed_components.contains_key(k));
106
107        self.snapshots.extend(other.snapshots);
108        let deltas = match (self.deltas, other.deltas) {
109            (Some(l), Some(r)) => Some(l.merge(r)),
110            (None, Some(r)) => Some(r),
111            (Some(l), None) => Some(l),
112            (None, None) => None,
113        };
114        self.removed_components
115            .extend(other.removed_components);
116        Self {
117            header: other.header,
118            snapshots: self.snapshots,
119            deltas,
120            removed_components: self.removed_components,
121        }
122    }
123}
124
125/// StateSynchronizer
126///
127/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
128/// delivering messages to the client that let him reconstruct subsets of the protocol state.
129///
130/// This involves deciding which components to track according to the clients preferences,
131/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
132/// delivering delta messages for the components that have changed.
133#[async_trait]
134pub trait StateSynchronizer: Send + Sync + 'static {
135    async fn initialize(&self) -> SyncResult<()>;
136    /// Starts the state synchronization.
137    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)>;
138    /// Ends the sychronization loop.
139    async fn close(&mut self) -> SyncResult<()>;
140}
141
142impl<R, D> ProtocolStateSynchronizer<R, D>
143where
144    // TODO: Consider moving these constraints directly to the
145    // client...
146    R: RPCClient + Clone + Send + Sync + 'static,
147    D: DeltasClient + Clone + Send + Sync + 'static,
148{
149    /// Creates a new state synchronizer.
150    #[allow(clippy::too_many_arguments)]
151    pub fn new(
152        extractor_id: ExtractorIdentity,
153        retrieve_balances: bool,
154        component_filter: ComponentFilter,
155        max_retries: u64,
156        include_snapshots: bool,
157        rpc_client: R,
158        deltas_client: D,
159        timeout: u64,
160    ) -> Self {
161        Self {
162            extractor_id: extractor_id.clone(),
163            retrieve_balances,
164            rpc_client: rpc_client.clone(),
165            include_snapshots,
166            deltas_client,
167            component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
168                extractor_id.chain,
169                extractor_id.name.as_str(),
170                component_filter,
171                rpc_client,
172            ))),
173            max_retries,
174            shared: Arc::new(Mutex::new(SharedState::default())),
175            end_tx: Arc::new(Mutex::new(None)),
176            timeout,
177        }
178    }
179
180    /// Retrieves state snapshots of the requested components
181    ///
182    /// TODO:
183    /// Future considerations:
184    /// The current design separates the concepts of snapshots and deltas, therefore requiring us to
185    /// fetch data for snapshots that might already exist in the deltas messages. This is
186    /// unnecessary and could be optimized by removing snapshots entirely and only using deltas.
187    #[allow(deprecated)]
188    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
189        &self,
190        header: Header,
191        tracked_components: &ComponentTracker<R>,
192        ids: Option<I>,
193    ) -> SyncResult<StateSyncMessage> {
194        if !self.include_snapshots {
195            return Ok(StateSyncMessage { header, ..Default::default() });
196        }
197        let version = VersionParam::new(
198            None,
199            Some(BlockParam {
200                chain: Some(self.extractor_id.chain),
201                hash: None,
202                number: Some(header.number as i64),
203            }),
204        );
205
206        // Use given ids or use all if not passed
207        let request_ids = ids
208            .map(|it| {
209                it.into_iter()
210                    .cloned()
211                    .collect::<Vec<_>>()
212            })
213            .unwrap_or_else(|| tracked_components.get_tracked_component_ids());
214
215        let component_ids = request_ids
216            .iter()
217            .collect::<HashSet<_>>();
218
219        if component_ids.is_empty() {
220            return Ok(StateSyncMessage { header, ..Default::default() });
221        }
222
223        let mut protocol_states = self
224            .rpc_client
225            .get_protocol_states_paginated(
226                self.extractor_id.chain,
227                &request_ids,
228                &self.extractor_id.name,
229                self.retrieve_balances,
230                &version,
231                100,
232                4,
233            )
234            .await?
235            .states
236            .into_iter()
237            .map(|state| (state.component_id.clone(), state))
238            .collect::<HashMap<_, _>>();
239
240        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
241        let states = tracked_components
242            .components
243            .values()
244            .filter_map(|component| {
245                if let Some(state) = protocol_states.remove(&component.id) {
246                    Some((
247                        component.id.clone(),
248                        ComponentWithState { state, component: component.clone() },
249                    ))
250                } else if component_ids.contains(&&component.id) {
251                    // only emit error event if we requested this component
252                    let component_id = &component.id;
253                    error!(?component_id, "Missing state for native component!");
254                    None
255                } else {
256                    None
257                }
258            })
259            .collect();
260
261        let contract_ids = tracked_components.get_contracts_by_component(component_ids.clone());
262        let vm_storage = if !contract_ids.is_empty() {
263            let ids: Vec<Bytes> = contract_ids
264                .clone()
265                .into_iter()
266                .collect();
267            let contract_states = self
268                .rpc_client
269                .get_contract_state_paginated(
270                    self.extractor_id.chain,
271                    ids.as_slice(),
272                    &self.extractor_id.name,
273                    &version,
274                    100,
275                    4,
276                )
277                .await?
278                .accounts
279                .into_iter()
280                .map(|acc| (acc.address.clone(), acc))
281                .collect::<HashMap<_, _>>();
282
283            trace!(states=?&contract_states, "Retrieved ContractState");
284
285            let contract_address_to_components = tracked_components
286                .components
287                .iter()
288                .filter_map(|(id, comp)| {
289                    if component_ids.contains(&id) {
290                        Some(
291                            comp.contract_ids
292                                .iter()
293                                .map(|address| (address.clone(), comp.id.clone())),
294                        )
295                    } else {
296                        None
297                    }
298                })
299                .flatten()
300                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
301                    acc.entry(addr).or_default().push(c_id);
302                    acc
303                });
304
305            contract_ids
306                .iter()
307                .filter_map(|address| {
308                    if let Some(state) = contract_states.get(address) {
309                        Some((address.clone(), state.clone()))
310                    } else if let Some(ids) = contract_address_to_components.get(address) {
311                        // only emit error even if we did actually request this address
312                        error!(
313                            ?address,
314                            ?ids,
315                            "Component with lacking contract storage encountered!"
316                        );
317                        None
318                    } else {
319                        None
320                    }
321                })
322                .collect()
323        } else {
324            HashMap::new()
325        };
326
327        Ok(StateSyncMessage {
328            header,
329            snapshots: Snapshot { states, vm_storage },
330            deltas: None,
331            removed_components: HashMap::new(),
332        })
333    }
334
335    /// Main method that does all the work.
336    #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
337    async fn state_sync(self, block_tx: &mut Sender<StateSyncMessage>) -> SyncResult<()> {
338        // initialisation
339        let mut tracker = self.component_tracker.lock().await;
340
341        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
342        let (_, mut msg_rx) = self
343            .deltas_client
344            .subscribe(self.extractor_id.clone(), subscription_options)
345            .await?;
346
347        info!("Waiting for deltas...");
348        // wait for first deltas message
349        let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
350            .await?
351            .ok_or_else(|| anyhow::format_err!("Subscription ended too soon"))?;
352        self.filter_deltas(&mut first_msg, &tracker);
353
354        // initial snapshot
355        let block = first_msg.get_block().clone();
356        info!(height = &block.number, "Deltas received. Retrieving snapshot");
357        let header = Header::from_block(first_msg.get_block(), first_msg.is_revert());
358        let snapshot = self
359            .get_snapshots::<Vec<&String>>(Header::from_block(&block, false), &tracker, None)
360            .await
361            .map_err(|rpc_err| anyhow::format_err!("failed to get initial snapshot: {}", rpc_err))?
362            .merge(StateSyncMessage {
363                header: Header::from_block(first_msg.get_block(), first_msg.is_revert()),
364                snapshots: Default::default(),
365                deltas: Some(first_msg),
366                removed_components: Default::default(),
367            });
368
369        let n_components = tracker.components.len();
370        let n_snapshots = snapshot.snapshots.states.len();
371        info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
372
373        {
374            let mut shared = self.shared.lock().await;
375            block_tx.send(snapshot).await?;
376            shared.last_synced_block = Some(header.clone());
377        }
378
379        loop {
380            if let Some(mut deltas) = msg_rx.recv().await {
381                let header = Header::from_block(deltas.get_block(), deltas.is_revert());
382                debug!(block_number=?header.number, "Received delta message");
383                let (snapshots, removed_components) = {
384                    // 1. Remove components based on latest changes
385                    // 2. Add components based on latest changes, query those for snapshots
386                    let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
387
388                    // Only components we don't track yet need a snapshot,
389                    let requiring_snapshot: Vec<_> = to_add
390                        .iter()
391                        .filter(|id| {
392                            !tracker
393                                .components
394                                .contains_key(id.as_str())
395                        })
396                        .collect();
397                    debug!(components=?requiring_snapshot, "SnapshotRequest");
398                    tracker
399                        .start_tracking(requiring_snapshot.as_slice())
400                        .await?;
401                    let snapshots = self
402                        .get_snapshots(header.clone(), &tracker, Some(requiring_snapshot))
403                        .await?
404                        .snapshots;
405
406                    let removed_components = if !to_remove.is_empty() {
407                        tracker.stop_tracking(&to_remove)
408                    } else {
409                        Default::default()
410                    };
411                    (snapshots, removed_components)
412                };
413
414                // 3. Filter deltas by currently tracked components / contracts
415                self.filter_deltas(&mut deltas, &tracker);
416                let n_changes = deltas.n_changes();
417
418                let next = StateSyncMessage {
419                    header: header.clone(),
420                    snapshots,
421                    deltas: Some(deltas),
422                    removed_components,
423                };
424                block_tx.send(next).await?;
425                {
426                    let mut shared = self.shared.lock().await;
427                    shared.last_synced_block = Some(header.clone());
428                }
429
430                debug!(block_number=?header.number, n_changes, "Finished processing delta message");
431            } else {
432                let mut shared = self.shared.lock().await;
433                warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
434                shared.last_synced_block = None;
435
436                return Err(anyhow::format_err!("Deltas channel closed!"));
437            }
438        }
439    }
440
441    fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
442        second_msg.filter_by_component(|id| tracker.components.contains_key(id));
443        second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
444    }
445}
446
447#[async_trait]
448impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
449where
450    R: RPCClient + Clone + Send + Sync + 'static,
451    D: DeltasClient + Clone + Send + Sync + 'static,
452{
453    async fn initialize(&self) -> SyncResult<()> {
454        let mut tracker = self.component_tracker.lock().await;
455        info!("Retrieving relevant protocol components");
456        tracker.initialise_components().await?;
457        info!(
458            n_components = tracker.components.len(),
459            n_contracts = tracker.contracts.len(),
460            "Finished retrieving components",
461        );
462
463        Ok(())
464    }
465    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)> {
466        let (mut tx, rx) = channel(15);
467
468        let this = self.clone();
469        let jh = tokio::spawn(async move {
470            let mut retry_count = 0;
471            while retry_count < this.max_retries {
472                info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
473                let (end_tx, end_rx) = oneshot::channel::<()>();
474                {
475                    let mut end_tx_guard = this.end_tx.lock().await;
476                    *end_tx_guard = Some(end_tx);
477                }
478
479                select! {
480                    res = this.clone().state_sync(&mut tx) => {
481                        match  res
482                        {
483                            Err(e) => {
484                                error!(
485                                    extractor_id=%&this.extractor_id,
486                                    retry_count,
487                                    error=%e,
488                                    "State synchronization errored!"
489                                );
490                            }
491                            _ => {
492                                warn!(
493                                    extractor_id=%&this.extractor_id,
494                                    retry_count,
495                                    "State sync exited with Ok(())"
496                                );
497                            }
498                        }
499                    },
500                    _ = end_rx => {
501                        info!(
502                            extractor_id=%&this.extractor_id,
503                            retry_count,
504                            "StateSynchronizer received close signal. Stopping"
505                        );
506                        return Ok(())
507                    }
508                }
509                retry_count += 1;
510            }
511            Err(anyhow::format_err!("Max retries exceeded giving up"))
512        });
513
514        Ok((jh, rx))
515    }
516
517    async fn close(&mut self) -> SyncResult<()> {
518        let mut end_tx = self.end_tx.lock().await;
519        if let Some(tx) = end_tx.take() {
520            let _ = tx.send(());
521            Ok(())
522        } else {
523            Err(anyhow::format_err!("Not started"))
524        }
525    }
526}
527
528#[cfg(test)]
529mod test {
530    use test_log::test;
531    use tycho_common::dto::{
532        Block, Chain, PaginationResponse, ProtocolComponentRequestResponse,
533        ProtocolComponentsRequestBody, ProtocolStateRequestBody, ProtocolStateRequestResponse,
534        ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse, StateRequestBody,
535        StateRequestResponse, TokensRequestBody, TokensRequestResponse,
536    };
537    use uuid::Uuid;
538
539    use super::*;
540    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
541
542    // Required for mock client to implement clone
543    struct ArcRPCClient<T>(Arc<T>);
544
545    // Default derive(Clone) does require T to be Clone as well.
546    impl<T> Clone for ArcRPCClient<T> {
547        fn clone(&self) -> Self {
548            ArcRPCClient(self.0.clone())
549        }
550    }
551
552    #[async_trait]
553    impl<T> RPCClient for ArcRPCClient<T>
554    where
555        T: RPCClient + Sync + Send + 'static,
556    {
557        async fn get_tokens(
558            &self,
559            request: &TokensRequestBody,
560        ) -> Result<TokensRequestResponse, RPCError> {
561            self.0.get_tokens(request).await
562        }
563
564        async fn get_contract_state(
565            &self,
566            request: &StateRequestBody,
567        ) -> Result<StateRequestResponse, RPCError> {
568            self.0.get_contract_state(request).await
569        }
570
571        async fn get_protocol_components(
572            &self,
573            request: &ProtocolComponentsRequestBody,
574        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
575            self.0
576                .get_protocol_components(request)
577                .await
578        }
579
580        async fn get_protocol_states(
581            &self,
582            request: &ProtocolStateRequestBody,
583        ) -> Result<ProtocolStateRequestResponse, RPCError> {
584            self.0
585                .get_protocol_states(request)
586                .await
587        }
588
589        async fn get_protocol_systems(
590            &self,
591            request: &ProtocolSystemsRequestBody,
592        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
593            self.0
594                .get_protocol_systems(request)
595                .await
596        }
597    }
598
599    // Required for mock client to implement clone
600    struct ArcDeltasClient<T>(Arc<T>);
601
602    // Default derive(Clone) does require T to be Clone as well.
603    impl<T> Clone for ArcDeltasClient<T> {
604        fn clone(&self) -> Self {
605            ArcDeltasClient(self.0.clone())
606        }
607    }
608
609    #[async_trait]
610    impl<T> DeltasClient for ArcDeltasClient<T>
611    where
612        T: DeltasClient + Sync + Send + 'static,
613    {
614        async fn subscribe(
615            &self,
616            extractor_id: ExtractorIdentity,
617            options: SubscriptionOptions,
618        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
619            self.0
620                .subscribe(extractor_id, options)
621                .await
622        }
623
624        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
625            self.0
626                .unsubscribe(subscription_id)
627                .await
628        }
629
630        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
631            self.0.connect().await
632        }
633
634        async fn close(&self) -> Result<(), DeltasError> {
635            self.0.close().await
636        }
637    }
638
639    fn with_mocked_clients(
640        native: bool,
641        rpc_client: Option<MockRPCClient>,
642        deltas_client: Option<MockDeltasClient>,
643    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
644    {
645        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
646        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
647
648        ProtocolStateSynchronizer::new(
649            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
650            native,
651            ComponentFilter::with_tvl_range(50.0, 50.0),
652            1,
653            true,
654            rpc_client,
655            deltas_client,
656            10_u64,
657        )
658    }
659
660    fn state_snapshot_native() -> ProtocolStateRequestResponse {
661        ProtocolStateRequestResponse {
662            states: vec![ResponseProtocolState {
663                component_id: "Component1".to_string(),
664                ..Default::default()
665            }],
666            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
667        }
668    }
669
670    #[test(tokio::test)]
671    async fn test_get_snapshots_native() {
672        let header = Header::default();
673        let mut rpc = MockRPCClient::new();
674        rpc.expect_get_protocol_states()
675            .returning(|_| Ok(state_snapshot_native()));
676        let state_sync = with_mocked_clients(true, Some(rpc), None);
677        let mut tracker = ComponentTracker::new(
678            Chain::Ethereum,
679            "uniswap-v2",
680            ComponentFilter::with_tvl_range(0.0, 0.0),
681            state_sync.rpc_client.clone(),
682        );
683        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
684        tracker
685            .components
686            .insert("Component1".to_string(), component.clone());
687        let components_arg = ["Component1".to_string()];
688        let exp = StateSyncMessage {
689            header: header.clone(),
690            snapshots: Snapshot {
691                states: state_snapshot_native()
692                    .states
693                    .into_iter()
694                    .map(|state| {
695                        (
696                            state.component_id.clone(),
697                            ComponentWithState { state, component: component.clone() },
698                        )
699                    })
700                    .collect(),
701                vm_storage: HashMap::new(),
702            },
703            deltas: None,
704            removed_components: Default::default(),
705        };
706
707        let snap = state_sync
708            .get_snapshots(header, &tracker, Some(&components_arg))
709            .await
710            .expect("Retrieving snapshot failed");
711
712        assert_eq!(snap, exp);
713    }
714
715    fn state_snapshot_vm() -> StateRequestResponse {
716        StateRequestResponse {
717            accounts: vec![
718                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
719                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
720            ],
721            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
722        }
723    }
724
725    #[test(tokio::test)]
726    async fn test_get_snapshots_vm() {
727        let header = Header::default();
728        let mut rpc = MockRPCClient::new();
729        rpc.expect_get_protocol_states()
730            .returning(|_| Ok(state_snapshot_native()));
731        rpc.expect_get_contract_state()
732            .returning(|_| Ok(state_snapshot_vm()));
733        let state_sync = with_mocked_clients(false, Some(rpc), None);
734        let mut tracker = ComponentTracker::new(
735            Chain::Ethereum,
736            "uniswap-v2",
737            ComponentFilter::with_tvl_range(0.0, 0.0),
738            state_sync.rpc_client.clone(),
739        );
740        let component = ProtocolComponent {
741            id: "Component1".to_string(),
742            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
743            ..Default::default()
744        };
745        tracker
746            .components
747            .insert("Component1".to_string(), component.clone());
748        let components_arg = ["Component1".to_string()];
749        let exp = StateSyncMessage {
750            header: header.clone(),
751            snapshots: Snapshot {
752                states: [(
753                    component.id.clone(),
754                    ComponentWithState {
755                        state: ResponseProtocolState {
756                            component_id: "Component1".to_string(),
757                            ..Default::default()
758                        },
759                        component: component.clone(),
760                    },
761                )]
762                .into_iter()
763                .collect(),
764                vm_storage: state_snapshot_vm()
765                    .accounts
766                    .into_iter()
767                    .map(|state| (state.address.clone(), state))
768                    .collect(),
769            },
770            deltas: None,
771            removed_components: Default::default(),
772        };
773
774        let snap = state_sync
775            .get_snapshots(header, &tracker, Some(&components_arg))
776            .await
777            .expect("Retrieving snapshot failed");
778
779        assert_eq!(snap, exp);
780    }
781
782    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
783        let mut rpc_client = MockRPCClient::new();
784        // Mocks for the start_tracking call, these need to come first because they are more
785        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
786        rpc_client
787            .expect_get_protocol_components()
788            .with(mockall::predicate::function(
789                move |request_params: &ProtocolComponentsRequestBody| {
790                    if let Some(ids) = request_params.component_ids.as_ref() {
791                        ids.contains(&"Component3".to_string())
792                    } else {
793                        false
794                    }
795                },
796            ))
797            .returning(|_| {
798                // return Component3
799                Ok(ProtocolComponentRequestResponse {
800                    protocol_components: vec![
801                        // this component shall have a tvl update above threshold
802                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
803                    ],
804                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
805                })
806            });
807        rpc_client
808            .expect_get_protocol_states()
809            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
810                let expected_id = "Component3".to_string();
811                if let Some(ids) = request_params.protocol_ids.as_ref() {
812                    ids.contains(&expected_id)
813                } else {
814                    false
815                }
816            }))
817            .returning(|_| {
818                // return Component3 state
819                Ok(ProtocolStateRequestResponse {
820                    states: vec![ResponseProtocolState {
821                        component_id: "Component3".to_string(),
822                        ..Default::default()
823                    }],
824                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
825                })
826            });
827
828        // mock calls for the initial state snapshots
829        rpc_client
830            .expect_get_protocol_components()
831            .returning(|_| {
832                // Initial sync of components
833                Ok(ProtocolComponentRequestResponse {
834                    protocol_components: vec![
835                        // this component shall have a tvl update above threshold
836                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
837                        // this component shall have a tvl update below threshold.
838                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
839                        // a third component will have a tvl update above threshold
840                    ],
841                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
842                })
843            });
844        rpc_client
845            .expect_get_protocol_states()
846            .returning(|_| {
847                // Initial state snapshot
848                Ok(ProtocolStateRequestResponse {
849                    states: vec![
850                        ResponseProtocolState {
851                            component_id: "Component1".to_string(),
852                            ..Default::default()
853                        },
854                        ResponseProtocolState {
855                            component_id: "Component2".to_string(),
856                            ..Default::default()
857                        },
858                    ],
859                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
860                })
861            });
862        // Mock deltas client and messages
863        let mut deltas_client = MockDeltasClient::new();
864        let (tx, rx) = channel(1);
865        deltas_client
866            .expect_subscribe()
867            .return_once(move |_, _| {
868                // Return subscriber id and a channel
869                Ok((Uuid::default(), rx))
870            });
871        (rpc_client, deltas_client, tx)
872    }
873
874    /// Test strategy
875    ///
876    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
877    /// - send 2 dummy messages, containing only blocks
878    /// - third message contains a new component with some significant tvl, one initial component
879    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
880    #[test(tokio::test)]
881    async fn test_state_sync() {
882        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
883        let deltas = [
884            BlockChanges {
885                extractor: "uniswap-v2".to_string(),
886                chain: Chain::Ethereum,
887                block: Block {
888                    number: 1,
889                    hash: Bytes::from("0x01"),
890                    parent_hash: Bytes::from("0x00"),
891                    chain: Chain::Ethereum,
892                    ts: Default::default(),
893                },
894                revert: false,
895                ..Default::default()
896            },
897            BlockChanges {
898                extractor: "uniswap-v2".to_string(),
899                chain: Chain::Ethereum,
900                block: Block {
901                    number: 2,
902                    hash: Bytes::from("0x02"),
903                    parent_hash: Bytes::from("0x01"),
904                    chain: Chain::Ethereum,
905                    ts: Default::default(),
906                },
907                revert: false,
908                component_tvl: [
909                    ("Component1".to_string(), 100.0),
910                    ("Component2".to_string(), 0.0),
911                    ("Component3".to_string(), 1000.0),
912                ]
913                .into_iter()
914                .collect(),
915                ..Default::default()
916            },
917        ];
918        let mut state_sync = with_mocked_clients(true, Some(rpc_client), Some(deltas_client));
919        state_sync
920            .initialize()
921            .await
922            .expect("Init failed");
923
924        // Test starts here
925        let (jh, mut rx) = state_sync
926            .start()
927            .await
928            .expect("Failed to start state synchronizer");
929        tx.send(deltas[0].clone())
930            .await
931            .expect("deltas channel msg 0 closed!");
932        let first_msg = timeout(Duration::from_millis(100), rx.recv())
933            .await
934            .expect("waiting for first state msg timed out!")
935            .expect("state sync block sender closed!");
936        tx.send(deltas[1].clone())
937            .await
938            .expect("deltas channel msg 1 closed!");
939        let second_msg = timeout(Duration::from_millis(100), rx.recv())
940            .await
941            .expect("waiting for second state msg timed out!")
942            .expect("state sync block sender closed!");
943        let _ = state_sync.close().await;
944        let exit = jh
945            .await
946            .expect("state sync task panicked!");
947
948        // assertions
949        let exp1 = StateSyncMessage {
950            header: Header {
951                number: 1,
952                hash: Bytes::from("0x01"),
953                parent_hash: Bytes::from("0x00"),
954                revert: false,
955            },
956            snapshots: Snapshot {
957                states: [
958                    (
959                        "Component1".to_string(),
960                        ComponentWithState {
961                            state: ResponseProtocolState {
962                                component_id: "Component1".to_string(),
963                                ..Default::default()
964                            },
965                            component: ProtocolComponent {
966                                id: "Component1".to_string(),
967                                ..Default::default()
968                            },
969                        },
970                    ),
971                    (
972                        "Component2".to_string(),
973                        ComponentWithState {
974                            state: ResponseProtocolState {
975                                component_id: "Component2".to_string(),
976                                ..Default::default()
977                            },
978                            component: ProtocolComponent {
979                                id: "Component2".to_string(),
980                                ..Default::default()
981                            },
982                        },
983                    ),
984                ]
985                .into_iter()
986                .collect(),
987                vm_storage: HashMap::new(),
988            },
989            deltas: Some(deltas[0].clone()),
990            removed_components: Default::default(),
991        };
992
993        let exp2 = StateSyncMessage {
994            header: Header {
995                number: 2,
996                hash: Bytes::from("0x02"),
997                parent_hash: Bytes::from("0x01"),
998                revert: false,
999            },
1000            snapshots: Snapshot {
1001                states: [
1002                    // This is the new component we queried once it passed the tvl threshold.
1003                    (
1004                        "Component3".to_string(),
1005                        ComponentWithState {
1006                            state: ResponseProtocolState {
1007                                component_id: "Component3".to_string(),
1008                                ..Default::default()
1009                            },
1010                            component: ProtocolComponent {
1011                                id: "Component3".to_string(),
1012                                ..Default::default()
1013                            },
1014                        },
1015                    ),
1016                ]
1017                .into_iter()
1018                .collect(),
1019                vm_storage: HashMap::new(),
1020            },
1021            // Our deltas are empty and since merge methods are
1022            // tested in tycho-common we don't have much to do here.
1023            deltas: Some(BlockChanges {
1024                extractor: "uniswap-v2".to_string(),
1025                chain: Chain::Ethereum,
1026                block: Block {
1027                    number: 2,
1028                    hash: Bytes::from("0x02"),
1029                    parent_hash: Bytes::from("0x01"),
1030                    chain: Chain::Ethereum,
1031                    ts: Default::default(),
1032                },
1033                revert: false,
1034                component_tvl: [
1035                    // "Component2" should not show here.
1036                    ("Component1".to_string(), 100.0),
1037                    ("Component3".to_string(), 1000.0),
1038                ]
1039                .into_iter()
1040                .collect(),
1041                ..Default::default()
1042            }),
1043            // "Component2" was removed, because it's tvl changed to 0.
1044            removed_components: [(
1045                "Component2".to_string(),
1046                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1047            )]
1048            .into_iter()
1049            .collect(),
1050        };
1051        assert_eq!(first_msg, exp1);
1052        assert_eq!(second_msg, exp2);
1053        assert!(exit.is_ok());
1054    }
1055
1056    #[test(tokio::test)]
1057    async fn test_state_sync_with_tvl_range() {
1058        // Define the range for testing
1059        let remove_tvl_threshold = 5.0;
1060        let add_tvl_threshold = 7.0;
1061
1062        let mut rpc_client = MockRPCClient::new();
1063        let mut deltas_client = MockDeltasClient::new();
1064
1065        rpc_client
1066            .expect_get_protocol_components()
1067            .with(mockall::predicate::function(
1068                move |request_params: &ProtocolComponentsRequestBody| {
1069                    if let Some(ids) = request_params.component_ids.as_ref() {
1070                        ids.contains(&"Component3".to_string())
1071                    } else {
1072                        false
1073                    }
1074                },
1075            ))
1076            .returning(|_| {
1077                Ok(ProtocolComponentRequestResponse {
1078                    protocol_components: vec![ProtocolComponent {
1079                        id: "Component3".to_string(),
1080                        ..Default::default()
1081                    }],
1082                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1083                })
1084            });
1085
1086        rpc_client
1087            .expect_get_protocol_states()
1088            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1089                let expected_id = "Component3".to_string();
1090                if let Some(ids) = request_params.protocol_ids.as_ref() {
1091                    ids.contains(&expected_id)
1092                } else {
1093                    false
1094                }
1095            }))
1096            .returning(|_| {
1097                Ok(ProtocolStateRequestResponse {
1098                    states: vec![ResponseProtocolState {
1099                        component_id: "Component3".to_string(),
1100                        ..Default::default()
1101                    }],
1102                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1103                })
1104            });
1105
1106        // Mock for the initial snapshot retrieval
1107        rpc_client
1108            .expect_get_protocol_components()
1109            .returning(|_| {
1110                Ok(ProtocolComponentRequestResponse {
1111                    protocol_components: vec![
1112                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1113                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1114                    ],
1115                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1116                })
1117            });
1118
1119        rpc_client
1120            .expect_get_protocol_states()
1121            .returning(|_| {
1122                Ok(ProtocolStateRequestResponse {
1123                    states: vec![
1124                        ResponseProtocolState {
1125                            component_id: "Component1".to_string(),
1126                            ..Default::default()
1127                        },
1128                        ResponseProtocolState {
1129                            component_id: "Component2".to_string(),
1130                            ..Default::default()
1131                        },
1132                    ],
1133                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1134                })
1135            });
1136
1137        let (tx, rx) = channel(1);
1138        deltas_client
1139            .expect_subscribe()
1140            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1141
1142        let mut state_sync = ProtocolStateSynchronizer::new(
1143            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1144            true,
1145            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1146            1,
1147            true,
1148            ArcRPCClient(Arc::new(rpc_client)),
1149            ArcDeltasClient(Arc::new(deltas_client)),
1150            10_u64,
1151        );
1152        state_sync
1153            .initialize()
1154            .await
1155            .expect("Init failed");
1156
1157        // Simulate the incoming BlockChanges
1158        let deltas = [
1159            BlockChanges {
1160                extractor: "uniswap-v2".to_string(),
1161                chain: Chain::Ethereum,
1162                block: Block {
1163                    number: 1,
1164                    hash: Bytes::from("0x01"),
1165                    parent_hash: Bytes::from("0x00"),
1166                    chain: Chain::Ethereum,
1167                    ts: Default::default(),
1168                },
1169                revert: false,
1170                ..Default::default()
1171            },
1172            BlockChanges {
1173                extractor: "uniswap-v2".to_string(),
1174                chain: Chain::Ethereum,
1175                block: Block {
1176                    number: 2,
1177                    hash: Bytes::from("0x02"),
1178                    parent_hash: Bytes::from("0x01"),
1179                    chain: Chain::Ethereum,
1180                    ts: Default::default(),
1181                },
1182                revert: false,
1183                component_tvl: [
1184                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1185                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1186                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1187                ]
1188                .into_iter()
1189                .collect(),
1190                ..Default::default()
1191            },
1192        ];
1193
1194        let (jh, mut rx) = state_sync
1195            .start()
1196            .await
1197            .expect("Failed to start state synchronizer");
1198
1199        // Simulate sending delta messages
1200        tx.send(deltas[0].clone())
1201            .await
1202            .expect("deltas channel msg 0 closed!");
1203
1204        // Expecting to receive the initial state message
1205        let _ = timeout(Duration::from_millis(100), rx.recv())
1206            .await
1207            .expect("waiting for first state msg timed out!")
1208            .expect("state sync block sender closed!");
1209
1210        // Send the third message, which should trigger TVL-based changes
1211        tx.send(deltas[1].clone())
1212            .await
1213            .expect("deltas channel msg 1 closed!");
1214        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1215            .await
1216            .expect("waiting for second state msg timed out!")
1217            .expect("state sync block sender closed!");
1218
1219        let _ = state_sync.close().await;
1220        let exit = jh
1221            .await
1222            .expect("state sync task panicked!");
1223
1224        let expected_second_msg = StateSyncMessage {
1225            header: Header {
1226                number: 2,
1227                hash: Bytes::from("0x02"),
1228                parent_hash: Bytes::from("0x01"),
1229                revert: false,
1230            },
1231            snapshots: Snapshot {
1232                states: [(
1233                    "Component3".to_string(),
1234                    ComponentWithState {
1235                        state: ResponseProtocolState {
1236                            component_id: "Component3".to_string(),
1237                            ..Default::default()
1238                        },
1239                        component: ProtocolComponent {
1240                            id: "Component3".to_string(),
1241                            ..Default::default()
1242                        },
1243                    },
1244                )]
1245                .into_iter()
1246                .collect(),
1247                vm_storage: HashMap::new(),
1248            },
1249            deltas: Some(BlockChanges {
1250                extractor: "uniswap-v2".to_string(),
1251                chain: Chain::Ethereum,
1252                block: Block {
1253                    number: 2,
1254                    hash: Bytes::from("0x02"),
1255                    parent_hash: Bytes::from("0x01"),
1256                    chain: Chain::Ethereum,
1257                    ts: Default::default(),
1258                },
1259                revert: false,
1260                component_tvl: [
1261                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1262                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1263                ]
1264                .into_iter()
1265                .collect(),
1266                ..Default::default()
1267            }),
1268            removed_components: [(
1269                "Component2".to_string(),
1270                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1271            )]
1272            .into_iter()
1273            .collect(),
1274        };
1275
1276        assert_eq!(second_msg, expected_second_msg);
1277        assert!(exit.is_ok());
1278    }
1279}