tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot,
11    },
12    task::JoinHandle,
13    time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, Chain, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        BlockHeader, HeaderLike,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
80    extractor_id: ExtractorIdentity,
81    retrieve_balances: bool,
82    rpc_client: R,
83    deltas_client: D,
84    max_retries: u64,
85    include_snapshots: bool,
86    component_tracker: ComponentTracker<R>,
87    last_synced_block: Option<BlockHeader>,
88    timeout: u64,
89    include_tvl: bool,
90}
91
92#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
93pub struct ComponentWithState {
94    pub state: ResponseProtocolState,
95    pub component: ProtocolComponent,
96    pub component_tvl: Option<f64>,
97    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
98}
99
100#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
101pub struct Snapshot {
102    pub states: HashMap<String, ComponentWithState>,
103    pub vm_storage: HashMap<Bytes, ResponseAccount>,
104}
105
106impl Snapshot {
107    fn extend(&mut self, other: Snapshot) {
108        self.states.extend(other.states);
109        self.vm_storage.extend(other.vm_storage);
110    }
111
112    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
113        &self.states
114    }
115
116    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
117        &self.vm_storage
118    }
119}
120
121#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
122pub struct StateSyncMessage<H>
123where
124    H: HeaderLike,
125{
126    /// The block information for this update.
127    pub header: H,
128    /// Snapshot for new components.
129    pub snapshots: Snapshot,
130    /// A single delta contains state updates for all tracked components, as well as additional
131    /// information about the system components e.g. newly added components (even below tvl), tvl
132    /// updates, balance updates.
133    pub deltas: Option<BlockChanges>,
134    /// Components that stopped being tracked.
135    pub removed_components: HashMap<String, ProtocolComponent>,
136}
137
138impl<H> StateSyncMessage<H>
139where
140    H: HeaderLike,
141{
142    pub fn merge(mut self, other: Self) -> Self {
143        // be careful with removed and snapshots attributes here, these can be ambiguous.
144        self.removed_components
145            .retain(|k, _| !other.snapshots.states.contains_key(k));
146        self.snapshots
147            .states
148            .retain(|k, _| !other.removed_components.contains_key(k));
149
150        self.snapshots.extend(other.snapshots);
151        let deltas = match (self.deltas, other.deltas) {
152            (Some(l), Some(r)) => Some(l.merge(r)),
153            (None, Some(r)) => Some(r),
154            (Some(l), None) => Some(l),
155            (None, None) => None,
156        };
157        self.removed_components
158            .extend(other.removed_components);
159        Self {
160            header: other.header,
161            snapshots: self.snapshots,
162            deltas,
163            removed_components: self.removed_components,
164        }
165    }
166}
167
168/// Handle for controlling a running synchronizer task.
169///
170/// This handle provides methods to gracefully shut down the synchronizer
171/// and await its completion with a timeout.
172pub struct SynchronizerTaskHandle {
173    join_handle: JoinHandle<SyncResult<()>>,
174    close_tx: oneshot::Sender<()>,
175}
176
177/// StateSynchronizer
178///
179/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
180/// delivering messages to the client that let him reconstruct subsets of the protocol state.
181///
182/// This involves deciding which components to track according to the clients preferences,
183/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
184/// delivering delta messages for the components that have changed.
185impl SynchronizerTaskHandle {
186    pub fn new(join_handle: JoinHandle<SyncResult<()>>, close_tx: oneshot::Sender<()>) -> Self {
187        Self { join_handle, close_tx }
188    }
189
190    /// Splits the handle into its join handle and close sender.
191    ///
192    /// This allows monitoring the task completion separately from controlling shutdown.
193    /// The join handle can be used with FuturesUnordered for monitoring, while the
194    /// close sender can be used to signal graceful shutdown.
195    pub fn split(self) -> (JoinHandle<SyncResult<()>>, oneshot::Sender<()>) {
196        (self.join_handle, self.close_tx)
197    }
198}
199
200#[async_trait]
201pub trait StateSynchronizer: Send + Sync + 'static {
202    async fn initialize(&mut self) -> SyncResult<()>;
203    /// Starts the state synchronization, consuming the synchronizer.
204    /// Returns a handle for controlling the running task and a receiver for messages.
205    async fn start(
206        mut self,
207    ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)>;
208}
209
210impl<R, D> ProtocolStateSynchronizer<R, D>
211where
212    // TODO: Consider moving these constraints directly to the
213    // client...
214    R: RPCClient + Clone + Send + Sync + 'static,
215    D: DeltasClient + Clone + Send + Sync + 'static,
216{
217    /// Creates a new state synchronizer.
218    #[allow(clippy::too_many_arguments)]
219    pub fn new(
220        extractor_id: ExtractorIdentity,
221        retrieve_balances: bool,
222        component_filter: ComponentFilter,
223        max_retries: u64,
224        include_snapshots: bool,
225        include_tvl: bool,
226        rpc_client: R,
227        deltas_client: D,
228        timeout: u64,
229    ) -> Self {
230        Self {
231            extractor_id: extractor_id.clone(),
232            retrieve_balances,
233            rpc_client: rpc_client.clone(),
234            include_snapshots,
235            deltas_client,
236            component_tracker: ComponentTracker::new(
237                extractor_id.chain,
238                extractor_id.name.as_str(),
239                component_filter,
240                rpc_client,
241            ),
242            max_retries,
243            last_synced_block: None,
244            timeout,
245            include_tvl,
246        }
247    }
248
249    /// Retrieves state snapshots of the requested components
250    #[allow(deprecated)]
251    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
252        &mut self,
253        header: BlockHeader,
254        ids: Option<I>,
255    ) -> SyncResult<StateSyncMessage<BlockHeader>> {
256        if !self.include_snapshots {
257            return Ok(StateSyncMessage { header, ..Default::default() });
258        }
259        let version = VersionParam::new(
260            None,
261            Some(BlockParam {
262                chain: Some(self.extractor_id.chain),
263                hash: None,
264                number: Some(header.number as i64),
265            }),
266        );
267
268        // Use given ids or use all if not passed
269        let component_ids: Vec<_> = match ids {
270            Some(ids) => ids.into_iter().cloned().collect(),
271            None => self
272                .component_tracker
273                .get_tracked_component_ids(),
274        };
275
276        if component_ids.is_empty() {
277            return Ok(StateSyncMessage { header, ..Default::default() });
278        }
279
280        let component_tvl = if self.include_tvl {
281            let body = ComponentTvlRequestBody::id_filtered(
282                component_ids.clone(),
283                self.extractor_id.chain,
284            );
285            self.rpc_client
286                .get_component_tvl_paginated(&body, 100, 4)
287                .await?
288                .tvl
289        } else {
290            HashMap::new()
291        };
292
293        //TODO: Improve this, we should not query for every component, but only for the ones that
294        // could have entrypoints. Maybe apply a filter per protocol?
295        let entrypoints_result = if self.extractor_id.chain == Chain::Ethereum {
296            // Fetch entrypoints
297            let result = self
298                .rpc_client
299                .get_traced_entry_points_paginated(
300                    self.extractor_id.chain,
301                    &self.extractor_id.name,
302                    &component_ids,
303                    100,
304                    4,
305                )
306                .await?;
307            self.component_tracker
308                .process_entrypoints(&result.clone().into());
309            Some(result)
310        } else {
311            None
312        };
313
314        // Fetch protocol states
315        let mut protocol_states = self
316            .rpc_client
317            .get_protocol_states_paginated(
318                self.extractor_id.chain,
319                &component_ids,
320                &self.extractor_id.name,
321                self.retrieve_balances,
322                &version,
323                100,
324                4,
325            )
326            .await?
327            .states
328            .into_iter()
329            .map(|state| (state.component_id.clone(), state))
330            .collect::<HashMap<_, _>>();
331
332        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
333        let states = self
334            .component_tracker
335            .components
336            .values()
337            .filter_map(|component| {
338                if let Some(state) = protocol_states.remove(&component.id) {
339                    Some((
340                        component.id.clone(),
341                        ComponentWithState {
342                            state,
343                            component: component.clone(),
344                            component_tvl: component_tvl
345                                .get(&component.id)
346                                .cloned(),
347                            entrypoints: entrypoints_result
348                                .as_ref()
349                                .map(|r| {
350                                    r.traced_entry_points
351                                        .get(&component.id)
352                                        .cloned()
353                                        .unwrap_or_default()
354                                })
355                                .unwrap_or_default(),
356                        },
357                    ))
358                } else if component_ids.contains(&component.id) {
359                    // only emit error event if we requested this component
360                    let component_id = &component.id;
361                    error!(?component_id, "Missing state for native component!");
362                    None
363                } else {
364                    None
365                }
366            })
367            .collect();
368
369        // Fetch contract states
370        let contract_ids = self
371            .component_tracker
372            .get_contracts_by_component(&component_ids);
373        let vm_storage = if !contract_ids.is_empty() {
374            let ids: Vec<Bytes> = contract_ids
375                .clone()
376                .into_iter()
377                .collect();
378            let contract_states = self
379                .rpc_client
380                .get_contract_state_paginated(
381                    self.extractor_id.chain,
382                    ids.as_slice(),
383                    &self.extractor_id.name,
384                    &version,
385                    100,
386                    4,
387                )
388                .await?
389                .accounts
390                .into_iter()
391                .map(|acc| (acc.address.clone(), acc))
392                .collect::<HashMap<_, _>>();
393
394            trace!(states=?&contract_states, "Retrieved ContractState");
395
396            let contract_address_to_components = self
397                .component_tracker
398                .components
399                .iter()
400                .filter_map(|(id, comp)| {
401                    if component_ids.contains(id) {
402                        Some(
403                            comp.contract_ids
404                                .iter()
405                                .map(|address| (address.clone(), comp.id.clone())),
406                        )
407                    } else {
408                        None
409                    }
410                })
411                .flatten()
412                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
413                    acc.entry(addr).or_default().push(c_id);
414                    acc
415                });
416
417            contract_ids
418                .iter()
419                .filter_map(|address| {
420                    if let Some(state) = contract_states.get(address) {
421                        Some((address.clone(), state.clone()))
422                    } else if let Some(ids) = contract_address_to_components.get(address) {
423                        // only emit error even if we did actually request this address
424                        error!(
425                            ?address,
426                            ?ids,
427                            "Component with lacking contract storage encountered!"
428                        );
429                        None
430                    } else {
431                        None
432                    }
433                })
434                .collect()
435        } else {
436            HashMap::new()
437        };
438
439        Ok(StateSyncMessage {
440            header,
441            snapshots: Snapshot { states, vm_storage },
442            deltas: None,
443            removed_components: HashMap::new(),
444        })
445    }
446
447    /// Main method that does all the work.
448    ///
449    /// ## Return Value
450    ///
451    /// Returns a `Result` where:
452    /// - `Ok(())` - Synchronization completed successfully (usually due to close signal)
453    /// - `Err((error, None))` - Error occurred AND close signal was received (don't retry)
454    /// - `Err((error, Some(end_rx)))` - Error occurred but close signal was NOT received (can
455    ///   retry)
456    ///
457    /// The returned `end_rx` (if any) should be reused for retry attempts since the close
458    /// signal may still arrive and we want to remain cancellable across retries.
459    #[instrument(skip(self, block_tx, end_rx), fields(extractor_id = %self.extractor_id))]
460    async fn state_sync(
461        &mut self,
462        block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
463        mut end_rx: oneshot::Receiver<()>,
464    ) -> Result<(), (SynchronizerError, Option<oneshot::Receiver<()>>)> {
465        // initialisation
466
467        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
468        let (subscription_id, mut msg_rx) = match self
469            .deltas_client
470            .subscribe(self.extractor_id.clone(), subscription_options)
471            .await
472        {
473            Ok(result) => result,
474            Err(e) => return Err((e.into(), Some(end_rx))),
475        };
476
477        let result = async {
478            info!("Waiting for deltas...");
479            // wait for first deltas message
480            let mut first_msg = select! {
481                deltas_result = timeout(Duration::from_secs(self.timeout), msg_rx.recv()) => {
482                    deltas_result
483                        .map_err(|_| {
484                            SynchronizerError::Timeout(format!(
485                                "First deltas took longer than {t}s to arrive",
486                                t = self.timeout
487                            ))
488                        })?
489                        .ok_or_else(|| {
490                            SynchronizerError::ConnectionError(
491                                "Deltas channel closed before first message".to_string(),
492                            )
493                        })?
494                },
495                _ = &mut end_rx => {
496                    info!("Received close signal while waiting for first deltas");
497                    return Ok(());
498                }
499            };
500            self.filter_deltas(&mut first_msg);
501
502            // initial snapshot
503            let block = first_msg.get_block().clone();
504            info!(height = &block.number, "Deltas received. Retrieving snapshot");
505            let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
506            let snapshot = self
507                .get_snapshots::<Vec<&String>>(
508                    BlockHeader::from_block(&block, false),
509                    None,
510                )
511                .await?
512                .merge(StateSyncMessage {
513                    header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
514                    snapshots: Default::default(),
515                    deltas: Some(first_msg),
516                    removed_components: Default::default(),
517                });
518
519            let n_components = self.component_tracker.components.len();
520            let n_snapshots = snapshot.snapshots.states.len();
521            info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
522
523            block_tx.send(snapshot).await?;
524            self.last_synced_block = Some(header.clone());
525            loop {
526                select! {
527                    deltas_opt = msg_rx.recv() => {
528                        if let Some(mut deltas) = deltas_opt {
529                            let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
530                            debug!(block_number=?header.number, "Received delta message");
531
532                            let (snapshots, removed_components) = {
533                                // 1. Remove components based on latest changes
534                                // 2. Add components based on latest changes, query those for snapshots
535                                let (to_add, to_remove) = self.component_tracker.filter_updated_components(&deltas);
536
537                                // Only components we don't track yet need a snapshot,
538                                let requiring_snapshot: Vec<_> = to_add
539                                    .iter()
540                                    .filter(|id| {
541                                        !self.component_tracker
542                                            .components
543                                            .contains_key(id.as_str())
544                                    })
545                                    .collect();
546                                debug!(components=?requiring_snapshot, "SnapshotRequest");
547                                self.component_tracker
548                                    .start_tracking(requiring_snapshot.as_slice())
549                                    .await?;
550                                let snapshots = self
551                                    .get_snapshots(header.clone(), Some(requiring_snapshot))
552                                    .await?
553                                    .snapshots;
554
555                                let removed_components = if !to_remove.is_empty() {
556                                    self.component_tracker.stop_tracking(&to_remove)
557                                } else {
558                                    Default::default()
559                                };
560
561                                (snapshots, removed_components)
562                            };
563
564                            // 3. Update entrypoints on the tracker (affects which contracts are tracked)
565                            self.component_tracker.process_entrypoints(&deltas.dci_update);
566
567                            // 4. Filter deltas by currently tracked components / contracts
568                            self.filter_deltas(&mut deltas);
569                            let n_changes = deltas.n_changes();
570
571                            // 5. Send the message
572                            let next = StateSyncMessage {
573                                header: header.clone(),
574                                snapshots,
575                                deltas: Some(deltas),
576                                removed_components,
577                            };
578                            block_tx.send(next).await?;
579                            self.last_synced_block = Some(header.clone());
580
581                            debug!(block_number=?header.number, n_changes, "Finished processing delta message");
582                        } else {
583                            return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
584                        }
585                    },
586                    _ = &mut end_rx => {
587                        info!("Received close signal during state_sync");
588                        return Ok(());
589                    }
590                }
591            }
592        }.await;
593
594        // This cleanup code now runs regardless of how the function exits (error or channel close)
595        warn!(last_synced_block = ?&self.last_synced_block, "Deltas processing ended, resetting last synced block.");
596        self.last_synced_block = None;
597        //Ignore error
598        let _ = self
599            .deltas_client
600            .unsubscribe(subscription_id)
601            .await
602            .map_err(|err| {
603                warn!(err=?err, "Unsubscribing from deltas on cleanup failed!");
604            });
605
606        // Handle the result: if it succeeded, we're done. If it errored, we need to determine
607        // whether the end_rx was consumed (close signal received) or not
608        match result {
609            Ok(()) => Ok(()), // Success, likely due to close signal
610            Err(e) => {
611                // The error came from the inner async block. Since the async block
612                // can receive close signals (which would return Ok), any error means
613                // the close signal was NOT received, so we can return the end_rx for retry
614                Err((e, Some(end_rx)))
615            }
616        }
617    }
618
619    fn filter_deltas(&self, deltas: &mut BlockChanges) {
620        deltas.filter_by_component(|id| {
621            self.component_tracker
622                .components
623                .contains_key(id)
624        });
625        deltas.filter_by_contract(|id| {
626            self.component_tracker
627                .contracts
628                .contains(id)
629        });
630    }
631}
632
633#[async_trait]
634impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
635where
636    R: RPCClient + Clone + Send + Sync + 'static,
637    D: DeltasClient + Clone + Send + Sync + 'static,
638{
639    async fn initialize(&mut self) -> SyncResult<()> {
640        info!("Retrieving relevant protocol components");
641        self.component_tracker
642            .initialise_components()
643            .await?;
644        info!(
645            n_components = self.component_tracker.components.len(),
646            n_contracts = self.component_tracker.contracts.len(),
647            "Finished retrieving components",
648        );
649
650        Ok(())
651    }
652
653    async fn start(
654        mut self,
655    ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)> {
656        let (mut tx, rx) = channel(15);
657        let (end_tx, end_rx) = oneshot::channel::<()>();
658
659        let jh = tokio::spawn(async move {
660            let mut retry_count = 0;
661            let mut current_end_rx = end_rx;
662
663            while retry_count < self.max_retries {
664                info!(extractor_id=%&self.extractor_id, retry_count, "(Re)starting synchronization loop");
665
666                let res = self
667                    .state_sync(&mut tx, current_end_rx)
668                    .await;
669                match res {
670                    Ok(()) => {
671                        info!(
672                            extractor_id=%&self.extractor_id,
673                            retry_count,
674                            "State synchronization exited cleanly"
675                        );
676                        return Ok(());
677                    }
678                    Err((e, maybe_end_rx)) => {
679                        error!(
680                            extractor_id=%&self.extractor_id,
681                            retry_count,
682                            error=%e,
683                            "State synchronization errored!"
684                        );
685
686                        // If we have the end_rx back, we can retry
687                        if let Some(recovered_end_rx) = maybe_end_rx {
688                            current_end_rx = recovered_end_rx;
689
690                            if let SynchronizerError::ConnectionClosed = e {
691                                // break synchronization loop if connection is closed
692                                return Err(e);
693                            }
694                        } else {
695                            // Close signal was received, exit cleanly
696                            info!(extractor_id=%&self.extractor_id, "Received close signal, exiting");
697                            return Ok(());
698                        }
699                    }
700                }
701                retry_count += 1;
702            }
703            warn!(extractor_id=%&self.extractor_id, retry_count, "Max retries exceeded");
704            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
705        });
706
707        let handle = SynchronizerTaskHandle::new(jh, end_tx);
708        Ok((handle, rx))
709    }
710}
711
712#[cfg(test)]
713mod test {
714    //! Test suite for ProtocolStateSynchronizer shutdown and cleanup behavior.
715    //!
716    //! ## Test Coverage Strategy:
717    //!
718    //! ### Shutdown & Close Signal Tests:
719    //! - `test_public_close_api_functionality` - Tests public API (start/close lifecycle)
720    //! - `test_close_signal_while_waiting_for_first_deltas` - Close during initial wait
721    //! - `test_close_signal_during_main_processing_loop` - Close during main processing
722    //!
723    //! ### Cleanup & Error Handling Tests:
724    //! - `test_cleanup_runs_when_state_sync_processing_errors` - Cleanup on processing errors
725    //!
726    //! ### Coverage Summary:
727    //! These tests ensure cleanup code (shared state reset + unsubscribe) runs on ALL exit paths:
728    //! ✓ Close signal before first deltas   ✓ Close signal during processing
729    //! ✓ Processing errors                  ✓ Channel closure
730    //! ✓ Public API close operations        ✓ Normal completion
731
732    use std::{collections::HashSet, sync::Arc};
733
734    use test_log::test;
735    use tycho_common::dto::{
736        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
737        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
738        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
739        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
740        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
741        TracedEntryPointRequestResponse, TracingParams,
742    };
743    use uuid::Uuid;
744
745    use super::*;
746    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
747
748    // Required for mock client to implement clone
749    struct ArcRPCClient<T>(Arc<T>);
750
751    // Default derive(Clone) does require T to be Clone as well.
752    impl<T> Clone for ArcRPCClient<T> {
753        fn clone(&self) -> Self {
754            ArcRPCClient(self.0.clone())
755        }
756    }
757
758    #[async_trait]
759    impl<T> RPCClient for ArcRPCClient<T>
760    where
761        T: RPCClient + Sync + Send + 'static,
762    {
763        async fn get_tokens(
764            &self,
765            request: &TokensRequestBody,
766        ) -> Result<TokensRequestResponse, RPCError> {
767            self.0.get_tokens(request).await
768        }
769
770        async fn get_contract_state(
771            &self,
772            request: &StateRequestBody,
773        ) -> Result<StateRequestResponse, RPCError> {
774            self.0.get_contract_state(request).await
775        }
776
777        async fn get_protocol_components(
778            &self,
779            request: &ProtocolComponentsRequestBody,
780        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
781            self.0
782                .get_protocol_components(request)
783                .await
784        }
785
786        async fn get_protocol_states(
787            &self,
788            request: &ProtocolStateRequestBody,
789        ) -> Result<ProtocolStateRequestResponse, RPCError> {
790            self.0
791                .get_protocol_states(request)
792                .await
793        }
794
795        async fn get_protocol_systems(
796            &self,
797            request: &ProtocolSystemsRequestBody,
798        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
799            self.0
800                .get_protocol_systems(request)
801                .await
802        }
803
804        async fn get_component_tvl(
805            &self,
806            request: &ComponentTvlRequestBody,
807        ) -> Result<ComponentTvlRequestResponse, RPCError> {
808            self.0.get_component_tvl(request).await
809        }
810
811        async fn get_traced_entry_points(
812            &self,
813            request: &TracedEntryPointRequestBody,
814        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
815            self.0
816                .get_traced_entry_points(request)
817                .await
818        }
819    }
820
821    // Required for mock client to implement clone
822    struct ArcDeltasClient<T>(Arc<T>);
823
824    // Default derive(Clone) does require T to be Clone as well.
825    impl<T> Clone for ArcDeltasClient<T> {
826        fn clone(&self) -> Self {
827            ArcDeltasClient(self.0.clone())
828        }
829    }
830
831    #[async_trait]
832    impl<T> DeltasClient for ArcDeltasClient<T>
833    where
834        T: DeltasClient + Sync + Send + 'static,
835    {
836        async fn subscribe(
837            &self,
838            extractor_id: ExtractorIdentity,
839            options: SubscriptionOptions,
840        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
841            self.0
842                .subscribe(extractor_id, options)
843                .await
844        }
845
846        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
847            self.0
848                .unsubscribe(subscription_id)
849                .await
850        }
851
852        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
853            self.0.connect().await
854        }
855
856        async fn close(&self) -> Result<(), DeltasError> {
857            self.0.close().await
858        }
859    }
860
861    fn with_mocked_clients(
862        native: bool,
863        include_tvl: bool,
864        rpc_client: Option<MockRPCClient>,
865        deltas_client: Option<MockDeltasClient>,
866    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
867    {
868        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
869        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
870
871        ProtocolStateSynchronizer::new(
872            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
873            native,
874            ComponentFilter::with_tvl_range(50.0, 50.0),
875            1,
876            true,
877            include_tvl,
878            rpc_client,
879            deltas_client,
880            10_u64,
881        )
882    }
883
884    fn state_snapshot_native() -> ProtocolStateRequestResponse {
885        ProtocolStateRequestResponse {
886            states: vec![ResponseProtocolState {
887                component_id: "Component1".to_string(),
888                ..Default::default()
889            }],
890            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
891        }
892    }
893
894    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
895        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
896
897        ComponentTvlRequestResponse {
898            tvl,
899            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
900        }
901    }
902
903    #[test(tokio::test)]
904    async fn test_get_snapshots_native() {
905        let header = BlockHeader::default();
906        let mut rpc = MockRPCClient::new();
907        rpc.expect_get_protocol_states()
908            .returning(|_| Ok(state_snapshot_native()));
909        rpc.expect_get_traced_entry_points()
910            .returning(|_| {
911                Ok(TracedEntryPointRequestResponse {
912                    traced_entry_points: HashMap::new(),
913                    pagination: PaginationResponse::new(0, 20, 0),
914                })
915            });
916        let mut state_sync = with_mocked_clients(true, false, Some(rpc), None);
917        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
918        state_sync
919            .component_tracker
920            .components
921            .insert("Component1".to_string(), component.clone());
922        let components_arg = ["Component1".to_string()];
923        let exp = StateSyncMessage {
924            header: header.clone(),
925            snapshots: Snapshot {
926                states: state_snapshot_native()
927                    .states
928                    .into_iter()
929                    .map(|state| {
930                        (
931                            state.component_id.clone(),
932                            ComponentWithState {
933                                state,
934                                component: component.clone(),
935                                entrypoints: vec![],
936                                component_tvl: None,
937                            },
938                        )
939                    })
940                    .collect(),
941                vm_storage: HashMap::new(),
942            },
943            deltas: None,
944            removed_components: Default::default(),
945        };
946
947        let snap = state_sync
948            .get_snapshots(header, Some(&components_arg))
949            .await
950            .expect("Retrieving snapshot failed");
951
952        assert_eq!(snap, exp);
953    }
954
955    #[test(tokio::test)]
956    async fn test_get_snapshots_native_with_tvl() {
957        let header = BlockHeader::default();
958        let mut rpc = MockRPCClient::new();
959        rpc.expect_get_protocol_states()
960            .returning(|_| Ok(state_snapshot_native()));
961        rpc.expect_get_component_tvl()
962            .returning(|_| Ok(component_tvl_snapshot()));
963        rpc.expect_get_traced_entry_points()
964            .returning(|_| {
965                Ok(TracedEntryPointRequestResponse {
966                    traced_entry_points: HashMap::new(),
967                    pagination: PaginationResponse::new(0, 20, 0),
968                })
969            });
970        let mut state_sync = with_mocked_clients(true, true, Some(rpc), None);
971        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
972        state_sync
973            .component_tracker
974            .components
975            .insert("Component1".to_string(), component.clone());
976        let components_arg = ["Component1".to_string()];
977        let exp = StateSyncMessage {
978            header: header.clone(),
979            snapshots: Snapshot {
980                states: state_snapshot_native()
981                    .states
982                    .into_iter()
983                    .map(|state| {
984                        (
985                            state.component_id.clone(),
986                            ComponentWithState {
987                                state,
988                                component: component.clone(),
989                                component_tvl: Some(100.0),
990                                entrypoints: vec![],
991                            },
992                        )
993                    })
994                    .collect(),
995                vm_storage: HashMap::new(),
996            },
997            deltas: None,
998            removed_components: Default::default(),
999        };
1000
1001        let snap = state_sync
1002            .get_snapshots(header, Some(&components_arg))
1003            .await
1004            .expect("Retrieving snapshot failed");
1005
1006        assert_eq!(snap, exp);
1007    }
1008
1009    fn state_snapshot_vm() -> StateRequestResponse {
1010        StateRequestResponse {
1011            accounts: vec![
1012                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
1013                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
1014            ],
1015            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1016        }
1017    }
1018
1019    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
1020        TracedEntryPointRequestResponse {
1021            traced_entry_points: HashMap::from([(
1022                "Component1".to_string(),
1023                vec![(
1024                    EntryPointWithTracingParams {
1025                        entry_point: EntryPoint {
1026                            external_id: "entrypoint_a".to_string(),
1027                            target: Bytes::from("0x0badc0ffee"),
1028                            signature: "sig()".to_string(),
1029                        },
1030                        params: TracingParams::RPCTracer(RPCTracerParams {
1031                            caller: Some(Bytes::from("0x0badc0ffee")),
1032                            calldata: Bytes::from("0x0badc0ffee"),
1033                        }),
1034                    },
1035                    TracingResult {
1036                        retriggers: HashSet::from([(
1037                            Bytes::from("0x0badc0ffee"),
1038                            Bytes::from("0x0badc0ffee"),
1039                        )]),
1040                        accessed_slots: HashMap::from([(
1041                            Bytes::from("0x0badc0ffee"),
1042                            HashSet::from([Bytes::from("0xbadbeef0")]),
1043                        )]),
1044                    },
1045                )],
1046            )]),
1047            pagination: PaginationResponse::new(0, 20, 0),
1048        }
1049    }
1050
1051    #[test(tokio::test)]
1052    async fn test_get_snapshots_vm() {
1053        let header = BlockHeader::default();
1054        let mut rpc = MockRPCClient::new();
1055        rpc.expect_get_protocol_states()
1056            .returning(|_| Ok(state_snapshot_native()));
1057        rpc.expect_get_contract_state()
1058            .returning(|_| Ok(state_snapshot_vm()));
1059        rpc.expect_get_traced_entry_points()
1060            .returning(|_| Ok(traced_entry_point_response()));
1061        let mut state_sync = with_mocked_clients(false, false, Some(rpc), None);
1062        let component = ProtocolComponent {
1063            id: "Component1".to_string(),
1064            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1065            ..Default::default()
1066        };
1067        state_sync
1068            .component_tracker
1069            .components
1070            .insert("Component1".to_string(), component.clone());
1071        let components_arg = ["Component1".to_string()];
1072        let exp = StateSyncMessage {
1073            header: header.clone(),
1074            snapshots: Snapshot {
1075                states: [(
1076                    component.id.clone(),
1077                    ComponentWithState {
1078                        state: ResponseProtocolState {
1079                            component_id: "Component1".to_string(),
1080                            ..Default::default()
1081                        },
1082                        component: component.clone(),
1083                        component_tvl: None,
1084                        entrypoints: vec![(
1085                            EntryPointWithTracingParams {
1086                                entry_point: EntryPoint {
1087                                    external_id: "entrypoint_a".to_string(),
1088                                    target: Bytes::from("0x0badc0ffee"),
1089                                    signature: "sig()".to_string(),
1090                                },
1091                                params: TracingParams::RPCTracer(RPCTracerParams {
1092                                    caller: Some(Bytes::from("0x0badc0ffee")),
1093                                    calldata: Bytes::from("0x0badc0ffee"),
1094                                }),
1095                            },
1096                            TracingResult {
1097                                retriggers: HashSet::from([(
1098                                    Bytes::from("0x0badc0ffee"),
1099                                    Bytes::from("0x0badc0ffee"),
1100                                )]),
1101                                accessed_slots: HashMap::from([(
1102                                    Bytes::from("0x0badc0ffee"),
1103                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1104                                )]),
1105                            },
1106                        )],
1107                    },
1108                )]
1109                .into_iter()
1110                .collect(),
1111                vm_storage: state_snapshot_vm()
1112                    .accounts
1113                    .into_iter()
1114                    .map(|state| (state.address.clone(), state))
1115                    .collect(),
1116            },
1117            deltas: None,
1118            removed_components: Default::default(),
1119        };
1120
1121        let snap = state_sync
1122            .get_snapshots(header, Some(&components_arg))
1123            .await
1124            .expect("Retrieving snapshot failed");
1125
1126        assert_eq!(snap, exp);
1127    }
1128
1129    #[test(tokio::test)]
1130    async fn test_get_snapshots_vm_with_tvl() {
1131        let header = BlockHeader::default();
1132        let mut rpc = MockRPCClient::new();
1133        rpc.expect_get_protocol_states()
1134            .returning(|_| Ok(state_snapshot_native()));
1135        rpc.expect_get_contract_state()
1136            .returning(|_| Ok(state_snapshot_vm()));
1137        rpc.expect_get_component_tvl()
1138            .returning(|_| Ok(component_tvl_snapshot()));
1139        rpc.expect_get_traced_entry_points()
1140            .returning(|_| {
1141                Ok(TracedEntryPointRequestResponse {
1142                    traced_entry_points: HashMap::new(),
1143                    pagination: PaginationResponse::new(0, 20, 0),
1144                })
1145            });
1146        let mut state_sync = with_mocked_clients(false, true, Some(rpc), None);
1147        let component = ProtocolComponent {
1148            id: "Component1".to_string(),
1149            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1150            ..Default::default()
1151        };
1152        state_sync
1153            .component_tracker
1154            .components
1155            .insert("Component1".to_string(), component.clone());
1156        let components_arg = ["Component1".to_string()];
1157        let exp = StateSyncMessage {
1158            header: header.clone(),
1159            snapshots: Snapshot {
1160                states: [(
1161                    component.id.clone(),
1162                    ComponentWithState {
1163                        state: ResponseProtocolState {
1164                            component_id: "Component1".to_string(),
1165                            ..Default::default()
1166                        },
1167                        component: component.clone(),
1168                        component_tvl: Some(100.0),
1169                        entrypoints: vec![],
1170                    },
1171                )]
1172                .into_iter()
1173                .collect(),
1174                vm_storage: state_snapshot_vm()
1175                    .accounts
1176                    .into_iter()
1177                    .map(|state| (state.address.clone(), state))
1178                    .collect(),
1179            },
1180            deltas: None,
1181            removed_components: Default::default(),
1182        };
1183
1184        let snap = state_sync
1185            .get_snapshots(header, Some(&components_arg))
1186            .await
1187            .expect("Retrieving snapshot failed");
1188
1189        assert_eq!(snap, exp);
1190    }
1191
1192    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1193        let mut rpc_client = MockRPCClient::new();
1194        // Mocks for the start_tracking call, these need to come first because they are more
1195        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1196        rpc_client
1197            .expect_get_protocol_components()
1198            .with(mockall::predicate::function(
1199                move |request_params: &ProtocolComponentsRequestBody| {
1200                    if let Some(ids) = request_params.component_ids.as_ref() {
1201                        ids.contains(&"Component3".to_string())
1202                    } else {
1203                        false
1204                    }
1205                },
1206            ))
1207            .returning(|_| {
1208                // return Component3
1209                Ok(ProtocolComponentRequestResponse {
1210                    protocol_components: vec![
1211                        // this component shall have a tvl update above threshold
1212                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1213                    ],
1214                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1215                })
1216            });
1217        rpc_client
1218            .expect_get_protocol_states()
1219            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1220                let expected_id = "Component3".to_string();
1221                if let Some(ids) = request_params.protocol_ids.as_ref() {
1222                    ids.contains(&expected_id)
1223                } else {
1224                    false
1225                }
1226            }))
1227            .returning(|_| {
1228                // return Component3 state
1229                Ok(ProtocolStateRequestResponse {
1230                    states: vec![ResponseProtocolState {
1231                        component_id: "Component3".to_string(),
1232                        ..Default::default()
1233                    }],
1234                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1235                })
1236            });
1237
1238        // mock calls for the initial state snapshots
1239        rpc_client
1240            .expect_get_protocol_components()
1241            .returning(|_| {
1242                // Initial sync of components
1243                Ok(ProtocolComponentRequestResponse {
1244                    protocol_components: vec![
1245                        // this component shall have a tvl update above threshold
1246                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1247                        // this component shall have a tvl update below threshold.
1248                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1249                        // a third component will have a tvl update above threshold
1250                    ],
1251                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1252                })
1253            });
1254        rpc_client
1255            .expect_get_protocol_states()
1256            .returning(|_| {
1257                // Initial state snapshot
1258                Ok(ProtocolStateRequestResponse {
1259                    states: vec![
1260                        ResponseProtocolState {
1261                            component_id: "Component1".to_string(),
1262                            ..Default::default()
1263                        },
1264                        ResponseProtocolState {
1265                            component_id: "Component2".to_string(),
1266                            ..Default::default()
1267                        },
1268                    ],
1269                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1270                })
1271            });
1272        rpc_client
1273            .expect_get_component_tvl()
1274            .returning(|_| {
1275                Ok(ComponentTvlRequestResponse {
1276                    tvl: [
1277                        ("Component1".to_string(), 100.0),
1278                        ("Component2".to_string(), 0.0),
1279                        ("Component3".to_string(), 1000.0),
1280                    ]
1281                    .into_iter()
1282                    .collect(),
1283                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1284                })
1285            });
1286        rpc_client
1287            .expect_get_traced_entry_points()
1288            .returning(|_| {
1289                Ok(TracedEntryPointRequestResponse {
1290                    traced_entry_points: HashMap::new(),
1291                    pagination: PaginationResponse::new(0, 20, 0),
1292                })
1293            });
1294
1295        // Mock deltas client and messages
1296        let mut deltas_client = MockDeltasClient::new();
1297        let (tx, rx) = channel(1);
1298        deltas_client
1299            .expect_subscribe()
1300            .return_once(move |_, _| {
1301                // Return subscriber id and a channel
1302                Ok((Uuid::default(), rx))
1303            });
1304
1305        // Expect unsubscribe call during cleanup
1306        deltas_client
1307            .expect_unsubscribe()
1308            .return_once(|_| Ok(()));
1309
1310        (rpc_client, deltas_client, tx)
1311    }
1312
1313    /// Test strategy
1314    ///
1315    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1316    /// - send 2 dummy messages, containing only blocks
1317    /// - third message contains a new component with some significant tvl, one initial component
1318    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1319    #[test(tokio::test)]
1320    async fn test_state_sync() {
1321        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1322        let deltas = [
1323            BlockChanges {
1324                extractor: "uniswap-v2".to_string(),
1325                chain: Chain::Ethereum,
1326                block: Block {
1327                    number: 1,
1328                    hash: Bytes::from("0x01"),
1329                    parent_hash: Bytes::from("0x00"),
1330                    chain: Chain::Ethereum,
1331                    ts: Default::default(),
1332                },
1333                revert: false,
1334                dci_update: DCIUpdate {
1335                    new_entrypoints: HashMap::from([(
1336                        "Component1".to_string(),
1337                        HashSet::from([EntryPoint {
1338                            external_id: "entrypoint_a".to_string(),
1339                            target: Bytes::from("0x0badc0ffee"),
1340                            signature: "sig()".to_string(),
1341                        }]),
1342                    )]),
1343                    new_entrypoint_params: HashMap::from([(
1344                        "entrypoint_a".to_string(),
1345                        HashSet::from([(
1346                            TracingParams::RPCTracer(RPCTracerParams {
1347                                caller: Some(Bytes::from("0x0badc0ffee")),
1348                                calldata: Bytes::from("0x0badc0ffee"),
1349                            }),
1350                            Some("Component1".to_string()),
1351                        )]),
1352                    )]),
1353                    trace_results: HashMap::from([(
1354                        "entrypoint_a".to_string(),
1355                        TracingResult {
1356                            retriggers: HashSet::from([(
1357                                Bytes::from("0x0badc0ffee"),
1358                                Bytes::from("0x0badc0ffee"),
1359                            )]),
1360                            accessed_slots: HashMap::from([(
1361                                Bytes::from("0x0badc0ffee"),
1362                                HashSet::from([Bytes::from("0xbadbeef0")]),
1363                            )]),
1364                        },
1365                    )]),
1366                },
1367                ..Default::default()
1368            },
1369            BlockChanges {
1370                extractor: "uniswap-v2".to_string(),
1371                chain: Chain::Ethereum,
1372                block: Block {
1373                    number: 2,
1374                    hash: Bytes::from("0x02"),
1375                    parent_hash: Bytes::from("0x01"),
1376                    chain: Chain::Ethereum,
1377                    ts: Default::default(),
1378                },
1379                revert: false,
1380                component_tvl: [
1381                    ("Component1".to_string(), 100.0),
1382                    ("Component2".to_string(), 0.0),
1383                    ("Component3".to_string(), 1000.0),
1384                ]
1385                .into_iter()
1386                .collect(),
1387                ..Default::default()
1388            },
1389        ];
1390        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1391        state_sync
1392            .initialize()
1393            .await
1394            .expect("Init failed");
1395
1396        // Test starts here
1397        let (handle, mut rx) = state_sync
1398            .start()
1399            .await
1400            .expect("Failed to start state synchronizer");
1401        let (jh, close_tx) = handle.split();
1402        tx.send(deltas[0].clone())
1403            .await
1404            .expect("deltas channel msg 0 closed!");
1405        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1406            .await
1407            .expect("waiting for first state msg timed out!")
1408            .expect("state sync block sender closed!");
1409        tx.send(deltas[1].clone())
1410            .await
1411            .expect("deltas channel msg 1 closed!");
1412        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1413            .await
1414            .expect("waiting for second state msg timed out!")
1415            .expect("state sync block sender closed!");
1416        let _ = close_tx.send(());
1417        let exit = jh
1418            .await
1419            .expect("state sync task panicked!");
1420
1421        // assertions
1422        let exp1 = StateSyncMessage {
1423            header: BlockHeader {
1424                number: 1,
1425                hash: Bytes::from("0x01"),
1426                parent_hash: Bytes::from("0x00"),
1427                revert: false,
1428                ..Default::default()
1429            },
1430            snapshots: Snapshot {
1431                states: [
1432                    (
1433                        "Component1".to_string(),
1434                        ComponentWithState {
1435                            state: ResponseProtocolState {
1436                                component_id: "Component1".to_string(),
1437                                ..Default::default()
1438                            },
1439                            component: ProtocolComponent {
1440                                id: "Component1".to_string(),
1441                                ..Default::default()
1442                            },
1443                            component_tvl: Some(100.0),
1444                            entrypoints: vec![],
1445                        },
1446                    ),
1447                    (
1448                        "Component2".to_string(),
1449                        ComponentWithState {
1450                            state: ResponseProtocolState {
1451                                component_id: "Component2".to_string(),
1452                                ..Default::default()
1453                            },
1454                            component: ProtocolComponent {
1455                                id: "Component2".to_string(),
1456                                ..Default::default()
1457                            },
1458                            component_tvl: Some(0.0),
1459                            entrypoints: vec![],
1460                        },
1461                    ),
1462                ]
1463                .into_iter()
1464                .collect(),
1465                vm_storage: HashMap::new(),
1466            },
1467            deltas: Some(deltas[0].clone()),
1468            removed_components: Default::default(),
1469        };
1470
1471        let exp2 = StateSyncMessage {
1472            header: BlockHeader {
1473                number: 2,
1474                hash: Bytes::from("0x02"),
1475                parent_hash: Bytes::from("0x01"),
1476                revert: false,
1477                ..Default::default()
1478            },
1479            snapshots: Snapshot {
1480                states: [
1481                    // This is the new component we queried once it passed the tvl threshold.
1482                    (
1483                        "Component3".to_string(),
1484                        ComponentWithState {
1485                            state: ResponseProtocolState {
1486                                component_id: "Component3".to_string(),
1487                                ..Default::default()
1488                            },
1489                            component: ProtocolComponent {
1490                                id: "Component3".to_string(),
1491                                ..Default::default()
1492                            },
1493                            component_tvl: Some(1000.0),
1494                            entrypoints: vec![],
1495                        },
1496                    ),
1497                ]
1498                .into_iter()
1499                .collect(),
1500                vm_storage: HashMap::new(),
1501            },
1502            // Our deltas are empty and since merge methods are
1503            // tested in tycho-common we don't have much to do here.
1504            deltas: Some(BlockChanges {
1505                extractor: "uniswap-v2".to_string(),
1506                chain: Chain::Ethereum,
1507                block: Block {
1508                    number: 2,
1509                    hash: Bytes::from("0x02"),
1510                    parent_hash: Bytes::from("0x01"),
1511                    chain: Chain::Ethereum,
1512                    ts: Default::default(),
1513                },
1514                revert: false,
1515                component_tvl: [
1516                    // "Component2" should not show here.
1517                    ("Component1".to_string(), 100.0),
1518                    ("Component3".to_string(), 1000.0),
1519                ]
1520                .into_iter()
1521                .collect(),
1522                ..Default::default()
1523            }),
1524            // "Component2" was removed, because its tvl changed to 0.
1525            removed_components: [(
1526                "Component2".to_string(),
1527                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1528            )]
1529            .into_iter()
1530            .collect(),
1531        };
1532        assert_eq!(first_msg, exp1);
1533        assert_eq!(second_msg, exp2);
1534        assert!(exit.is_ok());
1535    }
1536
1537    #[test(tokio::test)]
1538    async fn test_state_sync_with_tvl_range() {
1539        // Define the range for testing
1540        let remove_tvl_threshold = 5.0;
1541        let add_tvl_threshold = 7.0;
1542
1543        let mut rpc_client = MockRPCClient::new();
1544        let mut deltas_client = MockDeltasClient::new();
1545
1546        rpc_client
1547            .expect_get_protocol_components()
1548            .with(mockall::predicate::function(
1549                move |request_params: &ProtocolComponentsRequestBody| {
1550                    if let Some(ids) = request_params.component_ids.as_ref() {
1551                        ids.contains(&"Component3".to_string())
1552                    } else {
1553                        false
1554                    }
1555                },
1556            ))
1557            .returning(|_| {
1558                Ok(ProtocolComponentRequestResponse {
1559                    protocol_components: vec![ProtocolComponent {
1560                        id: "Component3".to_string(),
1561                        ..Default::default()
1562                    }],
1563                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1564                })
1565            });
1566        rpc_client
1567            .expect_get_protocol_states()
1568            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1569                let expected_id = "Component3".to_string();
1570                if let Some(ids) = request_params.protocol_ids.as_ref() {
1571                    ids.contains(&expected_id)
1572                } else {
1573                    false
1574                }
1575            }))
1576            .returning(|_| {
1577                Ok(ProtocolStateRequestResponse {
1578                    states: vec![ResponseProtocolState {
1579                        component_id: "Component3".to_string(),
1580                        ..Default::default()
1581                    }],
1582                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1583                })
1584            });
1585
1586        // Mock for the initial snapshot retrieval
1587        rpc_client
1588            .expect_get_protocol_components()
1589            .returning(|_| {
1590                Ok(ProtocolComponentRequestResponse {
1591                    protocol_components: vec![
1592                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1593                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1594                    ],
1595                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1596                })
1597            });
1598        rpc_client
1599            .expect_get_protocol_states()
1600            .returning(|_| {
1601                Ok(ProtocolStateRequestResponse {
1602                    states: vec![
1603                        ResponseProtocolState {
1604                            component_id: "Component1".to_string(),
1605                            ..Default::default()
1606                        },
1607                        ResponseProtocolState {
1608                            component_id: "Component2".to_string(),
1609                            ..Default::default()
1610                        },
1611                    ],
1612                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1613                })
1614            });
1615        rpc_client
1616            .expect_get_traced_entry_points()
1617            .returning(|_| {
1618                Ok(TracedEntryPointRequestResponse {
1619                    traced_entry_points: HashMap::new(),
1620                    pagination: PaginationResponse::new(0, 20, 0),
1621                })
1622            });
1623
1624        rpc_client
1625            .expect_get_component_tvl()
1626            .returning(|_| {
1627                Ok(ComponentTvlRequestResponse {
1628                    tvl: [
1629                        ("Component1".to_string(), 6.0),
1630                        ("Component2".to_string(), 2.0),
1631                        ("Component3".to_string(), 10.0),
1632                    ]
1633                    .into_iter()
1634                    .collect(),
1635                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1636                })
1637            });
1638
1639        rpc_client
1640            .expect_get_component_tvl()
1641            .returning(|_| {
1642                Ok(ComponentTvlRequestResponse {
1643                    tvl: [
1644                        ("Component1".to_string(), 6.0),
1645                        ("Component2".to_string(), 2.0),
1646                        ("Component3".to_string(), 10.0),
1647                    ]
1648                    .into_iter()
1649                    .collect(),
1650                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1651                })
1652            });
1653
1654        let (tx, rx) = channel(1);
1655        deltas_client
1656            .expect_subscribe()
1657            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1658
1659        // Expect unsubscribe call during cleanup
1660        deltas_client
1661            .expect_unsubscribe()
1662            .return_once(|_| Ok(()));
1663
1664        let mut state_sync = ProtocolStateSynchronizer::new(
1665            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1666            true,
1667            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1668            1,
1669            true,
1670            true,
1671            ArcRPCClient(Arc::new(rpc_client)),
1672            ArcDeltasClient(Arc::new(deltas_client)),
1673            10_u64,
1674        );
1675        state_sync
1676            .initialize()
1677            .await
1678            .expect("Init failed");
1679
1680        // Simulate the incoming BlockChanges
1681        let deltas = [
1682            BlockChanges {
1683                extractor: "uniswap-v2".to_string(),
1684                chain: Chain::Ethereum,
1685                block: Block {
1686                    number: 1,
1687                    hash: Bytes::from("0x01"),
1688                    parent_hash: Bytes::from("0x00"),
1689                    chain: Chain::Ethereum,
1690                    ts: Default::default(),
1691                },
1692                revert: false,
1693                ..Default::default()
1694            },
1695            BlockChanges {
1696                extractor: "uniswap-v2".to_string(),
1697                chain: Chain::Ethereum,
1698                block: Block {
1699                    number: 2,
1700                    hash: Bytes::from("0x02"),
1701                    parent_hash: Bytes::from("0x01"),
1702                    chain: Chain::Ethereum,
1703                    ts: Default::default(),
1704                },
1705                revert: false,
1706                component_tvl: [
1707                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1708                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1709                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1710                ]
1711                .into_iter()
1712                .collect(),
1713                ..Default::default()
1714            },
1715        ];
1716
1717        let (handle, mut rx) = state_sync
1718            .start()
1719            .await
1720            .expect("Failed to start state synchronizer");
1721        let (jh, close_tx) = handle.split();
1722
1723        // Simulate sending delta messages
1724        tx.send(deltas[0].clone())
1725            .await
1726            .expect("deltas channel msg 0 closed!");
1727
1728        // Expecting to receive the initial state message
1729        let _ = timeout(Duration::from_millis(100), rx.recv())
1730            .await
1731            .expect("waiting for first state msg timed out!")
1732            .expect("state sync block sender closed!");
1733
1734        // Send the third message, which should trigger TVL-based changes
1735        tx.send(deltas[1].clone())
1736            .await
1737            .expect("deltas channel msg 1 closed!");
1738        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1739            .await
1740            .expect("waiting for second state msg timed out!")
1741            .expect("state sync block sender closed!");
1742
1743        let _ = close_tx.send(());
1744        let exit = jh
1745            .await
1746            .expect("state sync task panicked!");
1747
1748        let expected_second_msg = StateSyncMessage {
1749            header: BlockHeader {
1750                number: 2,
1751                hash: Bytes::from("0x02"),
1752                parent_hash: Bytes::from("0x01"),
1753                revert: false,
1754                ..Default::default()
1755            },
1756            snapshots: Snapshot {
1757                states: [(
1758                    "Component3".to_string(),
1759                    ComponentWithState {
1760                        state: ResponseProtocolState {
1761                            component_id: "Component3".to_string(),
1762                            ..Default::default()
1763                        },
1764                        component: ProtocolComponent {
1765                            id: "Component3".to_string(),
1766                            ..Default::default()
1767                        },
1768                        component_tvl: Some(10.0),
1769                        entrypoints: vec![], // TODO: add entrypoints?
1770                    },
1771                )]
1772                .into_iter()
1773                .collect(),
1774                vm_storage: HashMap::new(),
1775            },
1776            deltas: Some(BlockChanges {
1777                extractor: "uniswap-v2".to_string(),
1778                chain: Chain::Ethereum,
1779                block: Block {
1780                    number: 2,
1781                    hash: Bytes::from("0x02"),
1782                    parent_hash: Bytes::from("0x01"),
1783                    chain: Chain::Ethereum,
1784                    ts: Default::default(),
1785                },
1786                revert: false,
1787                component_tvl: [
1788                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1789                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1790                ]
1791                .into_iter()
1792                .collect(),
1793                ..Default::default()
1794            }),
1795            removed_components: [(
1796                "Component2".to_string(),
1797                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1798            )]
1799            .into_iter()
1800            .collect(),
1801        };
1802
1803        assert_eq!(second_msg, expected_second_msg);
1804        assert!(exit.is_ok());
1805    }
1806
1807    #[test(tokio::test)]
1808    async fn test_public_close_api_functionality() {
1809        // Tests the public close() API through the StateSynchronizer trait:
1810        // - close() fails before start() is called
1811        // - close() succeeds while synchronizer is running
1812        // - close() fails after already closed
1813        // This tests the full start/close lifecycle via the public API
1814
1815        let mut rpc_client = MockRPCClient::new();
1816        let mut deltas_client = MockDeltasClient::new();
1817
1818        // Mock the initial components call
1819        rpc_client
1820            .expect_get_protocol_components()
1821            .returning(|_| {
1822                Ok(ProtocolComponentRequestResponse {
1823                    protocol_components: vec![],
1824                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1825                })
1826            });
1827
1828        // Set up deltas client that will wait for messages (blocking in state_sync)
1829        let (_tx, rx) = channel(1);
1830        deltas_client
1831            .expect_subscribe()
1832            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1833
1834        // Expect unsubscribe call during cleanup
1835        deltas_client
1836            .expect_unsubscribe()
1837            .return_once(|_| Ok(()));
1838
1839        let mut state_sync = ProtocolStateSynchronizer::new(
1840            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1841            true,
1842            ComponentFilter::with_tvl_range(0.0, 0.0),
1843            5, // Enough retries
1844            true,
1845            false,
1846            ArcRPCClient(Arc::new(rpc_client)),
1847            ArcDeltasClient(Arc::new(deltas_client)),
1848            10000_u64, // Long timeout so task doesn't exit on its own
1849        );
1850
1851        state_sync
1852            .initialize()
1853            .await
1854            .expect("Init should succeed");
1855
1856        // Start the synchronizer and test the new split-based close mechanism
1857        let (handle, _rx) = state_sync
1858            .start()
1859            .await
1860            .expect("Failed to start state synchronizer");
1861        let (jh, close_tx) = handle.split();
1862
1863        // Give it time to start up and enter state_sync
1864        tokio::time::sleep(Duration::from_millis(100)).await;
1865
1866        // Send close signal should succeed
1867        close_tx
1868            .send(())
1869            .expect("Should be able to send close signal");
1870        // Task should stop cleanly
1871        let task_result = jh.await.expect("Task should not panic");
1872        assert!(task_result.is_ok(), "Task should exit cleanly after close: {task_result:?}");
1873    }
1874
1875    #[test(tokio::test)]
1876    async fn test_cleanup_runs_when_state_sync_processing_errors() {
1877        // Tests that cleanup code runs when state_sync() errors during delta processing.
1878        // Specifically tests: RPC errors during snapshot retrieval cause proper cleanup.
1879        // Verifies: shared.last_synced_block reset + subscription unsubscribe on errors
1880
1881        let mut rpc_client = MockRPCClient::new();
1882        let mut deltas_client = MockDeltasClient::new();
1883
1884        // Mock the initial components call
1885        rpc_client
1886            .expect_get_protocol_components()
1887            .returning(|_| {
1888                Ok(ProtocolComponentRequestResponse {
1889                    protocol_components: vec![],
1890                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1891                })
1892            });
1893
1894        // Mock to fail during snapshot retrieval (this will cause an error during processing)
1895        rpc_client
1896            .expect_get_protocol_states()
1897            .returning(|_| {
1898                Err(RPCError::HttpClient("Test error during snapshot retrieval".to_string()))
1899            });
1900
1901        // Set up deltas client to send one message that will trigger snapshot retrieval
1902        let (tx, rx) = channel(10);
1903        deltas_client
1904            .expect_subscribe()
1905            .return_once(move |_, _| {
1906                // Send a delta message that will require a snapshot
1907                let delta = BlockChanges {
1908                    extractor: "test".to_string(),
1909                    chain: Chain::Ethereum,
1910                    block: Block {
1911                        hash: Bytes::from("0x0123"),
1912                        number: 1,
1913                        parent_hash: Bytes::from("0x0000"),
1914                        chain: Chain::Ethereum,
1915                        ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
1916                    },
1917                    revert: false,
1918                    // Add a new component to trigger snapshot request
1919                    new_protocol_components: [(
1920                        "new_component".to_string(),
1921                        ProtocolComponent {
1922                            id: "new_component".to_string(),
1923                            protocol_system: "test_protocol".to_string(),
1924                            protocol_type_name: "test".to_string(),
1925                            chain: Chain::Ethereum,
1926                            tokens: vec![Bytes::from("0x0badc0ffee")],
1927                            contract_ids: vec![Bytes::from("0x0badc0ffee")],
1928                            static_attributes: Default::default(),
1929                            creation_tx: Default::default(),
1930                            created_at: Default::default(),
1931                            change: Default::default(),
1932                        },
1933                    )]
1934                    .into_iter()
1935                    .collect(),
1936                    component_tvl: [("new_component".to_string(), 100.0)]
1937                        .into_iter()
1938                        .collect(),
1939                    ..Default::default()
1940                };
1941
1942                tokio::spawn(async move {
1943                    let _ = tx.send(delta).await;
1944                    // Close the channel after sending one message
1945                });
1946
1947                Ok((Uuid::default(), rx))
1948            });
1949
1950        // Expect unsubscribe call during cleanup
1951        deltas_client
1952            .expect_unsubscribe()
1953            .return_once(|_| Ok(()));
1954
1955        let mut state_sync = ProtocolStateSynchronizer::new(
1956            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1957            true,
1958            ComponentFilter::with_tvl_range(0.0, 1000.0), // Include the component
1959            1,
1960            true,
1961            false,
1962            ArcRPCClient(Arc::new(rpc_client)),
1963            ArcDeltasClient(Arc::new(deltas_client)),
1964            5000_u64,
1965        );
1966
1967        state_sync
1968            .initialize()
1969            .await
1970            .expect("Init should succeed");
1971
1972        // Before calling state_sync, set a value in last_synced_block
1973        state_sync.last_synced_block = Some(BlockHeader {
1974            hash: Bytes::from("0x0badc0ffee"),
1975            number: 42,
1976            parent_hash: Bytes::from("0xbadbeef0"),
1977            revert: false,
1978            timestamp: 123456789,
1979        });
1980
1981        // Create a channel for state_sync to send messages to
1982        let (mut block_tx, _block_rx) = channel(10);
1983
1984        // Call state_sync directly - this should error during processing
1985        let (_end_tx, end_rx) = oneshot::channel::<()>();
1986        let result = state_sync
1987            .state_sync(&mut block_tx, end_rx)
1988            .await;
1989        // Verify that state_sync returned an error
1990        assert!(result.is_err(), "state_sync should have errored during processing");
1991
1992        // Note: We can't verify internal state cleanup since state_sync consumes self,
1993        // but the cleanup logic is still tested by the fact that the method returns properly.
1994    }
1995
1996    #[test(tokio::test)]
1997    async fn test_close_signal_while_waiting_for_first_deltas() {
1998        // Tests close signal handling during the initial "waiting for deltas" phase.
1999        // This is the earliest possible close scenario - before any deltas are received.
2000        // Verifies: close signal received while waiting for first message triggers cleanup
2001        let mut rpc_client = MockRPCClient::new();
2002        let mut deltas_client = MockDeltasClient::new();
2003
2004        rpc_client
2005            .expect_get_protocol_components()
2006            .returning(|_| {
2007                Ok(ProtocolComponentRequestResponse {
2008                    protocol_components: vec![],
2009                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2010                })
2011            });
2012
2013        let (_tx, rx) = channel(1);
2014        deltas_client
2015            .expect_subscribe()
2016            .return_once(move |_, _| Ok((Uuid::default(), rx)));
2017
2018        deltas_client
2019            .expect_unsubscribe()
2020            .return_once(|_| Ok(()));
2021
2022        let mut state_sync = ProtocolStateSynchronizer::new(
2023            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2024            true,
2025            ComponentFilter::with_tvl_range(0.0, 0.0),
2026            1,
2027            true,
2028            false,
2029            ArcRPCClient(Arc::new(rpc_client)),
2030            ArcDeltasClient(Arc::new(deltas_client)),
2031            10000_u64,
2032        );
2033
2034        state_sync
2035            .initialize()
2036            .await
2037            .expect("Init should succeed");
2038
2039        let (mut block_tx, _block_rx) = channel(10);
2040        let (end_tx, end_rx) = oneshot::channel::<()>();
2041
2042        // Start state_sync in a task
2043        let state_sync_handle = tokio::spawn(async move {
2044            state_sync
2045                .state_sync(&mut block_tx, end_rx)
2046                .await
2047        });
2048
2049        // Give it a moment to start
2050        tokio::time::sleep(Duration::from_millis(100)).await;
2051
2052        // Send close signal
2053        let _ = end_tx.send(());
2054
2055        // state_sync should exit cleanly
2056        let result = state_sync_handle
2057            .await
2058            .expect("Task should not panic");
2059        assert!(result.is_ok(), "state_sync should exit cleanly when closed: {result:?}");
2060
2061        println!("SUCCESS: Close signal handled correctly while waiting for first deltas");
2062    }
2063
2064    #[test(tokio::test)]
2065    async fn test_close_signal_during_main_processing_loop() {
2066        // Tests close signal handling during the main delta processing loop.
2067        // This tests the scenario where first message is processed successfully,
2068        // then close signal is received while waiting for subsequent deltas.
2069        // Verifies: close signal in main loop (after initialization) triggers cleanup
2070
2071        let mut rpc_client = MockRPCClient::new();
2072        let mut deltas_client = MockDeltasClient::new();
2073
2074        // Mock the initial components call
2075        rpc_client
2076            .expect_get_protocol_components()
2077            .returning(|_| {
2078                Ok(ProtocolComponentRequestResponse {
2079                    protocol_components: vec![],
2080                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2081                })
2082            });
2083
2084        // Mock the snapshot retrieval that happens after first message
2085        rpc_client
2086            .expect_get_protocol_states()
2087            .returning(|_| {
2088                Ok(ProtocolStateRequestResponse {
2089                    states: vec![],
2090                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2091                })
2092            });
2093
2094        rpc_client
2095            .expect_get_component_tvl()
2096            .returning(|_| {
2097                Ok(ComponentTvlRequestResponse {
2098                    tvl: HashMap::new(),
2099                    pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2100                })
2101            });
2102
2103        rpc_client
2104            .expect_get_traced_entry_points()
2105            .returning(|_| {
2106                Ok(TracedEntryPointRequestResponse {
2107                    traced_entry_points: HashMap::new(),
2108                    pagination: PaginationResponse::new(0, 20, 0),
2109                })
2110            });
2111
2112        // Set up deltas client to send one message, then keep channel open
2113        let (tx, rx) = channel(10);
2114        deltas_client
2115            .expect_subscribe()
2116            .return_once(move |_, _| {
2117                // Send first message immediately
2118                let first_delta = BlockChanges {
2119                    extractor: "test".to_string(),
2120                    chain: Chain::Ethereum,
2121                    block: Block {
2122                        hash: Bytes::from("0x0123"),
2123                        number: 1,
2124                        parent_hash: Bytes::from("0x0000"),
2125                        chain: Chain::Ethereum,
2126                        ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
2127                    },
2128                    revert: false,
2129                    ..Default::default()
2130                };
2131
2132                tokio::spawn(async move {
2133                    let _ = tx.send(first_delta).await;
2134                    // Keep the sender alive but don't send more messages
2135                    // This will make the recv() block waiting for the next message
2136                    tokio::time::sleep(Duration::from_secs(30)).await;
2137                });
2138
2139                Ok((Uuid::default(), rx))
2140            });
2141
2142        deltas_client
2143            .expect_unsubscribe()
2144            .return_once(|_| Ok(()));
2145
2146        let mut state_sync = ProtocolStateSynchronizer::new(
2147            ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2148            true,
2149            ComponentFilter::with_tvl_range(0.0, 1000.0),
2150            1,
2151            true,
2152            false,
2153            ArcRPCClient(Arc::new(rpc_client)),
2154            ArcDeltasClient(Arc::new(deltas_client)),
2155            10000_u64,
2156        );
2157
2158        state_sync
2159            .initialize()
2160            .await
2161            .expect("Init should succeed");
2162
2163        let (mut block_tx, mut block_rx) = channel(10);
2164        let (end_tx, end_rx) = oneshot::channel::<()>();
2165
2166        // Start state_sync in a task
2167        let state_sync_handle = tokio::spawn(async move {
2168            state_sync
2169                .state_sync(&mut block_tx, end_rx)
2170                .await
2171        });
2172
2173        // Wait for the first message to be processed (snapshot sent)
2174        let first_snapshot = block_rx
2175            .recv()
2176            .await
2177            .expect("Should receive first snapshot");
2178        assert!(
2179            !first_snapshot
2180                .snapshots
2181                .states
2182                .is_empty() ||
2183                first_snapshot.deltas.is_some()
2184        );
2185        // Now send close signal - this should be handled in the main processing loop
2186        let _ = end_tx.send(());
2187
2188        // state_sync should exit cleanly after receiving close signal in main loop
2189        let result = state_sync_handle
2190            .await
2191            .expect("Task should not panic");
2192        assert!(
2193            result.is_ok(),
2194            "state_sync should exit cleanly when closed after first message: {result:?}"
2195        );
2196        println!("SUCCESS: Close signal handled correctly during main processing loop");
2197    }
2198}