tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot,
11    },
12    task::JoinHandle,
13    time::{sleep, timeout},
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, Chain, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        BlockHeader, HeaderLike,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
80    extractor_id: ExtractorIdentity,
81    retrieve_balances: bool,
82    rpc_client: R,
83    deltas_client: D,
84    max_retries: u64,
85    retry_cooldown: Duration,
86    include_snapshots: bool,
87    component_tracker: ComponentTracker<R>,
88    last_synced_block: Option<BlockHeader>,
89    timeout: u64,
90    include_tvl: bool,
91}
92
93#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
94pub struct ComponentWithState {
95    pub state: ResponseProtocolState,
96    pub component: ProtocolComponent,
97    pub component_tvl: Option<f64>,
98    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
99}
100
101#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
102pub struct Snapshot {
103    pub states: HashMap<String, ComponentWithState>,
104    pub vm_storage: HashMap<Bytes, ResponseAccount>,
105}
106
107impl Snapshot {
108    fn extend(&mut self, other: Snapshot) {
109        self.states.extend(other.states);
110        self.vm_storage.extend(other.vm_storage);
111    }
112
113    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
114        &self.states
115    }
116
117    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
118        &self.vm_storage
119    }
120}
121
122#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
123pub struct StateSyncMessage<H>
124where
125    H: HeaderLike,
126{
127    /// The block information for this update.
128    pub header: H,
129    /// Snapshot for new components.
130    pub snapshots: Snapshot,
131    /// A single delta contains state updates for all tracked components, as well as additional
132    /// information about the system components e.g. newly added components (even below tvl), tvl
133    /// updates, balance updates.
134    pub deltas: Option<BlockChanges>,
135    /// Components that stopped being tracked.
136    pub removed_components: HashMap<String, ProtocolComponent>,
137}
138
139impl<H> StateSyncMessage<H>
140where
141    H: HeaderLike,
142{
143    pub fn merge(mut self, other: Self) -> Self {
144        // be careful with removed and snapshots attributes here, these can be ambiguous.
145        self.removed_components
146            .retain(|k, _| !other.snapshots.states.contains_key(k));
147        self.snapshots
148            .states
149            .retain(|k, _| !other.removed_components.contains_key(k));
150
151        self.snapshots.extend(other.snapshots);
152        let deltas = match (self.deltas, other.deltas) {
153            (Some(l), Some(r)) => Some(l.merge(r)),
154            (None, Some(r)) => Some(r),
155            (Some(l), None) => Some(l),
156            (None, None) => None,
157        };
158        self.removed_components
159            .extend(other.removed_components);
160        Self {
161            header: other.header,
162            snapshots: self.snapshots,
163            deltas,
164            removed_components: self.removed_components,
165        }
166    }
167}
168
169/// Handle for controlling a running synchronizer task.
170///
171/// This handle provides methods to gracefully shut down the synchronizer
172/// and await its completion with a timeout.
173pub struct SynchronizerTaskHandle {
174    join_handle: JoinHandle<SyncResult<()>>,
175    close_tx: oneshot::Sender<()>,
176}
177
178/// StateSynchronizer
179///
180/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
181/// delivering messages to the client that let him reconstruct subsets of the protocol state.
182///
183/// This involves deciding which components to track according to the clients preferences,
184/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
185/// delivering delta messages for the components that have changed.
186impl SynchronizerTaskHandle {
187    pub fn new(join_handle: JoinHandle<SyncResult<()>>, close_tx: oneshot::Sender<()>) -> Self {
188        Self { join_handle, close_tx }
189    }
190
191    /// Splits the handle into its join handle and close sender.
192    ///
193    /// This allows monitoring the task completion separately from controlling shutdown.
194    /// The join handle can be used with FuturesUnordered for monitoring, while the
195    /// close sender can be used to signal graceful shutdown.
196    pub fn split(self) -> (JoinHandle<SyncResult<()>>, oneshot::Sender<()>) {
197        (self.join_handle, self.close_tx)
198    }
199}
200
201#[async_trait]
202pub trait StateSynchronizer: Send + Sync + 'static {
203    async fn initialize(&mut self) -> SyncResult<()>;
204    /// Starts the state synchronization, consuming the synchronizer.
205    /// Returns a handle for controlling the running task and a receiver for messages.
206    async fn start(
207        mut self,
208    ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)>;
209}
210
211impl<R, D> ProtocolStateSynchronizer<R, D>
212where
213    // TODO: Consider moving these constraints directly to the
214    // client...
215    R: RPCClient + Clone + Send + Sync + 'static,
216    D: DeltasClient + Clone + Send + Sync + 'static,
217{
218    /// Creates a new state synchronizer.
219    #[allow(clippy::too_many_arguments)]
220    pub fn new(
221        extractor_id: ExtractorIdentity,
222        retrieve_balances: bool,
223        component_filter: ComponentFilter,
224        max_retries: u64,
225        retry_cooldown: Duration,
226        include_snapshots: bool,
227        include_tvl: bool,
228        rpc_client: R,
229        deltas_client: D,
230        timeout: u64,
231    ) -> Self {
232        Self {
233            extractor_id: extractor_id.clone(),
234            retrieve_balances,
235            rpc_client: rpc_client.clone(),
236            include_snapshots,
237            deltas_client,
238            component_tracker: ComponentTracker::new(
239                extractor_id.chain,
240                extractor_id.name.as_str(),
241                component_filter,
242                rpc_client,
243            ),
244            max_retries,
245            retry_cooldown,
246            last_synced_block: None,
247            timeout,
248            include_tvl,
249        }
250    }
251
252    /// Retrieves state snapshots of the requested components
253    #[allow(deprecated)]
254    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
255        &mut self,
256        header: BlockHeader,
257        ids: Option<I>,
258    ) -> SyncResult<StateSyncMessage<BlockHeader>> {
259        if !self.include_snapshots {
260            return Ok(StateSyncMessage { header, ..Default::default() });
261        }
262        let version = VersionParam::new(
263            None,
264            Some(BlockParam {
265                chain: Some(self.extractor_id.chain),
266                hash: None,
267                number: Some(header.number as i64),
268            }),
269        );
270
271        // Use given ids or use all if not passed
272        let component_ids: Vec<_> = match ids {
273            Some(ids) => ids.into_iter().cloned().collect(),
274            None => self
275                .component_tracker
276                .get_tracked_component_ids(),
277        };
278
279        if component_ids.is_empty() {
280            return Ok(StateSyncMessage { header, ..Default::default() });
281        }
282
283        let component_tvl = if self.include_tvl {
284            let body = ComponentTvlRequestBody::id_filtered(
285                component_ids.clone(),
286                self.extractor_id.chain,
287            );
288            self.rpc_client
289                .get_component_tvl_paginated(&body, 100, 4)
290                .await?
291                .tvl
292        } else {
293            HashMap::new()
294        };
295
296        //TODO: Improve this, we should not query for every component, but only for the ones that
297        // could have entrypoints. Maybe apply a filter per protocol?
298        let entrypoints_result = if self.extractor_id.chain == Chain::Ethereum {
299            // Fetch entrypoints
300            let result = self
301                .rpc_client
302                .get_traced_entry_points_paginated(
303                    self.extractor_id.chain,
304                    &self.extractor_id.name,
305                    &component_ids,
306                    100,
307                    4,
308                )
309                .await?;
310            self.component_tracker
311                .process_entrypoints(&result.clone().into());
312            Some(result)
313        } else {
314            None
315        };
316
317        // Fetch protocol states
318        let mut protocol_states = self
319            .rpc_client
320            .get_protocol_states_paginated(
321                self.extractor_id.chain,
322                &component_ids,
323                &self.extractor_id.name,
324                self.retrieve_balances,
325                &version,
326                100,
327                4,
328            )
329            .await?
330            .states
331            .into_iter()
332            .map(|state| (state.component_id.clone(), state))
333            .collect::<HashMap<_, _>>();
334
335        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
336        let states = self
337            .component_tracker
338            .components
339            .values()
340            .filter_map(|component| {
341                if let Some(state) = protocol_states.remove(&component.id) {
342                    Some((
343                        component.id.clone(),
344                        ComponentWithState {
345                            state,
346                            component: component.clone(),
347                            component_tvl: component_tvl
348                                .get(&component.id)
349                                .cloned(),
350                            entrypoints: entrypoints_result
351                                .as_ref()
352                                .map(|r| {
353                                    r.traced_entry_points
354                                        .get(&component.id)
355                                        .cloned()
356                                        .unwrap_or_default()
357                                })
358                                .unwrap_or_default(),
359                        },
360                    ))
361                } else if component_ids.contains(&component.id) {
362                    // only emit error event if we requested this component
363                    let component_id = &component.id;
364                    error!(?component_id, "Missing state for native component!");
365                    None
366                } else {
367                    None
368                }
369            })
370            .collect();
371
372        // Fetch contract states
373        let contract_ids = self
374            .component_tracker
375            .get_contracts_by_component(&component_ids);
376        let vm_storage = if !contract_ids.is_empty() {
377            let ids: Vec<Bytes> = contract_ids
378                .clone()
379                .into_iter()
380                .collect();
381            let contract_states = self
382                .rpc_client
383                .get_contract_state_paginated(
384                    self.extractor_id.chain,
385                    ids.as_slice(),
386                    &self.extractor_id.name,
387                    &version,
388                    100,
389                    4,
390                )
391                .await?
392                .accounts
393                .into_iter()
394                .map(|acc| (acc.address.clone(), acc))
395                .collect::<HashMap<_, _>>();
396
397            trace!(states=?&contract_states, "Retrieved ContractState");
398
399            let contract_address_to_components = self
400                .component_tracker
401                .components
402                .iter()
403                .filter_map(|(id, comp)| {
404                    if component_ids.contains(id) {
405                        Some(
406                            comp.contract_ids
407                                .iter()
408                                .map(|address| (address.clone(), comp.id.clone())),
409                        )
410                    } else {
411                        None
412                    }
413                })
414                .flatten()
415                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
416                    acc.entry(addr).or_default().push(c_id);
417                    acc
418                });
419
420            contract_ids
421                .iter()
422                .filter_map(|address| {
423                    if let Some(state) = contract_states.get(address) {
424                        Some((address.clone(), state.clone()))
425                    } else if let Some(ids) = contract_address_to_components.get(address) {
426                        // only emit error even if we did actually request this address
427                        error!(
428                            ?address,
429                            ?ids,
430                            "Component with lacking contract storage encountered!"
431                        );
432                        None
433                    } else {
434                        None
435                    }
436                })
437                .collect()
438        } else {
439            HashMap::new()
440        };
441
442        Ok(StateSyncMessage {
443            header,
444            snapshots: Snapshot { states, vm_storage },
445            deltas: None,
446            removed_components: HashMap::new(),
447        })
448    }
449
450    /// Main method that does all the work.
451    ///
452    /// ## Return Value
453    ///
454    /// Returns a `Result` where:
455    /// - `Ok(())` - Synchronization completed successfully (usually due to close signal)
456    /// - `Err((error, None))` - Error occurred AND close signal was received (don't retry)
457    /// - `Err((error, Some(end_rx)))` - Error occurred but close signal was NOT received (can
458    ///   retry)
459    ///
460    /// The returned `end_rx` (if any) should be reused for retry attempts since the close
461    /// signal may still arrive and we want to remain cancellable across retries.
462    #[instrument(skip(self, block_tx, end_rx), fields(extractor_id = %self.extractor_id))]
463    async fn state_sync(
464        &mut self,
465        block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
466        mut end_rx: oneshot::Receiver<()>,
467    ) -> Result<(), (SynchronizerError, Option<oneshot::Receiver<()>>)> {
468        // initialisation
469
470        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
471        let (subscription_id, mut msg_rx) = match self
472            .deltas_client
473            .subscribe(self.extractor_id.clone(), subscription_options)
474            .await
475        {
476            Ok(result) => result,
477            Err(e) => return Err((e.into(), Some(end_rx))),
478        };
479
480        let result = async {
481            info!("Waiting for deltas...");
482            // wait for first deltas message
483            let mut first_msg = select! {
484                deltas_result = timeout(Duration::from_secs(self.timeout), msg_rx.recv()) => {
485                    deltas_result
486                        .map_err(|_| {
487                            SynchronizerError::Timeout(format!(
488                                "First deltas took longer than {t}s to arrive",
489                                t = self.timeout
490                            ))
491                        })?
492                        .ok_or_else(|| {
493                            SynchronizerError::ConnectionError(
494                                "Deltas channel closed before first message".to_string(),
495                            )
496                        })?
497                },
498                _ = &mut end_rx => {
499                    info!("Received close signal while waiting for first deltas");
500                    return Ok(());
501                }
502            };
503            self.filter_deltas(&mut first_msg);
504
505            // initial snapshot
506            let block = first_msg.get_block().clone();
507            info!(height = &block.number, "Deltas received. Retrieving snapshot");
508            let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
509            let snapshot = self
510                .get_snapshots::<Vec<&String>>(
511                    BlockHeader::from_block(&block, false),
512                    None,
513                )
514                .await?
515                .merge(StateSyncMessage {
516                    header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
517                    snapshots: Default::default(),
518                    deltas: Some(first_msg),
519                    removed_components: Default::default(),
520                });
521
522            let n_components = self.component_tracker.components.len();
523            let n_snapshots = snapshot.snapshots.states.len();
524            info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
525
526            block_tx.send(snapshot).await?;
527            self.last_synced_block = Some(header.clone());
528            loop {
529                select! {
530                    deltas_opt = msg_rx.recv() => {
531                        if let Some(mut deltas) = deltas_opt {
532                            let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
533                            debug!(block_number=?header.number, "Received delta message");
534
535                            let (snapshots, removed_components) = {
536                                // 1. Remove components based on latest changes
537                                // 2. Add components based on latest changes, query those for snapshots
538                                let (to_add, to_remove) = self.component_tracker.filter_updated_components(&deltas);
539
540                                // Only components we don't track yet need a snapshot,
541                                let requiring_snapshot: Vec<_> = to_add
542                                    .iter()
543                                    .filter(|id| {
544                                        !self.component_tracker
545                                            .components
546                                            .contains_key(id.as_str())
547                                    })
548                                    .collect();
549                                debug!(components=?requiring_snapshot, "SnapshotRequest");
550                                self.component_tracker
551                                    .start_tracking(requiring_snapshot.as_slice())
552                                    .await?;
553                                let snapshots = self
554                                    .get_snapshots(header.clone(), Some(requiring_snapshot))
555                                    .await?
556                                    .snapshots;
557
558                                let removed_components = if !to_remove.is_empty() {
559                                    self.component_tracker.stop_tracking(&to_remove)
560                                } else {
561                                    Default::default()
562                                };
563
564                                (snapshots, removed_components)
565                            };
566
567                            // 3. Update entrypoints on the tracker (affects which contracts are tracked)
568                            self.component_tracker.process_entrypoints(&deltas.dci_update);
569
570                            // 4. Filter deltas by currently tracked components / contracts
571                            self.filter_deltas(&mut deltas);
572                            let n_changes = deltas.n_changes();
573
574                            // 5. Send the message
575                            let next = StateSyncMessage {
576                                header: header.clone(),
577                                snapshots,
578                                deltas: Some(deltas),
579                                removed_components,
580                            };
581                            block_tx.send(next).await?;
582                            self.last_synced_block = Some(header.clone());
583
584                            debug!(block_number=?header.number, n_changes, "Finished processing delta message");
585                        } else {
586                            return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
587                        }
588                    },
589                    _ = &mut end_rx => {
590                        info!("Received close signal during state_sync");
591                        return Ok(());
592                    }
593                }
594            }
595        }.await;
596
597        // This cleanup code now runs regardless of how the function exits (error or channel close)
598        warn!(last_synced_block = ?&self.last_synced_block, "Deltas processing ended, resetting last synced block.");
599        self.last_synced_block = None;
600        //Ignore error
601        let _ = self
602            .deltas_client
603            .unsubscribe(subscription_id)
604            .await
605            .map_err(|err| {
606                warn!(err=?err, "Unsubscribing from deltas on cleanup failed!");
607            });
608
609        // Handle the result: if it succeeded, we're done. If it errored, we need to determine
610        // whether the end_rx was consumed (close signal received) or not
611        match result {
612            Ok(()) => Ok(()), // Success, likely due to close signal
613            Err(e) => {
614                // The error came from the inner async block. Since the async block
615                // can receive close signals (which would return Ok), any error means
616                // the close signal was NOT received, so we can return the end_rx for retry
617                Err((e, Some(end_rx)))
618            }
619        }
620    }
621
622    fn filter_deltas(&self, deltas: &mut BlockChanges) {
623        deltas.filter_by_component(|id| {
624            self.component_tracker
625                .components
626                .contains_key(id)
627        });
628        deltas.filter_by_contract(|id| {
629            self.component_tracker
630                .contracts
631                .contains(id)
632        });
633    }
634}
635
636#[async_trait]
637impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
638where
639    R: RPCClient + Clone + Send + Sync + 'static,
640    D: DeltasClient + Clone + Send + Sync + 'static,
641{
642    async fn initialize(&mut self) -> SyncResult<()> {
643        info!("Retrieving relevant protocol components");
644        self.component_tracker
645            .initialise_components()
646            .await?;
647        info!(
648            n_components = self.component_tracker.components.len(),
649            n_contracts = self.component_tracker.contracts.len(),
650            "Finished retrieving components",
651        );
652
653        Ok(())
654    }
655
656    async fn start(
657        mut self,
658    ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)> {
659        let (mut tx, rx) = channel(15);
660        let (end_tx, end_rx) = oneshot::channel::<()>();
661
662        let jh = tokio::spawn(async move {
663            let mut retry_count = 0;
664            let mut current_end_rx = end_rx;
665
666            while retry_count < self.max_retries {
667                info!(extractor_id=%&self.extractor_id, retry_count, "(Re)starting synchronization loop");
668
669                let res = self
670                    .state_sync(&mut tx, current_end_rx)
671                    .await;
672                match res {
673                    Ok(()) => {
674                        info!(
675                            extractor_id=%&self.extractor_id,
676                            retry_count,
677                            "State synchronization exited cleanly"
678                        );
679                        return Ok(());
680                    }
681                    Err((e, maybe_end_rx)) => {
682                        error!(
683                            extractor_id=%&self.extractor_id,
684                            retry_count,
685                            error=%e,
686                            "State synchronization errored!"
687                        );
688
689                        // If we have the end_rx back, we can retry
690                        if let Some(recovered_end_rx) = maybe_end_rx {
691                            current_end_rx = recovered_end_rx;
692
693                            if let SynchronizerError::ConnectionClosed = e {
694                                // break synchronization loop if connection is closed
695                                return Err(e);
696                            }
697                        } else {
698                            // Close signal was received, exit cleanly
699                            info!(extractor_id=%&self.extractor_id, "Received close signal, exiting");
700                            return Ok(());
701                        }
702                    }
703                }
704                sleep(self.retry_cooldown).await;
705                retry_count += 1;
706            }
707            warn!(extractor_id=%&self.extractor_id, retry_count, "Max retries exceeded");
708            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
709        });
710
711        let handle = SynchronizerTaskHandle::new(jh, end_tx);
712        Ok((handle, rx))
713    }
714}
715
716#[cfg(test)]
717mod test {
718    //! Test suite for ProtocolStateSynchronizer shutdown and cleanup behavior.
719    //!
720    //! ## Test Coverage Strategy:
721    //!
722    //! ### Shutdown & Close Signal Tests:
723    //! - `test_public_close_api_functionality` - Tests public API (start/close lifecycle)
724    //! - `test_close_signal_while_waiting_for_first_deltas` - Close during initial wait
725    //! - `test_close_signal_during_main_processing_loop` - Close during main processing
726    //!
727    //! ### Cleanup & Error Handling Tests:
728    //! - `test_cleanup_runs_when_state_sync_processing_errors` - Cleanup on processing errors
729    //!
730    //! ### Coverage Summary:
731    //! These tests ensure cleanup code (shared state reset + unsubscribe) runs on ALL exit paths:
732    //! ✓ Close signal before first deltas   ✓ Close signal during processing
733    //! ✓ Processing errors                  ✓ Channel closure
734    //! ✓ Public API close operations        ✓ Normal completion
735
736    use std::{collections::HashSet, sync::Arc};
737
738    use test_log::test;
739    use tycho_common::dto::{
740        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
741        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
742        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
743        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
744        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
745        TracedEntryPointRequestResponse, TracingParams,
746    };
747    use uuid::Uuid;
748
749    use super::*;
750    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
751
752    // Required for mock client to implement clone
753    struct ArcRPCClient<T>(Arc<T>);
754
755    // Default derive(Clone) does require T to be Clone as well.
756    impl<T> Clone for ArcRPCClient<T> {
757        fn clone(&self) -> Self {
758            ArcRPCClient(self.0.clone())
759        }
760    }
761
762    #[async_trait]
763    impl<T> RPCClient for ArcRPCClient<T>
764    where
765        T: RPCClient + Sync + Send + 'static,
766    {
767        async fn get_tokens(
768            &self,
769            request: &TokensRequestBody,
770        ) -> Result<TokensRequestResponse, RPCError> {
771            self.0.get_tokens(request).await
772        }
773
774        async fn get_contract_state(
775            &self,
776            request: &StateRequestBody,
777        ) -> Result<StateRequestResponse, RPCError> {
778            self.0.get_contract_state(request).await
779        }
780
781        async fn get_protocol_components(
782            &self,
783            request: &ProtocolComponentsRequestBody,
784        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
785            self.0
786                .get_protocol_components(request)
787                .await
788        }
789
790        async fn get_protocol_states(
791            &self,
792            request: &ProtocolStateRequestBody,
793        ) -> Result<ProtocolStateRequestResponse, RPCError> {
794            self.0
795                .get_protocol_states(request)
796                .await
797        }
798
799        async fn get_protocol_systems(
800            &self,
801            request: &ProtocolSystemsRequestBody,
802        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
803            self.0
804                .get_protocol_systems(request)
805                .await
806        }
807
808        async fn get_component_tvl(
809            &self,
810            request: &ComponentTvlRequestBody,
811        ) -> Result<ComponentTvlRequestResponse, RPCError> {
812            self.0.get_component_tvl(request).await
813        }
814
815        async fn get_traced_entry_points(
816            &self,
817            request: &TracedEntryPointRequestBody,
818        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
819            self.0
820                .get_traced_entry_points(request)
821                .await
822        }
823    }
824
825    // Required for mock client to implement clone
826    struct ArcDeltasClient<T>(Arc<T>);
827
828    // Default derive(Clone) does require T to be Clone as well.
829    impl<T> Clone for ArcDeltasClient<T> {
830        fn clone(&self) -> Self {
831            ArcDeltasClient(self.0.clone())
832        }
833    }
834
835    #[async_trait]
836    impl<T> DeltasClient for ArcDeltasClient<T>
837    where
838        T: DeltasClient + Sync + Send + 'static,
839    {
840        async fn subscribe(
841            &self,
842            extractor_id: ExtractorIdentity,
843            options: SubscriptionOptions,
844        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
845            self.0
846                .subscribe(extractor_id, options)
847                .await
848        }
849
850        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
851            self.0
852                .unsubscribe(subscription_id)
853                .await
854        }
855
856        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
857            self.0.connect().await
858        }
859
860        async fn close(&self) -> Result<(), DeltasError> {
861            self.0.close().await
862        }
863    }
864
865    fn with_mocked_clients(
866        native: bool,
867        include_tvl: bool,
868        rpc_client: Option<MockRPCClient>,
869        deltas_client: Option<MockDeltasClient>,
870    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
871    {
872        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
873        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
874
875        ProtocolStateSynchronizer::new(
876            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
877            native,
878            ComponentFilter::with_tvl_range(50.0, 50.0),
879            1,
880            Duration::from_secs(0),
881            true,
882            include_tvl,
883            rpc_client,
884            deltas_client,
885            10_u64,
886        )
887    }
888
889    fn state_snapshot_native() -> ProtocolStateRequestResponse {
890        ProtocolStateRequestResponse {
891            states: vec![ResponseProtocolState {
892                component_id: "Component1".to_string(),
893                ..Default::default()
894            }],
895            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
896        }
897    }
898
899    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
900        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
901
902        ComponentTvlRequestResponse {
903            tvl,
904            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
905        }
906    }
907
908    #[test(tokio::test)]
909    async fn test_get_snapshots_native() {
910        let header = BlockHeader::default();
911        let mut rpc = MockRPCClient::new();
912        rpc.expect_get_protocol_states()
913            .returning(|_| Ok(state_snapshot_native()));
914        rpc.expect_get_traced_entry_points()
915            .returning(|_| {
916                Ok(TracedEntryPointRequestResponse {
917                    traced_entry_points: HashMap::new(),
918                    pagination: PaginationResponse::new(0, 20, 0),
919                })
920            });
921        let mut state_sync = with_mocked_clients(true, false, Some(rpc), None);
922        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
923        state_sync
924            .component_tracker
925            .components
926            .insert("Component1".to_string(), component.clone());
927        let components_arg = ["Component1".to_string()];
928        let exp = StateSyncMessage {
929            header: header.clone(),
930            snapshots: Snapshot {
931                states: state_snapshot_native()
932                    .states
933                    .into_iter()
934                    .map(|state| {
935                        (
936                            state.component_id.clone(),
937                            ComponentWithState {
938                                state,
939                                component: component.clone(),
940                                entrypoints: vec![],
941                                component_tvl: None,
942                            },
943                        )
944                    })
945                    .collect(),
946                vm_storage: HashMap::new(),
947            },
948            deltas: None,
949            removed_components: Default::default(),
950        };
951
952        let snap = state_sync
953            .get_snapshots(header, Some(&components_arg))
954            .await
955            .expect("Retrieving snapshot failed");
956
957        assert_eq!(snap, exp);
958    }
959
960    #[test(tokio::test)]
961    async fn test_get_snapshots_native_with_tvl() {
962        let header = BlockHeader::default();
963        let mut rpc = MockRPCClient::new();
964        rpc.expect_get_protocol_states()
965            .returning(|_| Ok(state_snapshot_native()));
966        rpc.expect_get_component_tvl()
967            .returning(|_| Ok(component_tvl_snapshot()));
968        rpc.expect_get_traced_entry_points()
969            .returning(|_| {
970                Ok(TracedEntryPointRequestResponse {
971                    traced_entry_points: HashMap::new(),
972                    pagination: PaginationResponse::new(0, 20, 0),
973                })
974            });
975        let mut state_sync = with_mocked_clients(true, true, Some(rpc), None);
976        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
977        state_sync
978            .component_tracker
979            .components
980            .insert("Component1".to_string(), component.clone());
981        let components_arg = ["Component1".to_string()];
982        let exp = StateSyncMessage {
983            header: header.clone(),
984            snapshots: Snapshot {
985                states: state_snapshot_native()
986                    .states
987                    .into_iter()
988                    .map(|state| {
989                        (
990                            state.component_id.clone(),
991                            ComponentWithState {
992                                state,
993                                component: component.clone(),
994                                component_tvl: Some(100.0),
995                                entrypoints: vec![],
996                            },
997                        )
998                    })
999                    .collect(),
1000                vm_storage: HashMap::new(),
1001            },
1002            deltas: None,
1003            removed_components: Default::default(),
1004        };
1005
1006        let snap = state_sync
1007            .get_snapshots(header, Some(&components_arg))
1008            .await
1009            .expect("Retrieving snapshot failed");
1010
1011        assert_eq!(snap, exp);
1012    }
1013
1014    fn state_snapshot_vm() -> StateRequestResponse {
1015        StateRequestResponse {
1016            accounts: vec![
1017                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
1018                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
1019            ],
1020            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1021        }
1022    }
1023
1024    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
1025        TracedEntryPointRequestResponse {
1026            traced_entry_points: HashMap::from([(
1027                "Component1".to_string(),
1028                vec![(
1029                    EntryPointWithTracingParams {
1030                        entry_point: EntryPoint {
1031                            external_id: "entrypoint_a".to_string(),
1032                            target: Bytes::from("0x0badc0ffee"),
1033                            signature: "sig()".to_string(),
1034                        },
1035                        params: TracingParams::RPCTracer(RPCTracerParams {
1036                            caller: Some(Bytes::from("0x0badc0ffee")),
1037                            calldata: Bytes::from("0x0badc0ffee"),
1038                        }),
1039                    },
1040                    TracingResult {
1041                        retriggers: HashSet::from([(
1042                            Bytes::from("0x0badc0ffee"),
1043                            Bytes::from("0x0badc0ffee"),
1044                        )]),
1045                        accessed_slots: HashMap::from([(
1046                            Bytes::from("0x0badc0ffee"),
1047                            HashSet::from([Bytes::from("0xbadbeef0")]),
1048                        )]),
1049                    },
1050                )],
1051            )]),
1052            pagination: PaginationResponse::new(0, 20, 0),
1053        }
1054    }
1055
1056    #[test(tokio::test)]
1057    async fn test_get_snapshots_vm() {
1058        let header = BlockHeader::default();
1059        let mut rpc = MockRPCClient::new();
1060        rpc.expect_get_protocol_states()
1061            .returning(|_| Ok(state_snapshot_native()));
1062        rpc.expect_get_contract_state()
1063            .returning(|_| Ok(state_snapshot_vm()));
1064        rpc.expect_get_traced_entry_points()
1065            .returning(|_| Ok(traced_entry_point_response()));
1066        let mut state_sync = with_mocked_clients(false, false, Some(rpc), None);
1067        let component = ProtocolComponent {
1068            id: "Component1".to_string(),
1069            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1070            ..Default::default()
1071        };
1072        state_sync
1073            .component_tracker
1074            .components
1075            .insert("Component1".to_string(), component.clone());
1076        let components_arg = ["Component1".to_string()];
1077        let exp = StateSyncMessage {
1078            header: header.clone(),
1079            snapshots: Snapshot {
1080                states: [(
1081                    component.id.clone(),
1082                    ComponentWithState {
1083                        state: ResponseProtocolState {
1084                            component_id: "Component1".to_string(),
1085                            ..Default::default()
1086                        },
1087                        component: component.clone(),
1088                        component_tvl: None,
1089                        entrypoints: vec![(
1090                            EntryPointWithTracingParams {
1091                                entry_point: EntryPoint {
1092                                    external_id: "entrypoint_a".to_string(),
1093                                    target: Bytes::from("0x0badc0ffee"),
1094                                    signature: "sig()".to_string(),
1095                                },
1096                                params: TracingParams::RPCTracer(RPCTracerParams {
1097                                    caller: Some(Bytes::from("0x0badc0ffee")),
1098                                    calldata: Bytes::from("0x0badc0ffee"),
1099                                }),
1100                            },
1101                            TracingResult {
1102                                retriggers: HashSet::from([(
1103                                    Bytes::from("0x0badc0ffee"),
1104                                    Bytes::from("0x0badc0ffee"),
1105                                )]),
1106                                accessed_slots: HashMap::from([(
1107                                    Bytes::from("0x0badc0ffee"),
1108                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1109                                )]),
1110                            },
1111                        )],
1112                    },
1113                )]
1114                .into_iter()
1115                .collect(),
1116                vm_storage: state_snapshot_vm()
1117                    .accounts
1118                    .into_iter()
1119                    .map(|state| (state.address.clone(), state))
1120                    .collect(),
1121            },
1122            deltas: None,
1123            removed_components: Default::default(),
1124        };
1125
1126        let snap = state_sync
1127            .get_snapshots(header, Some(&components_arg))
1128            .await
1129            .expect("Retrieving snapshot failed");
1130
1131        assert_eq!(snap, exp);
1132    }
1133
1134    #[test(tokio::test)]
1135    async fn test_get_snapshots_vm_with_tvl() {
1136        let header = BlockHeader::default();
1137        let mut rpc = MockRPCClient::new();
1138        rpc.expect_get_protocol_states()
1139            .returning(|_| Ok(state_snapshot_native()));
1140        rpc.expect_get_contract_state()
1141            .returning(|_| Ok(state_snapshot_vm()));
1142        rpc.expect_get_component_tvl()
1143            .returning(|_| Ok(component_tvl_snapshot()));
1144        rpc.expect_get_traced_entry_points()
1145            .returning(|_| {
1146                Ok(TracedEntryPointRequestResponse {
1147                    traced_entry_points: HashMap::new(),
1148                    pagination: PaginationResponse::new(0, 20, 0),
1149                })
1150            });
1151        let mut state_sync = with_mocked_clients(false, true, Some(rpc), None);
1152        let component = ProtocolComponent {
1153            id: "Component1".to_string(),
1154            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1155            ..Default::default()
1156        };
1157        state_sync
1158            .component_tracker
1159            .components
1160            .insert("Component1".to_string(), component.clone());
1161        let components_arg = ["Component1".to_string()];
1162        let exp = StateSyncMessage {
1163            header: header.clone(),
1164            snapshots: Snapshot {
1165                states: [(
1166                    component.id.clone(),
1167                    ComponentWithState {
1168                        state: ResponseProtocolState {
1169                            component_id: "Component1".to_string(),
1170                            ..Default::default()
1171                        },
1172                        component: component.clone(),
1173                        component_tvl: Some(100.0),
1174                        entrypoints: vec![],
1175                    },
1176                )]
1177                .into_iter()
1178                .collect(),
1179                vm_storage: state_snapshot_vm()
1180                    .accounts
1181                    .into_iter()
1182                    .map(|state| (state.address.clone(), state))
1183                    .collect(),
1184            },
1185            deltas: None,
1186            removed_components: Default::default(),
1187        };
1188
1189        let snap = state_sync
1190            .get_snapshots(header, Some(&components_arg))
1191            .await
1192            .expect("Retrieving snapshot failed");
1193
1194        assert_eq!(snap, exp);
1195    }
1196
1197    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1198        let mut rpc_client = MockRPCClient::new();
1199        // Mocks for the start_tracking call, these need to come first because they are more
1200        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1201        rpc_client
1202            .expect_get_protocol_components()
1203            .with(mockall::predicate::function(
1204                move |request_params: &ProtocolComponentsRequestBody| {
1205                    if let Some(ids) = request_params.component_ids.as_ref() {
1206                        ids.contains(&"Component3".to_string())
1207                    } else {
1208                        false
1209                    }
1210                },
1211            ))
1212            .returning(|_| {
1213                // return Component3
1214                Ok(ProtocolComponentRequestResponse {
1215                    protocol_components: vec![
1216                        // this component shall have a tvl update above threshold
1217                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1218                    ],
1219                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1220                })
1221            });
1222        rpc_client
1223            .expect_get_protocol_states()
1224            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1225                let expected_id = "Component3".to_string();
1226                if let Some(ids) = request_params.protocol_ids.as_ref() {
1227                    ids.contains(&expected_id)
1228                } else {
1229                    false
1230                }
1231            }))
1232            .returning(|_| {
1233                // return Component3 state
1234                Ok(ProtocolStateRequestResponse {
1235                    states: vec![ResponseProtocolState {
1236                        component_id: "Component3".to_string(),
1237                        ..Default::default()
1238                    }],
1239                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1240                })
1241            });
1242
1243        // mock calls for the initial state snapshots
1244        rpc_client
1245            .expect_get_protocol_components()
1246            .returning(|_| {
1247                // Initial sync of components
1248                Ok(ProtocolComponentRequestResponse {
1249                    protocol_components: vec![
1250                        // this component shall have a tvl update above threshold
1251                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1252                        // this component shall have a tvl update below threshold.
1253                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1254                        // a third component will have a tvl update above threshold
1255                    ],
1256                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1257                })
1258            });
1259        rpc_client
1260            .expect_get_protocol_states()
1261            .returning(|_| {
1262                // Initial state snapshot
1263                Ok(ProtocolStateRequestResponse {
1264                    states: vec![
1265                        ResponseProtocolState {
1266                            component_id: "Component1".to_string(),
1267                            ..Default::default()
1268                        },
1269                        ResponseProtocolState {
1270                            component_id: "Component2".to_string(),
1271                            ..Default::default()
1272                        },
1273                    ],
1274                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1275                })
1276            });
1277        rpc_client
1278            .expect_get_component_tvl()
1279            .returning(|_| {
1280                Ok(ComponentTvlRequestResponse {
1281                    tvl: [
1282                        ("Component1".to_string(), 100.0),
1283                        ("Component2".to_string(), 0.0),
1284                        ("Component3".to_string(), 1000.0),
1285                    ]
1286                    .into_iter()
1287                    .collect(),
1288                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1289                })
1290            });
1291        rpc_client
1292            .expect_get_traced_entry_points()
1293            .returning(|_| {
1294                Ok(TracedEntryPointRequestResponse {
1295                    traced_entry_points: HashMap::new(),
1296                    pagination: PaginationResponse::new(0, 20, 0),
1297                })
1298            });
1299
1300        // Mock deltas client and messages
1301        let mut deltas_client = MockDeltasClient::new();
1302        let (tx, rx) = channel(1);
1303        deltas_client
1304            .expect_subscribe()
1305            .return_once(move |_, _| {
1306                // Return subscriber id and a channel
1307                Ok((Uuid::default(), rx))
1308            });
1309
1310        // Expect unsubscribe call during cleanup
1311        deltas_client
1312            .expect_unsubscribe()
1313            .return_once(|_| Ok(()));
1314
1315        (rpc_client, deltas_client, tx)
1316    }
1317
1318    /// Test strategy
1319    ///
1320    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1321    /// - send 2 dummy messages, containing only blocks
1322    /// - third message contains a new component with some significant tvl, one initial component
1323    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1324    #[test(tokio::test)]
1325    async fn test_state_sync() {
1326        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1327        let deltas = [
1328            BlockChanges {
1329                extractor: "uniswap-v2".to_string(),
1330                chain: Chain::Ethereum,
1331                block: Block {
1332                    number: 1,
1333                    hash: Bytes::from("0x01"),
1334                    parent_hash: Bytes::from("0x00"),
1335                    chain: Chain::Ethereum,
1336                    ts: Default::default(),
1337                },
1338                revert: false,
1339                dci_update: DCIUpdate {
1340                    new_entrypoints: HashMap::from([(
1341                        "Component1".to_string(),
1342                        HashSet::from([EntryPoint {
1343                            external_id: "entrypoint_a".to_string(),
1344                            target: Bytes::from("0x0badc0ffee"),
1345                            signature: "sig()".to_string(),
1346                        }]),
1347                    )]),
1348                    new_entrypoint_params: HashMap::from([(
1349                        "entrypoint_a".to_string(),
1350                        HashSet::from([(
1351                            TracingParams::RPCTracer(RPCTracerParams {
1352                                caller: Some(Bytes::from("0x0badc0ffee")),
1353                                calldata: Bytes::from("0x0badc0ffee"),
1354                            }),
1355                            Some("Component1".to_string()),
1356                        )]),
1357                    )]),
1358                    trace_results: HashMap::from([(
1359                        "entrypoint_a".to_string(),
1360                        TracingResult {
1361                            retriggers: HashSet::from([(
1362                                Bytes::from("0x0badc0ffee"),
1363                                Bytes::from("0x0badc0ffee"),
1364                            )]),
1365                            accessed_slots: HashMap::from([(
1366                                Bytes::from("0x0badc0ffee"),
1367                                HashSet::from([Bytes::from("0xbadbeef0")]),
1368                            )]),
1369                        },
1370                    )]),
1371                },
1372                ..Default::default()
1373            },
1374            BlockChanges {
1375                extractor: "uniswap-v2".to_string(),
1376                chain: Chain::Ethereum,
1377                block: Block {
1378                    number: 2,
1379                    hash: Bytes::from("0x02"),
1380                    parent_hash: Bytes::from("0x01"),
1381                    chain: Chain::Ethereum,
1382                    ts: Default::default(),
1383                },
1384                revert: false,
1385                component_tvl: [
1386                    ("Component1".to_string(), 100.0),
1387                    ("Component2".to_string(), 0.0),
1388                    ("Component3".to_string(), 1000.0),
1389                ]
1390                .into_iter()
1391                .collect(),
1392                ..Default::default()
1393            },
1394        ];
1395        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1396        state_sync
1397            .initialize()
1398            .await
1399            .expect("Init failed");
1400
1401        // Test starts here
1402        let (handle, mut rx) = state_sync
1403            .start()
1404            .await
1405            .expect("Failed to start state synchronizer");
1406        let (jh, close_tx) = handle.split();
1407        tx.send(deltas[0].clone())
1408            .await
1409            .expect("deltas channel msg 0 closed!");
1410        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1411            .await
1412            .expect("waiting for first state msg timed out!")
1413            .expect("state sync block sender closed!");
1414        tx.send(deltas[1].clone())
1415            .await
1416            .expect("deltas channel msg 1 closed!");
1417        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1418            .await
1419            .expect("waiting for second state msg timed out!")
1420            .expect("state sync block sender closed!");
1421        let _ = close_tx.send(());
1422        let exit = jh
1423            .await
1424            .expect("state sync task panicked!");
1425
1426        // assertions
1427        let exp1 = StateSyncMessage {
1428            header: BlockHeader {
1429                number: 1,
1430                hash: Bytes::from("0x01"),
1431                parent_hash: Bytes::from("0x00"),
1432                revert: false,
1433                ..Default::default()
1434            },
1435            snapshots: Snapshot {
1436                states: [
1437                    (
1438                        "Component1".to_string(),
1439                        ComponentWithState {
1440                            state: ResponseProtocolState {
1441                                component_id: "Component1".to_string(),
1442                                ..Default::default()
1443                            },
1444                            component: ProtocolComponent {
1445                                id: "Component1".to_string(),
1446                                ..Default::default()
1447                            },
1448                            component_tvl: Some(100.0),
1449                            entrypoints: vec![],
1450                        },
1451                    ),
1452                    (
1453                        "Component2".to_string(),
1454                        ComponentWithState {
1455                            state: ResponseProtocolState {
1456                                component_id: "Component2".to_string(),
1457                                ..Default::default()
1458                            },
1459                            component: ProtocolComponent {
1460                                id: "Component2".to_string(),
1461                                ..Default::default()
1462                            },
1463                            component_tvl: Some(0.0),
1464                            entrypoints: vec![],
1465                        },
1466                    ),
1467                ]
1468                .into_iter()
1469                .collect(),
1470                vm_storage: HashMap::new(),
1471            },
1472            deltas: Some(deltas[0].clone()),
1473            removed_components: Default::default(),
1474        };
1475
1476        let exp2 = StateSyncMessage {
1477            header: BlockHeader {
1478                number: 2,
1479                hash: Bytes::from("0x02"),
1480                parent_hash: Bytes::from("0x01"),
1481                revert: false,
1482                ..Default::default()
1483            },
1484            snapshots: Snapshot {
1485                states: [
1486                    // This is the new component we queried once it passed the tvl threshold.
1487                    (
1488                        "Component3".to_string(),
1489                        ComponentWithState {
1490                            state: ResponseProtocolState {
1491                                component_id: "Component3".to_string(),
1492                                ..Default::default()
1493                            },
1494                            component: ProtocolComponent {
1495                                id: "Component3".to_string(),
1496                                ..Default::default()
1497                            },
1498                            component_tvl: Some(1000.0),
1499                            entrypoints: vec![],
1500                        },
1501                    ),
1502                ]
1503                .into_iter()
1504                .collect(),
1505                vm_storage: HashMap::new(),
1506            },
1507            // Our deltas are empty and since merge methods are
1508            // tested in tycho-common we don't have much to do here.
1509            deltas: Some(BlockChanges {
1510                extractor: "uniswap-v2".to_string(),
1511                chain: Chain::Ethereum,
1512                block: Block {
1513                    number: 2,
1514                    hash: Bytes::from("0x02"),
1515                    parent_hash: Bytes::from("0x01"),
1516                    chain: Chain::Ethereum,
1517                    ts: Default::default(),
1518                },
1519                revert: false,
1520                component_tvl: [
1521                    // "Component2" should not show here.
1522                    ("Component1".to_string(), 100.0),
1523                    ("Component3".to_string(), 1000.0),
1524                ]
1525                .into_iter()
1526                .collect(),
1527                ..Default::default()
1528            }),
1529            // "Component2" was removed, because its tvl changed to 0.
1530            removed_components: [(
1531                "Component2".to_string(),
1532                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1533            )]
1534            .into_iter()
1535            .collect(),
1536        };
1537        assert_eq!(first_msg, exp1);
1538        assert_eq!(second_msg, exp2);
1539        assert!(exit.is_ok());
1540    }
1541
1542    #[test(tokio::test)]
1543    async fn test_state_sync_with_tvl_range() {
1544        // Define the range for testing
1545        let remove_tvl_threshold = 5.0;
1546        let add_tvl_threshold = 7.0;
1547
1548        let mut rpc_client = MockRPCClient::new();
1549        let mut deltas_client = MockDeltasClient::new();
1550
1551        rpc_client
1552            .expect_get_protocol_components()
1553            .with(mockall::predicate::function(
1554                move |request_params: &ProtocolComponentsRequestBody| {
1555                    if let Some(ids) = request_params.component_ids.as_ref() {
1556                        ids.contains(&"Component3".to_string())
1557                    } else {
1558                        false
1559                    }
1560                },
1561            ))
1562            .returning(|_| {
1563                Ok(ProtocolComponentRequestResponse {
1564                    protocol_components: vec![ProtocolComponent {
1565                        id: "Component3".to_string(),
1566                        ..Default::default()
1567                    }],
1568                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1569                })
1570            });
1571        rpc_client
1572            .expect_get_protocol_states()
1573            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1574                let expected_id = "Component3".to_string();
1575                if let Some(ids) = request_params.protocol_ids.as_ref() {
1576                    ids.contains(&expected_id)
1577                } else {
1578                    false
1579                }
1580            }))
1581            .returning(|_| {
1582                Ok(ProtocolStateRequestResponse {
1583                    states: vec![ResponseProtocolState {
1584                        component_id: "Component3".to_string(),
1585                        ..Default::default()
1586                    }],
1587                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1588                })
1589            });
1590
1591        // Mock for the initial snapshot retrieval
1592        rpc_client
1593            .expect_get_protocol_components()
1594            .returning(|_| {
1595                Ok(ProtocolComponentRequestResponse {
1596                    protocol_components: vec![
1597                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1598                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1599                    ],
1600                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1601                })
1602            });
1603        rpc_client
1604            .expect_get_protocol_states()
1605            .returning(|_| {
1606                Ok(ProtocolStateRequestResponse {
1607                    states: vec![
1608                        ResponseProtocolState {
1609                            component_id: "Component1".to_string(),
1610                            ..Default::default()
1611                        },
1612                        ResponseProtocolState {
1613                            component_id: "Component2".to_string(),
1614                            ..Default::default()
1615                        },
1616                    ],
1617                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1618                })
1619            });
1620        rpc_client
1621            .expect_get_traced_entry_points()
1622            .returning(|_| {
1623                Ok(TracedEntryPointRequestResponse {
1624                    traced_entry_points: HashMap::new(),
1625                    pagination: PaginationResponse::new(0, 20, 0),
1626                })
1627            });
1628
1629        rpc_client
1630            .expect_get_component_tvl()
1631            .returning(|_| {
1632                Ok(ComponentTvlRequestResponse {
1633                    tvl: [
1634                        ("Component1".to_string(), 6.0),
1635                        ("Component2".to_string(), 2.0),
1636                        ("Component3".to_string(), 10.0),
1637                    ]
1638                    .into_iter()
1639                    .collect(),
1640                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1641                })
1642            });
1643
1644        rpc_client
1645            .expect_get_component_tvl()
1646            .returning(|_| {
1647                Ok(ComponentTvlRequestResponse {
1648                    tvl: [
1649                        ("Component1".to_string(), 6.0),
1650                        ("Component2".to_string(), 2.0),
1651                        ("Component3".to_string(), 10.0),
1652                    ]
1653                    .into_iter()
1654                    .collect(),
1655                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1656                })
1657            });
1658
1659        let (tx, rx) = channel(1);
1660        deltas_client
1661            .expect_subscribe()
1662            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1663
1664        // Expect unsubscribe call during cleanup
1665        deltas_client
1666            .expect_unsubscribe()
1667            .return_once(|_| Ok(()));
1668
1669        let mut state_sync = ProtocolStateSynchronizer::new(
1670            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1671            true,
1672            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1673            1,
1674            Duration::from_secs(0),
1675            true,
1676            true,
1677            ArcRPCClient(Arc::new(rpc_client)),
1678            ArcDeltasClient(Arc::new(deltas_client)),
1679            10_u64,
1680        );
1681        state_sync
1682            .initialize()
1683            .await
1684            .expect("Init failed");
1685
1686        // Simulate the incoming BlockChanges
1687        let deltas = [
1688            BlockChanges {
1689                extractor: "uniswap-v2".to_string(),
1690                chain: Chain::Ethereum,
1691                block: Block {
1692                    number: 1,
1693                    hash: Bytes::from("0x01"),
1694                    parent_hash: Bytes::from("0x00"),
1695                    chain: Chain::Ethereum,
1696                    ts: Default::default(),
1697                },
1698                revert: false,
1699                ..Default::default()
1700            },
1701            BlockChanges {
1702                extractor: "uniswap-v2".to_string(),
1703                chain: Chain::Ethereum,
1704                block: Block {
1705                    number: 2,
1706                    hash: Bytes::from("0x02"),
1707                    parent_hash: Bytes::from("0x01"),
1708                    chain: Chain::Ethereum,
1709                    ts: Default::default(),
1710                },
1711                revert: false,
1712                component_tvl: [
1713                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1714                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1715                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1716                ]
1717                .into_iter()
1718                .collect(),
1719                ..Default::default()
1720            },
1721        ];
1722
1723        let (handle, mut rx) = state_sync
1724            .start()
1725            .await
1726            .expect("Failed to start state synchronizer");
1727        let (jh, close_tx) = handle.split();
1728
1729        // Simulate sending delta messages
1730        tx.send(deltas[0].clone())
1731            .await
1732            .expect("deltas channel msg 0 closed!");
1733
1734        // Expecting to receive the initial state message
1735        let _ = timeout(Duration::from_millis(100), rx.recv())
1736            .await
1737            .expect("waiting for first state msg timed out!")
1738            .expect("state sync block sender closed!");
1739
1740        // Send the third message, which should trigger TVL-based changes
1741        tx.send(deltas[1].clone())
1742            .await
1743            .expect("deltas channel msg 1 closed!");
1744        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1745            .await
1746            .expect("waiting for second state msg timed out!")
1747            .expect("state sync block sender closed!");
1748
1749        let _ = close_tx.send(());
1750        let exit = jh
1751            .await
1752            .expect("state sync task panicked!");
1753
1754        let expected_second_msg = StateSyncMessage {
1755            header: BlockHeader {
1756                number: 2,
1757                hash: Bytes::from("0x02"),
1758                parent_hash: Bytes::from("0x01"),
1759                revert: false,
1760                ..Default::default()
1761            },
1762            snapshots: Snapshot {
1763                states: [(
1764                    "Component3".to_string(),
1765                    ComponentWithState {
1766                        state: ResponseProtocolState {
1767                            component_id: "Component3".to_string(),
1768                            ..Default::default()
1769                        },
1770                        component: ProtocolComponent {
1771                            id: "Component3".to_string(),
1772                            ..Default::default()
1773                        },
1774                        component_tvl: Some(10.0),
1775                        entrypoints: vec![], // TODO: add entrypoints?
1776                    },
1777                )]
1778                .into_iter()
1779                .collect(),
1780                vm_storage: HashMap::new(),
1781            },
1782            deltas: Some(BlockChanges {
1783                extractor: "uniswap-v2".to_string(),
1784                chain: Chain::Ethereum,
1785                block: Block {
1786                    number: 2,
1787                    hash: Bytes::from("0x02"),
1788                    parent_hash: Bytes::from("0x01"),
1789                    chain: Chain::Ethereum,
1790                    ts: Default::default(),
1791                },
1792                revert: false,
1793                component_tvl: [
1794                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1795                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1796                ]
1797                .into_iter()
1798                .collect(),
1799                ..Default::default()
1800            }),
1801            removed_components: [(
1802                "Component2".to_string(),
1803                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1804            )]
1805            .into_iter()
1806            .collect(),
1807        };
1808
1809        assert_eq!(second_msg, expected_second_msg);
1810        assert!(exit.is_ok());
1811    }
1812
1813    #[test(tokio::test)]
1814    async fn test_public_close_api_functionality() {
1815        // Tests the public close() API through the StateSynchronizer trait:
1816        // - close() fails before start() is called
1817        // - close() succeeds while synchronizer is running
1818        // - close() fails after already closed
1819        // This tests the full start/close lifecycle via the public API
1820
1821        let mut rpc_client = MockRPCClient::new();
1822        let mut deltas_client = MockDeltasClient::new();
1823
1824        // Mock the initial components call
1825        rpc_client
1826            .expect_get_protocol_components()
1827            .returning(|_| {
1828                Ok(ProtocolComponentRequestResponse {
1829                    protocol_components: vec![],
1830                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1831                })
1832            });
1833
1834        // Set up deltas client that will wait for messages (blocking in state_sync)
1835        let (_tx, rx) = channel(1);
1836        deltas_client
1837            .expect_subscribe()
1838            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1839
1840        // Expect unsubscribe call during cleanup
1841        deltas_client
1842            .expect_unsubscribe()
1843            .return_once(|_| Ok(()));
1844
1845        let mut state_sync = ProtocolStateSynchronizer::new(
1846            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1847            true,
1848            ComponentFilter::with_tvl_range(0.0, 0.0),
1849            5, // Enough retries
1850            Duration::from_secs(0),
1851            true,
1852            false,
1853            ArcRPCClient(Arc::new(rpc_client)),
1854            ArcDeltasClient(Arc::new(deltas_client)),
1855            10000_u64, // Long timeout so task doesn't exit on its own
1856        );
1857
1858        state_sync
1859            .initialize()
1860            .await
1861            .expect("Init should succeed");
1862
1863        // Start the synchronizer and test the new split-based close mechanism
1864        let (handle, _rx) = state_sync
1865            .start()
1866            .await
1867            .expect("Failed to start state synchronizer");
1868        let (jh, close_tx) = handle.split();
1869
1870        // Give it time to start up and enter state_sync
1871        tokio::time::sleep(Duration::from_millis(100)).await;
1872
1873        // Send close signal should succeed
1874        close_tx
1875            .send(())
1876            .expect("Should be able to send close signal");
1877        // Task should stop cleanly
1878        let task_result = jh.await.expect("Task should not panic");
1879        assert!(task_result.is_ok(), "Task should exit cleanly after close: {task_result:?}");
1880    }
1881
1882    #[test(tokio::test)]
1883    async fn test_cleanup_runs_when_state_sync_processing_errors() {
1884        // Tests that cleanup code runs when state_sync() errors during delta processing.
1885        // Specifically tests: RPC errors during snapshot retrieval cause proper cleanup.
1886        // Verifies: shared.last_synced_block reset + subscription unsubscribe on errors
1887
1888        let mut rpc_client = MockRPCClient::new();
1889        let mut deltas_client = MockDeltasClient::new();
1890
1891        // Mock the initial components call
1892        rpc_client
1893            .expect_get_protocol_components()
1894            .returning(|_| {
1895                Ok(ProtocolComponentRequestResponse {
1896                    protocol_components: vec![],
1897                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1898                })
1899            });
1900
1901        // Mock to fail during snapshot retrieval (this will cause an error during processing)
1902        rpc_client
1903            .expect_get_protocol_states()
1904            .returning(|_| {
1905                Err(RPCError::HttpClient("Test error during snapshot retrieval".to_string()))
1906            });
1907
1908        // Set up deltas client to send one message that will trigger snapshot retrieval
1909        let (tx, rx) = channel(10);
1910        deltas_client
1911            .expect_subscribe()
1912            .return_once(move |_, _| {
1913                // Send a delta message that will require a snapshot
1914                let delta = BlockChanges {
1915                    extractor: "test".to_string(),
1916                    chain: Chain::Ethereum,
1917                    block: Block {
1918                        hash: Bytes::from("0x0123"),
1919                        number: 1,
1920                        parent_hash: Bytes::from("0x0000"),
1921                        chain: Chain::Ethereum,
1922                        ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
1923                    },
1924                    revert: false,
1925                    // Add a new component to trigger snapshot request
1926                    new_protocol_components: [(
1927                        "new_component".to_string(),
1928                        ProtocolComponent {
1929                            id: "new_component".to_string(),
1930                            protocol_system: "test_protocol".to_string(),
1931                            protocol_type_name: "test".to_string(),
1932                            chain: Chain::Ethereum,
1933                            tokens: vec![Bytes::from("0x0badc0ffee")],
1934                            contract_ids: vec![Bytes::from("0x0badc0ffee")],
1935                            static_attributes: Default::default(),
1936                            creation_tx: Default::default(),
1937                            created_at: Default::default(),
1938                            change: Default::default(),
1939                        },
1940                    )]
1941                    .into_iter()
1942                    .collect(),
1943                    component_tvl: [("new_component".to_string(), 100.0)]
1944                        .into_iter()
1945                        .collect(),
1946                    ..Default::default()
1947                };
1948
1949                tokio::spawn(async move {
1950                    let _ = tx.send(delta).await;
1951                    // Close the channel after sending one message
1952                });
1953
1954                Ok((Uuid::default(), rx))
1955            });
1956
1957        // Expect unsubscribe call during cleanup
1958        deltas_client
1959            .expect_unsubscribe()
1960            .return_once(|_| Ok(()));
1961
1962        let mut state_sync = ProtocolStateSynchronizer::new(
1963            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1964            true,
1965            ComponentFilter::with_tvl_range(0.0, 1000.0), // Include the component
1966            1,
1967            Duration::from_secs(0),
1968            true,
1969            false,
1970            ArcRPCClient(Arc::new(rpc_client)),
1971            ArcDeltasClient(Arc::new(deltas_client)),
1972            5000_u64,
1973        );
1974
1975        state_sync
1976            .initialize()
1977            .await
1978            .expect("Init should succeed");
1979
1980        // Before calling state_sync, set a value in last_synced_block
1981        state_sync.last_synced_block = Some(BlockHeader {
1982            hash: Bytes::from("0x0badc0ffee"),
1983            number: 42,
1984            parent_hash: Bytes::from("0xbadbeef0"),
1985            revert: false,
1986            timestamp: 123456789,
1987        });
1988
1989        // Create a channel for state_sync to send messages to
1990        let (mut block_tx, _block_rx) = channel(10);
1991
1992        // Call state_sync directly - this should error during processing
1993        let (_end_tx, end_rx) = oneshot::channel::<()>();
1994        let result = state_sync
1995            .state_sync(&mut block_tx, end_rx)
1996            .await;
1997        // Verify that state_sync returned an error
1998        assert!(result.is_err(), "state_sync should have errored during processing");
1999
2000        // Note: We can't verify internal state cleanup since state_sync consumes self,
2001        // but the cleanup logic is still tested by the fact that the method returns properly.
2002    }
2003
2004    #[test(tokio::test)]
2005    async fn test_close_signal_while_waiting_for_first_deltas() {
2006        // Tests close signal handling during the initial "waiting for deltas" phase.
2007        // This is the earliest possible close scenario - before any deltas are received.
2008        // Verifies: close signal received while waiting for first message triggers cleanup
2009        let mut rpc_client = MockRPCClient::new();
2010        let mut deltas_client = MockDeltasClient::new();
2011
2012        rpc_client
2013            .expect_get_protocol_components()
2014            .returning(|_| {
2015                Ok(ProtocolComponentRequestResponse {
2016                    protocol_components: vec![],
2017                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2018                })
2019            });
2020
2021        let (_tx, rx) = channel(1);
2022        deltas_client
2023            .expect_subscribe()
2024            .return_once(move |_, _| Ok((Uuid::default(), rx)));
2025
2026        deltas_client
2027            .expect_unsubscribe()
2028            .return_once(|_| Ok(()));
2029
2030        let mut state_sync = ProtocolStateSynchronizer::new(
2031            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2032            true,
2033            ComponentFilter::with_tvl_range(0.0, 0.0),
2034            1,
2035            Duration::from_secs(0),
2036            true,
2037            false,
2038            ArcRPCClient(Arc::new(rpc_client)),
2039            ArcDeltasClient(Arc::new(deltas_client)),
2040            10000_u64,
2041        );
2042
2043        state_sync
2044            .initialize()
2045            .await
2046            .expect("Init should succeed");
2047
2048        let (mut block_tx, _block_rx) = channel(10);
2049        let (end_tx, end_rx) = oneshot::channel::<()>();
2050
2051        // Start state_sync in a task
2052        let state_sync_handle = tokio::spawn(async move {
2053            state_sync
2054                .state_sync(&mut block_tx, end_rx)
2055                .await
2056        });
2057
2058        // Give it a moment to start
2059        tokio::time::sleep(Duration::from_millis(100)).await;
2060
2061        // Send close signal
2062        let _ = end_tx.send(());
2063
2064        // state_sync should exit cleanly
2065        let result = state_sync_handle
2066            .await
2067            .expect("Task should not panic");
2068        assert!(result.is_ok(), "state_sync should exit cleanly when closed: {result:?}");
2069
2070        println!("SUCCESS: Close signal handled correctly while waiting for first deltas");
2071    }
2072
2073    #[test(tokio::test)]
2074    async fn test_close_signal_during_main_processing_loop() {
2075        // Tests close signal handling during the main delta processing loop.
2076        // This tests the scenario where first message is processed successfully,
2077        // then close signal is received while waiting for subsequent deltas.
2078        // Verifies: close signal in main loop (after initialization) triggers cleanup
2079
2080        let mut rpc_client = MockRPCClient::new();
2081        let mut deltas_client = MockDeltasClient::new();
2082
2083        // Mock the initial components call
2084        rpc_client
2085            .expect_get_protocol_components()
2086            .returning(|_| {
2087                Ok(ProtocolComponentRequestResponse {
2088                    protocol_components: vec![],
2089                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2090                })
2091            });
2092
2093        // Mock the snapshot retrieval that happens after first message
2094        rpc_client
2095            .expect_get_protocol_states()
2096            .returning(|_| {
2097                Ok(ProtocolStateRequestResponse {
2098                    states: vec![],
2099                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2100                })
2101            });
2102
2103        rpc_client
2104            .expect_get_component_tvl()
2105            .returning(|_| {
2106                Ok(ComponentTvlRequestResponse {
2107                    tvl: HashMap::new(),
2108                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2109                })
2110            });
2111
2112        rpc_client
2113            .expect_get_traced_entry_points()
2114            .returning(|_| {
2115                Ok(TracedEntryPointRequestResponse {
2116                    traced_entry_points: HashMap::new(),
2117                    pagination: PaginationResponse::new(0, 20, 0),
2118                })
2119            });
2120
2121        // Set up deltas client to send one message, then keep channel open
2122        let (tx, rx) = channel(10);
2123        deltas_client
2124            .expect_subscribe()
2125            .return_once(move |_, _| {
2126                // Send first message immediately
2127                let first_delta = BlockChanges {
2128                    extractor: "test".to_string(),
2129                    chain: Chain::Ethereum,
2130                    block: Block {
2131                        hash: Bytes::from("0x0123"),
2132                        number: 1,
2133                        parent_hash: Bytes::from("0x0000"),
2134                        chain: Chain::Ethereum,
2135                        ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
2136                    },
2137                    revert: false,
2138                    ..Default::default()
2139                };
2140
2141                tokio::spawn(async move {
2142                    let _ = tx.send(first_delta).await;
2143                    // Keep the sender alive but don't send more messages
2144                    // This will make the recv() block waiting for the next message
2145                    tokio::time::sleep(Duration::from_secs(30)).await;
2146                });
2147
2148                Ok((Uuid::default(), rx))
2149            });
2150
2151        deltas_client
2152            .expect_unsubscribe()
2153            .return_once(|_| Ok(()));
2154
2155        let mut state_sync = ProtocolStateSynchronizer::new(
2156            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2157            true,
2158            ComponentFilter::with_tvl_range(0.0, 1000.0),
2159            1,
2160            Duration::from_secs(0),
2161            true,
2162            false,
2163            ArcRPCClient(Arc::new(rpc_client)),
2164            ArcDeltasClient(Arc::new(deltas_client)),
2165            10000_u64,
2166        );
2167
2168        state_sync
2169            .initialize()
2170            .await
2171            .expect("Init should succeed");
2172
2173        let (mut block_tx, mut block_rx) = channel(10);
2174        let (end_tx, end_rx) = oneshot::channel::<()>();
2175
2176        // Start state_sync in a task
2177        let state_sync_handle = tokio::spawn(async move {
2178            state_sync
2179                .state_sync(&mut block_tx, end_rx)
2180                .await
2181        });
2182
2183        // Wait for the first message to be processed (snapshot sent)
2184        let first_snapshot = block_rx
2185            .recv()
2186            .await
2187            .expect("Should receive first snapshot");
2188        assert!(
2189            !first_snapshot
2190                .snapshots
2191                .states
2192                .is_empty() ||
2193                first_snapshot.deltas.is_some()
2194        );
2195        // Now send close signal - this should be handled in the main processing loop
2196        let _ = end_tx.send(());
2197
2198        // state_sync should exit cleanly after receiving close signal in main loop
2199        let result = state_sync_handle
2200            .await
2201            .expect("Task should not panic");
2202        assert!(
2203            result.is_ok(),
2204            "state_sync should exit cleanly when closed after first message: {result:?}"
2205        );
2206        println!("SUCCESS: Close signal handled correctly during main processing loop");
2207    }
2208}