tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot, Mutex,
11    },
12    task::JoinHandle,
13    time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        Header,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
81    extractor_id: ExtractorIdentity,
82    retrieve_balances: bool,
83    rpc_client: R,
84    deltas_client: D,
85    max_retries: u64,
86    include_snapshots: bool,
87    component_tracker: Arc<Mutex<ComponentTracker<R>>>,
88    shared: Arc<Mutex<SharedState>>,
89    end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
90    timeout: u64,
91    include_tvl: bool,
92}
93
94#[derive(Debug, Default)]
95struct SharedState {
96    last_synced_block: Option<Header>,
97}
98
99#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
100pub struct ComponentWithState {
101    pub state: ResponseProtocolState,
102    pub component: ProtocolComponent,
103    pub component_tvl: Option<f64>,
104    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
105}
106
107#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
108pub struct Snapshot {
109    pub states: HashMap<String, ComponentWithState>,
110    pub vm_storage: HashMap<Bytes, ResponseAccount>,
111}
112
113impl Snapshot {
114    fn extend(&mut self, other: Snapshot) {
115        self.states.extend(other.states);
116        self.vm_storage.extend(other.vm_storage);
117    }
118
119    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
120        &self.states
121    }
122
123    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
124        &self.vm_storage
125    }
126}
127
128#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
129pub struct StateSyncMessage {
130    /// The block information for this update.
131    pub header: Header,
132    /// Snapshot for new components.
133    pub snapshots: Snapshot,
134    /// A single delta contains state updates for all tracked components, as well as additional
135    /// information about the system components e.g. newly added components (even below tvl), tvl
136    /// updates, balance updates.
137    pub deltas: Option<BlockChanges>,
138    /// Components that stopped being tracked.
139    pub removed_components: HashMap<String, ProtocolComponent>,
140}
141
142impl StateSyncMessage {
143    pub fn merge(mut self, other: Self) -> Self {
144        // be careful with removed and snapshots attributes here, these can be ambiguous.
145        self.removed_components
146            .retain(|k, _| !other.snapshots.states.contains_key(k));
147        self.snapshots
148            .states
149            .retain(|k, _| !other.removed_components.contains_key(k));
150
151        self.snapshots.extend(other.snapshots);
152        let deltas = match (self.deltas, other.deltas) {
153            (Some(l), Some(r)) => Some(l.merge(r)),
154            (None, Some(r)) => Some(r),
155            (Some(l), None) => Some(l),
156            (None, None) => None,
157        };
158        self.removed_components
159            .extend(other.removed_components);
160        Self {
161            header: other.header,
162            snapshots: self.snapshots,
163            deltas,
164            removed_components: self.removed_components,
165        }
166    }
167}
168
169/// StateSynchronizer
170///
171/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
172/// delivering messages to the client that let him reconstruct subsets of the protocol state.
173///
174/// This involves deciding which components to track according to the clients preferences,
175/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
176/// delivering delta messages for the components that have changed.
177#[async_trait]
178pub trait StateSynchronizer: Send + Sync + 'static {
179    async fn initialize(&self) -> SyncResult<()>;
180    /// Starts the state synchronization.
181    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)>;
182    /// Ends the sychronization loop.
183    async fn close(&mut self) -> SyncResult<()>;
184}
185
186impl<R, D> ProtocolStateSynchronizer<R, D>
187where
188    // TODO: Consider moving these constraints directly to the
189    // client...
190    R: RPCClient + Clone + Send + Sync + 'static,
191    D: DeltasClient + Clone + Send + Sync + 'static,
192{
193    /// Creates a new state synchronizer.
194    #[allow(clippy::too_many_arguments)]
195    pub fn new(
196        extractor_id: ExtractorIdentity,
197        retrieve_balances: bool,
198        component_filter: ComponentFilter,
199        max_retries: u64,
200        include_snapshots: bool,
201        include_tvl: bool,
202        rpc_client: R,
203        deltas_client: D,
204        timeout: u64,
205    ) -> Self {
206        Self {
207            extractor_id: extractor_id.clone(),
208            retrieve_balances,
209            rpc_client: rpc_client.clone(),
210            include_snapshots,
211            deltas_client,
212            component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
213                extractor_id.chain,
214                extractor_id.name.as_str(),
215                component_filter,
216                rpc_client,
217            ))),
218            max_retries,
219            shared: Arc::new(Mutex::new(SharedState::default())),
220            end_tx: Arc::new(Mutex::new(None)),
221            timeout,
222            include_tvl,
223        }
224    }
225
226    /// Retrieves state snapshots of the requested components
227    #[allow(deprecated)]
228    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
229        &self,
230        header: Header,
231        tracked_components: &mut ComponentTracker<R>,
232        ids: Option<I>,
233    ) -> SyncResult<StateSyncMessage> {
234        if !self.include_snapshots {
235            return Ok(StateSyncMessage { header, ..Default::default() });
236        }
237        let version = VersionParam::new(
238            None,
239            Some(BlockParam {
240                chain: Some(self.extractor_id.chain),
241                hash: None,
242                number: Some(header.number as i64),
243            }),
244        );
245
246        // Use given ids or use all if not passed
247        let component_ids: Vec<_> = match ids {
248            Some(ids) => ids.into_iter().cloned().collect(),
249            None => tracked_components.get_tracked_component_ids(),
250        };
251
252        if component_ids.is_empty() {
253            return Ok(StateSyncMessage { header, ..Default::default() });
254        }
255
256        let component_tvl = if self.include_tvl {
257            let body = ComponentTvlRequestBody::id_filtered(
258                component_ids.clone(),
259                self.extractor_id.chain,
260            );
261            self.rpc_client
262                .get_component_tvl_paginated(&body, 100, 4)
263                .await?
264                .tvl
265        } else {
266            HashMap::new()
267        };
268
269        // Fetch entrypoints
270        let entrypoints_result = self
271            .rpc_client
272            .get_traced_entry_points_paginated(
273                self.extractor_id.chain,
274                &self.extractor_id.name,
275                &component_ids,
276                100,
277                4,
278            )
279            .await?;
280        tracked_components.process_entrypoints(&entrypoints_result.clone().into())?;
281
282        // Fetch protocol states
283        let mut protocol_states = self
284            .rpc_client
285            .get_protocol_states_paginated(
286                self.extractor_id.chain,
287                &component_ids,
288                &self.extractor_id.name,
289                self.retrieve_balances,
290                &version,
291                100,
292                4,
293            )
294            .await?
295            .states
296            .into_iter()
297            .map(|state| (state.component_id.clone(), state))
298            .collect::<HashMap<_, _>>();
299
300        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
301        let states = tracked_components
302            .components
303            .values()
304            .filter_map(|component| {
305                if let Some(state) = protocol_states.remove(&component.id) {
306                    Some((
307                        component.id.clone(),
308                        ComponentWithState {
309                            state,
310                            component: component.clone(),
311                            component_tvl: component_tvl
312                                .get(&component.id)
313                                .cloned(),
314                            entrypoints: entrypoints_result
315                                .traced_entry_points
316                                .get(&component.id)
317                                .cloned()
318                                .unwrap_or_default(),
319                        },
320                    ))
321                } else if component_ids.contains(&component.id) {
322                    // only emit error event if we requested this component
323                    let component_id = &component.id;
324                    error!(?component_id, "Missing state for native component!");
325                    None
326                } else {
327                    None
328                }
329            })
330            .collect();
331
332        // Fetch contract states
333        let contract_ids = tracked_components.get_contracts_by_component(&component_ids);
334        let vm_storage = if !contract_ids.is_empty() {
335            let ids: Vec<Bytes> = contract_ids
336                .clone()
337                .into_iter()
338                .collect();
339            let contract_states = self
340                .rpc_client
341                .get_contract_state_paginated(
342                    self.extractor_id.chain,
343                    ids.as_slice(),
344                    &self.extractor_id.name,
345                    &version,
346                    100,
347                    4,
348                )
349                .await?
350                .accounts
351                .into_iter()
352                .map(|acc| (acc.address.clone(), acc))
353                .collect::<HashMap<_, _>>();
354
355            trace!(states=?&contract_states, "Retrieved ContractState");
356
357            let contract_address_to_components = tracked_components
358                .components
359                .iter()
360                .filter_map(|(id, comp)| {
361                    if component_ids.contains(id) {
362                        Some(
363                            comp.contract_ids
364                                .iter()
365                                .map(|address| (address.clone(), comp.id.clone())),
366                        )
367                    } else {
368                        None
369                    }
370                })
371                .flatten()
372                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
373                    acc.entry(addr).or_default().push(c_id);
374                    acc
375                });
376
377            contract_ids
378                .iter()
379                .filter_map(|address| {
380                    if let Some(state) = contract_states.get(address) {
381                        Some((address.clone(), state.clone()))
382                    } else if let Some(ids) = contract_address_to_components.get(address) {
383                        // only emit error even if we did actually request this address
384                        error!(
385                            ?address,
386                            ?ids,
387                            "Component with lacking contract storage encountered!"
388                        );
389                        None
390                    } else {
391                        None
392                    }
393                })
394                .collect()
395        } else {
396            HashMap::new()
397        };
398
399        Ok(StateSyncMessage {
400            header,
401            snapshots: Snapshot { states, vm_storage },
402            deltas: None,
403            removed_components: HashMap::new(),
404        })
405    }
406
407    /// Main method that does all the work.
408    #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
409    async fn state_sync(self, block_tx: &mut Sender<StateSyncMessage>) -> SyncResult<()> {
410        // initialisation
411        let mut tracker = self.component_tracker.lock().await;
412
413        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
414        let (_, mut msg_rx) = self
415            .deltas_client
416            .subscribe(self.extractor_id.clone(), subscription_options)
417            .await?;
418
419        info!("Waiting for deltas...");
420        // wait for first deltas message
421        let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
422            .await
423            .map_err(|_| {
424                SynchronizerError::Timeout(format!(
425                    "First deltas took longer than {t}s to arrive",
426                    t = self.timeout
427                ))
428            })?
429            .ok_or_else(|| {
430                SynchronizerError::ConnectionError(
431                    "Deltas channel closed before first message".to_string(),
432                )
433            })?;
434        self.filter_deltas(&mut first_msg, &tracker);
435
436        // initial snapshot
437        let block = first_msg.get_block().clone();
438        info!(height = &block.number, "Deltas received. Retrieving snapshot");
439        let header = Header::from_block(first_msg.get_block(), first_msg.is_revert());
440        let snapshot = self
441            .get_snapshots::<Vec<&String>>(Header::from_block(&block, false), &mut tracker, None)
442            .await?
443            .merge(StateSyncMessage {
444                header: Header::from_block(first_msg.get_block(), first_msg.is_revert()),
445                snapshots: Default::default(),
446                deltas: Some(first_msg),
447                removed_components: Default::default(),
448            });
449
450        let n_components = tracker.components.len();
451        let n_snapshots = snapshot.snapshots.states.len();
452        info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
453
454        {
455            let mut shared = self.shared.lock().await;
456            block_tx.send(snapshot).await?;
457            shared.last_synced_block = Some(header.clone());
458        }
459
460        loop {
461            if let Some(mut deltas) = msg_rx.recv().await {
462                let header = Header::from_block(deltas.get_block(), deltas.is_revert());
463                debug!(block_number=?header.number, "Received delta message");
464
465                let (snapshots, removed_components) = {
466                    // 1. Remove components based on latest changes
467                    // 2. Add components based on latest changes, query those for snapshots
468                    let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
469
470                    // Only components we don't track yet need a snapshot,
471                    let requiring_snapshot: Vec<_> = to_add
472                        .iter()
473                        .filter(|id| {
474                            !tracker
475                                .components
476                                .contains_key(id.as_str())
477                        })
478                        .collect();
479                    debug!(components=?requiring_snapshot, "SnapshotRequest");
480                    tracker
481                        .start_tracking(requiring_snapshot.as_slice())
482                        .await?;
483                    let snapshots = self
484                        .get_snapshots(header.clone(), &mut tracker, Some(requiring_snapshot))
485                        .await?
486                        .snapshots;
487
488                    let removed_components = if !to_remove.is_empty() {
489                        tracker.stop_tracking(&to_remove)
490                    } else {
491                        Default::default()
492                    };
493
494                    (snapshots, removed_components)
495                };
496
497                // 3. Update entrypoints on the tracker (affects which contracts are tracked)
498                tracker.process_entrypoints(&deltas.dci_update)?;
499
500                // 4. Filter deltas by currently tracked components / contracts
501                self.filter_deltas(&mut deltas, &tracker);
502                let n_changes = deltas.n_changes();
503
504                // 5. Send the message
505                let next = StateSyncMessage {
506                    header: header.clone(),
507                    snapshots,
508                    deltas: Some(deltas),
509                    removed_components,
510                };
511                block_tx.send(next).await?;
512                {
513                    let mut shared = self.shared.lock().await;
514                    shared.last_synced_block = Some(header.clone());
515                }
516
517                debug!(block_number=?header.number, n_changes, "Finished processing delta message");
518            } else {
519                let mut shared = self.shared.lock().await;
520                warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
521                shared.last_synced_block = None;
522
523                return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
524            }
525        }
526    }
527
528    fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
529        second_msg.filter_by_component(|id| tracker.components.contains_key(id));
530        second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
531    }
532}
533
534#[async_trait]
535impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
536where
537    R: RPCClient + Clone + Send + Sync + 'static,
538    D: DeltasClient + Clone + Send + Sync + 'static,
539{
540    async fn initialize(&self) -> SyncResult<()> {
541        let mut tracker = self.component_tracker.lock().await;
542        info!("Retrieving relevant protocol components");
543        tracker.initialise_components().await?;
544        info!(
545            n_components = tracker.components.len(),
546            n_contracts = tracker.contracts.len(),
547            "Finished retrieving components",
548        );
549
550        Ok(())
551    }
552
553    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)> {
554        let (mut tx, rx) = channel(15);
555
556        let this = self.clone();
557        let jh = tokio::spawn(async move {
558            let mut retry_count = 0;
559            while retry_count < this.max_retries {
560                info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
561                let (end_tx, end_rx) = oneshot::channel::<()>();
562                {
563                    let mut end_tx_guard = this.end_tx.lock().await;
564                    *end_tx_guard = Some(end_tx);
565                }
566
567                select! {
568                    res = this.clone().state_sync(&mut tx) => {
569                        match res {
570                            Err(e) => {
571                                error!(
572                                    extractor_id=%&this.extractor_id,
573                                    retry_count,
574                                    error=%e,
575                                    "State synchronization errored!"
576                                );
577                                if let SynchronizerError::ConnectionClosed = e {
578                                    // break synchronization loop if connection is closed
579                                    return Err(e);
580                                }
581                            }
582                            _ => {
583                                warn!(
584                                    extractor_id=%&this.extractor_id,
585                                    retry_count,
586                                    "State synchronization exited with Ok(())"
587                                );
588                            }
589                        }
590                    },
591                    _ = end_rx => {
592                        info!(
593                            extractor_id=%&this.extractor_id,
594                            retry_count,
595                            "StateSynchronizer received close signal. Stopping"
596                        );
597                        return Ok(())
598                    }
599                }
600                retry_count += 1;
601            }
602            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
603        });
604
605        Ok((jh, rx))
606    }
607
608    async fn close(&mut self) -> SyncResult<()> {
609        let mut end_tx = self.end_tx.lock().await;
610        if let Some(tx) = end_tx.take() {
611            let _ = tx.send(());
612            Ok(())
613        } else {
614            Err(SynchronizerError::CloseError("Synchronizer not started".to_string()))
615        }
616    }
617}
618
619#[cfg(test)]
620mod test {
621    use std::collections::HashSet;
622
623    use test_log::test;
624    use tycho_common::dto::{
625        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
626        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
627        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
628        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
629        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
630        TracedEntryPointRequestResponse, TracingParams,
631    };
632    use uuid::Uuid;
633
634    use super::*;
635    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
636
637    // Required for mock client to implement clone
638    struct ArcRPCClient<T>(Arc<T>);
639
640    // Default derive(Clone) does require T to be Clone as well.
641    impl<T> Clone for ArcRPCClient<T> {
642        fn clone(&self) -> Self {
643            ArcRPCClient(self.0.clone())
644        }
645    }
646
647    #[async_trait]
648    impl<T> RPCClient for ArcRPCClient<T>
649    where
650        T: RPCClient + Sync + Send + 'static,
651    {
652        async fn get_tokens(
653            &self,
654            request: &TokensRequestBody,
655        ) -> Result<TokensRequestResponse, RPCError> {
656            self.0.get_tokens(request).await
657        }
658
659        async fn get_contract_state(
660            &self,
661            request: &StateRequestBody,
662        ) -> Result<StateRequestResponse, RPCError> {
663            self.0.get_contract_state(request).await
664        }
665
666        async fn get_protocol_components(
667            &self,
668            request: &ProtocolComponentsRequestBody,
669        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
670            self.0
671                .get_protocol_components(request)
672                .await
673        }
674
675        async fn get_protocol_states(
676            &self,
677            request: &ProtocolStateRequestBody,
678        ) -> Result<ProtocolStateRequestResponse, RPCError> {
679            self.0
680                .get_protocol_states(request)
681                .await
682        }
683
684        async fn get_protocol_systems(
685            &self,
686            request: &ProtocolSystemsRequestBody,
687        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
688            self.0
689                .get_protocol_systems(request)
690                .await
691        }
692
693        async fn get_component_tvl(
694            &self,
695            request: &ComponentTvlRequestBody,
696        ) -> Result<ComponentTvlRequestResponse, RPCError> {
697            self.0.get_component_tvl(request).await
698        }
699
700        async fn get_traced_entry_points(
701            &self,
702            request: &TracedEntryPointRequestBody,
703        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
704            self.0
705                .get_traced_entry_points(request)
706                .await
707        }
708    }
709
710    // Required for mock client to implement clone
711    struct ArcDeltasClient<T>(Arc<T>);
712
713    // Default derive(Clone) does require T to be Clone as well.
714    impl<T> Clone for ArcDeltasClient<T> {
715        fn clone(&self) -> Self {
716            ArcDeltasClient(self.0.clone())
717        }
718    }
719
720    #[async_trait]
721    impl<T> DeltasClient for ArcDeltasClient<T>
722    where
723        T: DeltasClient + Sync + Send + 'static,
724    {
725        async fn subscribe(
726            &self,
727            extractor_id: ExtractorIdentity,
728            options: SubscriptionOptions,
729        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
730            self.0
731                .subscribe(extractor_id, options)
732                .await
733        }
734
735        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
736            self.0
737                .unsubscribe(subscription_id)
738                .await
739        }
740
741        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
742            self.0.connect().await
743        }
744
745        async fn close(&self) -> Result<(), DeltasError> {
746            self.0.close().await
747        }
748    }
749
750    fn with_mocked_clients(
751        native: bool,
752        include_tvl: bool,
753        rpc_client: Option<MockRPCClient>,
754        deltas_client: Option<MockDeltasClient>,
755    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
756    {
757        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
758        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
759
760        ProtocolStateSynchronizer::new(
761            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
762            native,
763            ComponentFilter::with_tvl_range(50.0, 50.0),
764            1,
765            true,
766            include_tvl,
767            rpc_client,
768            deltas_client,
769            10_u64,
770        )
771    }
772
773    fn state_snapshot_native() -> ProtocolStateRequestResponse {
774        ProtocolStateRequestResponse {
775            states: vec![ResponseProtocolState {
776                component_id: "Component1".to_string(),
777                ..Default::default()
778            }],
779            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
780        }
781    }
782
783    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
784        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
785
786        ComponentTvlRequestResponse {
787            tvl,
788            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
789        }
790    }
791
792    #[test(tokio::test)]
793    async fn test_get_snapshots_native() {
794        let header = Header::default();
795        let mut rpc = MockRPCClient::new();
796        rpc.expect_get_protocol_states()
797            .returning(|_| Ok(state_snapshot_native()));
798        rpc.expect_get_traced_entry_points()
799            .returning(|_| {
800                Ok(TracedEntryPointRequestResponse {
801                    traced_entry_points: HashMap::new(),
802                    pagination: PaginationResponse::new(0, 20, 0),
803                })
804            });
805        let state_sync = with_mocked_clients(true, false, Some(rpc), None);
806        let mut tracker = ComponentTracker::new(
807            Chain::Ethereum,
808            "uniswap-v2",
809            ComponentFilter::with_tvl_range(0.0, 0.0),
810            state_sync.rpc_client.clone(),
811        );
812        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
813        tracker
814            .components
815            .insert("Component1".to_string(), component.clone());
816        let components_arg = ["Component1".to_string()];
817        let exp = StateSyncMessage {
818            header: header.clone(),
819            snapshots: Snapshot {
820                states: state_snapshot_native()
821                    .states
822                    .into_iter()
823                    .map(|state| {
824                        (
825                            state.component_id.clone(),
826                            ComponentWithState {
827                                state,
828                                component: component.clone(),
829                                entrypoints: vec![],
830                                component_tvl: None,
831                            },
832                        )
833                    })
834                    .collect(),
835                vm_storage: HashMap::new(),
836            },
837            deltas: None,
838            removed_components: Default::default(),
839        };
840
841        let snap = state_sync
842            .get_snapshots(header, &mut tracker, Some(&components_arg))
843            .await
844            .expect("Retrieving snapshot failed");
845
846        assert_eq!(snap, exp);
847    }
848
849    #[test(tokio::test)]
850    async fn test_get_snapshots_native_with_tvl() {
851        let header = Header::default();
852        let mut rpc = MockRPCClient::new();
853        rpc.expect_get_protocol_states()
854            .returning(|_| Ok(state_snapshot_native()));
855        rpc.expect_get_component_tvl()
856            .returning(|_| Ok(component_tvl_snapshot()));
857        rpc.expect_get_traced_entry_points()
858            .returning(|_| {
859                Ok(TracedEntryPointRequestResponse {
860                    traced_entry_points: HashMap::new(),
861                    pagination: PaginationResponse::new(0, 20, 0),
862                })
863            });
864        let state_sync = with_mocked_clients(true, true, Some(rpc), None);
865        let mut tracker = ComponentTracker::new(
866            Chain::Ethereum,
867            "uniswap-v2",
868            ComponentFilter::with_tvl_range(0.0, 0.0),
869            state_sync.rpc_client.clone(),
870        );
871        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
872        tracker
873            .components
874            .insert("Component1".to_string(), component.clone());
875        let components_arg = ["Component1".to_string()];
876        let exp = StateSyncMessage {
877            header: header.clone(),
878            snapshots: Snapshot {
879                states: state_snapshot_native()
880                    .states
881                    .into_iter()
882                    .map(|state| {
883                        (
884                            state.component_id.clone(),
885                            ComponentWithState {
886                                state,
887                                component: component.clone(),
888                                component_tvl: Some(100.0),
889                                entrypoints: vec![],
890                            },
891                        )
892                    })
893                    .collect(),
894                vm_storage: HashMap::new(),
895            },
896            deltas: None,
897            removed_components: Default::default(),
898        };
899
900        let snap = state_sync
901            .get_snapshots(header, &mut tracker, Some(&components_arg))
902            .await
903            .expect("Retrieving snapshot failed");
904
905        assert_eq!(snap, exp);
906    }
907
908    fn state_snapshot_vm() -> StateRequestResponse {
909        StateRequestResponse {
910            accounts: vec![
911                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
912                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
913            ],
914            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
915        }
916    }
917
918    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
919        TracedEntryPointRequestResponse {
920            traced_entry_points: HashMap::from([(
921                "Component1".to_string(),
922                vec![(
923                    EntryPointWithTracingParams {
924                        entry_point: EntryPoint {
925                            external_id: "entrypoint_a".to_string(),
926                            target: Bytes::from("0x0badc0ffee"),
927                            signature: "sig()".to_string(),
928                        },
929                        params: TracingParams::RPCTracer(RPCTracerParams {
930                            caller: Some(Bytes::from("0x0badc0ffee")),
931                            calldata: Bytes::from("0x0badc0ffee"),
932                        }),
933                    },
934                    TracingResult {
935                        retriggers: HashSet::from([(
936                            Bytes::from("0x0badc0ffee"),
937                            Bytes::from("0x0badc0ffee"),
938                        )]),
939                        accessed_slots: HashMap::from([(
940                            Bytes::from("0x0badc0ffee"),
941                            HashSet::from([Bytes::from("0xbadbeef0")]),
942                        )]),
943                    },
944                )],
945            )]),
946            pagination: PaginationResponse::new(0, 20, 0),
947        }
948    }
949
950    #[test(tokio::test)]
951    async fn test_get_snapshots_vm() {
952        let header = Header::default();
953        let mut rpc = MockRPCClient::new();
954        rpc.expect_get_protocol_states()
955            .returning(|_| Ok(state_snapshot_native()));
956        rpc.expect_get_contract_state()
957            .returning(|_| Ok(state_snapshot_vm()));
958        rpc.expect_get_traced_entry_points()
959            .returning(|_| Ok(traced_entry_point_response()));
960        let state_sync = with_mocked_clients(false, false, Some(rpc), None);
961        let mut tracker = ComponentTracker::new(
962            Chain::Ethereum,
963            "uniswap-v2",
964            ComponentFilter::with_tvl_range(0.0, 0.0),
965            state_sync.rpc_client.clone(),
966        );
967        let component = ProtocolComponent {
968            id: "Component1".to_string(),
969            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
970            ..Default::default()
971        };
972        tracker
973            .components
974            .insert("Component1".to_string(), component.clone());
975        let components_arg = ["Component1".to_string()];
976        let exp = StateSyncMessage {
977            header: header.clone(),
978            snapshots: Snapshot {
979                states: [(
980                    component.id.clone(),
981                    ComponentWithState {
982                        state: ResponseProtocolState {
983                            component_id: "Component1".to_string(),
984                            ..Default::default()
985                        },
986                        component: component.clone(),
987                        component_tvl: None,
988                        entrypoints: vec![(
989                            EntryPointWithTracingParams {
990                                entry_point: EntryPoint {
991                                    external_id: "entrypoint_a".to_string(),
992                                    target: Bytes::from("0x0badc0ffee"),
993                                    signature: "sig()".to_string(),
994                                },
995                                params: TracingParams::RPCTracer(RPCTracerParams {
996                                    caller: Some(Bytes::from("0x0badc0ffee")),
997                                    calldata: Bytes::from("0x0badc0ffee"),
998                                }),
999                            },
1000                            TracingResult {
1001                                retriggers: HashSet::from([(
1002                                    Bytes::from("0x0badc0ffee"),
1003                                    Bytes::from("0x0badc0ffee"),
1004                                )]),
1005                                accessed_slots: HashMap::from([(
1006                                    Bytes::from("0x0badc0ffee"),
1007                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1008                                )]),
1009                            },
1010                        )],
1011                    },
1012                )]
1013                .into_iter()
1014                .collect(),
1015                vm_storage: state_snapshot_vm()
1016                    .accounts
1017                    .into_iter()
1018                    .map(|state| (state.address.clone(), state))
1019                    .collect(),
1020            },
1021            deltas: None,
1022            removed_components: Default::default(),
1023        };
1024
1025        let snap = state_sync
1026            .get_snapshots(header, &mut tracker, Some(&components_arg))
1027            .await
1028            .expect("Retrieving snapshot failed");
1029
1030        assert_eq!(snap, exp);
1031    }
1032
1033    #[test(tokio::test)]
1034    async fn test_get_snapshots_vm_with_tvl() {
1035        let header = Header::default();
1036        let mut rpc = MockRPCClient::new();
1037        rpc.expect_get_protocol_states()
1038            .returning(|_| Ok(state_snapshot_native()));
1039        rpc.expect_get_contract_state()
1040            .returning(|_| Ok(state_snapshot_vm()));
1041        rpc.expect_get_component_tvl()
1042            .returning(|_| Ok(component_tvl_snapshot()));
1043        rpc.expect_get_traced_entry_points()
1044            .returning(|_| {
1045                Ok(TracedEntryPointRequestResponse {
1046                    traced_entry_points: HashMap::new(),
1047                    pagination: PaginationResponse::new(0, 20, 0),
1048                })
1049            });
1050        let state_sync = with_mocked_clients(false, true, Some(rpc), None);
1051        let mut tracker = ComponentTracker::new(
1052            Chain::Ethereum,
1053            "uniswap-v2",
1054            ComponentFilter::with_tvl_range(0.0, 0.0),
1055            state_sync.rpc_client.clone(),
1056        );
1057        let component = ProtocolComponent {
1058            id: "Component1".to_string(),
1059            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1060            ..Default::default()
1061        };
1062        tracker
1063            .components
1064            .insert("Component1".to_string(), component.clone());
1065        let components_arg = ["Component1".to_string()];
1066        let exp = StateSyncMessage {
1067            header: header.clone(),
1068            snapshots: Snapshot {
1069                states: [(
1070                    component.id.clone(),
1071                    ComponentWithState {
1072                        state: ResponseProtocolState {
1073                            component_id: "Component1".to_string(),
1074                            ..Default::default()
1075                        },
1076                        component: component.clone(),
1077                        component_tvl: Some(100.0),
1078                        entrypoints: vec![],
1079                    },
1080                )]
1081                .into_iter()
1082                .collect(),
1083                vm_storage: state_snapshot_vm()
1084                    .accounts
1085                    .into_iter()
1086                    .map(|state| (state.address.clone(), state))
1087                    .collect(),
1088            },
1089            deltas: None,
1090            removed_components: Default::default(),
1091        };
1092
1093        let snap = state_sync
1094            .get_snapshots(header, &mut tracker, Some(&components_arg))
1095            .await
1096            .expect("Retrieving snapshot failed");
1097
1098        assert_eq!(snap, exp);
1099    }
1100
1101    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1102        let mut rpc_client = MockRPCClient::new();
1103        // Mocks for the start_tracking call, these need to come first because they are more
1104        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1105        rpc_client
1106            .expect_get_protocol_components()
1107            .with(mockall::predicate::function(
1108                move |request_params: &ProtocolComponentsRequestBody| {
1109                    if let Some(ids) = request_params.component_ids.as_ref() {
1110                        ids.contains(&"Component3".to_string())
1111                    } else {
1112                        false
1113                    }
1114                },
1115            ))
1116            .returning(|_| {
1117                // return Component3
1118                Ok(ProtocolComponentRequestResponse {
1119                    protocol_components: vec![
1120                        // this component shall have a tvl update above threshold
1121                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1122                    ],
1123                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1124                })
1125            });
1126        rpc_client
1127            .expect_get_protocol_states()
1128            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1129                let expected_id = "Component3".to_string();
1130                if let Some(ids) = request_params.protocol_ids.as_ref() {
1131                    ids.contains(&expected_id)
1132                } else {
1133                    false
1134                }
1135            }))
1136            .returning(|_| {
1137                // return Component3 state
1138                Ok(ProtocolStateRequestResponse {
1139                    states: vec![ResponseProtocolState {
1140                        component_id: "Component3".to_string(),
1141                        ..Default::default()
1142                    }],
1143                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1144                })
1145            });
1146
1147        // mock calls for the initial state snapshots
1148        rpc_client
1149            .expect_get_protocol_components()
1150            .returning(|_| {
1151                // Initial sync of components
1152                Ok(ProtocolComponentRequestResponse {
1153                    protocol_components: vec![
1154                        // this component shall have a tvl update above threshold
1155                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1156                        // this component shall have a tvl update below threshold.
1157                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1158                        // a third component will have a tvl update above threshold
1159                    ],
1160                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1161                })
1162            });
1163        rpc_client
1164            .expect_get_protocol_states()
1165            .returning(|_| {
1166                // Initial state snapshot
1167                Ok(ProtocolStateRequestResponse {
1168                    states: vec![
1169                        ResponseProtocolState {
1170                            component_id: "Component1".to_string(),
1171                            ..Default::default()
1172                        },
1173                        ResponseProtocolState {
1174                            component_id: "Component2".to_string(),
1175                            ..Default::default()
1176                        },
1177                    ],
1178                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1179                })
1180            });
1181        rpc_client
1182            .expect_get_component_tvl()
1183            .returning(|_| {
1184                Ok(ComponentTvlRequestResponse {
1185                    tvl: [
1186                        ("Component1".to_string(), 100.0),
1187                        ("Component2".to_string(), 0.0),
1188                        ("Component3".to_string(), 1000.0),
1189                    ]
1190                    .into_iter()
1191                    .collect(),
1192                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1193                })
1194            });
1195        rpc_client
1196            .expect_get_traced_entry_points()
1197            .returning(|_| {
1198                Ok(TracedEntryPointRequestResponse {
1199                    traced_entry_points: HashMap::new(),
1200                    pagination: PaginationResponse::new(0, 20, 0),
1201                })
1202            });
1203
1204        // Mock deltas client and messages
1205        let mut deltas_client = MockDeltasClient::new();
1206        let (tx, rx) = channel(1);
1207        deltas_client
1208            .expect_subscribe()
1209            .return_once(move |_, _| {
1210                // Return subscriber id and a channel
1211                Ok((Uuid::default(), rx))
1212            });
1213        (rpc_client, deltas_client, tx)
1214    }
1215
1216    /// Test strategy
1217    ///
1218    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1219    /// - send 2 dummy messages, containing only blocks
1220    /// - third message contains a new component with some significant tvl, one initial component
1221    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1222    #[test(tokio::test)]
1223    async fn test_state_sync() {
1224        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1225        let deltas = [
1226            BlockChanges {
1227                extractor: "uniswap-v2".to_string(),
1228                chain: Chain::Ethereum,
1229                block: Block {
1230                    number: 1,
1231                    hash: Bytes::from("0x01"),
1232                    parent_hash: Bytes::from("0x00"),
1233                    chain: Chain::Ethereum,
1234                    ts: Default::default(),
1235                },
1236                revert: false,
1237                dci_update: DCIUpdate {
1238                    new_entrypoints: HashMap::from([(
1239                        "Component1".to_string(),
1240                        HashSet::from([EntryPoint {
1241                            external_id: "entrypoint_a".to_string(),
1242                            target: Bytes::from("0x0badc0ffee"),
1243                            signature: "sig()".to_string(),
1244                        }]),
1245                    )]),
1246                    new_entrypoint_params: HashMap::from([(
1247                        "entrypoint_a".to_string(),
1248                        HashSet::from([(
1249                            TracingParams::RPCTracer(RPCTracerParams {
1250                                caller: Some(Bytes::from("0x0badc0ffee")),
1251                                calldata: Bytes::from("0x0badc0ffee"),
1252                            }),
1253                            Some("Component1".to_string()),
1254                        )]),
1255                    )]),
1256                    trace_results: HashMap::from([(
1257                        "entrypoint_a".to_string(),
1258                        TracingResult {
1259                            retriggers: HashSet::from([(
1260                                Bytes::from("0x0badc0ffee"),
1261                                Bytes::from("0x0badc0ffee"),
1262                            )]),
1263                            accessed_slots: HashMap::from([(
1264                                Bytes::from("0x0badc0ffee"),
1265                                HashSet::from([Bytes::from("0xbadbeef0")]),
1266                            )]),
1267                        },
1268                    )]),
1269                },
1270                ..Default::default()
1271            },
1272            BlockChanges {
1273                extractor: "uniswap-v2".to_string(),
1274                chain: Chain::Ethereum,
1275                block: Block {
1276                    number: 2,
1277                    hash: Bytes::from("0x02"),
1278                    parent_hash: Bytes::from("0x01"),
1279                    chain: Chain::Ethereum,
1280                    ts: Default::default(),
1281                },
1282                revert: false,
1283                component_tvl: [
1284                    ("Component1".to_string(), 100.0),
1285                    ("Component2".to_string(), 0.0),
1286                    ("Component3".to_string(), 1000.0),
1287                ]
1288                .into_iter()
1289                .collect(),
1290                ..Default::default()
1291            },
1292        ];
1293        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1294        state_sync
1295            .initialize()
1296            .await
1297            .expect("Init failed");
1298
1299        // Test starts here
1300        let (jh, mut rx) = state_sync
1301            .start()
1302            .await
1303            .expect("Failed to start state synchronizer");
1304        tx.send(deltas[0].clone())
1305            .await
1306            .expect("deltas channel msg 0 closed!");
1307        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1308            .await
1309            .expect("waiting for first state msg timed out!")
1310            .expect("state sync block sender closed!");
1311        tx.send(deltas[1].clone())
1312            .await
1313            .expect("deltas channel msg 1 closed!");
1314        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1315            .await
1316            .expect("waiting for second state msg timed out!")
1317            .expect("state sync block sender closed!");
1318        let _ = state_sync.close().await;
1319        let exit = jh
1320            .await
1321            .expect("state sync task panicked!");
1322
1323        // assertions
1324        let exp1 = StateSyncMessage {
1325            header: Header {
1326                number: 1,
1327                hash: Bytes::from("0x01"),
1328                parent_hash: Bytes::from("0x00"),
1329                revert: false,
1330            },
1331            snapshots: Snapshot {
1332                states: [
1333                    (
1334                        "Component1".to_string(),
1335                        ComponentWithState {
1336                            state: ResponseProtocolState {
1337                                component_id: "Component1".to_string(),
1338                                ..Default::default()
1339                            },
1340                            component: ProtocolComponent {
1341                                id: "Component1".to_string(),
1342                                ..Default::default()
1343                            },
1344                            component_tvl: Some(100.0),
1345                            entrypoints: vec![],
1346                        },
1347                    ),
1348                    (
1349                        "Component2".to_string(),
1350                        ComponentWithState {
1351                            state: ResponseProtocolState {
1352                                component_id: "Component2".to_string(),
1353                                ..Default::default()
1354                            },
1355                            component: ProtocolComponent {
1356                                id: "Component2".to_string(),
1357                                ..Default::default()
1358                            },
1359                            component_tvl: Some(0.0),
1360                            entrypoints: vec![],
1361                        },
1362                    ),
1363                ]
1364                .into_iter()
1365                .collect(),
1366                vm_storage: HashMap::new(),
1367            },
1368            deltas: Some(deltas[0].clone()),
1369            removed_components: Default::default(),
1370        };
1371
1372        let exp2 = StateSyncMessage {
1373            header: Header {
1374                number: 2,
1375                hash: Bytes::from("0x02"),
1376                parent_hash: Bytes::from("0x01"),
1377                revert: false,
1378            },
1379            snapshots: Snapshot {
1380                states: [
1381                    // This is the new component we queried once it passed the tvl threshold.
1382                    (
1383                        "Component3".to_string(),
1384                        ComponentWithState {
1385                            state: ResponseProtocolState {
1386                                component_id: "Component3".to_string(),
1387                                ..Default::default()
1388                            },
1389                            component: ProtocolComponent {
1390                                id: "Component3".to_string(),
1391                                ..Default::default()
1392                            },
1393                            component_tvl: Some(1000.0),
1394                            entrypoints: vec![],
1395                        },
1396                    ),
1397                ]
1398                .into_iter()
1399                .collect(),
1400                vm_storage: HashMap::new(),
1401            },
1402            // Our deltas are empty and since merge methods are
1403            // tested in tycho-common we don't have much to do here.
1404            deltas: Some(BlockChanges {
1405                extractor: "uniswap-v2".to_string(),
1406                chain: Chain::Ethereum,
1407                block: Block {
1408                    number: 2,
1409                    hash: Bytes::from("0x02"),
1410                    parent_hash: Bytes::from("0x01"),
1411                    chain: Chain::Ethereum,
1412                    ts: Default::default(),
1413                },
1414                revert: false,
1415                component_tvl: [
1416                    // "Component2" should not show here.
1417                    ("Component1".to_string(), 100.0),
1418                    ("Component3".to_string(), 1000.0),
1419                ]
1420                .into_iter()
1421                .collect(),
1422                ..Default::default()
1423            }),
1424            // "Component2" was removed, because it's tvl changed to 0.
1425            removed_components: [(
1426                "Component2".to_string(),
1427                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1428            )]
1429            .into_iter()
1430            .collect(),
1431        };
1432        assert_eq!(first_msg, exp1);
1433        assert_eq!(second_msg, exp2);
1434        assert!(exit.is_ok());
1435    }
1436
1437    #[test(tokio::test)]
1438    async fn test_state_sync_with_tvl_range() {
1439        // Define the range for testing
1440        let remove_tvl_threshold = 5.0;
1441        let add_tvl_threshold = 7.0;
1442
1443        let mut rpc_client = MockRPCClient::new();
1444        let mut deltas_client = MockDeltasClient::new();
1445
1446        rpc_client
1447            .expect_get_protocol_components()
1448            .with(mockall::predicate::function(
1449                move |request_params: &ProtocolComponentsRequestBody| {
1450                    if let Some(ids) = request_params.component_ids.as_ref() {
1451                        ids.contains(&"Component3".to_string())
1452                    } else {
1453                        false
1454                    }
1455                },
1456            ))
1457            .returning(|_| {
1458                Ok(ProtocolComponentRequestResponse {
1459                    protocol_components: vec![ProtocolComponent {
1460                        id: "Component3".to_string(),
1461                        ..Default::default()
1462                    }],
1463                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1464                })
1465            });
1466        rpc_client
1467            .expect_get_protocol_states()
1468            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1469                let expected_id = "Component3".to_string();
1470                if let Some(ids) = request_params.protocol_ids.as_ref() {
1471                    ids.contains(&expected_id)
1472                } else {
1473                    false
1474                }
1475            }))
1476            .returning(|_| {
1477                Ok(ProtocolStateRequestResponse {
1478                    states: vec![ResponseProtocolState {
1479                        component_id: "Component3".to_string(),
1480                        ..Default::default()
1481                    }],
1482                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1483                })
1484            });
1485
1486        // Mock for the initial snapshot retrieval
1487        rpc_client
1488            .expect_get_protocol_components()
1489            .returning(|_| {
1490                Ok(ProtocolComponentRequestResponse {
1491                    protocol_components: vec![
1492                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1493                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1494                    ],
1495                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1496                })
1497            });
1498        rpc_client
1499            .expect_get_protocol_states()
1500            .returning(|_| {
1501                Ok(ProtocolStateRequestResponse {
1502                    states: vec![
1503                        ResponseProtocolState {
1504                            component_id: "Component1".to_string(),
1505                            ..Default::default()
1506                        },
1507                        ResponseProtocolState {
1508                            component_id: "Component2".to_string(),
1509                            ..Default::default()
1510                        },
1511                    ],
1512                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1513                })
1514            });
1515        rpc_client
1516            .expect_get_traced_entry_points()
1517            .returning(|_| {
1518                Ok(TracedEntryPointRequestResponse {
1519                    traced_entry_points: HashMap::new(),
1520                    pagination: PaginationResponse::new(0, 20, 0),
1521                })
1522            });
1523
1524        rpc_client
1525            .expect_get_component_tvl()
1526            .returning(|_| {
1527                Ok(ComponentTvlRequestResponse {
1528                    tvl: [
1529                        ("Component1".to_string(), 6.0),
1530                        ("Component2".to_string(), 2.0),
1531                        ("Component3".to_string(), 10.0),
1532                    ]
1533                    .into_iter()
1534                    .collect(),
1535                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1536                })
1537            });
1538
1539        rpc_client
1540            .expect_get_component_tvl()
1541            .returning(|_| {
1542                Ok(ComponentTvlRequestResponse {
1543                    tvl: [
1544                        ("Component1".to_string(), 6.0),
1545                        ("Component2".to_string(), 2.0),
1546                        ("Component3".to_string(), 10.0),
1547                    ]
1548                    .into_iter()
1549                    .collect(),
1550                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1551                })
1552            });
1553
1554        let (tx, rx) = channel(1);
1555        deltas_client
1556            .expect_subscribe()
1557            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1558
1559        let mut state_sync = ProtocolStateSynchronizer::new(
1560            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1561            true,
1562            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1563            1,
1564            true,
1565            true,
1566            ArcRPCClient(Arc::new(rpc_client)),
1567            ArcDeltasClient(Arc::new(deltas_client)),
1568            10_u64,
1569        );
1570        state_sync
1571            .initialize()
1572            .await
1573            .expect("Init failed");
1574
1575        // Simulate the incoming BlockChanges
1576        let deltas = [
1577            BlockChanges {
1578                extractor: "uniswap-v2".to_string(),
1579                chain: Chain::Ethereum,
1580                block: Block {
1581                    number: 1,
1582                    hash: Bytes::from("0x01"),
1583                    parent_hash: Bytes::from("0x00"),
1584                    chain: Chain::Ethereum,
1585                    ts: Default::default(),
1586                },
1587                revert: false,
1588                ..Default::default()
1589            },
1590            BlockChanges {
1591                extractor: "uniswap-v2".to_string(),
1592                chain: Chain::Ethereum,
1593                block: Block {
1594                    number: 2,
1595                    hash: Bytes::from("0x02"),
1596                    parent_hash: Bytes::from("0x01"),
1597                    chain: Chain::Ethereum,
1598                    ts: Default::default(),
1599                },
1600                revert: false,
1601                component_tvl: [
1602                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1603                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1604                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1605                ]
1606                .into_iter()
1607                .collect(),
1608                ..Default::default()
1609            },
1610        ];
1611
1612        let (jh, mut rx) = state_sync
1613            .start()
1614            .await
1615            .expect("Failed to start state synchronizer");
1616
1617        // Simulate sending delta messages
1618        tx.send(deltas[0].clone())
1619            .await
1620            .expect("deltas channel msg 0 closed!");
1621
1622        // Expecting to receive the initial state message
1623        let _ = timeout(Duration::from_millis(100), rx.recv())
1624            .await
1625            .expect("waiting for first state msg timed out!")
1626            .expect("state sync block sender closed!");
1627
1628        // Send the third message, which should trigger TVL-based changes
1629        tx.send(deltas[1].clone())
1630            .await
1631            .expect("deltas channel msg 1 closed!");
1632        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1633            .await
1634            .expect("waiting for second state msg timed out!")
1635            .expect("state sync block sender closed!");
1636
1637        let _ = state_sync.close().await;
1638        let exit = jh
1639            .await
1640            .expect("state sync task panicked!");
1641
1642        let expected_second_msg = StateSyncMessage {
1643            header: Header {
1644                number: 2,
1645                hash: Bytes::from("0x02"),
1646                parent_hash: Bytes::from("0x01"),
1647                revert: false,
1648            },
1649            snapshots: Snapshot {
1650                states: [(
1651                    "Component3".to_string(),
1652                    ComponentWithState {
1653                        state: ResponseProtocolState {
1654                            component_id: "Component3".to_string(),
1655                            ..Default::default()
1656                        },
1657                        component: ProtocolComponent {
1658                            id: "Component3".to_string(),
1659                            ..Default::default()
1660                        },
1661                        component_tvl: Some(10.0),
1662                        entrypoints: vec![], // TODO: add entrypoints?
1663                    },
1664                )]
1665                .into_iter()
1666                .collect(),
1667                vm_storage: HashMap::new(),
1668            },
1669            deltas: Some(BlockChanges {
1670                extractor: "uniswap-v2".to_string(),
1671                chain: Chain::Ethereum,
1672                block: Block {
1673                    number: 2,
1674                    hash: Bytes::from("0x02"),
1675                    parent_hash: Bytes::from("0x01"),
1676                    chain: Chain::Ethereum,
1677                    ts: Default::default(),
1678                },
1679                revert: false,
1680                component_tvl: [
1681                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1682                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1683                ]
1684                .into_iter()
1685                .collect(),
1686                ..Default::default()
1687            }),
1688            removed_components: [(
1689                "Component2".to_string(),
1690                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1691            )]
1692            .into_iter()
1693            .collect(),
1694        };
1695
1696        assert_eq!(second_msg, expected_second_msg);
1697        assert!(exit.is_ok());
1698    }
1699}