tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot, Mutex,
11    },
12    task::JoinHandle,
13    time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        Header,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
81    extractor_id: ExtractorIdentity,
82    retrieve_balances: bool,
83    rpc_client: R,
84    deltas_client: D,
85    max_retries: u64,
86    include_snapshots: bool,
87    component_tracker: Arc<Mutex<ComponentTracker<R>>>,
88    shared: Arc<Mutex<SharedState>>,
89    end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
90    timeout: u64,
91    include_tvl: bool,
92}
93
94#[derive(Debug, Default)]
95struct SharedState {
96    last_synced_block: Option<Header>,
97}
98
99#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
100pub struct ComponentWithState {
101    pub state: ResponseProtocolState,
102    pub component: ProtocolComponent,
103    pub component_tvl: Option<f64>,
104    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
105}
106
107#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
108pub struct Snapshot {
109    pub states: HashMap<String, ComponentWithState>,
110    pub vm_storage: HashMap<Bytes, ResponseAccount>,
111}
112
113impl Snapshot {
114    fn extend(&mut self, other: Snapshot) {
115        self.states.extend(other.states);
116        self.vm_storage.extend(other.vm_storage);
117    }
118
119    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
120        &self.states
121    }
122
123    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
124        &self.vm_storage
125    }
126}
127
128#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
129pub struct StateSyncMessage {
130    /// The block information for this update.
131    pub header: Header,
132    /// Snapshot for new components.
133    pub snapshots: Snapshot,
134    /// A single delta contains state updates for all tracked components, as well as additional
135    /// information about the system components e.g. newly added components (even below tvl), tvl
136    /// updates, balance updates.
137    pub deltas: Option<BlockChanges>,
138    /// Components that stopped being tracked.
139    pub removed_components: HashMap<String, ProtocolComponent>,
140}
141
142impl StateSyncMessage {
143    pub fn merge(mut self, other: Self) -> Self {
144        // be careful with removed and snapshots attributes here, these can be ambiguous.
145        self.removed_components
146            .retain(|k, _| !other.snapshots.states.contains_key(k));
147        self.snapshots
148            .states
149            .retain(|k, _| !other.removed_components.contains_key(k));
150
151        self.snapshots.extend(other.snapshots);
152        let deltas = match (self.deltas, other.deltas) {
153            (Some(l), Some(r)) => Some(l.merge(r)),
154            (None, Some(r)) => Some(r),
155            (Some(l), None) => Some(l),
156            (None, None) => None,
157        };
158        self.removed_components
159            .extend(other.removed_components);
160        Self {
161            header: other.header,
162            snapshots: self.snapshots,
163            deltas,
164            removed_components: self.removed_components,
165        }
166    }
167}
168
169/// StateSynchronizer
170///
171/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
172/// delivering messages to the client that let him reconstruct subsets of the protocol state.
173///
174/// This involves deciding which components to track according to the clients preferences,
175/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
176/// delivering delta messages for the components that have changed.
177#[async_trait]
178pub trait StateSynchronizer: Send + Sync + 'static {
179    async fn initialize(&self) -> SyncResult<()>;
180    /// Starts the state synchronization.
181    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)>;
182    /// Ends the synchronization loop.
183    async fn close(&mut self) -> SyncResult<()>;
184}
185
186impl<R, D> ProtocolStateSynchronizer<R, D>
187where
188    // TODO: Consider moving these constraints directly to the
189    // client...
190    R: RPCClient + Clone + Send + Sync + 'static,
191    D: DeltasClient + Clone + Send + Sync + 'static,
192{
193    /// Creates a new state synchronizer.
194    #[allow(clippy::too_many_arguments)]
195    pub fn new(
196        extractor_id: ExtractorIdentity,
197        retrieve_balances: bool,
198        component_filter: ComponentFilter,
199        max_retries: u64,
200        include_snapshots: bool,
201        include_tvl: bool,
202        rpc_client: R,
203        deltas_client: D,
204        timeout: u64,
205    ) -> Self {
206        Self {
207            extractor_id: extractor_id.clone(),
208            retrieve_balances,
209            rpc_client: rpc_client.clone(),
210            include_snapshots,
211            deltas_client,
212            component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
213                extractor_id.chain,
214                extractor_id.name.as_str(),
215                component_filter,
216                rpc_client,
217            ))),
218            max_retries,
219            shared: Arc::new(Mutex::new(SharedState::default())),
220            end_tx: Arc::new(Mutex::new(None)),
221            timeout,
222            include_tvl,
223        }
224    }
225
226    /// Retrieves state snapshots of the requested components
227    #[allow(deprecated)]
228    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
229        &self,
230        header: Header,
231        tracked_components: &mut ComponentTracker<R>,
232        ids: Option<I>,
233    ) -> SyncResult<StateSyncMessage> {
234        if !self.include_snapshots {
235            return Ok(StateSyncMessage { header, ..Default::default() });
236        }
237        let version = VersionParam::new(
238            None,
239            Some(BlockParam {
240                chain: Some(self.extractor_id.chain),
241                hash: None,
242                number: Some(header.number as i64),
243            }),
244        );
245
246        // Use given ids or use all if not passed
247        let component_ids: Vec<_> = match ids {
248            Some(ids) => ids.into_iter().cloned().collect(),
249            None => tracked_components.get_tracked_component_ids(),
250        };
251
252        if component_ids.is_empty() {
253            return Ok(StateSyncMessage { header, ..Default::default() });
254        }
255
256        let component_tvl = if self.include_tvl {
257            let body = ComponentTvlRequestBody::id_filtered(
258                component_ids.clone(),
259                self.extractor_id.chain,
260            );
261            self.rpc_client
262                .get_component_tvl_paginated(&body, 100, 4)
263                .await?
264                .tvl
265        } else {
266            HashMap::new()
267        };
268
269        // Fetch entrypoints
270        let entrypoints_result = self
271            .rpc_client
272            .get_traced_entry_points_paginated(
273                self.extractor_id.chain,
274                &self.extractor_id.name,
275                &component_ids,
276                100,
277                4,
278            )
279            .await?;
280        tracked_components.process_entrypoints(&entrypoints_result.clone().into())?;
281
282        // Fetch protocol states
283        let mut protocol_states = self
284            .rpc_client
285            .get_protocol_states_paginated(
286                self.extractor_id.chain,
287                &component_ids,
288                &self.extractor_id.name,
289                self.retrieve_balances,
290                &version,
291                100,
292                4,
293            )
294            .await?
295            .states
296            .into_iter()
297            .map(|state| (state.component_id.clone(), state))
298            .collect::<HashMap<_, _>>();
299
300        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
301        let states = tracked_components
302            .components
303            .values()
304            .filter_map(|component| {
305                if let Some(state) = protocol_states.remove(&component.id) {
306                    Some((
307                        component.id.clone(),
308                        ComponentWithState {
309                            state,
310                            component: component.clone(),
311                            component_tvl: component_tvl
312                                .get(&component.id)
313                                .cloned(),
314                            entrypoints: entrypoints_result
315                                .traced_entry_points
316                                .get(&component.id)
317                                .cloned()
318                                .unwrap_or_default(),
319                        },
320                    ))
321                } else if component_ids.contains(&component.id) {
322                    // only emit error event if we requested this component
323                    let component_id = &component.id;
324                    error!(?component_id, "Missing state for native component!");
325                    None
326                } else {
327                    None
328                }
329            })
330            .collect();
331
332        // Fetch contract states
333        let contract_ids = tracked_components.get_contracts_by_component(&component_ids);
334        let vm_storage = if !contract_ids.is_empty() {
335            let ids: Vec<Bytes> = contract_ids
336                .clone()
337                .into_iter()
338                .collect();
339            let contract_states = self
340                .rpc_client
341                .get_contract_state_paginated(
342                    self.extractor_id.chain,
343                    ids.as_slice(),
344                    &self.extractor_id.name,
345                    &version,
346                    100,
347                    4,
348                )
349                .await?
350                .accounts
351                .into_iter()
352                .map(|acc| (acc.address.clone(), acc))
353                .collect::<HashMap<_, _>>();
354
355            trace!(states=?&contract_states, "Retrieved ContractState");
356
357            let contract_address_to_components = tracked_components
358                .components
359                .iter()
360                .filter_map(|(id, comp)| {
361                    if component_ids.contains(id) {
362                        Some(
363                            comp.contract_ids
364                                .iter()
365                                .map(|address| (address.clone(), comp.id.clone())),
366                        )
367                    } else {
368                        None
369                    }
370                })
371                .flatten()
372                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
373                    acc.entry(addr).or_default().push(c_id);
374                    acc
375                });
376
377            contract_ids
378                .iter()
379                .filter_map(|address| {
380                    if let Some(state) = contract_states.get(address) {
381                        Some((address.clone(), state.clone()))
382                    } else if let Some(ids) = contract_address_to_components.get(address) {
383                        // only emit error even if we did actually request this address
384                        error!(
385                            ?address,
386                            ?ids,
387                            "Component with lacking contract storage encountered!"
388                        );
389                        None
390                    } else {
391                        None
392                    }
393                })
394                .collect()
395        } else {
396            HashMap::new()
397        };
398
399        Ok(StateSyncMessage {
400            header,
401            snapshots: Snapshot { states, vm_storage },
402            deltas: None,
403            removed_components: HashMap::new(),
404        })
405    }
406
407    /// Main method that does all the work.
408    #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
409    async fn state_sync(self, block_tx: &mut Sender<StateSyncMessage>) -> SyncResult<()> {
410        // initialisation
411        let mut tracker = self.component_tracker.lock().await;
412
413        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
414        let (_, mut msg_rx) = self
415            .deltas_client
416            .subscribe(self.extractor_id.clone(), subscription_options)
417            .await?;
418
419        info!("Waiting for deltas...");
420        // wait for first deltas message
421        let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
422            .await
423            .map_err(|_| {
424                SynchronizerError::Timeout(format!(
425                    "First deltas took longer than {t}s to arrive",
426                    t = self.timeout
427                ))
428            })?
429            .ok_or_else(|| {
430                SynchronizerError::ConnectionError(
431                    "Deltas channel closed before first message".to_string(),
432                )
433            })?;
434        self.filter_deltas(&mut first_msg, &tracker);
435
436        // initial snapshot
437        let block = first_msg.get_block().clone();
438        info!(height = &block.number, "Deltas received. Retrieving snapshot");
439        let header = Header::from_block(first_msg.get_block(), first_msg.is_revert());
440        let snapshot = self
441            .get_snapshots::<Vec<&String>>(Header::from_block(&block, false), &mut tracker, None)
442            .await?
443            .merge(StateSyncMessage {
444                header: Header::from_block(first_msg.get_block(), first_msg.is_revert()),
445                snapshots: Default::default(),
446                deltas: Some(first_msg),
447                removed_components: Default::default(),
448            });
449
450        let n_components = tracker.components.len();
451        let n_snapshots = snapshot.snapshots.states.len();
452        info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
453
454        {
455            let mut shared = self.shared.lock().await;
456            block_tx.send(snapshot).await?;
457            shared.last_synced_block = Some(header.clone());
458        }
459
460        loop {
461            if let Some(mut deltas) = msg_rx.recv().await {
462                let header = Header::from_block(deltas.get_block(), deltas.is_revert());
463                debug!(block_number=?header.number, "Received delta message");
464
465                let (snapshots, removed_components) = {
466                    // 1. Remove components based on latest changes
467                    // 2. Add components based on latest changes, query those for snapshots
468                    let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
469
470                    // Only components we don't track yet need a snapshot,
471                    let requiring_snapshot: Vec<_> = to_add
472                        .iter()
473                        .filter(|id| {
474                            !tracker
475                                .components
476                                .contains_key(id.as_str())
477                        })
478                        .collect();
479                    debug!(components=?requiring_snapshot, "SnapshotRequest");
480                    tracker
481                        .start_tracking(requiring_snapshot.as_slice())
482                        .await?;
483                    let snapshots = self
484                        .get_snapshots(header.clone(), &mut tracker, Some(requiring_snapshot))
485                        .await?
486                        .snapshots;
487
488                    let removed_components = if !to_remove.is_empty() {
489                        tracker.stop_tracking(&to_remove)
490                    } else {
491                        Default::default()
492                    };
493
494                    (snapshots, removed_components)
495                };
496
497                // 3. Update entrypoints on the tracker (affects which contracts are tracked)
498                tracker.process_entrypoints(&deltas.dci_update)?;
499
500                // 4. Filter deltas by currently tracked components / contracts
501                self.filter_deltas(&mut deltas, &tracker);
502                let n_changes = deltas.n_changes();
503
504                // 5. Send the message
505                let next = StateSyncMessage {
506                    header: header.clone(),
507                    snapshots,
508                    deltas: Some(deltas),
509                    removed_components,
510                };
511                block_tx.send(next).await?;
512                {
513                    let mut shared = self.shared.lock().await;
514                    shared.last_synced_block = Some(header.clone());
515                }
516
517                debug!(block_number=?header.number, n_changes, "Finished processing delta message");
518            } else {
519                let mut shared = self.shared.lock().await;
520                warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
521                shared.last_synced_block = None;
522
523                return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
524            }
525        }
526    }
527
528    fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
529        second_msg.filter_by_component(|id| tracker.components.contains_key(id));
530        second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
531    }
532}
533
534#[async_trait]
535impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
536where
537    R: RPCClient + Clone + Send + Sync + 'static,
538    D: DeltasClient + Clone + Send + Sync + 'static,
539{
540    async fn initialize(&self) -> SyncResult<()> {
541        let mut tracker = self.component_tracker.lock().await;
542        info!("Retrieving relevant protocol components");
543        tracker.initialise_components().await?;
544        info!(
545            n_components = tracker.components.len(),
546            n_contracts = tracker.contracts.len(),
547            "Finished retrieving components",
548        );
549
550        Ok(())
551    }
552
553    async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)> {
554        let (mut tx, rx) = channel(15);
555
556        let this = self.clone();
557        let jh = tokio::spawn(async move {
558            let mut retry_count = 0;
559            while retry_count < this.max_retries {
560                info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
561                let (end_tx, end_rx) = oneshot::channel::<()>();
562                {
563                    let mut end_tx_guard = this.end_tx.lock().await;
564                    *end_tx_guard = Some(end_tx);
565                }
566
567                select! {
568                    res = this.clone().state_sync(&mut tx) => {
569                        match res {
570                            Err(e) => {
571                                error!(
572                                    extractor_id=%&this.extractor_id,
573                                    retry_count,
574                                    error=%e,
575                                    "State synchronization errored!"
576                                );
577                                if let SynchronizerError::ConnectionClosed = e {
578                                    // break synchronization loop if connection is closed
579                                    return Err(e);
580                                }
581                            }
582                            _ => {
583                                warn!(
584                                    extractor_id=%&this.extractor_id,
585                                    retry_count,
586                                    "State synchronization exited with Ok(())"
587                                );
588                            }
589                        }
590                    },
591                    _ = end_rx => {
592                        info!(
593                            extractor_id=%&this.extractor_id,
594                            retry_count,
595                            "StateSynchronizer received close signal. Stopping"
596                        );
597                        return Ok(())
598                    }
599                }
600                retry_count += 1;
601            }
602            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
603        });
604
605        Ok((jh, rx))
606    }
607
608    async fn close(&mut self) -> SyncResult<()> {
609        let mut end_tx = self.end_tx.lock().await;
610        if let Some(tx) = end_tx.take() {
611            let _ = tx.send(());
612            Ok(())
613        } else {
614            Err(SynchronizerError::CloseError("Synchronizer not started".to_string()))
615        }
616    }
617}
618
619#[cfg(test)]
620mod test {
621    use std::collections::HashSet;
622
623    use test_log::test;
624    use tycho_common::dto::{
625        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
626        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
627        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
628        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
629        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
630        TracedEntryPointRequestResponse, TracingParams,
631    };
632    use uuid::Uuid;
633
634    use super::*;
635    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
636
637    // Required for mock client to implement clone
638    struct ArcRPCClient<T>(Arc<T>);
639
640    // Default derive(Clone) does require T to be Clone as well.
641    impl<T> Clone for ArcRPCClient<T> {
642        fn clone(&self) -> Self {
643            ArcRPCClient(self.0.clone())
644        }
645    }
646
647    #[async_trait]
648    impl<T> RPCClient for ArcRPCClient<T>
649    where
650        T: RPCClient + Sync + Send + 'static,
651    {
652        async fn get_tokens(
653            &self,
654            request: &TokensRequestBody,
655        ) -> Result<TokensRequestResponse, RPCError> {
656            self.0.get_tokens(request).await
657        }
658
659        async fn get_contract_state(
660            &self,
661            request: &StateRequestBody,
662        ) -> Result<StateRequestResponse, RPCError> {
663            self.0.get_contract_state(request).await
664        }
665
666        async fn get_protocol_components(
667            &self,
668            request: &ProtocolComponentsRequestBody,
669        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
670            self.0
671                .get_protocol_components(request)
672                .await
673        }
674
675        async fn get_protocol_states(
676            &self,
677            request: &ProtocolStateRequestBody,
678        ) -> Result<ProtocolStateRequestResponse, RPCError> {
679            self.0
680                .get_protocol_states(request)
681                .await
682        }
683
684        async fn get_protocol_systems(
685            &self,
686            request: &ProtocolSystemsRequestBody,
687        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
688            self.0
689                .get_protocol_systems(request)
690                .await
691        }
692
693        async fn get_component_tvl(
694            &self,
695            request: &ComponentTvlRequestBody,
696        ) -> Result<ComponentTvlRequestResponse, RPCError> {
697            self.0.get_component_tvl(request).await
698        }
699
700        async fn get_traced_entry_points(
701            &self,
702            request: &TracedEntryPointRequestBody,
703        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
704            self.0
705                .get_traced_entry_points(request)
706                .await
707        }
708    }
709
710    // Required for mock client to implement clone
711    struct ArcDeltasClient<T>(Arc<T>);
712
713    // Default derive(Clone) does require T to be Clone as well.
714    impl<T> Clone for ArcDeltasClient<T> {
715        fn clone(&self) -> Self {
716            ArcDeltasClient(self.0.clone())
717        }
718    }
719
720    #[async_trait]
721    impl<T> DeltasClient for ArcDeltasClient<T>
722    where
723        T: DeltasClient + Sync + Send + 'static,
724    {
725        async fn subscribe(
726            &self,
727            extractor_id: ExtractorIdentity,
728            options: SubscriptionOptions,
729        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
730            self.0
731                .subscribe(extractor_id, options)
732                .await
733        }
734
735        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
736            self.0
737                .unsubscribe(subscription_id)
738                .await
739        }
740
741        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
742            self.0.connect().await
743        }
744
745        async fn close(&self) -> Result<(), DeltasError> {
746            self.0.close().await
747        }
748    }
749
750    fn with_mocked_clients(
751        native: bool,
752        include_tvl: bool,
753        rpc_client: Option<MockRPCClient>,
754        deltas_client: Option<MockDeltasClient>,
755    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
756    {
757        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
758        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
759
760        ProtocolStateSynchronizer::new(
761            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
762            native,
763            ComponentFilter::with_tvl_range(50.0, 50.0),
764            1,
765            true,
766            include_tvl,
767            rpc_client,
768            deltas_client,
769            10_u64,
770        )
771    }
772
773    fn state_snapshot_native() -> ProtocolStateRequestResponse {
774        ProtocolStateRequestResponse {
775            states: vec![ResponseProtocolState {
776                component_id: "Component1".to_string(),
777                ..Default::default()
778            }],
779            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
780        }
781    }
782
783    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
784        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
785
786        ComponentTvlRequestResponse {
787            tvl,
788            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
789        }
790    }
791
792    #[test(tokio::test)]
793    async fn test_get_snapshots_native() {
794        let header = Header::default();
795        let mut rpc = MockRPCClient::new();
796        rpc.expect_get_protocol_states()
797            .returning(|_| Ok(state_snapshot_native()));
798        rpc.expect_get_traced_entry_points()
799            .returning(|_| {
800                Ok(TracedEntryPointRequestResponse {
801                    traced_entry_points: HashMap::new(),
802                    pagination: PaginationResponse::new(0, 20, 0),
803                })
804            });
805        let state_sync = with_mocked_clients(true, false, Some(rpc), None);
806        let mut tracker = ComponentTracker::new(
807            Chain::Ethereum,
808            "uniswap-v2",
809            ComponentFilter::with_tvl_range(0.0, 0.0),
810            state_sync.rpc_client.clone(),
811        );
812        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
813        tracker
814            .components
815            .insert("Component1".to_string(), component.clone());
816        let components_arg = ["Component1".to_string()];
817        let exp = StateSyncMessage {
818            header: header.clone(),
819            snapshots: Snapshot {
820                states: state_snapshot_native()
821                    .states
822                    .into_iter()
823                    .map(|state| {
824                        (
825                            state.component_id.clone(),
826                            ComponentWithState {
827                                state,
828                                component: component.clone(),
829                                entrypoints: vec![],
830                                component_tvl: None,
831                            },
832                        )
833                    })
834                    .collect(),
835                vm_storage: HashMap::new(),
836            },
837            deltas: None,
838            removed_components: Default::default(),
839        };
840
841        let snap = state_sync
842            .get_snapshots(header, &mut tracker, Some(&components_arg))
843            .await
844            .expect("Retrieving snapshot failed");
845
846        assert_eq!(snap, exp);
847    }
848
849    #[test(tokio::test)]
850    async fn test_get_snapshots_native_with_tvl() {
851        let header = Header::default();
852        let mut rpc = MockRPCClient::new();
853        rpc.expect_get_protocol_states()
854            .returning(|_| Ok(state_snapshot_native()));
855        rpc.expect_get_component_tvl()
856            .returning(|_| Ok(component_tvl_snapshot()));
857        rpc.expect_get_traced_entry_points()
858            .returning(|_| {
859                Ok(TracedEntryPointRequestResponse {
860                    traced_entry_points: HashMap::new(),
861                    pagination: PaginationResponse::new(0, 20, 0),
862                })
863            });
864        let state_sync = with_mocked_clients(true, true, Some(rpc), None);
865        let mut tracker = ComponentTracker::new(
866            Chain::Ethereum,
867            "uniswap-v2",
868            ComponentFilter::with_tvl_range(0.0, 0.0),
869            state_sync.rpc_client.clone(),
870        );
871        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
872        tracker
873            .components
874            .insert("Component1".to_string(), component.clone());
875        let components_arg = ["Component1".to_string()];
876        let exp = StateSyncMessage {
877            header: header.clone(),
878            snapshots: Snapshot {
879                states: state_snapshot_native()
880                    .states
881                    .into_iter()
882                    .map(|state| {
883                        (
884                            state.component_id.clone(),
885                            ComponentWithState {
886                                state,
887                                component: component.clone(),
888                                component_tvl: Some(100.0),
889                                entrypoints: vec![],
890                            },
891                        )
892                    })
893                    .collect(),
894                vm_storage: HashMap::new(),
895            },
896            deltas: None,
897            removed_components: Default::default(),
898        };
899
900        let snap = state_sync
901            .get_snapshots(header, &mut tracker, Some(&components_arg))
902            .await
903            .expect("Retrieving snapshot failed");
904
905        assert_eq!(snap, exp);
906    }
907
908    fn state_snapshot_vm() -> StateRequestResponse {
909        StateRequestResponse {
910            accounts: vec![
911                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
912                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
913            ],
914            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
915        }
916    }
917
918    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
919        TracedEntryPointRequestResponse {
920            traced_entry_points: HashMap::from([(
921                "Component1".to_string(),
922                vec![(
923                    EntryPointWithTracingParams {
924                        entry_point: EntryPoint {
925                            external_id: "entrypoint_a".to_string(),
926                            target: Bytes::from("0x0badc0ffee"),
927                            signature: "sig()".to_string(),
928                        },
929                        params: TracingParams::RPCTracer(RPCTracerParams {
930                            caller: Some(Bytes::from("0x0badc0ffee")),
931                            calldata: Bytes::from("0x0badc0ffee"),
932                        }),
933                    },
934                    TracingResult {
935                        retriggers: HashSet::from([(
936                            Bytes::from("0x0badc0ffee"),
937                            Bytes::from("0x0badc0ffee"),
938                        )]),
939                        accessed_slots: HashMap::from([(
940                            Bytes::from("0x0badc0ffee"),
941                            HashSet::from([Bytes::from("0xbadbeef0")]),
942                        )]),
943                    },
944                )],
945            )]),
946            pagination: PaginationResponse::new(0, 20, 0),
947        }
948    }
949
950    #[test(tokio::test)]
951    async fn test_get_snapshots_vm() {
952        let header = Header::default();
953        let mut rpc = MockRPCClient::new();
954        rpc.expect_get_protocol_states()
955            .returning(|_| Ok(state_snapshot_native()));
956        rpc.expect_get_contract_state()
957            .returning(|_| Ok(state_snapshot_vm()));
958        rpc.expect_get_traced_entry_points()
959            .returning(|_| Ok(traced_entry_point_response()));
960        let state_sync = with_mocked_clients(false, false, Some(rpc), None);
961        let mut tracker = ComponentTracker::new(
962            Chain::Ethereum,
963            "uniswap-v2",
964            ComponentFilter::with_tvl_range(0.0, 0.0),
965            state_sync.rpc_client.clone(),
966        );
967        let component = ProtocolComponent {
968            id: "Component1".to_string(),
969            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
970            ..Default::default()
971        };
972        tracker
973            .components
974            .insert("Component1".to_string(), component.clone());
975        let components_arg = ["Component1".to_string()];
976        let exp = StateSyncMessage {
977            header: header.clone(),
978            snapshots: Snapshot {
979                states: [(
980                    component.id.clone(),
981                    ComponentWithState {
982                        state: ResponseProtocolState {
983                            component_id: "Component1".to_string(),
984                            ..Default::default()
985                        },
986                        component: component.clone(),
987                        component_tvl: None,
988                        entrypoints: vec![(
989                            EntryPointWithTracingParams {
990                                entry_point: EntryPoint {
991                                    external_id: "entrypoint_a".to_string(),
992                                    target: Bytes::from("0x0badc0ffee"),
993                                    signature: "sig()".to_string(),
994                                },
995                                params: TracingParams::RPCTracer(RPCTracerParams {
996                                    caller: Some(Bytes::from("0x0badc0ffee")),
997                                    calldata: Bytes::from("0x0badc0ffee"),
998                                }),
999                            },
1000                            TracingResult {
1001                                retriggers: HashSet::from([(
1002                                    Bytes::from("0x0badc0ffee"),
1003                                    Bytes::from("0x0badc0ffee"),
1004                                )]),
1005                                accessed_slots: HashMap::from([(
1006                                    Bytes::from("0x0badc0ffee"),
1007                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1008                                )]),
1009                            },
1010                        )],
1011                    },
1012                )]
1013                .into_iter()
1014                .collect(),
1015                vm_storage: state_snapshot_vm()
1016                    .accounts
1017                    .into_iter()
1018                    .map(|state| (state.address.clone(), state))
1019                    .collect(),
1020            },
1021            deltas: None,
1022            removed_components: Default::default(),
1023        };
1024
1025        let snap = state_sync
1026            .get_snapshots(header, &mut tracker, Some(&components_arg))
1027            .await
1028            .expect("Retrieving snapshot failed");
1029
1030        assert_eq!(snap, exp);
1031    }
1032
1033    #[test(tokio::test)]
1034    async fn test_get_snapshots_vm_with_tvl() {
1035        let header = Header::default();
1036        let mut rpc = MockRPCClient::new();
1037        rpc.expect_get_protocol_states()
1038            .returning(|_| Ok(state_snapshot_native()));
1039        rpc.expect_get_contract_state()
1040            .returning(|_| Ok(state_snapshot_vm()));
1041        rpc.expect_get_component_tvl()
1042            .returning(|_| Ok(component_tvl_snapshot()));
1043        rpc.expect_get_traced_entry_points()
1044            .returning(|_| {
1045                Ok(TracedEntryPointRequestResponse {
1046                    traced_entry_points: HashMap::new(),
1047                    pagination: PaginationResponse::new(0, 20, 0),
1048                })
1049            });
1050        let state_sync = with_mocked_clients(false, true, Some(rpc), None);
1051        let mut tracker = ComponentTracker::new(
1052            Chain::Ethereum,
1053            "uniswap-v2",
1054            ComponentFilter::with_tvl_range(0.0, 0.0),
1055            state_sync.rpc_client.clone(),
1056        );
1057        let component = ProtocolComponent {
1058            id: "Component1".to_string(),
1059            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1060            ..Default::default()
1061        };
1062        tracker
1063            .components
1064            .insert("Component1".to_string(), component.clone());
1065        let components_arg = ["Component1".to_string()];
1066        let exp = StateSyncMessage {
1067            header: header.clone(),
1068            snapshots: Snapshot {
1069                states: [(
1070                    component.id.clone(),
1071                    ComponentWithState {
1072                        state: ResponseProtocolState {
1073                            component_id: "Component1".to_string(),
1074                            ..Default::default()
1075                        },
1076                        component: component.clone(),
1077                        component_tvl: Some(100.0),
1078                        entrypoints: vec![],
1079                    },
1080                )]
1081                .into_iter()
1082                .collect(),
1083                vm_storage: state_snapshot_vm()
1084                    .accounts
1085                    .into_iter()
1086                    .map(|state| (state.address.clone(), state))
1087                    .collect(),
1088            },
1089            deltas: None,
1090            removed_components: Default::default(),
1091        };
1092
1093        let snap = state_sync
1094            .get_snapshots(header, &mut tracker, Some(&components_arg))
1095            .await
1096            .expect("Retrieving snapshot failed");
1097
1098        assert_eq!(snap, exp);
1099    }
1100
1101    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1102        let mut rpc_client = MockRPCClient::new();
1103        // Mocks for the start_tracking call, these need to come first because they are more
1104        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1105        rpc_client
1106            .expect_get_protocol_components()
1107            .with(mockall::predicate::function(
1108                move |request_params: &ProtocolComponentsRequestBody| {
1109                    if let Some(ids) = request_params.component_ids.as_ref() {
1110                        ids.contains(&"Component3".to_string())
1111                    } else {
1112                        false
1113                    }
1114                },
1115            ))
1116            .returning(|_| {
1117                // return Component3
1118                Ok(ProtocolComponentRequestResponse {
1119                    protocol_components: vec![
1120                        // this component shall have a tvl update above threshold
1121                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1122                    ],
1123                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1124                })
1125            });
1126        rpc_client
1127            .expect_get_protocol_states()
1128            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1129                let expected_id = "Component3".to_string();
1130                if let Some(ids) = request_params.protocol_ids.as_ref() {
1131                    ids.contains(&expected_id)
1132                } else {
1133                    false
1134                }
1135            }))
1136            .returning(|_| {
1137                // return Component3 state
1138                Ok(ProtocolStateRequestResponse {
1139                    states: vec![ResponseProtocolState {
1140                        component_id: "Component3".to_string(),
1141                        ..Default::default()
1142                    }],
1143                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1144                })
1145            });
1146
1147        // mock calls for the initial state snapshots
1148        rpc_client
1149            .expect_get_protocol_components()
1150            .returning(|_| {
1151                // Initial sync of components
1152                Ok(ProtocolComponentRequestResponse {
1153                    protocol_components: vec![
1154                        // this component shall have a tvl update above threshold
1155                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1156                        // this component shall have a tvl update below threshold.
1157                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1158                        // a third component will have a tvl update above threshold
1159                    ],
1160                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1161                })
1162            });
1163        rpc_client
1164            .expect_get_protocol_states()
1165            .returning(|_| {
1166                // Initial state snapshot
1167                Ok(ProtocolStateRequestResponse {
1168                    states: vec![
1169                        ResponseProtocolState {
1170                            component_id: "Component1".to_string(),
1171                            ..Default::default()
1172                        },
1173                        ResponseProtocolState {
1174                            component_id: "Component2".to_string(),
1175                            ..Default::default()
1176                        },
1177                    ],
1178                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1179                })
1180            });
1181        rpc_client
1182            .expect_get_component_tvl()
1183            .returning(|_| {
1184                Ok(ComponentTvlRequestResponse {
1185                    tvl: [
1186                        ("Component1".to_string(), 100.0),
1187                        ("Component2".to_string(), 0.0),
1188                        ("Component3".to_string(), 1000.0),
1189                    ]
1190                    .into_iter()
1191                    .collect(),
1192                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1193                })
1194            });
1195        rpc_client
1196            .expect_get_traced_entry_points()
1197            .returning(|_| {
1198                Ok(TracedEntryPointRequestResponse {
1199                    traced_entry_points: HashMap::new(),
1200                    pagination: PaginationResponse::new(0, 20, 0),
1201                })
1202            });
1203
1204        // Mock deltas client and messages
1205        let mut deltas_client = MockDeltasClient::new();
1206        let (tx, rx) = channel(1);
1207        deltas_client
1208            .expect_subscribe()
1209            .return_once(move |_, _| {
1210                // Return subscriber id and a channel
1211                Ok((Uuid::default(), rx))
1212            });
1213        (rpc_client, deltas_client, tx)
1214    }
1215
1216    /// Test strategy
1217    ///
1218    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1219    /// - send 2 dummy messages, containing only blocks
1220    /// - third message contains a new component with some significant tvl, one initial component
1221    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1222    #[test(tokio::test)]
1223    async fn test_state_sync() {
1224        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1225        let deltas = [
1226            BlockChanges {
1227                extractor: "uniswap-v2".to_string(),
1228                chain: Chain::Ethereum,
1229                block: Block {
1230                    number: 1,
1231                    hash: Bytes::from("0x01"),
1232                    parent_hash: Bytes::from("0x00"),
1233                    chain: Chain::Ethereum,
1234                    ts: Default::default(),
1235                },
1236                revert: false,
1237                dci_update: DCIUpdate {
1238                    new_entrypoints: HashMap::from([(
1239                        "Component1".to_string(),
1240                        HashSet::from([EntryPoint {
1241                            external_id: "entrypoint_a".to_string(),
1242                            target: Bytes::from("0x0badc0ffee"),
1243                            signature: "sig()".to_string(),
1244                        }]),
1245                    )]),
1246                    new_entrypoint_params: HashMap::from([(
1247                        "entrypoint_a".to_string(),
1248                        HashSet::from([(
1249                            TracingParams::RPCTracer(RPCTracerParams {
1250                                caller: Some(Bytes::from("0x0badc0ffee")),
1251                                calldata: Bytes::from("0x0badc0ffee"),
1252                            }),
1253                            Some("Component1".to_string()),
1254                        )]),
1255                    )]),
1256                    trace_results: HashMap::from([(
1257                        "entrypoint_a".to_string(),
1258                        TracingResult {
1259                            retriggers: HashSet::from([(
1260                                Bytes::from("0x0badc0ffee"),
1261                                Bytes::from("0x0badc0ffee"),
1262                            )]),
1263                            accessed_slots: HashMap::from([(
1264                                Bytes::from("0x0badc0ffee"),
1265                                HashSet::from([Bytes::from("0xbadbeef0")]),
1266                            )]),
1267                        },
1268                    )]),
1269                },
1270                ..Default::default()
1271            },
1272            BlockChanges {
1273                extractor: "uniswap-v2".to_string(),
1274                chain: Chain::Ethereum,
1275                block: Block {
1276                    number: 2,
1277                    hash: Bytes::from("0x02"),
1278                    parent_hash: Bytes::from("0x01"),
1279                    chain: Chain::Ethereum,
1280                    ts: Default::default(),
1281                },
1282                revert: false,
1283                component_tvl: [
1284                    ("Component1".to_string(), 100.0),
1285                    ("Component2".to_string(), 0.0),
1286                    ("Component3".to_string(), 1000.0),
1287                ]
1288                .into_iter()
1289                .collect(),
1290                ..Default::default()
1291            },
1292        ];
1293        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1294        state_sync
1295            .initialize()
1296            .await
1297            .expect("Init failed");
1298
1299        // Test starts here
1300        let (jh, mut rx) = state_sync
1301            .start()
1302            .await
1303            .expect("Failed to start state synchronizer");
1304        tx.send(deltas[0].clone())
1305            .await
1306            .expect("deltas channel msg 0 closed!");
1307        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1308            .await
1309            .expect("waiting for first state msg timed out!")
1310            .expect("state sync block sender closed!");
1311        tx.send(deltas[1].clone())
1312            .await
1313            .expect("deltas channel msg 1 closed!");
1314        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1315            .await
1316            .expect("waiting for second state msg timed out!")
1317            .expect("state sync block sender closed!");
1318        let _ = state_sync.close().await;
1319        let exit = jh
1320            .await
1321            .expect("state sync task panicked!");
1322
1323        // assertions
1324        let exp1 = StateSyncMessage {
1325            header: Header {
1326                number: 1,
1327                hash: Bytes::from("0x01"),
1328                parent_hash: Bytes::from("0x00"),
1329                revert: false,
1330                ..Default::default()
1331            },
1332            snapshots: Snapshot {
1333                states: [
1334                    (
1335                        "Component1".to_string(),
1336                        ComponentWithState {
1337                            state: ResponseProtocolState {
1338                                component_id: "Component1".to_string(),
1339                                ..Default::default()
1340                            },
1341                            component: ProtocolComponent {
1342                                id: "Component1".to_string(),
1343                                ..Default::default()
1344                            },
1345                            component_tvl: Some(100.0),
1346                            entrypoints: vec![],
1347                        },
1348                    ),
1349                    (
1350                        "Component2".to_string(),
1351                        ComponentWithState {
1352                            state: ResponseProtocolState {
1353                                component_id: "Component2".to_string(),
1354                                ..Default::default()
1355                            },
1356                            component: ProtocolComponent {
1357                                id: "Component2".to_string(),
1358                                ..Default::default()
1359                            },
1360                            component_tvl: Some(0.0),
1361                            entrypoints: vec![],
1362                        },
1363                    ),
1364                ]
1365                .into_iter()
1366                .collect(),
1367                vm_storage: HashMap::new(),
1368            },
1369            deltas: Some(deltas[0].clone()),
1370            removed_components: Default::default(),
1371        };
1372
1373        let exp2 = StateSyncMessage {
1374            header: Header {
1375                number: 2,
1376                hash: Bytes::from("0x02"),
1377                parent_hash: Bytes::from("0x01"),
1378                revert: false,
1379                ..Default::default()
1380            },
1381            snapshots: Snapshot {
1382                states: [
1383                    // This is the new component we queried once it passed the tvl threshold.
1384                    (
1385                        "Component3".to_string(),
1386                        ComponentWithState {
1387                            state: ResponseProtocolState {
1388                                component_id: "Component3".to_string(),
1389                                ..Default::default()
1390                            },
1391                            component: ProtocolComponent {
1392                                id: "Component3".to_string(),
1393                                ..Default::default()
1394                            },
1395                            component_tvl: Some(1000.0),
1396                            entrypoints: vec![],
1397                        },
1398                    ),
1399                ]
1400                .into_iter()
1401                .collect(),
1402                vm_storage: HashMap::new(),
1403            },
1404            // Our deltas are empty and since merge methods are
1405            // tested in tycho-common we don't have much to do here.
1406            deltas: Some(BlockChanges {
1407                extractor: "uniswap-v2".to_string(),
1408                chain: Chain::Ethereum,
1409                block: Block {
1410                    number: 2,
1411                    hash: Bytes::from("0x02"),
1412                    parent_hash: Bytes::from("0x01"),
1413                    chain: Chain::Ethereum,
1414                    ts: Default::default(),
1415                },
1416                revert: false,
1417                component_tvl: [
1418                    // "Component2" should not show here.
1419                    ("Component1".to_string(), 100.0),
1420                    ("Component3".to_string(), 1000.0),
1421                ]
1422                .into_iter()
1423                .collect(),
1424                ..Default::default()
1425            }),
1426            // "Component2" was removed, because its tvl changed to 0.
1427            removed_components: [(
1428                "Component2".to_string(),
1429                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1430            )]
1431            .into_iter()
1432            .collect(),
1433        };
1434        assert_eq!(first_msg, exp1);
1435        assert_eq!(second_msg, exp2);
1436        assert!(exit.is_ok());
1437    }
1438
1439    #[test(tokio::test)]
1440    async fn test_state_sync_with_tvl_range() {
1441        // Define the range for testing
1442        let remove_tvl_threshold = 5.0;
1443        let add_tvl_threshold = 7.0;
1444
1445        let mut rpc_client = MockRPCClient::new();
1446        let mut deltas_client = MockDeltasClient::new();
1447
1448        rpc_client
1449            .expect_get_protocol_components()
1450            .with(mockall::predicate::function(
1451                move |request_params: &ProtocolComponentsRequestBody| {
1452                    if let Some(ids) = request_params.component_ids.as_ref() {
1453                        ids.contains(&"Component3".to_string())
1454                    } else {
1455                        false
1456                    }
1457                },
1458            ))
1459            .returning(|_| {
1460                Ok(ProtocolComponentRequestResponse {
1461                    protocol_components: vec![ProtocolComponent {
1462                        id: "Component3".to_string(),
1463                        ..Default::default()
1464                    }],
1465                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1466                })
1467            });
1468        rpc_client
1469            .expect_get_protocol_states()
1470            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1471                let expected_id = "Component3".to_string();
1472                if let Some(ids) = request_params.protocol_ids.as_ref() {
1473                    ids.contains(&expected_id)
1474                } else {
1475                    false
1476                }
1477            }))
1478            .returning(|_| {
1479                Ok(ProtocolStateRequestResponse {
1480                    states: vec![ResponseProtocolState {
1481                        component_id: "Component3".to_string(),
1482                        ..Default::default()
1483                    }],
1484                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1485                })
1486            });
1487
1488        // Mock for the initial snapshot retrieval
1489        rpc_client
1490            .expect_get_protocol_components()
1491            .returning(|_| {
1492                Ok(ProtocolComponentRequestResponse {
1493                    protocol_components: vec![
1494                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1495                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1496                    ],
1497                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1498                })
1499            });
1500        rpc_client
1501            .expect_get_protocol_states()
1502            .returning(|_| {
1503                Ok(ProtocolStateRequestResponse {
1504                    states: vec![
1505                        ResponseProtocolState {
1506                            component_id: "Component1".to_string(),
1507                            ..Default::default()
1508                        },
1509                        ResponseProtocolState {
1510                            component_id: "Component2".to_string(),
1511                            ..Default::default()
1512                        },
1513                    ],
1514                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1515                })
1516            });
1517        rpc_client
1518            .expect_get_traced_entry_points()
1519            .returning(|_| {
1520                Ok(TracedEntryPointRequestResponse {
1521                    traced_entry_points: HashMap::new(),
1522                    pagination: PaginationResponse::new(0, 20, 0),
1523                })
1524            });
1525
1526        rpc_client
1527            .expect_get_component_tvl()
1528            .returning(|_| {
1529                Ok(ComponentTvlRequestResponse {
1530                    tvl: [
1531                        ("Component1".to_string(), 6.0),
1532                        ("Component2".to_string(), 2.0),
1533                        ("Component3".to_string(), 10.0),
1534                    ]
1535                    .into_iter()
1536                    .collect(),
1537                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1538                })
1539            });
1540
1541        rpc_client
1542            .expect_get_component_tvl()
1543            .returning(|_| {
1544                Ok(ComponentTvlRequestResponse {
1545                    tvl: [
1546                        ("Component1".to_string(), 6.0),
1547                        ("Component2".to_string(), 2.0),
1548                        ("Component3".to_string(), 10.0),
1549                    ]
1550                    .into_iter()
1551                    .collect(),
1552                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1553                })
1554            });
1555
1556        let (tx, rx) = channel(1);
1557        deltas_client
1558            .expect_subscribe()
1559            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1560
1561        let mut state_sync = ProtocolStateSynchronizer::new(
1562            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1563            true,
1564            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1565            1,
1566            true,
1567            true,
1568            ArcRPCClient(Arc::new(rpc_client)),
1569            ArcDeltasClient(Arc::new(deltas_client)),
1570            10_u64,
1571        );
1572        state_sync
1573            .initialize()
1574            .await
1575            .expect("Init failed");
1576
1577        // Simulate the incoming BlockChanges
1578        let deltas = [
1579            BlockChanges {
1580                extractor: "uniswap-v2".to_string(),
1581                chain: Chain::Ethereum,
1582                block: Block {
1583                    number: 1,
1584                    hash: Bytes::from("0x01"),
1585                    parent_hash: Bytes::from("0x00"),
1586                    chain: Chain::Ethereum,
1587                    ts: Default::default(),
1588                },
1589                revert: false,
1590                ..Default::default()
1591            },
1592            BlockChanges {
1593                extractor: "uniswap-v2".to_string(),
1594                chain: Chain::Ethereum,
1595                block: Block {
1596                    number: 2,
1597                    hash: Bytes::from("0x02"),
1598                    parent_hash: Bytes::from("0x01"),
1599                    chain: Chain::Ethereum,
1600                    ts: Default::default(),
1601                },
1602                revert: false,
1603                component_tvl: [
1604                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1605                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1606                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1607                ]
1608                .into_iter()
1609                .collect(),
1610                ..Default::default()
1611            },
1612        ];
1613
1614        let (jh, mut rx) = state_sync
1615            .start()
1616            .await
1617            .expect("Failed to start state synchronizer");
1618
1619        // Simulate sending delta messages
1620        tx.send(deltas[0].clone())
1621            .await
1622            .expect("deltas channel msg 0 closed!");
1623
1624        // Expecting to receive the initial state message
1625        let _ = timeout(Duration::from_millis(100), rx.recv())
1626            .await
1627            .expect("waiting for first state msg timed out!")
1628            .expect("state sync block sender closed!");
1629
1630        // Send the third message, which should trigger TVL-based changes
1631        tx.send(deltas[1].clone())
1632            .await
1633            .expect("deltas channel msg 1 closed!");
1634        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1635            .await
1636            .expect("waiting for second state msg timed out!")
1637            .expect("state sync block sender closed!");
1638
1639        let _ = state_sync.close().await;
1640        let exit = jh
1641            .await
1642            .expect("state sync task panicked!");
1643
1644        let expected_second_msg = StateSyncMessage {
1645            header: Header {
1646                number: 2,
1647                hash: Bytes::from("0x02"),
1648                parent_hash: Bytes::from("0x01"),
1649                revert: false,
1650                ..Default::default()
1651            },
1652            snapshots: Snapshot {
1653                states: [(
1654                    "Component3".to_string(),
1655                    ComponentWithState {
1656                        state: ResponseProtocolState {
1657                            component_id: "Component3".to_string(),
1658                            ..Default::default()
1659                        },
1660                        component: ProtocolComponent {
1661                            id: "Component3".to_string(),
1662                            ..Default::default()
1663                        },
1664                        component_tvl: Some(10.0),
1665                        entrypoints: vec![], // TODO: add entrypoints?
1666                    },
1667                )]
1668                .into_iter()
1669                .collect(),
1670                vm_storage: HashMap::new(),
1671            },
1672            deltas: Some(BlockChanges {
1673                extractor: "uniswap-v2".to_string(),
1674                chain: Chain::Ethereum,
1675                block: Block {
1676                    number: 2,
1677                    hash: Bytes::from("0x02"),
1678                    parent_hash: Bytes::from("0x01"),
1679                    chain: Chain::Ethereum,
1680                    ts: Default::default(),
1681                },
1682                revert: false,
1683                component_tvl: [
1684                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1685                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1686                ]
1687                .into_iter()
1688                .collect(),
1689                ..Default::default()
1690            }),
1691            removed_components: [(
1692                "Component2".to_string(),
1693                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1694            )]
1695            .into_iter()
1696            .collect(),
1697        };
1698
1699        assert_eq!(second_msg, expected_second_msg);
1700        assert!(exit.is_ok());
1701    }
1702}