1use std::{collections::HashMap, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7 select,
8 sync::{
9 mpsc::{channel, error::SendError, Receiver, Sender},
10 oneshot,
11 },
12 task::JoinHandle,
13 time::{sleep, timeout},
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17 dto::{
18 BlockChanges, BlockParam, Chain, ComponentTvlRequestBody, EntryPointWithTracingParams,
19 ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20 TracingResult, VersionParam,
21 },
22 Bytes,
23};
24
25use crate::{
26 deltas::{DeltasClient, SubscriptionOptions},
27 feed::{
28 component_tracker::{ComponentFilter, ComponentTracker},
29 BlockHeader, HeaderLike,
30 },
31 rpc::{RPCClient, RPCError},
32 DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37 #[error("RPC error: {0}")]
39 RPCError(#[from] RPCError),
40
41 #[error("Failed to send channel message: {0}")]
43 ChannelError(String),
44
45 #[error("Timeout error: {0}")]
47 Timeout(String),
48
49 #[error("Failed to close synchronizer: {0}")]
51 CloseError(String),
52
53 #[error("Connection error: {0}")]
55 ConnectionError(String),
56
57 #[error("Connection closed")]
59 ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65 fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66 SynchronizerError::ChannelError(err.to_string())
67 }
68}
69
70impl From<DeltasError> for SynchronizerError {
71 fn from(err: DeltasError) -> Self {
72 match err {
73 DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74 _ => SynchronizerError::ConnectionError(err.to_string()),
75 }
76 }
77}
78
79pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
80 extractor_id: ExtractorIdentity,
81 retrieve_balances: bool,
82 rpc_client: R,
83 deltas_client: D,
84 max_retries: u64,
85 retry_cooldown: Duration,
86 include_snapshots: bool,
87 component_tracker: ComponentTracker<R>,
88 last_synced_block: Option<BlockHeader>,
89 timeout: u64,
90 include_tvl: bool,
91}
92
93#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
94pub struct ComponentWithState {
95 pub state: ResponseProtocolState,
96 pub component: ProtocolComponent,
97 pub component_tvl: Option<f64>,
98 pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
99}
100
101#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
102pub struct Snapshot {
103 pub states: HashMap<String, ComponentWithState>,
104 pub vm_storage: HashMap<Bytes, ResponseAccount>,
105}
106
107impl Snapshot {
108 fn extend(&mut self, other: Snapshot) {
109 self.states.extend(other.states);
110 self.vm_storage.extend(other.vm_storage);
111 }
112
113 pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
114 &self.states
115 }
116
117 pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
118 &self.vm_storage
119 }
120}
121
122#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
123pub struct StateSyncMessage<H>
124where
125 H: HeaderLike,
126{
127 pub header: H,
129 pub snapshots: Snapshot,
131 pub deltas: Option<BlockChanges>,
135 pub removed_components: HashMap<String, ProtocolComponent>,
137}
138
139impl<H> StateSyncMessage<H>
140where
141 H: HeaderLike,
142{
143 pub fn merge(mut self, other: Self) -> Self {
144 self.removed_components
146 .retain(|k, _| !other.snapshots.states.contains_key(k));
147 self.snapshots
148 .states
149 .retain(|k, _| !other.removed_components.contains_key(k));
150
151 self.snapshots.extend(other.snapshots);
152 let deltas = match (self.deltas, other.deltas) {
153 (Some(l), Some(r)) => Some(l.merge(r)),
154 (None, Some(r)) => Some(r),
155 (Some(l), None) => Some(l),
156 (None, None) => None,
157 };
158 self.removed_components
159 .extend(other.removed_components);
160 Self {
161 header: other.header,
162 snapshots: self.snapshots,
163 deltas,
164 removed_components: self.removed_components,
165 }
166 }
167}
168
169pub struct SynchronizerTaskHandle {
174 join_handle: JoinHandle<SyncResult<()>>,
175 close_tx: oneshot::Sender<()>,
176}
177
178impl SynchronizerTaskHandle {
187 pub fn new(join_handle: JoinHandle<SyncResult<()>>, close_tx: oneshot::Sender<()>) -> Self {
188 Self { join_handle, close_tx }
189 }
190
191 pub fn split(self) -> (JoinHandle<SyncResult<()>>, oneshot::Sender<()>) {
197 (self.join_handle, self.close_tx)
198 }
199}
200
201#[async_trait]
202pub trait StateSynchronizer: Send + Sync + 'static {
203 async fn initialize(&mut self) -> SyncResult<()>;
204 async fn start(
207 mut self,
208 ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)>;
209}
210
211impl<R, D> ProtocolStateSynchronizer<R, D>
212where
213 R: RPCClient + Clone + Send + Sync + 'static,
216 D: DeltasClient + Clone + Send + Sync + 'static,
217{
218 #[allow(clippy::too_many_arguments)]
220 pub fn new(
221 extractor_id: ExtractorIdentity,
222 retrieve_balances: bool,
223 component_filter: ComponentFilter,
224 max_retries: u64,
225 retry_cooldown: Duration,
226 include_snapshots: bool,
227 include_tvl: bool,
228 rpc_client: R,
229 deltas_client: D,
230 timeout: u64,
231 ) -> Self {
232 Self {
233 extractor_id: extractor_id.clone(),
234 retrieve_balances,
235 rpc_client: rpc_client.clone(),
236 include_snapshots,
237 deltas_client,
238 component_tracker: ComponentTracker::new(
239 extractor_id.chain,
240 extractor_id.name.as_str(),
241 component_filter,
242 rpc_client,
243 ),
244 max_retries,
245 retry_cooldown,
246 last_synced_block: None,
247 timeout,
248 include_tvl,
249 }
250 }
251
252 #[allow(deprecated)]
254 async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
255 &mut self,
256 header: BlockHeader,
257 ids: Option<I>,
258 ) -> SyncResult<StateSyncMessage<BlockHeader>> {
259 if !self.include_snapshots {
260 return Ok(StateSyncMessage { header, ..Default::default() });
261 }
262 let version = VersionParam::new(
263 None,
264 Some(BlockParam {
265 chain: Some(self.extractor_id.chain),
266 hash: None,
267 number: Some(header.number as i64),
268 }),
269 );
270
271 let component_ids: Vec<_> = match ids {
273 Some(ids) => ids.into_iter().cloned().collect(),
274 None => self
275 .component_tracker
276 .get_tracked_component_ids(),
277 };
278
279 if component_ids.is_empty() {
280 return Ok(StateSyncMessage { header, ..Default::default() });
281 }
282
283 let component_tvl = if self.include_tvl {
284 let body = ComponentTvlRequestBody::id_filtered(
285 component_ids.clone(),
286 self.extractor_id.chain,
287 );
288 self.rpc_client
289 .get_component_tvl_paginated(&body, 100, 4)
290 .await?
291 .tvl
292 } else {
293 HashMap::new()
294 };
295
296 let entrypoints_result = if self.extractor_id.chain == Chain::Ethereum {
299 let result = self
301 .rpc_client
302 .get_traced_entry_points_paginated(
303 self.extractor_id.chain,
304 &self.extractor_id.name,
305 &component_ids,
306 100,
307 4,
308 )
309 .await?;
310 self.component_tracker
311 .process_entrypoints(&result.clone().into());
312 Some(result)
313 } else {
314 None
315 };
316
317 let mut protocol_states = self
319 .rpc_client
320 .get_protocol_states_paginated(
321 self.extractor_id.chain,
322 &component_ids,
323 &self.extractor_id.name,
324 self.retrieve_balances,
325 &version,
326 100,
327 4,
328 )
329 .await?
330 .states
331 .into_iter()
332 .map(|state| (state.component_id.clone(), state))
333 .collect::<HashMap<_, _>>();
334
335 trace!(states=?&protocol_states, "Retrieved ProtocolStates");
336 let states = self
337 .component_tracker
338 .components
339 .values()
340 .filter_map(|component| {
341 if let Some(state) = protocol_states.remove(&component.id) {
342 Some((
343 component.id.clone(),
344 ComponentWithState {
345 state,
346 component: component.clone(),
347 component_tvl: component_tvl
348 .get(&component.id)
349 .cloned(),
350 entrypoints: entrypoints_result
351 .as_ref()
352 .map(|r| {
353 r.traced_entry_points
354 .get(&component.id)
355 .cloned()
356 .unwrap_or_default()
357 })
358 .unwrap_or_default(),
359 },
360 ))
361 } else if component_ids.contains(&component.id) {
362 let component_id = &component.id;
364 error!(?component_id, "Missing state for native component!");
365 None
366 } else {
367 None
368 }
369 })
370 .collect();
371
372 let contract_ids = self
374 .component_tracker
375 .get_contracts_by_component(&component_ids);
376 let vm_storage = if !contract_ids.is_empty() {
377 let ids: Vec<Bytes> = contract_ids
378 .clone()
379 .into_iter()
380 .collect();
381 let contract_states = self
382 .rpc_client
383 .get_contract_state_paginated(
384 self.extractor_id.chain,
385 ids.as_slice(),
386 &self.extractor_id.name,
387 &version,
388 100,
389 4,
390 )
391 .await?
392 .accounts
393 .into_iter()
394 .map(|acc| (acc.address.clone(), acc))
395 .collect::<HashMap<_, _>>();
396
397 trace!(states=?&contract_states, "Retrieved ContractState");
398
399 let contract_address_to_components = self
400 .component_tracker
401 .components
402 .iter()
403 .filter_map(|(id, comp)| {
404 if component_ids.contains(id) {
405 Some(
406 comp.contract_ids
407 .iter()
408 .map(|address| (address.clone(), comp.id.clone())),
409 )
410 } else {
411 None
412 }
413 })
414 .flatten()
415 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
416 acc.entry(addr).or_default().push(c_id);
417 acc
418 });
419
420 contract_ids
421 .iter()
422 .filter_map(|address| {
423 if let Some(state) = contract_states.get(address) {
424 Some((address.clone(), state.clone()))
425 } else if let Some(ids) = contract_address_to_components.get(address) {
426 error!(
428 ?address,
429 ?ids,
430 "Component with lacking contract storage encountered!"
431 );
432 None
433 } else {
434 None
435 }
436 })
437 .collect()
438 } else {
439 HashMap::new()
440 };
441
442 Ok(StateSyncMessage {
443 header,
444 snapshots: Snapshot { states, vm_storage },
445 deltas: None,
446 removed_components: HashMap::new(),
447 })
448 }
449
450 #[instrument(skip(self, block_tx, end_rx), fields(extractor_id = %self.extractor_id))]
463 async fn state_sync(
464 &mut self,
465 block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
466 mut end_rx: oneshot::Receiver<()>,
467 ) -> Result<(), (SynchronizerError, Option<oneshot::Receiver<()>>)> {
468 let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
471 let (subscription_id, mut msg_rx) = match self
472 .deltas_client
473 .subscribe(self.extractor_id.clone(), subscription_options)
474 .await
475 {
476 Ok(result) => result,
477 Err(e) => return Err((e.into(), Some(end_rx))),
478 };
479
480 let result = async {
481 info!("Waiting for deltas...");
482 let mut first_msg = select! {
484 deltas_result = timeout(Duration::from_secs(self.timeout), msg_rx.recv()) => {
485 deltas_result
486 .map_err(|_| {
487 SynchronizerError::Timeout(format!(
488 "First deltas took longer than {t}s to arrive",
489 t = self.timeout
490 ))
491 })?
492 .ok_or_else(|| {
493 SynchronizerError::ConnectionError(
494 "Deltas channel closed before first message".to_string(),
495 )
496 })?
497 },
498 _ = &mut end_rx => {
499 info!("Received close signal while waiting for first deltas");
500 return Ok(());
501 }
502 };
503 self.filter_deltas(&mut first_msg);
504
505 let block = first_msg.get_block().clone();
507 info!(height = &block.number, "Deltas received. Retrieving snapshot");
508 let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
509 let snapshot = self
510 .get_snapshots::<Vec<&String>>(
511 BlockHeader::from_block(&block, false),
512 None,
513 )
514 .await?
515 .merge(StateSyncMessage {
516 header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
517 snapshots: Default::default(),
518 deltas: Some(first_msg),
519 removed_components: Default::default(),
520 });
521
522 let n_components = self.component_tracker.components.len();
523 let n_snapshots = snapshot.snapshots.states.len();
524 info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
525
526 block_tx.send(snapshot).await?;
527 self.last_synced_block = Some(header.clone());
528 loop {
529 select! {
530 deltas_opt = msg_rx.recv() => {
531 if let Some(mut deltas) = deltas_opt {
532 let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
533 debug!(block_number=?header.number, "Received delta message");
534
535 let (snapshots, removed_components) = {
536 let (to_add, to_remove) = self.component_tracker.filter_updated_components(&deltas);
539
540 let requiring_snapshot: Vec<_> = to_add
542 .iter()
543 .filter(|id| {
544 !self.component_tracker
545 .components
546 .contains_key(id.as_str())
547 })
548 .collect();
549 debug!(components=?requiring_snapshot, "SnapshotRequest");
550 self.component_tracker
551 .start_tracking(requiring_snapshot.as_slice())
552 .await?;
553 let snapshots = self
554 .get_snapshots(header.clone(), Some(requiring_snapshot))
555 .await?
556 .snapshots;
557
558 let removed_components = if !to_remove.is_empty() {
559 self.component_tracker.stop_tracking(&to_remove)
560 } else {
561 Default::default()
562 };
563
564 (snapshots, removed_components)
565 };
566
567 self.component_tracker.process_entrypoints(&deltas.dci_update);
569
570 self.filter_deltas(&mut deltas);
572 let n_changes = deltas.n_changes();
573
574 let next = StateSyncMessage {
576 header: header.clone(),
577 snapshots,
578 deltas: Some(deltas),
579 removed_components,
580 };
581 block_tx.send(next).await?;
582 self.last_synced_block = Some(header.clone());
583
584 debug!(block_number=?header.number, n_changes, "Finished processing delta message");
585 } else {
586 return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
587 }
588 },
589 _ = &mut end_rx => {
590 info!("Received close signal during state_sync");
591 return Ok(());
592 }
593 }
594 }
595 }.await;
596
597 warn!(last_synced_block = ?&self.last_synced_block, "Deltas processing ended, resetting last synced block.");
599 self.last_synced_block = None;
600 let _ = self
602 .deltas_client
603 .unsubscribe(subscription_id)
604 .await
605 .map_err(|err| {
606 warn!(err=?err, "Unsubscribing from deltas on cleanup failed!");
607 });
608
609 match result {
612 Ok(()) => Ok(()), Err(e) => {
614 Err((e, Some(end_rx)))
618 }
619 }
620 }
621
622 fn filter_deltas(&self, deltas: &mut BlockChanges) {
623 deltas.filter_by_component(|id| {
624 self.component_tracker
625 .components
626 .contains_key(id)
627 });
628 deltas.filter_by_contract(|id| {
629 self.component_tracker
630 .contracts
631 .contains(id)
632 });
633 }
634}
635
636#[async_trait]
637impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
638where
639 R: RPCClient + Clone + Send + Sync + 'static,
640 D: DeltasClient + Clone + Send + Sync + 'static,
641{
642 async fn initialize(&mut self) -> SyncResult<()> {
643 info!("Retrieving relevant protocol components");
644 self.component_tracker
645 .initialise_components()
646 .await?;
647 info!(
648 n_components = self.component_tracker.components.len(),
649 n_contracts = self.component_tracker.contracts.len(),
650 "Finished retrieving components",
651 );
652
653 Ok(())
654 }
655
656 async fn start(
657 mut self,
658 ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)> {
659 let (mut tx, rx) = channel(15);
660 let (end_tx, end_rx) = oneshot::channel::<()>();
661
662 let jh = tokio::spawn(async move {
663 let mut retry_count = 0;
664 let mut current_end_rx = end_rx;
665
666 while retry_count < self.max_retries {
667 info!(extractor_id=%&self.extractor_id, retry_count, "(Re)starting synchronization loop");
668
669 let res = self
670 .state_sync(&mut tx, current_end_rx)
671 .await;
672 match res {
673 Ok(()) => {
674 info!(
675 extractor_id=%&self.extractor_id,
676 retry_count,
677 "State synchronization exited cleanly"
678 );
679 return Ok(());
680 }
681 Err((e, maybe_end_rx)) => {
682 error!(
683 extractor_id=%&self.extractor_id,
684 retry_count,
685 error=%e,
686 "State synchronization errored!"
687 );
688
689 if let Some(recovered_end_rx) = maybe_end_rx {
691 current_end_rx = recovered_end_rx;
692
693 if let SynchronizerError::ConnectionClosed = e {
694 return Err(e);
696 }
697 } else {
698 info!(extractor_id=%&self.extractor_id, "Received close signal, exiting");
700 return Ok(());
701 }
702 }
703 }
704 sleep(self.retry_cooldown).await;
705 retry_count += 1;
706 }
707 warn!(extractor_id=%&self.extractor_id, retry_count, "Max retries exceeded");
708 Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
709 });
710
711 let handle = SynchronizerTaskHandle::new(jh, end_tx);
712 Ok((handle, rx))
713 }
714}
715
716#[cfg(test)]
717mod test {
718 use std::{collections::HashSet, sync::Arc};
737
738 use test_log::test;
739 use tycho_common::dto::{
740 Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
741 PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
742 ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
743 ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
744 TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
745 TracedEntryPointRequestResponse, TracingParams,
746 };
747 use uuid::Uuid;
748
749 use super::*;
750 use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
751
752 struct ArcRPCClient<T>(Arc<T>);
754
755 impl<T> Clone for ArcRPCClient<T> {
757 fn clone(&self) -> Self {
758 ArcRPCClient(self.0.clone())
759 }
760 }
761
762 #[async_trait]
763 impl<T> RPCClient for ArcRPCClient<T>
764 where
765 T: RPCClient + Sync + Send + 'static,
766 {
767 async fn get_tokens(
768 &self,
769 request: &TokensRequestBody,
770 ) -> Result<TokensRequestResponse, RPCError> {
771 self.0.get_tokens(request).await
772 }
773
774 async fn get_contract_state(
775 &self,
776 request: &StateRequestBody,
777 ) -> Result<StateRequestResponse, RPCError> {
778 self.0.get_contract_state(request).await
779 }
780
781 async fn get_protocol_components(
782 &self,
783 request: &ProtocolComponentsRequestBody,
784 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
785 self.0
786 .get_protocol_components(request)
787 .await
788 }
789
790 async fn get_protocol_states(
791 &self,
792 request: &ProtocolStateRequestBody,
793 ) -> Result<ProtocolStateRequestResponse, RPCError> {
794 self.0
795 .get_protocol_states(request)
796 .await
797 }
798
799 async fn get_protocol_systems(
800 &self,
801 request: &ProtocolSystemsRequestBody,
802 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
803 self.0
804 .get_protocol_systems(request)
805 .await
806 }
807
808 async fn get_component_tvl(
809 &self,
810 request: &ComponentTvlRequestBody,
811 ) -> Result<ComponentTvlRequestResponse, RPCError> {
812 self.0.get_component_tvl(request).await
813 }
814
815 async fn get_traced_entry_points(
816 &self,
817 request: &TracedEntryPointRequestBody,
818 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
819 self.0
820 .get_traced_entry_points(request)
821 .await
822 }
823 }
824
825 struct ArcDeltasClient<T>(Arc<T>);
827
828 impl<T> Clone for ArcDeltasClient<T> {
830 fn clone(&self) -> Self {
831 ArcDeltasClient(self.0.clone())
832 }
833 }
834
835 #[async_trait]
836 impl<T> DeltasClient for ArcDeltasClient<T>
837 where
838 T: DeltasClient + Sync + Send + 'static,
839 {
840 async fn subscribe(
841 &self,
842 extractor_id: ExtractorIdentity,
843 options: SubscriptionOptions,
844 ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
845 self.0
846 .subscribe(extractor_id, options)
847 .await
848 }
849
850 async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
851 self.0
852 .unsubscribe(subscription_id)
853 .await
854 }
855
856 async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
857 self.0.connect().await
858 }
859
860 async fn close(&self) -> Result<(), DeltasError> {
861 self.0.close().await
862 }
863 }
864
865 fn with_mocked_clients(
866 native: bool,
867 include_tvl: bool,
868 rpc_client: Option<MockRPCClient>,
869 deltas_client: Option<MockDeltasClient>,
870 ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
871 {
872 let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
873 let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
874
875 ProtocolStateSynchronizer::new(
876 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
877 native,
878 ComponentFilter::with_tvl_range(50.0, 50.0),
879 1,
880 Duration::from_secs(0),
881 true,
882 include_tvl,
883 rpc_client,
884 deltas_client,
885 10_u64,
886 )
887 }
888
889 fn state_snapshot_native() -> ProtocolStateRequestResponse {
890 ProtocolStateRequestResponse {
891 states: vec![ResponseProtocolState {
892 component_id: "Component1".to_string(),
893 ..Default::default()
894 }],
895 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
896 }
897 }
898
899 fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
900 let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
901
902 ComponentTvlRequestResponse {
903 tvl,
904 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
905 }
906 }
907
908 #[test(tokio::test)]
909 async fn test_get_snapshots_native() {
910 let header = BlockHeader::default();
911 let mut rpc = MockRPCClient::new();
912 rpc.expect_get_protocol_states()
913 .returning(|_| Ok(state_snapshot_native()));
914 rpc.expect_get_traced_entry_points()
915 .returning(|_| {
916 Ok(TracedEntryPointRequestResponse {
917 traced_entry_points: HashMap::new(),
918 pagination: PaginationResponse::new(0, 20, 0),
919 })
920 });
921 let mut state_sync = with_mocked_clients(true, false, Some(rpc), None);
922 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
923 state_sync
924 .component_tracker
925 .components
926 .insert("Component1".to_string(), component.clone());
927 let components_arg = ["Component1".to_string()];
928 let exp = StateSyncMessage {
929 header: header.clone(),
930 snapshots: Snapshot {
931 states: state_snapshot_native()
932 .states
933 .into_iter()
934 .map(|state| {
935 (
936 state.component_id.clone(),
937 ComponentWithState {
938 state,
939 component: component.clone(),
940 entrypoints: vec![],
941 component_tvl: None,
942 },
943 )
944 })
945 .collect(),
946 vm_storage: HashMap::new(),
947 },
948 deltas: None,
949 removed_components: Default::default(),
950 };
951
952 let snap = state_sync
953 .get_snapshots(header, Some(&components_arg))
954 .await
955 .expect("Retrieving snapshot failed");
956
957 assert_eq!(snap, exp);
958 }
959
960 #[test(tokio::test)]
961 async fn test_get_snapshots_native_with_tvl() {
962 let header = BlockHeader::default();
963 let mut rpc = MockRPCClient::new();
964 rpc.expect_get_protocol_states()
965 .returning(|_| Ok(state_snapshot_native()));
966 rpc.expect_get_component_tvl()
967 .returning(|_| Ok(component_tvl_snapshot()));
968 rpc.expect_get_traced_entry_points()
969 .returning(|_| {
970 Ok(TracedEntryPointRequestResponse {
971 traced_entry_points: HashMap::new(),
972 pagination: PaginationResponse::new(0, 20, 0),
973 })
974 });
975 let mut state_sync = with_mocked_clients(true, true, Some(rpc), None);
976 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
977 state_sync
978 .component_tracker
979 .components
980 .insert("Component1".to_string(), component.clone());
981 let components_arg = ["Component1".to_string()];
982 let exp = StateSyncMessage {
983 header: header.clone(),
984 snapshots: Snapshot {
985 states: state_snapshot_native()
986 .states
987 .into_iter()
988 .map(|state| {
989 (
990 state.component_id.clone(),
991 ComponentWithState {
992 state,
993 component: component.clone(),
994 component_tvl: Some(100.0),
995 entrypoints: vec![],
996 },
997 )
998 })
999 .collect(),
1000 vm_storage: HashMap::new(),
1001 },
1002 deltas: None,
1003 removed_components: Default::default(),
1004 };
1005
1006 let snap = state_sync
1007 .get_snapshots(header, Some(&components_arg))
1008 .await
1009 .expect("Retrieving snapshot failed");
1010
1011 assert_eq!(snap, exp);
1012 }
1013
1014 fn state_snapshot_vm() -> StateRequestResponse {
1015 StateRequestResponse {
1016 accounts: vec![
1017 ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
1018 ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
1019 ],
1020 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1021 }
1022 }
1023
1024 fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
1025 TracedEntryPointRequestResponse {
1026 traced_entry_points: HashMap::from([(
1027 "Component1".to_string(),
1028 vec![(
1029 EntryPointWithTracingParams {
1030 entry_point: EntryPoint {
1031 external_id: "entrypoint_a".to_string(),
1032 target: Bytes::from("0x0badc0ffee"),
1033 signature: "sig()".to_string(),
1034 },
1035 params: TracingParams::RPCTracer(RPCTracerParams {
1036 caller: Some(Bytes::from("0x0badc0ffee")),
1037 calldata: Bytes::from("0x0badc0ffee"),
1038 }),
1039 },
1040 TracingResult {
1041 retriggers: HashSet::from([(
1042 Bytes::from("0x0badc0ffee"),
1043 Bytes::from("0x0badc0ffee"),
1044 )]),
1045 accessed_slots: HashMap::from([(
1046 Bytes::from("0x0badc0ffee"),
1047 HashSet::from([Bytes::from("0xbadbeef0")]),
1048 )]),
1049 },
1050 )],
1051 )]),
1052 pagination: PaginationResponse::new(0, 20, 0),
1053 }
1054 }
1055
1056 #[test(tokio::test)]
1057 async fn test_get_snapshots_vm() {
1058 let header = BlockHeader::default();
1059 let mut rpc = MockRPCClient::new();
1060 rpc.expect_get_protocol_states()
1061 .returning(|_| Ok(state_snapshot_native()));
1062 rpc.expect_get_contract_state()
1063 .returning(|_| Ok(state_snapshot_vm()));
1064 rpc.expect_get_traced_entry_points()
1065 .returning(|_| Ok(traced_entry_point_response()));
1066 let mut state_sync = with_mocked_clients(false, false, Some(rpc), None);
1067 let component = ProtocolComponent {
1068 id: "Component1".to_string(),
1069 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1070 ..Default::default()
1071 };
1072 state_sync
1073 .component_tracker
1074 .components
1075 .insert("Component1".to_string(), component.clone());
1076 let components_arg = ["Component1".to_string()];
1077 let exp = StateSyncMessage {
1078 header: header.clone(),
1079 snapshots: Snapshot {
1080 states: [(
1081 component.id.clone(),
1082 ComponentWithState {
1083 state: ResponseProtocolState {
1084 component_id: "Component1".to_string(),
1085 ..Default::default()
1086 },
1087 component: component.clone(),
1088 component_tvl: None,
1089 entrypoints: vec![(
1090 EntryPointWithTracingParams {
1091 entry_point: EntryPoint {
1092 external_id: "entrypoint_a".to_string(),
1093 target: Bytes::from("0x0badc0ffee"),
1094 signature: "sig()".to_string(),
1095 },
1096 params: TracingParams::RPCTracer(RPCTracerParams {
1097 caller: Some(Bytes::from("0x0badc0ffee")),
1098 calldata: Bytes::from("0x0badc0ffee"),
1099 }),
1100 },
1101 TracingResult {
1102 retriggers: HashSet::from([(
1103 Bytes::from("0x0badc0ffee"),
1104 Bytes::from("0x0badc0ffee"),
1105 )]),
1106 accessed_slots: HashMap::from([(
1107 Bytes::from("0x0badc0ffee"),
1108 HashSet::from([Bytes::from("0xbadbeef0")]),
1109 )]),
1110 },
1111 )],
1112 },
1113 )]
1114 .into_iter()
1115 .collect(),
1116 vm_storage: state_snapshot_vm()
1117 .accounts
1118 .into_iter()
1119 .map(|state| (state.address.clone(), state))
1120 .collect(),
1121 },
1122 deltas: None,
1123 removed_components: Default::default(),
1124 };
1125
1126 let snap = state_sync
1127 .get_snapshots(header, Some(&components_arg))
1128 .await
1129 .expect("Retrieving snapshot failed");
1130
1131 assert_eq!(snap, exp);
1132 }
1133
1134 #[test(tokio::test)]
1135 async fn test_get_snapshots_vm_with_tvl() {
1136 let header = BlockHeader::default();
1137 let mut rpc = MockRPCClient::new();
1138 rpc.expect_get_protocol_states()
1139 .returning(|_| Ok(state_snapshot_native()));
1140 rpc.expect_get_contract_state()
1141 .returning(|_| Ok(state_snapshot_vm()));
1142 rpc.expect_get_component_tvl()
1143 .returning(|_| Ok(component_tvl_snapshot()));
1144 rpc.expect_get_traced_entry_points()
1145 .returning(|_| {
1146 Ok(TracedEntryPointRequestResponse {
1147 traced_entry_points: HashMap::new(),
1148 pagination: PaginationResponse::new(0, 20, 0),
1149 })
1150 });
1151 let mut state_sync = with_mocked_clients(false, true, Some(rpc), None);
1152 let component = ProtocolComponent {
1153 id: "Component1".to_string(),
1154 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1155 ..Default::default()
1156 };
1157 state_sync
1158 .component_tracker
1159 .components
1160 .insert("Component1".to_string(), component.clone());
1161 let components_arg = ["Component1".to_string()];
1162 let exp = StateSyncMessage {
1163 header: header.clone(),
1164 snapshots: Snapshot {
1165 states: [(
1166 component.id.clone(),
1167 ComponentWithState {
1168 state: ResponseProtocolState {
1169 component_id: "Component1".to_string(),
1170 ..Default::default()
1171 },
1172 component: component.clone(),
1173 component_tvl: Some(100.0),
1174 entrypoints: vec![],
1175 },
1176 )]
1177 .into_iter()
1178 .collect(),
1179 vm_storage: state_snapshot_vm()
1180 .accounts
1181 .into_iter()
1182 .map(|state| (state.address.clone(), state))
1183 .collect(),
1184 },
1185 deltas: None,
1186 removed_components: Default::default(),
1187 };
1188
1189 let snap = state_sync
1190 .get_snapshots(header, Some(&components_arg))
1191 .await
1192 .expect("Retrieving snapshot failed");
1193
1194 assert_eq!(snap, exp);
1195 }
1196
1197 fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1198 let mut rpc_client = MockRPCClient::new();
1199 rpc_client
1202 .expect_get_protocol_components()
1203 .with(mockall::predicate::function(
1204 move |request_params: &ProtocolComponentsRequestBody| {
1205 if let Some(ids) = request_params.component_ids.as_ref() {
1206 ids.contains(&"Component3".to_string())
1207 } else {
1208 false
1209 }
1210 },
1211 ))
1212 .returning(|_| {
1213 Ok(ProtocolComponentRequestResponse {
1215 protocol_components: vec![
1216 ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1218 ],
1219 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1220 })
1221 });
1222 rpc_client
1223 .expect_get_protocol_states()
1224 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1225 let expected_id = "Component3".to_string();
1226 if let Some(ids) = request_params.protocol_ids.as_ref() {
1227 ids.contains(&expected_id)
1228 } else {
1229 false
1230 }
1231 }))
1232 .returning(|_| {
1233 Ok(ProtocolStateRequestResponse {
1235 states: vec![ResponseProtocolState {
1236 component_id: "Component3".to_string(),
1237 ..Default::default()
1238 }],
1239 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1240 })
1241 });
1242
1243 rpc_client
1245 .expect_get_protocol_components()
1246 .returning(|_| {
1247 Ok(ProtocolComponentRequestResponse {
1249 protocol_components: vec![
1250 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1252 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1254 ],
1256 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1257 })
1258 });
1259 rpc_client
1260 .expect_get_protocol_states()
1261 .returning(|_| {
1262 Ok(ProtocolStateRequestResponse {
1264 states: vec![
1265 ResponseProtocolState {
1266 component_id: "Component1".to_string(),
1267 ..Default::default()
1268 },
1269 ResponseProtocolState {
1270 component_id: "Component2".to_string(),
1271 ..Default::default()
1272 },
1273 ],
1274 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1275 })
1276 });
1277 rpc_client
1278 .expect_get_component_tvl()
1279 .returning(|_| {
1280 Ok(ComponentTvlRequestResponse {
1281 tvl: [
1282 ("Component1".to_string(), 100.0),
1283 ("Component2".to_string(), 0.0),
1284 ("Component3".to_string(), 1000.0),
1285 ]
1286 .into_iter()
1287 .collect(),
1288 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1289 })
1290 });
1291 rpc_client
1292 .expect_get_traced_entry_points()
1293 .returning(|_| {
1294 Ok(TracedEntryPointRequestResponse {
1295 traced_entry_points: HashMap::new(),
1296 pagination: PaginationResponse::new(0, 20, 0),
1297 })
1298 });
1299
1300 let mut deltas_client = MockDeltasClient::new();
1302 let (tx, rx) = channel(1);
1303 deltas_client
1304 .expect_subscribe()
1305 .return_once(move |_, _| {
1306 Ok((Uuid::default(), rx))
1308 });
1309
1310 deltas_client
1312 .expect_unsubscribe()
1313 .return_once(|_| Ok(()));
1314
1315 (rpc_client, deltas_client, tx)
1316 }
1317
1318 #[test(tokio::test)]
1325 async fn test_state_sync() {
1326 let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1327 let deltas = [
1328 BlockChanges {
1329 extractor: "uniswap-v2".to_string(),
1330 chain: Chain::Ethereum,
1331 block: Block {
1332 number: 1,
1333 hash: Bytes::from("0x01"),
1334 parent_hash: Bytes::from("0x00"),
1335 chain: Chain::Ethereum,
1336 ts: Default::default(),
1337 },
1338 revert: false,
1339 dci_update: DCIUpdate {
1340 new_entrypoints: HashMap::from([(
1341 "Component1".to_string(),
1342 HashSet::from([EntryPoint {
1343 external_id: "entrypoint_a".to_string(),
1344 target: Bytes::from("0x0badc0ffee"),
1345 signature: "sig()".to_string(),
1346 }]),
1347 )]),
1348 new_entrypoint_params: HashMap::from([(
1349 "entrypoint_a".to_string(),
1350 HashSet::from([(
1351 TracingParams::RPCTracer(RPCTracerParams {
1352 caller: Some(Bytes::from("0x0badc0ffee")),
1353 calldata: Bytes::from("0x0badc0ffee"),
1354 }),
1355 Some("Component1".to_string()),
1356 )]),
1357 )]),
1358 trace_results: HashMap::from([(
1359 "entrypoint_a".to_string(),
1360 TracingResult {
1361 retriggers: HashSet::from([(
1362 Bytes::from("0x0badc0ffee"),
1363 Bytes::from("0x0badc0ffee"),
1364 )]),
1365 accessed_slots: HashMap::from([(
1366 Bytes::from("0x0badc0ffee"),
1367 HashSet::from([Bytes::from("0xbadbeef0")]),
1368 )]),
1369 },
1370 )]),
1371 },
1372 ..Default::default()
1373 },
1374 BlockChanges {
1375 extractor: "uniswap-v2".to_string(),
1376 chain: Chain::Ethereum,
1377 block: Block {
1378 number: 2,
1379 hash: Bytes::from("0x02"),
1380 parent_hash: Bytes::from("0x01"),
1381 chain: Chain::Ethereum,
1382 ts: Default::default(),
1383 },
1384 revert: false,
1385 component_tvl: [
1386 ("Component1".to_string(), 100.0),
1387 ("Component2".to_string(), 0.0),
1388 ("Component3".to_string(), 1000.0),
1389 ]
1390 .into_iter()
1391 .collect(),
1392 ..Default::default()
1393 },
1394 ];
1395 let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1396 state_sync
1397 .initialize()
1398 .await
1399 .expect("Init failed");
1400
1401 let (handle, mut rx) = state_sync
1403 .start()
1404 .await
1405 .expect("Failed to start state synchronizer");
1406 let (jh, close_tx) = handle.split();
1407 tx.send(deltas[0].clone())
1408 .await
1409 .expect("deltas channel msg 0 closed!");
1410 let first_msg = timeout(Duration::from_millis(100), rx.recv())
1411 .await
1412 .expect("waiting for first state msg timed out!")
1413 .expect("state sync block sender closed!");
1414 tx.send(deltas[1].clone())
1415 .await
1416 .expect("deltas channel msg 1 closed!");
1417 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1418 .await
1419 .expect("waiting for second state msg timed out!")
1420 .expect("state sync block sender closed!");
1421 let _ = close_tx.send(());
1422 let exit = jh
1423 .await
1424 .expect("state sync task panicked!");
1425
1426 let exp1 = StateSyncMessage {
1428 header: BlockHeader {
1429 number: 1,
1430 hash: Bytes::from("0x01"),
1431 parent_hash: Bytes::from("0x00"),
1432 revert: false,
1433 ..Default::default()
1434 },
1435 snapshots: Snapshot {
1436 states: [
1437 (
1438 "Component1".to_string(),
1439 ComponentWithState {
1440 state: ResponseProtocolState {
1441 component_id: "Component1".to_string(),
1442 ..Default::default()
1443 },
1444 component: ProtocolComponent {
1445 id: "Component1".to_string(),
1446 ..Default::default()
1447 },
1448 component_tvl: Some(100.0),
1449 entrypoints: vec![],
1450 },
1451 ),
1452 (
1453 "Component2".to_string(),
1454 ComponentWithState {
1455 state: ResponseProtocolState {
1456 component_id: "Component2".to_string(),
1457 ..Default::default()
1458 },
1459 component: ProtocolComponent {
1460 id: "Component2".to_string(),
1461 ..Default::default()
1462 },
1463 component_tvl: Some(0.0),
1464 entrypoints: vec![],
1465 },
1466 ),
1467 ]
1468 .into_iter()
1469 .collect(),
1470 vm_storage: HashMap::new(),
1471 },
1472 deltas: Some(deltas[0].clone()),
1473 removed_components: Default::default(),
1474 };
1475
1476 let exp2 = StateSyncMessage {
1477 header: BlockHeader {
1478 number: 2,
1479 hash: Bytes::from("0x02"),
1480 parent_hash: Bytes::from("0x01"),
1481 revert: false,
1482 ..Default::default()
1483 },
1484 snapshots: Snapshot {
1485 states: [
1486 (
1488 "Component3".to_string(),
1489 ComponentWithState {
1490 state: ResponseProtocolState {
1491 component_id: "Component3".to_string(),
1492 ..Default::default()
1493 },
1494 component: ProtocolComponent {
1495 id: "Component3".to_string(),
1496 ..Default::default()
1497 },
1498 component_tvl: Some(1000.0),
1499 entrypoints: vec![],
1500 },
1501 ),
1502 ]
1503 .into_iter()
1504 .collect(),
1505 vm_storage: HashMap::new(),
1506 },
1507 deltas: Some(BlockChanges {
1510 extractor: "uniswap-v2".to_string(),
1511 chain: Chain::Ethereum,
1512 block: Block {
1513 number: 2,
1514 hash: Bytes::from("0x02"),
1515 parent_hash: Bytes::from("0x01"),
1516 chain: Chain::Ethereum,
1517 ts: Default::default(),
1518 },
1519 revert: false,
1520 component_tvl: [
1521 ("Component1".to_string(), 100.0),
1523 ("Component3".to_string(), 1000.0),
1524 ]
1525 .into_iter()
1526 .collect(),
1527 ..Default::default()
1528 }),
1529 removed_components: [(
1531 "Component2".to_string(),
1532 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1533 )]
1534 .into_iter()
1535 .collect(),
1536 };
1537 assert_eq!(first_msg, exp1);
1538 assert_eq!(second_msg, exp2);
1539 assert!(exit.is_ok());
1540 }
1541
1542 #[test(tokio::test)]
1543 async fn test_state_sync_with_tvl_range() {
1544 let remove_tvl_threshold = 5.0;
1546 let add_tvl_threshold = 7.0;
1547
1548 let mut rpc_client = MockRPCClient::new();
1549 let mut deltas_client = MockDeltasClient::new();
1550
1551 rpc_client
1552 .expect_get_protocol_components()
1553 .with(mockall::predicate::function(
1554 move |request_params: &ProtocolComponentsRequestBody| {
1555 if let Some(ids) = request_params.component_ids.as_ref() {
1556 ids.contains(&"Component3".to_string())
1557 } else {
1558 false
1559 }
1560 },
1561 ))
1562 .returning(|_| {
1563 Ok(ProtocolComponentRequestResponse {
1564 protocol_components: vec![ProtocolComponent {
1565 id: "Component3".to_string(),
1566 ..Default::default()
1567 }],
1568 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1569 })
1570 });
1571 rpc_client
1572 .expect_get_protocol_states()
1573 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1574 let expected_id = "Component3".to_string();
1575 if let Some(ids) = request_params.protocol_ids.as_ref() {
1576 ids.contains(&expected_id)
1577 } else {
1578 false
1579 }
1580 }))
1581 .returning(|_| {
1582 Ok(ProtocolStateRequestResponse {
1583 states: vec![ResponseProtocolState {
1584 component_id: "Component3".to_string(),
1585 ..Default::default()
1586 }],
1587 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1588 })
1589 });
1590
1591 rpc_client
1593 .expect_get_protocol_components()
1594 .returning(|_| {
1595 Ok(ProtocolComponentRequestResponse {
1596 protocol_components: vec![
1597 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1598 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1599 ],
1600 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1601 })
1602 });
1603 rpc_client
1604 .expect_get_protocol_states()
1605 .returning(|_| {
1606 Ok(ProtocolStateRequestResponse {
1607 states: vec![
1608 ResponseProtocolState {
1609 component_id: "Component1".to_string(),
1610 ..Default::default()
1611 },
1612 ResponseProtocolState {
1613 component_id: "Component2".to_string(),
1614 ..Default::default()
1615 },
1616 ],
1617 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1618 })
1619 });
1620 rpc_client
1621 .expect_get_traced_entry_points()
1622 .returning(|_| {
1623 Ok(TracedEntryPointRequestResponse {
1624 traced_entry_points: HashMap::new(),
1625 pagination: PaginationResponse::new(0, 20, 0),
1626 })
1627 });
1628
1629 rpc_client
1630 .expect_get_component_tvl()
1631 .returning(|_| {
1632 Ok(ComponentTvlRequestResponse {
1633 tvl: [
1634 ("Component1".to_string(), 6.0),
1635 ("Component2".to_string(), 2.0),
1636 ("Component3".to_string(), 10.0),
1637 ]
1638 .into_iter()
1639 .collect(),
1640 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1641 })
1642 });
1643
1644 rpc_client
1645 .expect_get_component_tvl()
1646 .returning(|_| {
1647 Ok(ComponentTvlRequestResponse {
1648 tvl: [
1649 ("Component1".to_string(), 6.0),
1650 ("Component2".to_string(), 2.0),
1651 ("Component3".to_string(), 10.0),
1652 ]
1653 .into_iter()
1654 .collect(),
1655 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1656 })
1657 });
1658
1659 let (tx, rx) = channel(1);
1660 deltas_client
1661 .expect_subscribe()
1662 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1663
1664 deltas_client
1666 .expect_unsubscribe()
1667 .return_once(|_| Ok(()));
1668
1669 let mut state_sync = ProtocolStateSynchronizer::new(
1670 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1671 true,
1672 ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1673 1,
1674 Duration::from_secs(0),
1675 true,
1676 true,
1677 ArcRPCClient(Arc::new(rpc_client)),
1678 ArcDeltasClient(Arc::new(deltas_client)),
1679 10_u64,
1680 );
1681 state_sync
1682 .initialize()
1683 .await
1684 .expect("Init failed");
1685
1686 let deltas = [
1688 BlockChanges {
1689 extractor: "uniswap-v2".to_string(),
1690 chain: Chain::Ethereum,
1691 block: Block {
1692 number: 1,
1693 hash: Bytes::from("0x01"),
1694 parent_hash: Bytes::from("0x00"),
1695 chain: Chain::Ethereum,
1696 ts: Default::default(),
1697 },
1698 revert: false,
1699 ..Default::default()
1700 },
1701 BlockChanges {
1702 extractor: "uniswap-v2".to_string(),
1703 chain: Chain::Ethereum,
1704 block: Block {
1705 number: 2,
1706 hash: Bytes::from("0x02"),
1707 parent_hash: Bytes::from("0x01"),
1708 chain: Chain::Ethereum,
1709 ts: Default::default(),
1710 },
1711 revert: false,
1712 component_tvl: [
1713 ("Component1".to_string(), 6.0), ("Component2".to_string(), 2.0), ("Component3".to_string(), 10.0), ]
1717 .into_iter()
1718 .collect(),
1719 ..Default::default()
1720 },
1721 ];
1722
1723 let (handle, mut rx) = state_sync
1724 .start()
1725 .await
1726 .expect("Failed to start state synchronizer");
1727 let (jh, close_tx) = handle.split();
1728
1729 tx.send(deltas[0].clone())
1731 .await
1732 .expect("deltas channel msg 0 closed!");
1733
1734 let _ = timeout(Duration::from_millis(100), rx.recv())
1736 .await
1737 .expect("waiting for first state msg timed out!")
1738 .expect("state sync block sender closed!");
1739
1740 tx.send(deltas[1].clone())
1742 .await
1743 .expect("deltas channel msg 1 closed!");
1744 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1745 .await
1746 .expect("waiting for second state msg timed out!")
1747 .expect("state sync block sender closed!");
1748
1749 let _ = close_tx.send(());
1750 let exit = jh
1751 .await
1752 .expect("state sync task panicked!");
1753
1754 let expected_second_msg = StateSyncMessage {
1755 header: BlockHeader {
1756 number: 2,
1757 hash: Bytes::from("0x02"),
1758 parent_hash: Bytes::from("0x01"),
1759 revert: false,
1760 ..Default::default()
1761 },
1762 snapshots: Snapshot {
1763 states: [(
1764 "Component3".to_string(),
1765 ComponentWithState {
1766 state: ResponseProtocolState {
1767 component_id: "Component3".to_string(),
1768 ..Default::default()
1769 },
1770 component: ProtocolComponent {
1771 id: "Component3".to_string(),
1772 ..Default::default()
1773 },
1774 component_tvl: Some(10.0),
1775 entrypoints: vec![], },
1777 )]
1778 .into_iter()
1779 .collect(),
1780 vm_storage: HashMap::new(),
1781 },
1782 deltas: Some(BlockChanges {
1783 extractor: "uniswap-v2".to_string(),
1784 chain: Chain::Ethereum,
1785 block: Block {
1786 number: 2,
1787 hash: Bytes::from("0x02"),
1788 parent_hash: Bytes::from("0x01"),
1789 chain: Chain::Ethereum,
1790 ts: Default::default(),
1791 },
1792 revert: false,
1793 component_tvl: [
1794 ("Component1".to_string(), 6.0), ("Component3".to_string(), 10.0), ]
1797 .into_iter()
1798 .collect(),
1799 ..Default::default()
1800 }),
1801 removed_components: [(
1802 "Component2".to_string(),
1803 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1804 )]
1805 .into_iter()
1806 .collect(),
1807 };
1808
1809 assert_eq!(second_msg, expected_second_msg);
1810 assert!(exit.is_ok());
1811 }
1812
1813 #[test(tokio::test)]
1814 async fn test_public_close_api_functionality() {
1815 let mut rpc_client = MockRPCClient::new();
1822 let mut deltas_client = MockDeltasClient::new();
1823
1824 rpc_client
1826 .expect_get_protocol_components()
1827 .returning(|_| {
1828 Ok(ProtocolComponentRequestResponse {
1829 protocol_components: vec![],
1830 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1831 })
1832 });
1833
1834 let (_tx, rx) = channel(1);
1836 deltas_client
1837 .expect_subscribe()
1838 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1839
1840 deltas_client
1842 .expect_unsubscribe()
1843 .return_once(|_| Ok(()));
1844
1845 let mut state_sync = ProtocolStateSynchronizer::new(
1846 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1847 true,
1848 ComponentFilter::with_tvl_range(0.0, 0.0),
1849 5, Duration::from_secs(0),
1851 true,
1852 false,
1853 ArcRPCClient(Arc::new(rpc_client)),
1854 ArcDeltasClient(Arc::new(deltas_client)),
1855 10000_u64, );
1857
1858 state_sync
1859 .initialize()
1860 .await
1861 .expect("Init should succeed");
1862
1863 let (handle, _rx) = state_sync
1865 .start()
1866 .await
1867 .expect("Failed to start state synchronizer");
1868 let (jh, close_tx) = handle.split();
1869
1870 tokio::time::sleep(Duration::from_millis(100)).await;
1872
1873 close_tx
1875 .send(())
1876 .expect("Should be able to send close signal");
1877 let task_result = jh.await.expect("Task should not panic");
1879 assert!(task_result.is_ok(), "Task should exit cleanly after close: {task_result:?}");
1880 }
1881
1882 #[test(tokio::test)]
1883 async fn test_cleanup_runs_when_state_sync_processing_errors() {
1884 let mut rpc_client = MockRPCClient::new();
1889 let mut deltas_client = MockDeltasClient::new();
1890
1891 rpc_client
1893 .expect_get_protocol_components()
1894 .returning(|_| {
1895 Ok(ProtocolComponentRequestResponse {
1896 protocol_components: vec![],
1897 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1898 })
1899 });
1900
1901 rpc_client
1903 .expect_get_protocol_states()
1904 .returning(|_| {
1905 Err(RPCError::HttpClient("Test error during snapshot retrieval".to_string()))
1906 });
1907
1908 let (tx, rx) = channel(10);
1910 deltas_client
1911 .expect_subscribe()
1912 .return_once(move |_, _| {
1913 let delta = BlockChanges {
1915 extractor: "test".to_string(),
1916 chain: Chain::Ethereum,
1917 block: Block {
1918 hash: Bytes::from("0x0123"),
1919 number: 1,
1920 parent_hash: Bytes::from("0x0000"),
1921 chain: Chain::Ethereum,
1922 ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
1923 },
1924 revert: false,
1925 new_protocol_components: [(
1927 "new_component".to_string(),
1928 ProtocolComponent {
1929 id: "new_component".to_string(),
1930 protocol_system: "test_protocol".to_string(),
1931 protocol_type_name: "test".to_string(),
1932 chain: Chain::Ethereum,
1933 tokens: vec![Bytes::from("0x0badc0ffee")],
1934 contract_ids: vec![Bytes::from("0x0badc0ffee")],
1935 static_attributes: Default::default(),
1936 creation_tx: Default::default(),
1937 created_at: Default::default(),
1938 change: Default::default(),
1939 },
1940 )]
1941 .into_iter()
1942 .collect(),
1943 component_tvl: [("new_component".to_string(), 100.0)]
1944 .into_iter()
1945 .collect(),
1946 ..Default::default()
1947 };
1948
1949 tokio::spawn(async move {
1950 let _ = tx.send(delta).await;
1951 });
1953
1954 Ok((Uuid::default(), rx))
1955 });
1956
1957 deltas_client
1959 .expect_unsubscribe()
1960 .return_once(|_| Ok(()));
1961
1962 let mut state_sync = ProtocolStateSynchronizer::new(
1963 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1964 true,
1965 ComponentFilter::with_tvl_range(0.0, 1000.0), 1,
1967 Duration::from_secs(0),
1968 true,
1969 false,
1970 ArcRPCClient(Arc::new(rpc_client)),
1971 ArcDeltasClient(Arc::new(deltas_client)),
1972 5000_u64,
1973 );
1974
1975 state_sync
1976 .initialize()
1977 .await
1978 .expect("Init should succeed");
1979
1980 state_sync.last_synced_block = Some(BlockHeader {
1982 hash: Bytes::from("0x0badc0ffee"),
1983 number: 42,
1984 parent_hash: Bytes::from("0xbadbeef0"),
1985 revert: false,
1986 timestamp: 123456789,
1987 });
1988
1989 let (mut block_tx, _block_rx) = channel(10);
1991
1992 let (_end_tx, end_rx) = oneshot::channel::<()>();
1994 let result = state_sync
1995 .state_sync(&mut block_tx, end_rx)
1996 .await;
1997 assert!(result.is_err(), "state_sync should have errored during processing");
1999
2000 }
2003
2004 #[test(tokio::test)]
2005 async fn test_close_signal_while_waiting_for_first_deltas() {
2006 let mut rpc_client = MockRPCClient::new();
2010 let mut deltas_client = MockDeltasClient::new();
2011
2012 rpc_client
2013 .expect_get_protocol_components()
2014 .returning(|_| {
2015 Ok(ProtocolComponentRequestResponse {
2016 protocol_components: vec![],
2017 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2018 })
2019 });
2020
2021 let (_tx, rx) = channel(1);
2022 deltas_client
2023 .expect_subscribe()
2024 .return_once(move |_, _| Ok((Uuid::default(), rx)));
2025
2026 deltas_client
2027 .expect_unsubscribe()
2028 .return_once(|_| Ok(()));
2029
2030 let mut state_sync = ProtocolStateSynchronizer::new(
2031 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2032 true,
2033 ComponentFilter::with_tvl_range(0.0, 0.0),
2034 1,
2035 Duration::from_secs(0),
2036 true,
2037 false,
2038 ArcRPCClient(Arc::new(rpc_client)),
2039 ArcDeltasClient(Arc::new(deltas_client)),
2040 10000_u64,
2041 );
2042
2043 state_sync
2044 .initialize()
2045 .await
2046 .expect("Init should succeed");
2047
2048 let (mut block_tx, _block_rx) = channel(10);
2049 let (end_tx, end_rx) = oneshot::channel::<()>();
2050
2051 let state_sync_handle = tokio::spawn(async move {
2053 state_sync
2054 .state_sync(&mut block_tx, end_rx)
2055 .await
2056 });
2057
2058 tokio::time::sleep(Duration::from_millis(100)).await;
2060
2061 let _ = end_tx.send(());
2063
2064 let result = state_sync_handle
2066 .await
2067 .expect("Task should not panic");
2068 assert!(result.is_ok(), "state_sync should exit cleanly when closed: {result:?}");
2069
2070 println!("SUCCESS: Close signal handled correctly while waiting for first deltas");
2071 }
2072
2073 #[test(tokio::test)]
2074 async fn test_close_signal_during_main_processing_loop() {
2075 let mut rpc_client = MockRPCClient::new();
2081 let mut deltas_client = MockDeltasClient::new();
2082
2083 rpc_client
2085 .expect_get_protocol_components()
2086 .returning(|_| {
2087 Ok(ProtocolComponentRequestResponse {
2088 protocol_components: vec![],
2089 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2090 })
2091 });
2092
2093 rpc_client
2095 .expect_get_protocol_states()
2096 .returning(|_| {
2097 Ok(ProtocolStateRequestResponse {
2098 states: vec![],
2099 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2100 })
2101 });
2102
2103 rpc_client
2104 .expect_get_component_tvl()
2105 .returning(|_| {
2106 Ok(ComponentTvlRequestResponse {
2107 tvl: HashMap::new(),
2108 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2109 })
2110 });
2111
2112 rpc_client
2113 .expect_get_traced_entry_points()
2114 .returning(|_| {
2115 Ok(TracedEntryPointRequestResponse {
2116 traced_entry_points: HashMap::new(),
2117 pagination: PaginationResponse::new(0, 20, 0),
2118 })
2119 });
2120
2121 let (tx, rx) = channel(10);
2123 deltas_client
2124 .expect_subscribe()
2125 .return_once(move |_, _| {
2126 let first_delta = BlockChanges {
2128 extractor: "test".to_string(),
2129 chain: Chain::Ethereum,
2130 block: Block {
2131 hash: Bytes::from("0x0123"),
2132 number: 1,
2133 parent_hash: Bytes::from("0x0000"),
2134 chain: Chain::Ethereum,
2135 ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
2136 },
2137 revert: false,
2138 ..Default::default()
2139 };
2140
2141 tokio::spawn(async move {
2142 let _ = tx.send(first_delta).await;
2143 tokio::time::sleep(Duration::from_secs(30)).await;
2146 });
2147
2148 Ok((Uuid::default(), rx))
2149 });
2150
2151 deltas_client
2152 .expect_unsubscribe()
2153 .return_once(|_| Ok(()));
2154
2155 let mut state_sync = ProtocolStateSynchronizer::new(
2156 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2157 true,
2158 ComponentFilter::with_tvl_range(0.0, 1000.0),
2159 1,
2160 Duration::from_secs(0),
2161 true,
2162 false,
2163 ArcRPCClient(Arc::new(rpc_client)),
2164 ArcDeltasClient(Arc::new(deltas_client)),
2165 10000_u64,
2166 );
2167
2168 state_sync
2169 .initialize()
2170 .await
2171 .expect("Init should succeed");
2172
2173 let (mut block_tx, mut block_rx) = channel(10);
2174 let (end_tx, end_rx) = oneshot::channel::<()>();
2175
2176 let state_sync_handle = tokio::spawn(async move {
2178 state_sync
2179 .state_sync(&mut block_tx, end_rx)
2180 .await
2181 });
2182
2183 let first_snapshot = block_rx
2185 .recv()
2186 .await
2187 .expect("Should receive first snapshot");
2188 assert!(
2189 !first_snapshot
2190 .snapshots
2191 .states
2192 .is_empty() ||
2193 first_snapshot.deltas.is_some()
2194 );
2195 let _ = end_tx.send(());
2197
2198 let result = state_sync_handle
2200 .await
2201 .expect("Task should not panic");
2202 assert!(
2203 result.is_ok(),
2204 "state_sync should exit cleanly when closed after first message: {result:?}"
2205 );
2206 println!("SUCCESS: Close signal handled correctly during main processing loop");
2207 }
2208}