tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot, Mutex,
11    },
12    task::JoinHandle,
13    time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, Chain, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        BlockHeader, HeaderLike,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
81    extractor_id: ExtractorIdentity,
82    retrieve_balances: bool,
83    rpc_client: R,
84    deltas_client: D,
85    max_retries: u64,
86    include_snapshots: bool,
87    component_tracker: Arc<Mutex<ComponentTracker<R>>>,
88    shared: Arc<Mutex<SharedState>>,
89    end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
90    timeout: u64,
91    include_tvl: bool,
92}
93
94#[derive(Debug, Default)]
95struct SharedState {
96    last_synced_block: Option<BlockHeader>,
97}
98
99#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
100pub struct ComponentWithState {
101    pub state: ResponseProtocolState,
102    pub component: ProtocolComponent,
103    pub component_tvl: Option<f64>,
104    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
105}
106
107#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
108pub struct Snapshot {
109    pub states: HashMap<String, ComponentWithState>,
110    pub vm_storage: HashMap<Bytes, ResponseAccount>,
111}
112
113impl Snapshot {
114    fn extend(&mut self, other: Snapshot) {
115        self.states.extend(other.states);
116        self.vm_storage.extend(other.vm_storage);
117    }
118
119    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
120        &self.states
121    }
122
123    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
124        &self.vm_storage
125    }
126}
127
128#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
129pub struct StateSyncMessage<H>
130where
131    H: HeaderLike,
132{
133    /// The block information for this update.
134    pub header: H,
135    /// Snapshot for new components.
136    pub snapshots: Snapshot,
137    /// A single delta contains state updates for all tracked components, as well as additional
138    /// information about the system components e.g. newly added components (even below tvl), tvl
139    /// updates, balance updates.
140    pub deltas: Option<BlockChanges>,
141    /// Components that stopped being tracked.
142    pub removed_components: HashMap<String, ProtocolComponent>,
143}
144
145impl<H> StateSyncMessage<H>
146where
147    H: HeaderLike,
148{
149    pub fn merge(mut self, other: Self) -> Self {
150        // be careful with removed and snapshots attributes here, these can be ambiguous.
151        self.removed_components
152            .retain(|k, _| !other.snapshots.states.contains_key(k));
153        self.snapshots
154            .states
155            .retain(|k, _| !other.removed_components.contains_key(k));
156
157        self.snapshots.extend(other.snapshots);
158        let deltas = match (self.deltas, other.deltas) {
159            (Some(l), Some(r)) => Some(l.merge(r)),
160            (None, Some(r)) => Some(r),
161            (Some(l), None) => Some(l),
162            (None, None) => None,
163        };
164        self.removed_components
165            .extend(other.removed_components);
166        Self {
167            header: other.header,
168            snapshots: self.snapshots,
169            deltas,
170            removed_components: self.removed_components,
171        }
172    }
173}
174
175/// StateSynchronizer
176///
177/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
178/// delivering messages to the client that let him reconstruct subsets of the protocol state.
179///
180/// This involves deciding which components to track according to the clients preferences,
181/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
182/// delivering delta messages for the components that have changed.
183#[async_trait]
184pub trait StateSynchronizer: Send + Sync + 'static {
185    async fn initialize(&self) -> SyncResult<()>;
186    /// Starts the state synchronization.
187    async fn start(
188        &self,
189    ) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage<BlockHeader>>)>;
190    /// Ends the synchronization loop.
191    async fn close(&mut self) -> SyncResult<()>;
192}
193
194impl<R, D> ProtocolStateSynchronizer<R, D>
195where
196    // TODO: Consider moving these constraints directly to the
197    // client...
198    R: RPCClient + Clone + Send + Sync + 'static,
199    D: DeltasClient + Clone + Send + Sync + 'static,
200{
201    /// Creates a new state synchronizer.
202    #[allow(clippy::too_many_arguments)]
203    pub fn new(
204        extractor_id: ExtractorIdentity,
205        retrieve_balances: bool,
206        component_filter: ComponentFilter,
207        max_retries: u64,
208        include_snapshots: bool,
209        include_tvl: bool,
210        rpc_client: R,
211        deltas_client: D,
212        timeout: u64,
213    ) -> Self {
214        Self {
215            extractor_id: extractor_id.clone(),
216            retrieve_balances,
217            rpc_client: rpc_client.clone(),
218            include_snapshots,
219            deltas_client,
220            component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
221                extractor_id.chain,
222                extractor_id.name.as_str(),
223                component_filter,
224                rpc_client,
225            ))),
226            max_retries,
227            shared: Arc::new(Mutex::new(SharedState::default())),
228            end_tx: Arc::new(Mutex::new(None)),
229            timeout,
230            include_tvl,
231        }
232    }
233
234    /// Retrieves state snapshots of the requested components
235    #[allow(deprecated)]
236    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
237        &self,
238        header: BlockHeader,
239        tracked_components: &mut ComponentTracker<R>,
240        ids: Option<I>,
241    ) -> SyncResult<StateSyncMessage<BlockHeader>> {
242        if !self.include_snapshots {
243            return Ok(StateSyncMessage { header, ..Default::default() });
244        }
245        let version = VersionParam::new(
246            None,
247            Some(BlockParam {
248                chain: Some(self.extractor_id.chain),
249                hash: None,
250                number: Some(header.number as i64),
251            }),
252        );
253
254        // Use given ids or use all if not passed
255        let component_ids: Vec<_> = match ids {
256            Some(ids) => ids.into_iter().cloned().collect(),
257            None => tracked_components.get_tracked_component_ids(),
258        };
259
260        if component_ids.is_empty() {
261            return Ok(StateSyncMessage { header, ..Default::default() });
262        }
263
264        let component_tvl = if self.include_tvl {
265            let body = ComponentTvlRequestBody::id_filtered(
266                component_ids.clone(),
267                self.extractor_id.chain,
268            );
269            self.rpc_client
270                .get_component_tvl_paginated(&body, 100, 4)
271                .await?
272                .tvl
273        } else {
274            HashMap::new()
275        };
276
277        //TODO: Improve this, we should not query for every component, but only for the ones that
278        // could have entrypoints. Maybe apply a filter per protocol?
279        let entrypoints_result = if self.extractor_id.chain == Chain::Ethereum {
280            // Fetch entrypoints
281            let result = self
282                .rpc_client
283                .get_traced_entry_points_paginated(
284                    self.extractor_id.chain,
285                    &self.extractor_id.name,
286                    &component_ids,
287                    100,
288                    4,
289                )
290                .await?;
291            tracked_components.process_entrypoints(&result.clone().into())?;
292            Some(result)
293        } else {
294            None
295        };
296
297        // Fetch protocol states
298        let mut protocol_states = self
299            .rpc_client
300            .get_protocol_states_paginated(
301                self.extractor_id.chain,
302                &component_ids,
303                &self.extractor_id.name,
304                self.retrieve_balances,
305                &version,
306                100,
307                4,
308            )
309            .await?
310            .states
311            .into_iter()
312            .map(|state| (state.component_id.clone(), state))
313            .collect::<HashMap<_, _>>();
314
315        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
316        let states = tracked_components
317            .components
318            .values()
319            .filter_map(|component| {
320                if let Some(state) = protocol_states.remove(&component.id) {
321                    Some((
322                        component.id.clone(),
323                        ComponentWithState {
324                            state,
325                            component: component.clone(),
326                            component_tvl: component_tvl
327                                .get(&component.id)
328                                .cloned(),
329                            entrypoints: entrypoints_result
330                                .as_ref()
331                                .map(|r| {
332                                    r.traced_entry_points
333                                        .get(&component.id)
334                                        .cloned()
335                                        .unwrap_or_default()
336                                })
337                                .unwrap_or_default(),
338                        },
339                    ))
340                } else if component_ids.contains(&component.id) {
341                    // only emit error event if we requested this component
342                    let component_id = &component.id;
343                    error!(?component_id, "Missing state for native component!");
344                    None
345                } else {
346                    None
347                }
348            })
349            .collect();
350
351        // Fetch contract states
352        let contract_ids = tracked_components.get_contracts_by_component(&component_ids);
353        let vm_storage = if !contract_ids.is_empty() {
354            let ids: Vec<Bytes> = contract_ids
355                .clone()
356                .into_iter()
357                .collect();
358            let contract_states = self
359                .rpc_client
360                .get_contract_state_paginated(
361                    self.extractor_id.chain,
362                    ids.as_slice(),
363                    &self.extractor_id.name,
364                    &version,
365                    100,
366                    4,
367                )
368                .await?
369                .accounts
370                .into_iter()
371                .map(|acc| (acc.address.clone(), acc))
372                .collect::<HashMap<_, _>>();
373
374            trace!(states=?&contract_states, "Retrieved ContractState");
375
376            let contract_address_to_components = tracked_components
377                .components
378                .iter()
379                .filter_map(|(id, comp)| {
380                    if component_ids.contains(id) {
381                        Some(
382                            comp.contract_ids
383                                .iter()
384                                .map(|address| (address.clone(), comp.id.clone())),
385                        )
386                    } else {
387                        None
388                    }
389                })
390                .flatten()
391                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
392                    acc.entry(addr).or_default().push(c_id);
393                    acc
394                });
395
396            contract_ids
397                .iter()
398                .filter_map(|address| {
399                    if let Some(state) = contract_states.get(address) {
400                        Some((address.clone(), state.clone()))
401                    } else if let Some(ids) = contract_address_to_components.get(address) {
402                        // only emit error even if we did actually request this address
403                        error!(
404                            ?address,
405                            ?ids,
406                            "Component with lacking contract storage encountered!"
407                        );
408                        None
409                    } else {
410                        None
411                    }
412                })
413                .collect()
414        } else {
415            HashMap::new()
416        };
417
418        Ok(StateSyncMessage {
419            header,
420            snapshots: Snapshot { states, vm_storage },
421            deltas: None,
422            removed_components: HashMap::new(),
423        })
424    }
425
426    /// Main method that does all the work.
427    #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
428    async fn state_sync(
429        self,
430        block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
431    ) -> SyncResult<()> {
432        // initialisation
433        let mut tracker = self.component_tracker.lock().await;
434
435        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
436        let (_, mut msg_rx) = self
437            .deltas_client
438            .subscribe(self.extractor_id.clone(), subscription_options)
439            .await?;
440
441        info!("Waiting for deltas...");
442        // wait for first deltas message
443        let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
444            .await
445            .map_err(|_| {
446                SynchronizerError::Timeout(format!(
447                    "First deltas took longer than {t}s to arrive",
448                    t = self.timeout
449                ))
450            })?
451            .ok_or_else(|| {
452                SynchronizerError::ConnectionError(
453                    "Deltas channel closed before first message".to_string(),
454                )
455            })?;
456        self.filter_deltas(&mut first_msg, &tracker);
457
458        // initial snapshot
459        let block = first_msg.get_block().clone();
460        info!(height = &block.number, "Deltas received. Retrieving snapshot");
461        let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
462        let snapshot = self
463            .get_snapshots::<Vec<&String>>(
464                BlockHeader::from_block(&block, false),
465                &mut tracker,
466                None,
467            )
468            .await?
469            .merge(StateSyncMessage {
470                header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
471                snapshots: Default::default(),
472                deltas: Some(first_msg),
473                removed_components: Default::default(),
474            });
475
476        let n_components = tracker.components.len();
477        let n_snapshots = snapshot.snapshots.states.len();
478        info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
479
480        {
481            let mut shared = self.shared.lock().await;
482            block_tx.send(snapshot).await?;
483            shared.last_synced_block = Some(header.clone());
484        }
485
486        loop {
487            if let Some(mut deltas) = msg_rx.recv().await {
488                let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
489                debug!(block_number=?header.number, "Received delta message");
490
491                let (snapshots, removed_components) = {
492                    // 1. Remove components based on latest changes
493                    // 2. Add components based on latest changes, query those for snapshots
494                    let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
495
496                    // Only components we don't track yet need a snapshot,
497                    let requiring_snapshot: Vec<_> = to_add
498                        .iter()
499                        .filter(|id| {
500                            !tracker
501                                .components
502                                .contains_key(id.as_str())
503                        })
504                        .collect();
505                    debug!(components=?requiring_snapshot, "SnapshotRequest");
506                    tracker
507                        .start_tracking(requiring_snapshot.as_slice())
508                        .await?;
509                    let snapshots = self
510                        .get_snapshots(header.clone(), &mut tracker, Some(requiring_snapshot))
511                        .await?
512                        .snapshots;
513
514                    let removed_components = if !to_remove.is_empty() {
515                        tracker.stop_tracking(&to_remove)
516                    } else {
517                        Default::default()
518                    };
519
520                    (snapshots, removed_components)
521                };
522
523                // 3. Update entrypoints on the tracker (affects which contracts are tracked)
524                tracker.process_entrypoints(&deltas.dci_update)?;
525
526                // 4. Filter deltas by currently tracked components / contracts
527                self.filter_deltas(&mut deltas, &tracker);
528                let n_changes = deltas.n_changes();
529
530                // 5. Send the message
531                let next = StateSyncMessage {
532                    header: header.clone(),
533                    snapshots,
534                    deltas: Some(deltas),
535                    removed_components,
536                };
537                block_tx.send(next).await?;
538                {
539                    let mut shared = self.shared.lock().await;
540                    shared.last_synced_block = Some(header.clone());
541                }
542
543                debug!(block_number=?header.number, n_changes, "Finished processing delta message");
544            } else {
545                let mut shared = self.shared.lock().await;
546                warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
547                shared.last_synced_block = None;
548
549                return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
550            }
551        }
552    }
553
554    fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
555        second_msg.filter_by_component(|id| tracker.components.contains_key(id));
556        second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
557    }
558}
559
560#[async_trait]
561impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
562where
563    R: RPCClient + Clone + Send + Sync + 'static,
564    D: DeltasClient + Clone + Send + Sync + 'static,
565{
566    async fn initialize(&self) -> SyncResult<()> {
567        let mut tracker = self.component_tracker.lock().await;
568        info!("Retrieving relevant protocol components");
569        tracker.initialise_components().await?;
570        info!(
571            n_components = tracker.components.len(),
572            n_contracts = tracker.contracts.len(),
573            "Finished retrieving components",
574        );
575
576        Ok(())
577    }
578
579    async fn start(
580        &self,
581    ) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage<BlockHeader>>)> {
582        let (mut tx, rx) = channel(15);
583
584        let this = self.clone();
585        let jh = tokio::spawn(async move {
586            let mut retry_count = 0;
587            while retry_count < this.max_retries {
588                info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
589                let (end_tx, end_rx) = oneshot::channel::<()>();
590                {
591                    let mut end_tx_guard = this.end_tx.lock().await;
592                    *end_tx_guard = Some(end_tx);
593                }
594
595                select! {
596                    res = this.clone().state_sync(&mut tx) => {
597                        match res {
598                            Err(e) => {
599                                error!(
600                                    extractor_id=%&this.extractor_id,
601                                    retry_count,
602                                    error=%e,
603                                    "State synchronization errored!"
604                                );
605                                if let SynchronizerError::ConnectionClosed = e {
606                                    // break synchronization loop if connection is closed
607                                    return Err(e);
608                                }
609                            }
610                            _ => {
611                                warn!(
612                                    extractor_id=%&this.extractor_id,
613                                    retry_count,
614                                    "State synchronization exited with Ok(())"
615                                );
616                            }
617                        }
618                    },
619                    _ = end_rx => {
620                        info!(
621                            extractor_id=%&this.extractor_id,
622                            retry_count,
623                            "StateSynchronizer received close signal. Stopping"
624                        );
625                        return Ok(())
626                    }
627                }
628                retry_count += 1;
629            }
630            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
631        });
632
633        Ok((jh, rx))
634    }
635
636    async fn close(&mut self) -> SyncResult<()> {
637        let mut end_tx = self.end_tx.lock().await;
638        if let Some(tx) = end_tx.take() {
639            let _ = tx.send(());
640            Ok(())
641        } else {
642            Err(SynchronizerError::CloseError("Synchronizer not started".to_string()))
643        }
644    }
645}
646
647#[cfg(test)]
648mod test {
649    use std::collections::HashSet;
650
651    use test_log::test;
652    use tycho_common::dto::{
653        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
654        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
655        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
656        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
657        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
658        TracedEntryPointRequestResponse, TracingParams,
659    };
660    use uuid::Uuid;
661
662    use super::*;
663    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
664
665    // Required for mock client to implement clone
666    struct ArcRPCClient<T>(Arc<T>);
667
668    // Default derive(Clone) does require T to be Clone as well.
669    impl<T> Clone for ArcRPCClient<T> {
670        fn clone(&self) -> Self {
671            ArcRPCClient(self.0.clone())
672        }
673    }
674
675    #[async_trait]
676    impl<T> RPCClient for ArcRPCClient<T>
677    where
678        T: RPCClient + Sync + Send + 'static,
679    {
680        async fn get_tokens(
681            &self,
682            request: &TokensRequestBody,
683        ) -> Result<TokensRequestResponse, RPCError> {
684            self.0.get_tokens(request).await
685        }
686
687        async fn get_contract_state(
688            &self,
689            request: &StateRequestBody,
690        ) -> Result<StateRequestResponse, RPCError> {
691            self.0.get_contract_state(request).await
692        }
693
694        async fn get_protocol_components(
695            &self,
696            request: &ProtocolComponentsRequestBody,
697        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
698            self.0
699                .get_protocol_components(request)
700                .await
701        }
702
703        async fn get_protocol_states(
704            &self,
705            request: &ProtocolStateRequestBody,
706        ) -> Result<ProtocolStateRequestResponse, RPCError> {
707            self.0
708                .get_protocol_states(request)
709                .await
710        }
711
712        async fn get_protocol_systems(
713            &self,
714            request: &ProtocolSystemsRequestBody,
715        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
716            self.0
717                .get_protocol_systems(request)
718                .await
719        }
720
721        async fn get_component_tvl(
722            &self,
723            request: &ComponentTvlRequestBody,
724        ) -> Result<ComponentTvlRequestResponse, RPCError> {
725            self.0.get_component_tvl(request).await
726        }
727
728        async fn get_traced_entry_points(
729            &self,
730            request: &TracedEntryPointRequestBody,
731        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
732            self.0
733                .get_traced_entry_points(request)
734                .await
735        }
736    }
737
738    // Required for mock client to implement clone
739    struct ArcDeltasClient<T>(Arc<T>);
740
741    // Default derive(Clone) does require T to be Clone as well.
742    impl<T> Clone for ArcDeltasClient<T> {
743        fn clone(&self) -> Self {
744            ArcDeltasClient(self.0.clone())
745        }
746    }
747
748    #[async_trait]
749    impl<T> DeltasClient for ArcDeltasClient<T>
750    where
751        T: DeltasClient + Sync + Send + 'static,
752    {
753        async fn subscribe(
754            &self,
755            extractor_id: ExtractorIdentity,
756            options: SubscriptionOptions,
757        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
758            self.0
759                .subscribe(extractor_id, options)
760                .await
761        }
762
763        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
764            self.0
765                .unsubscribe(subscription_id)
766                .await
767        }
768
769        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
770            self.0.connect().await
771        }
772
773        async fn close(&self) -> Result<(), DeltasError> {
774            self.0.close().await
775        }
776    }
777
778    fn with_mocked_clients(
779        native: bool,
780        include_tvl: bool,
781        rpc_client: Option<MockRPCClient>,
782        deltas_client: Option<MockDeltasClient>,
783    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
784    {
785        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
786        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
787
788        ProtocolStateSynchronizer::new(
789            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
790            native,
791            ComponentFilter::with_tvl_range(50.0, 50.0),
792            1,
793            true,
794            include_tvl,
795            rpc_client,
796            deltas_client,
797            10_u64,
798        )
799    }
800
801    fn state_snapshot_native() -> ProtocolStateRequestResponse {
802        ProtocolStateRequestResponse {
803            states: vec![ResponseProtocolState {
804                component_id: "Component1".to_string(),
805                ..Default::default()
806            }],
807            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
808        }
809    }
810
811    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
812        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
813
814        ComponentTvlRequestResponse {
815            tvl,
816            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
817        }
818    }
819
820    #[test(tokio::test)]
821    async fn test_get_snapshots_native() {
822        let header = BlockHeader::default();
823        let mut rpc = MockRPCClient::new();
824        rpc.expect_get_protocol_states()
825            .returning(|_| Ok(state_snapshot_native()));
826        rpc.expect_get_traced_entry_points()
827            .returning(|_| {
828                Ok(TracedEntryPointRequestResponse {
829                    traced_entry_points: HashMap::new(),
830                    pagination: PaginationResponse::new(0, 20, 0),
831                })
832            });
833        let state_sync = with_mocked_clients(true, false, Some(rpc), None);
834        let mut tracker = ComponentTracker::new(
835            Chain::Ethereum,
836            "uniswap-v2",
837            ComponentFilter::with_tvl_range(0.0, 0.0),
838            state_sync.rpc_client.clone(),
839        );
840        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
841        tracker
842            .components
843            .insert("Component1".to_string(), component.clone());
844        let components_arg = ["Component1".to_string()];
845        let exp = StateSyncMessage {
846            header: header.clone(),
847            snapshots: Snapshot {
848                states: state_snapshot_native()
849                    .states
850                    .into_iter()
851                    .map(|state| {
852                        (
853                            state.component_id.clone(),
854                            ComponentWithState {
855                                state,
856                                component: component.clone(),
857                                entrypoints: vec![],
858                                component_tvl: None,
859                            },
860                        )
861                    })
862                    .collect(),
863                vm_storage: HashMap::new(),
864            },
865            deltas: None,
866            removed_components: Default::default(),
867        };
868
869        let snap = state_sync
870            .get_snapshots(header, &mut tracker, Some(&components_arg))
871            .await
872            .expect("Retrieving snapshot failed");
873
874        assert_eq!(snap, exp);
875    }
876
877    #[test(tokio::test)]
878    async fn test_get_snapshots_native_with_tvl() {
879        let header = BlockHeader::default();
880        let mut rpc = MockRPCClient::new();
881        rpc.expect_get_protocol_states()
882            .returning(|_| Ok(state_snapshot_native()));
883        rpc.expect_get_component_tvl()
884            .returning(|_| Ok(component_tvl_snapshot()));
885        rpc.expect_get_traced_entry_points()
886            .returning(|_| {
887                Ok(TracedEntryPointRequestResponse {
888                    traced_entry_points: HashMap::new(),
889                    pagination: PaginationResponse::new(0, 20, 0),
890                })
891            });
892        let state_sync = with_mocked_clients(true, true, Some(rpc), None);
893        let mut tracker = ComponentTracker::new(
894            Chain::Ethereum,
895            "uniswap-v2",
896            ComponentFilter::with_tvl_range(0.0, 0.0),
897            state_sync.rpc_client.clone(),
898        );
899        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
900        tracker
901            .components
902            .insert("Component1".to_string(), component.clone());
903        let components_arg = ["Component1".to_string()];
904        let exp = StateSyncMessage {
905            header: header.clone(),
906            snapshots: Snapshot {
907                states: state_snapshot_native()
908                    .states
909                    .into_iter()
910                    .map(|state| {
911                        (
912                            state.component_id.clone(),
913                            ComponentWithState {
914                                state,
915                                component: component.clone(),
916                                component_tvl: Some(100.0),
917                                entrypoints: vec![],
918                            },
919                        )
920                    })
921                    .collect(),
922                vm_storage: HashMap::new(),
923            },
924            deltas: None,
925            removed_components: Default::default(),
926        };
927
928        let snap = state_sync
929            .get_snapshots(header, &mut tracker, Some(&components_arg))
930            .await
931            .expect("Retrieving snapshot failed");
932
933        assert_eq!(snap, exp);
934    }
935
936    fn state_snapshot_vm() -> StateRequestResponse {
937        StateRequestResponse {
938            accounts: vec![
939                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
940                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
941            ],
942            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
943        }
944    }
945
946    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
947        TracedEntryPointRequestResponse {
948            traced_entry_points: HashMap::from([(
949                "Component1".to_string(),
950                vec![(
951                    EntryPointWithTracingParams {
952                        entry_point: EntryPoint {
953                            external_id: "entrypoint_a".to_string(),
954                            target: Bytes::from("0x0badc0ffee"),
955                            signature: "sig()".to_string(),
956                        },
957                        params: TracingParams::RPCTracer(RPCTracerParams {
958                            caller: Some(Bytes::from("0x0badc0ffee")),
959                            calldata: Bytes::from("0x0badc0ffee"),
960                        }),
961                    },
962                    TracingResult {
963                        retriggers: HashSet::from([(
964                            Bytes::from("0x0badc0ffee"),
965                            Bytes::from("0x0badc0ffee"),
966                        )]),
967                        accessed_slots: HashMap::from([(
968                            Bytes::from("0x0badc0ffee"),
969                            HashSet::from([Bytes::from("0xbadbeef0")]),
970                        )]),
971                    },
972                )],
973            )]),
974            pagination: PaginationResponse::new(0, 20, 0),
975        }
976    }
977
978    #[test(tokio::test)]
979    async fn test_get_snapshots_vm() {
980        let header = BlockHeader::default();
981        let mut rpc = MockRPCClient::new();
982        rpc.expect_get_protocol_states()
983            .returning(|_| Ok(state_snapshot_native()));
984        rpc.expect_get_contract_state()
985            .returning(|_| Ok(state_snapshot_vm()));
986        rpc.expect_get_traced_entry_points()
987            .returning(|_| Ok(traced_entry_point_response()));
988        let state_sync = with_mocked_clients(false, false, Some(rpc), None);
989        let mut tracker = ComponentTracker::new(
990            Chain::Ethereum,
991            "uniswap-v2",
992            ComponentFilter::with_tvl_range(0.0, 0.0),
993            state_sync.rpc_client.clone(),
994        );
995        let component = ProtocolComponent {
996            id: "Component1".to_string(),
997            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
998            ..Default::default()
999        };
1000        tracker
1001            .components
1002            .insert("Component1".to_string(), component.clone());
1003        let components_arg = ["Component1".to_string()];
1004        let exp = StateSyncMessage {
1005            header: header.clone(),
1006            snapshots: Snapshot {
1007                states: [(
1008                    component.id.clone(),
1009                    ComponentWithState {
1010                        state: ResponseProtocolState {
1011                            component_id: "Component1".to_string(),
1012                            ..Default::default()
1013                        },
1014                        component: component.clone(),
1015                        component_tvl: None,
1016                        entrypoints: vec![(
1017                            EntryPointWithTracingParams {
1018                                entry_point: EntryPoint {
1019                                    external_id: "entrypoint_a".to_string(),
1020                                    target: Bytes::from("0x0badc0ffee"),
1021                                    signature: "sig()".to_string(),
1022                                },
1023                                params: TracingParams::RPCTracer(RPCTracerParams {
1024                                    caller: Some(Bytes::from("0x0badc0ffee")),
1025                                    calldata: Bytes::from("0x0badc0ffee"),
1026                                }),
1027                            },
1028                            TracingResult {
1029                                retriggers: HashSet::from([(
1030                                    Bytes::from("0x0badc0ffee"),
1031                                    Bytes::from("0x0badc0ffee"),
1032                                )]),
1033                                accessed_slots: HashMap::from([(
1034                                    Bytes::from("0x0badc0ffee"),
1035                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1036                                )]),
1037                            },
1038                        )],
1039                    },
1040                )]
1041                .into_iter()
1042                .collect(),
1043                vm_storage: state_snapshot_vm()
1044                    .accounts
1045                    .into_iter()
1046                    .map(|state| (state.address.clone(), state))
1047                    .collect(),
1048            },
1049            deltas: None,
1050            removed_components: Default::default(),
1051        };
1052
1053        let snap = state_sync
1054            .get_snapshots(header, &mut tracker, Some(&components_arg))
1055            .await
1056            .expect("Retrieving snapshot failed");
1057
1058        assert_eq!(snap, exp);
1059    }
1060
1061    #[test(tokio::test)]
1062    async fn test_get_snapshots_vm_with_tvl() {
1063        let header = BlockHeader::default();
1064        let mut rpc = MockRPCClient::new();
1065        rpc.expect_get_protocol_states()
1066            .returning(|_| Ok(state_snapshot_native()));
1067        rpc.expect_get_contract_state()
1068            .returning(|_| Ok(state_snapshot_vm()));
1069        rpc.expect_get_component_tvl()
1070            .returning(|_| Ok(component_tvl_snapshot()));
1071        rpc.expect_get_traced_entry_points()
1072            .returning(|_| {
1073                Ok(TracedEntryPointRequestResponse {
1074                    traced_entry_points: HashMap::new(),
1075                    pagination: PaginationResponse::new(0, 20, 0),
1076                })
1077            });
1078        let state_sync = with_mocked_clients(false, true, Some(rpc), None);
1079        let mut tracker = ComponentTracker::new(
1080            Chain::Ethereum,
1081            "uniswap-v2",
1082            ComponentFilter::with_tvl_range(0.0, 0.0),
1083            state_sync.rpc_client.clone(),
1084        );
1085        let component = ProtocolComponent {
1086            id: "Component1".to_string(),
1087            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1088            ..Default::default()
1089        };
1090        tracker
1091            .components
1092            .insert("Component1".to_string(), component.clone());
1093        let components_arg = ["Component1".to_string()];
1094        let exp = StateSyncMessage {
1095            header: header.clone(),
1096            snapshots: Snapshot {
1097                states: [(
1098                    component.id.clone(),
1099                    ComponentWithState {
1100                        state: ResponseProtocolState {
1101                            component_id: "Component1".to_string(),
1102                            ..Default::default()
1103                        },
1104                        component: component.clone(),
1105                        component_tvl: Some(100.0),
1106                        entrypoints: vec![],
1107                    },
1108                )]
1109                .into_iter()
1110                .collect(),
1111                vm_storage: state_snapshot_vm()
1112                    .accounts
1113                    .into_iter()
1114                    .map(|state| (state.address.clone(), state))
1115                    .collect(),
1116            },
1117            deltas: None,
1118            removed_components: Default::default(),
1119        };
1120
1121        let snap = state_sync
1122            .get_snapshots(header, &mut tracker, Some(&components_arg))
1123            .await
1124            .expect("Retrieving snapshot failed");
1125
1126        assert_eq!(snap, exp);
1127    }
1128
1129    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1130        let mut rpc_client = MockRPCClient::new();
1131        // Mocks for the start_tracking call, these need to come first because they are more
1132        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1133        rpc_client
1134            .expect_get_protocol_components()
1135            .with(mockall::predicate::function(
1136                move |request_params: &ProtocolComponentsRequestBody| {
1137                    if let Some(ids) = request_params.component_ids.as_ref() {
1138                        ids.contains(&"Component3".to_string())
1139                    } else {
1140                        false
1141                    }
1142                },
1143            ))
1144            .returning(|_| {
1145                // return Component3
1146                Ok(ProtocolComponentRequestResponse {
1147                    protocol_components: vec![
1148                        // this component shall have a tvl update above threshold
1149                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1150                    ],
1151                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1152                })
1153            });
1154        rpc_client
1155            .expect_get_protocol_states()
1156            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1157                let expected_id = "Component3".to_string();
1158                if let Some(ids) = request_params.protocol_ids.as_ref() {
1159                    ids.contains(&expected_id)
1160                } else {
1161                    false
1162                }
1163            }))
1164            .returning(|_| {
1165                // return Component3 state
1166                Ok(ProtocolStateRequestResponse {
1167                    states: vec![ResponseProtocolState {
1168                        component_id: "Component3".to_string(),
1169                        ..Default::default()
1170                    }],
1171                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1172                })
1173            });
1174
1175        // mock calls for the initial state snapshots
1176        rpc_client
1177            .expect_get_protocol_components()
1178            .returning(|_| {
1179                // Initial sync of components
1180                Ok(ProtocolComponentRequestResponse {
1181                    protocol_components: vec![
1182                        // this component shall have a tvl update above threshold
1183                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1184                        // this component shall have a tvl update below threshold.
1185                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1186                        // a third component will have a tvl update above threshold
1187                    ],
1188                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1189                })
1190            });
1191        rpc_client
1192            .expect_get_protocol_states()
1193            .returning(|_| {
1194                // Initial state snapshot
1195                Ok(ProtocolStateRequestResponse {
1196                    states: vec![
1197                        ResponseProtocolState {
1198                            component_id: "Component1".to_string(),
1199                            ..Default::default()
1200                        },
1201                        ResponseProtocolState {
1202                            component_id: "Component2".to_string(),
1203                            ..Default::default()
1204                        },
1205                    ],
1206                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1207                })
1208            });
1209        rpc_client
1210            .expect_get_component_tvl()
1211            .returning(|_| {
1212                Ok(ComponentTvlRequestResponse {
1213                    tvl: [
1214                        ("Component1".to_string(), 100.0),
1215                        ("Component2".to_string(), 0.0),
1216                        ("Component3".to_string(), 1000.0),
1217                    ]
1218                    .into_iter()
1219                    .collect(),
1220                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1221                })
1222            });
1223        rpc_client
1224            .expect_get_traced_entry_points()
1225            .returning(|_| {
1226                Ok(TracedEntryPointRequestResponse {
1227                    traced_entry_points: HashMap::new(),
1228                    pagination: PaginationResponse::new(0, 20, 0),
1229                })
1230            });
1231
1232        // Mock deltas client and messages
1233        let mut deltas_client = MockDeltasClient::new();
1234        let (tx, rx) = channel(1);
1235        deltas_client
1236            .expect_subscribe()
1237            .return_once(move |_, _| {
1238                // Return subscriber id and a channel
1239                Ok((Uuid::default(), rx))
1240            });
1241        (rpc_client, deltas_client, tx)
1242    }
1243
1244    /// Test strategy
1245    ///
1246    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1247    /// - send 2 dummy messages, containing only blocks
1248    /// - third message contains a new component with some significant tvl, one initial component
1249    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1250    #[test(tokio::test)]
1251    async fn test_state_sync() {
1252        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1253        let deltas = [
1254            BlockChanges {
1255                extractor: "uniswap-v2".to_string(),
1256                chain: Chain::Ethereum,
1257                block: Block {
1258                    number: 1,
1259                    hash: Bytes::from("0x01"),
1260                    parent_hash: Bytes::from("0x00"),
1261                    chain: Chain::Ethereum,
1262                    ts: Default::default(),
1263                },
1264                revert: false,
1265                dci_update: DCIUpdate {
1266                    new_entrypoints: HashMap::from([(
1267                        "Component1".to_string(),
1268                        HashSet::from([EntryPoint {
1269                            external_id: "entrypoint_a".to_string(),
1270                            target: Bytes::from("0x0badc0ffee"),
1271                            signature: "sig()".to_string(),
1272                        }]),
1273                    )]),
1274                    new_entrypoint_params: HashMap::from([(
1275                        "entrypoint_a".to_string(),
1276                        HashSet::from([(
1277                            TracingParams::RPCTracer(RPCTracerParams {
1278                                caller: Some(Bytes::from("0x0badc0ffee")),
1279                                calldata: Bytes::from("0x0badc0ffee"),
1280                            }),
1281                            Some("Component1".to_string()),
1282                        )]),
1283                    )]),
1284                    trace_results: HashMap::from([(
1285                        "entrypoint_a".to_string(),
1286                        TracingResult {
1287                            retriggers: HashSet::from([(
1288                                Bytes::from("0x0badc0ffee"),
1289                                Bytes::from("0x0badc0ffee"),
1290                            )]),
1291                            accessed_slots: HashMap::from([(
1292                                Bytes::from("0x0badc0ffee"),
1293                                HashSet::from([Bytes::from("0xbadbeef0")]),
1294                            )]),
1295                        },
1296                    )]),
1297                },
1298                ..Default::default()
1299            },
1300            BlockChanges {
1301                extractor: "uniswap-v2".to_string(),
1302                chain: Chain::Ethereum,
1303                block: Block {
1304                    number: 2,
1305                    hash: Bytes::from("0x02"),
1306                    parent_hash: Bytes::from("0x01"),
1307                    chain: Chain::Ethereum,
1308                    ts: Default::default(),
1309                },
1310                revert: false,
1311                component_tvl: [
1312                    ("Component1".to_string(), 100.0),
1313                    ("Component2".to_string(), 0.0),
1314                    ("Component3".to_string(), 1000.0),
1315                ]
1316                .into_iter()
1317                .collect(),
1318                ..Default::default()
1319            },
1320        ];
1321        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1322        state_sync
1323            .initialize()
1324            .await
1325            .expect("Init failed");
1326
1327        // Test starts here
1328        let (jh, mut rx) = state_sync
1329            .start()
1330            .await
1331            .expect("Failed to start state synchronizer");
1332        tx.send(deltas[0].clone())
1333            .await
1334            .expect("deltas channel msg 0 closed!");
1335        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1336            .await
1337            .expect("waiting for first state msg timed out!")
1338            .expect("state sync block sender closed!");
1339        tx.send(deltas[1].clone())
1340            .await
1341            .expect("deltas channel msg 1 closed!");
1342        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1343            .await
1344            .expect("waiting for second state msg timed out!")
1345            .expect("state sync block sender closed!");
1346        let _ = state_sync.close().await;
1347        let exit = jh
1348            .await
1349            .expect("state sync task panicked!");
1350
1351        // assertions
1352        let exp1 = StateSyncMessage {
1353            header: BlockHeader {
1354                number: 1,
1355                hash: Bytes::from("0x01"),
1356                parent_hash: Bytes::from("0x00"),
1357                revert: false,
1358                ..Default::default()
1359            },
1360            snapshots: Snapshot {
1361                states: [
1362                    (
1363                        "Component1".to_string(),
1364                        ComponentWithState {
1365                            state: ResponseProtocolState {
1366                                component_id: "Component1".to_string(),
1367                                ..Default::default()
1368                            },
1369                            component: ProtocolComponent {
1370                                id: "Component1".to_string(),
1371                                ..Default::default()
1372                            },
1373                            component_tvl: Some(100.0),
1374                            entrypoints: vec![],
1375                        },
1376                    ),
1377                    (
1378                        "Component2".to_string(),
1379                        ComponentWithState {
1380                            state: ResponseProtocolState {
1381                                component_id: "Component2".to_string(),
1382                                ..Default::default()
1383                            },
1384                            component: ProtocolComponent {
1385                                id: "Component2".to_string(),
1386                                ..Default::default()
1387                            },
1388                            component_tvl: Some(0.0),
1389                            entrypoints: vec![],
1390                        },
1391                    ),
1392                ]
1393                .into_iter()
1394                .collect(),
1395                vm_storage: HashMap::new(),
1396            },
1397            deltas: Some(deltas[0].clone()),
1398            removed_components: Default::default(),
1399        };
1400
1401        let exp2 = StateSyncMessage {
1402            header: BlockHeader {
1403                number: 2,
1404                hash: Bytes::from("0x02"),
1405                parent_hash: Bytes::from("0x01"),
1406                revert: false,
1407                ..Default::default()
1408            },
1409            snapshots: Snapshot {
1410                states: [
1411                    // This is the new component we queried once it passed the tvl threshold.
1412                    (
1413                        "Component3".to_string(),
1414                        ComponentWithState {
1415                            state: ResponseProtocolState {
1416                                component_id: "Component3".to_string(),
1417                                ..Default::default()
1418                            },
1419                            component: ProtocolComponent {
1420                                id: "Component3".to_string(),
1421                                ..Default::default()
1422                            },
1423                            component_tvl: Some(1000.0),
1424                            entrypoints: vec![],
1425                        },
1426                    ),
1427                ]
1428                .into_iter()
1429                .collect(),
1430                vm_storage: HashMap::new(),
1431            },
1432            // Our deltas are empty and since merge methods are
1433            // tested in tycho-common we don't have much to do here.
1434            deltas: Some(BlockChanges {
1435                extractor: "uniswap-v2".to_string(),
1436                chain: Chain::Ethereum,
1437                block: Block {
1438                    number: 2,
1439                    hash: Bytes::from("0x02"),
1440                    parent_hash: Bytes::from("0x01"),
1441                    chain: Chain::Ethereum,
1442                    ts: Default::default(),
1443                },
1444                revert: false,
1445                component_tvl: [
1446                    // "Component2" should not show here.
1447                    ("Component1".to_string(), 100.0),
1448                    ("Component3".to_string(), 1000.0),
1449                ]
1450                .into_iter()
1451                .collect(),
1452                ..Default::default()
1453            }),
1454            // "Component2" was removed, because its tvl changed to 0.
1455            removed_components: [(
1456                "Component2".to_string(),
1457                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1458            )]
1459            .into_iter()
1460            .collect(),
1461        };
1462        assert_eq!(first_msg, exp1);
1463        assert_eq!(second_msg, exp2);
1464        assert!(exit.is_ok());
1465    }
1466
1467    #[test(tokio::test)]
1468    async fn test_state_sync_with_tvl_range() {
1469        // Define the range for testing
1470        let remove_tvl_threshold = 5.0;
1471        let add_tvl_threshold = 7.0;
1472
1473        let mut rpc_client = MockRPCClient::new();
1474        let mut deltas_client = MockDeltasClient::new();
1475
1476        rpc_client
1477            .expect_get_protocol_components()
1478            .with(mockall::predicate::function(
1479                move |request_params: &ProtocolComponentsRequestBody| {
1480                    if let Some(ids) = request_params.component_ids.as_ref() {
1481                        ids.contains(&"Component3".to_string())
1482                    } else {
1483                        false
1484                    }
1485                },
1486            ))
1487            .returning(|_| {
1488                Ok(ProtocolComponentRequestResponse {
1489                    protocol_components: vec![ProtocolComponent {
1490                        id: "Component3".to_string(),
1491                        ..Default::default()
1492                    }],
1493                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1494                })
1495            });
1496        rpc_client
1497            .expect_get_protocol_states()
1498            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1499                let expected_id = "Component3".to_string();
1500                if let Some(ids) = request_params.protocol_ids.as_ref() {
1501                    ids.contains(&expected_id)
1502                } else {
1503                    false
1504                }
1505            }))
1506            .returning(|_| {
1507                Ok(ProtocolStateRequestResponse {
1508                    states: vec![ResponseProtocolState {
1509                        component_id: "Component3".to_string(),
1510                        ..Default::default()
1511                    }],
1512                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1513                })
1514            });
1515
1516        // Mock for the initial snapshot retrieval
1517        rpc_client
1518            .expect_get_protocol_components()
1519            .returning(|_| {
1520                Ok(ProtocolComponentRequestResponse {
1521                    protocol_components: vec![
1522                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1523                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1524                    ],
1525                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1526                })
1527            });
1528        rpc_client
1529            .expect_get_protocol_states()
1530            .returning(|_| {
1531                Ok(ProtocolStateRequestResponse {
1532                    states: vec![
1533                        ResponseProtocolState {
1534                            component_id: "Component1".to_string(),
1535                            ..Default::default()
1536                        },
1537                        ResponseProtocolState {
1538                            component_id: "Component2".to_string(),
1539                            ..Default::default()
1540                        },
1541                    ],
1542                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1543                })
1544            });
1545        rpc_client
1546            .expect_get_traced_entry_points()
1547            .returning(|_| {
1548                Ok(TracedEntryPointRequestResponse {
1549                    traced_entry_points: HashMap::new(),
1550                    pagination: PaginationResponse::new(0, 20, 0),
1551                })
1552            });
1553
1554        rpc_client
1555            .expect_get_component_tvl()
1556            .returning(|_| {
1557                Ok(ComponentTvlRequestResponse {
1558                    tvl: [
1559                        ("Component1".to_string(), 6.0),
1560                        ("Component2".to_string(), 2.0),
1561                        ("Component3".to_string(), 10.0),
1562                    ]
1563                    .into_iter()
1564                    .collect(),
1565                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1566                })
1567            });
1568
1569        rpc_client
1570            .expect_get_component_tvl()
1571            .returning(|_| {
1572                Ok(ComponentTvlRequestResponse {
1573                    tvl: [
1574                        ("Component1".to_string(), 6.0),
1575                        ("Component2".to_string(), 2.0),
1576                        ("Component3".to_string(), 10.0),
1577                    ]
1578                    .into_iter()
1579                    .collect(),
1580                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1581                })
1582            });
1583
1584        let (tx, rx) = channel(1);
1585        deltas_client
1586            .expect_subscribe()
1587            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1588
1589        let mut state_sync = ProtocolStateSynchronizer::new(
1590            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1591            true,
1592            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1593            1,
1594            true,
1595            true,
1596            ArcRPCClient(Arc::new(rpc_client)),
1597            ArcDeltasClient(Arc::new(deltas_client)),
1598            10_u64,
1599        );
1600        state_sync
1601            .initialize()
1602            .await
1603            .expect("Init failed");
1604
1605        // Simulate the incoming BlockChanges
1606        let deltas = [
1607            BlockChanges {
1608                extractor: "uniswap-v2".to_string(),
1609                chain: Chain::Ethereum,
1610                block: Block {
1611                    number: 1,
1612                    hash: Bytes::from("0x01"),
1613                    parent_hash: Bytes::from("0x00"),
1614                    chain: Chain::Ethereum,
1615                    ts: Default::default(),
1616                },
1617                revert: false,
1618                ..Default::default()
1619            },
1620            BlockChanges {
1621                extractor: "uniswap-v2".to_string(),
1622                chain: Chain::Ethereum,
1623                block: Block {
1624                    number: 2,
1625                    hash: Bytes::from("0x02"),
1626                    parent_hash: Bytes::from("0x01"),
1627                    chain: Chain::Ethereum,
1628                    ts: Default::default(),
1629                },
1630                revert: false,
1631                component_tvl: [
1632                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1633                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1634                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1635                ]
1636                .into_iter()
1637                .collect(),
1638                ..Default::default()
1639            },
1640        ];
1641
1642        let (jh, mut rx) = state_sync
1643            .start()
1644            .await
1645            .expect("Failed to start state synchronizer");
1646
1647        // Simulate sending delta messages
1648        tx.send(deltas[0].clone())
1649            .await
1650            .expect("deltas channel msg 0 closed!");
1651
1652        // Expecting to receive the initial state message
1653        let _ = timeout(Duration::from_millis(100), rx.recv())
1654            .await
1655            .expect("waiting for first state msg timed out!")
1656            .expect("state sync block sender closed!");
1657
1658        // Send the third message, which should trigger TVL-based changes
1659        tx.send(deltas[1].clone())
1660            .await
1661            .expect("deltas channel msg 1 closed!");
1662        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1663            .await
1664            .expect("waiting for second state msg timed out!")
1665            .expect("state sync block sender closed!");
1666
1667        let _ = state_sync.close().await;
1668        let exit = jh
1669            .await
1670            .expect("state sync task panicked!");
1671
1672        let expected_second_msg = StateSyncMessage {
1673            header: BlockHeader {
1674                number: 2,
1675                hash: Bytes::from("0x02"),
1676                parent_hash: Bytes::from("0x01"),
1677                revert: false,
1678                ..Default::default()
1679            },
1680            snapshots: Snapshot {
1681                states: [(
1682                    "Component3".to_string(),
1683                    ComponentWithState {
1684                        state: ResponseProtocolState {
1685                            component_id: "Component3".to_string(),
1686                            ..Default::default()
1687                        },
1688                        component: ProtocolComponent {
1689                            id: "Component3".to_string(),
1690                            ..Default::default()
1691                        },
1692                        component_tvl: Some(10.0),
1693                        entrypoints: vec![], // TODO: add entrypoints?
1694                    },
1695                )]
1696                .into_iter()
1697                .collect(),
1698                vm_storage: HashMap::new(),
1699            },
1700            deltas: Some(BlockChanges {
1701                extractor: "uniswap-v2".to_string(),
1702                chain: Chain::Ethereum,
1703                block: Block {
1704                    number: 2,
1705                    hash: Bytes::from("0x02"),
1706                    parent_hash: Bytes::from("0x01"),
1707                    chain: Chain::Ethereum,
1708                    ts: Default::default(),
1709                },
1710                revert: false,
1711                component_tvl: [
1712                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1713                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1714                ]
1715                .into_iter()
1716                .collect(),
1717                ..Default::default()
1718            }),
1719            removed_components: [(
1720                "Component2".to_string(),
1721                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1722            )]
1723            .into_iter()
1724            .collect(),
1725        };
1726
1727        assert_eq!(second_msg, expected_second_msg);
1728        assert!(exit.is_ok());
1729    }
1730}