1use std::{collections::HashMap, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7 select,
8 sync::{
9 mpsc::{channel, error::SendError, Receiver, Sender},
10 oneshot,
11 },
12 task::JoinHandle,
13 time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17 dto::{
18 BlockChanges, BlockParam, Chain, ComponentTvlRequestBody, EntryPointWithTracingParams,
19 ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20 TracingResult, VersionParam,
21 },
22 Bytes,
23};
24
25use crate::{
26 deltas::{DeltasClient, SubscriptionOptions},
27 feed::{
28 component_tracker::{ComponentFilter, ComponentTracker},
29 BlockHeader, HeaderLike,
30 },
31 rpc::{RPCClient, RPCError},
32 DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37 #[error("RPC error: {0}")]
39 RPCError(#[from] RPCError),
40
41 #[error("Failed to send channel message: {0}")]
43 ChannelError(String),
44
45 #[error("Timeout error: {0}")]
47 Timeout(String),
48
49 #[error("Failed to close synchronizer: {0}")]
51 CloseError(String),
52
53 #[error("Connection error: {0}")]
55 ConnectionError(String),
56
57 #[error("Connection closed")]
59 ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65 fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66 SynchronizerError::ChannelError(err.to_string())
67 }
68}
69
70impl From<DeltasError> for SynchronizerError {
71 fn from(err: DeltasError) -> Self {
72 match err {
73 DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74 _ => SynchronizerError::ConnectionError(err.to_string()),
75 }
76 }
77}
78
79pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
80 extractor_id: ExtractorIdentity,
81 retrieve_balances: bool,
82 rpc_client: R,
83 deltas_client: D,
84 max_retries: u64,
85 include_snapshots: bool,
86 component_tracker: ComponentTracker<R>,
87 last_synced_block: Option<BlockHeader>,
88 timeout: u64,
89 include_tvl: bool,
90}
91
92#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
93pub struct ComponentWithState {
94 pub state: ResponseProtocolState,
95 pub component: ProtocolComponent,
96 pub component_tvl: Option<f64>,
97 pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
98}
99
100#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
101pub struct Snapshot {
102 pub states: HashMap<String, ComponentWithState>,
103 pub vm_storage: HashMap<Bytes, ResponseAccount>,
104}
105
106impl Snapshot {
107 fn extend(&mut self, other: Snapshot) {
108 self.states.extend(other.states);
109 self.vm_storage.extend(other.vm_storage);
110 }
111
112 pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
113 &self.states
114 }
115
116 pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
117 &self.vm_storage
118 }
119}
120
121#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
122pub struct StateSyncMessage<H>
123where
124 H: HeaderLike,
125{
126 pub header: H,
128 pub snapshots: Snapshot,
130 pub deltas: Option<BlockChanges>,
134 pub removed_components: HashMap<String, ProtocolComponent>,
136}
137
138impl<H> StateSyncMessage<H>
139where
140 H: HeaderLike,
141{
142 pub fn merge(mut self, other: Self) -> Self {
143 self.removed_components
145 .retain(|k, _| !other.snapshots.states.contains_key(k));
146 self.snapshots
147 .states
148 .retain(|k, _| !other.removed_components.contains_key(k));
149
150 self.snapshots.extend(other.snapshots);
151 let deltas = match (self.deltas, other.deltas) {
152 (Some(l), Some(r)) => Some(l.merge(r)),
153 (None, Some(r)) => Some(r),
154 (Some(l), None) => Some(l),
155 (None, None) => None,
156 };
157 self.removed_components
158 .extend(other.removed_components);
159 Self {
160 header: other.header,
161 snapshots: self.snapshots,
162 deltas,
163 removed_components: self.removed_components,
164 }
165 }
166}
167
168pub struct SynchronizerTaskHandle {
173 join_handle: JoinHandle<SyncResult<()>>,
174 close_tx: oneshot::Sender<()>,
175}
176
177impl SynchronizerTaskHandle {
186 pub fn new(join_handle: JoinHandle<SyncResult<()>>, close_tx: oneshot::Sender<()>) -> Self {
187 Self { join_handle, close_tx }
188 }
189
190 pub fn split(self) -> (JoinHandle<SyncResult<()>>, oneshot::Sender<()>) {
196 (self.join_handle, self.close_tx)
197 }
198}
199
200#[async_trait]
201pub trait StateSynchronizer: Send + Sync + 'static {
202 async fn initialize(&mut self) -> SyncResult<()>;
203 async fn start(
206 mut self,
207 ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)>;
208}
209
210impl<R, D> ProtocolStateSynchronizer<R, D>
211where
212 R: RPCClient + Clone + Send + Sync + 'static,
215 D: DeltasClient + Clone + Send + Sync + 'static,
216{
217 #[allow(clippy::too_many_arguments)]
219 pub fn new(
220 extractor_id: ExtractorIdentity,
221 retrieve_balances: bool,
222 component_filter: ComponentFilter,
223 max_retries: u64,
224 include_snapshots: bool,
225 include_tvl: bool,
226 rpc_client: R,
227 deltas_client: D,
228 timeout: u64,
229 ) -> Self {
230 Self {
231 extractor_id: extractor_id.clone(),
232 retrieve_balances,
233 rpc_client: rpc_client.clone(),
234 include_snapshots,
235 deltas_client,
236 component_tracker: ComponentTracker::new(
237 extractor_id.chain,
238 extractor_id.name.as_str(),
239 component_filter,
240 rpc_client,
241 ),
242 max_retries,
243 last_synced_block: None,
244 timeout,
245 include_tvl,
246 }
247 }
248
249 #[allow(deprecated)]
251 async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
252 &mut self,
253 header: BlockHeader,
254 ids: Option<I>,
255 ) -> SyncResult<StateSyncMessage<BlockHeader>> {
256 if !self.include_snapshots {
257 return Ok(StateSyncMessage { header, ..Default::default() });
258 }
259 let version = VersionParam::new(
260 None,
261 Some(BlockParam {
262 chain: Some(self.extractor_id.chain),
263 hash: None,
264 number: Some(header.number as i64),
265 }),
266 );
267
268 let component_ids: Vec<_> = match ids {
270 Some(ids) => ids.into_iter().cloned().collect(),
271 None => self
272 .component_tracker
273 .get_tracked_component_ids(),
274 };
275
276 if component_ids.is_empty() {
277 return Ok(StateSyncMessage { header, ..Default::default() });
278 }
279
280 let component_tvl = if self.include_tvl {
281 let body = ComponentTvlRequestBody::id_filtered(
282 component_ids.clone(),
283 self.extractor_id.chain,
284 );
285 self.rpc_client
286 .get_component_tvl_paginated(&body, 100, 4)
287 .await?
288 .tvl
289 } else {
290 HashMap::new()
291 };
292
293 let entrypoints_result = if self.extractor_id.chain == Chain::Ethereum {
296 let result = self
298 .rpc_client
299 .get_traced_entry_points_paginated(
300 self.extractor_id.chain,
301 &self.extractor_id.name,
302 &component_ids,
303 100,
304 4,
305 )
306 .await?;
307 self.component_tracker
308 .process_entrypoints(&result.clone().into());
309 Some(result)
310 } else {
311 None
312 };
313
314 let mut protocol_states = self
316 .rpc_client
317 .get_protocol_states_paginated(
318 self.extractor_id.chain,
319 &component_ids,
320 &self.extractor_id.name,
321 self.retrieve_balances,
322 &version,
323 100,
324 4,
325 )
326 .await?
327 .states
328 .into_iter()
329 .map(|state| (state.component_id.clone(), state))
330 .collect::<HashMap<_, _>>();
331
332 trace!(states=?&protocol_states, "Retrieved ProtocolStates");
333 let states = self
334 .component_tracker
335 .components
336 .values()
337 .filter_map(|component| {
338 if let Some(state) = protocol_states.remove(&component.id) {
339 Some((
340 component.id.clone(),
341 ComponentWithState {
342 state,
343 component: component.clone(),
344 component_tvl: component_tvl
345 .get(&component.id)
346 .cloned(),
347 entrypoints: entrypoints_result
348 .as_ref()
349 .map(|r| {
350 r.traced_entry_points
351 .get(&component.id)
352 .cloned()
353 .unwrap_or_default()
354 })
355 .unwrap_or_default(),
356 },
357 ))
358 } else if component_ids.contains(&component.id) {
359 let component_id = &component.id;
361 error!(?component_id, "Missing state for native component!");
362 None
363 } else {
364 None
365 }
366 })
367 .collect();
368
369 let contract_ids = self
371 .component_tracker
372 .get_contracts_by_component(&component_ids);
373 let vm_storage = if !contract_ids.is_empty() {
374 let ids: Vec<Bytes> = contract_ids
375 .clone()
376 .into_iter()
377 .collect();
378 let contract_states = self
379 .rpc_client
380 .get_contract_state_paginated(
381 self.extractor_id.chain,
382 ids.as_slice(),
383 &self.extractor_id.name,
384 &version,
385 100,
386 4,
387 )
388 .await?
389 .accounts
390 .into_iter()
391 .map(|acc| (acc.address.clone(), acc))
392 .collect::<HashMap<_, _>>();
393
394 trace!(states=?&contract_states, "Retrieved ContractState");
395
396 let contract_address_to_components = self
397 .component_tracker
398 .components
399 .iter()
400 .filter_map(|(id, comp)| {
401 if component_ids.contains(id) {
402 Some(
403 comp.contract_ids
404 .iter()
405 .map(|address| (address.clone(), comp.id.clone())),
406 )
407 } else {
408 None
409 }
410 })
411 .flatten()
412 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
413 acc.entry(addr).or_default().push(c_id);
414 acc
415 });
416
417 contract_ids
418 .iter()
419 .filter_map(|address| {
420 if let Some(state) = contract_states.get(address) {
421 Some((address.clone(), state.clone()))
422 } else if let Some(ids) = contract_address_to_components.get(address) {
423 error!(
425 ?address,
426 ?ids,
427 "Component with lacking contract storage encountered!"
428 );
429 None
430 } else {
431 None
432 }
433 })
434 .collect()
435 } else {
436 HashMap::new()
437 };
438
439 Ok(StateSyncMessage {
440 header,
441 snapshots: Snapshot { states, vm_storage },
442 deltas: None,
443 removed_components: HashMap::new(),
444 })
445 }
446
447 #[instrument(skip(self, block_tx, end_rx), fields(extractor_id = %self.extractor_id))]
460 async fn state_sync(
461 &mut self,
462 block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
463 mut end_rx: oneshot::Receiver<()>,
464 ) -> Result<(), (SynchronizerError, Option<oneshot::Receiver<()>>)> {
465 let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
468 let (subscription_id, mut msg_rx) = match self
469 .deltas_client
470 .subscribe(self.extractor_id.clone(), subscription_options)
471 .await
472 {
473 Ok(result) => result,
474 Err(e) => return Err((e.into(), Some(end_rx))),
475 };
476
477 let result = async {
478 info!("Waiting for deltas...");
479 let mut first_msg = select! {
481 deltas_result = timeout(Duration::from_secs(self.timeout), msg_rx.recv()) => {
482 deltas_result
483 .map_err(|_| {
484 SynchronizerError::Timeout(format!(
485 "First deltas took longer than {t}s to arrive",
486 t = self.timeout
487 ))
488 })?
489 .ok_or_else(|| {
490 SynchronizerError::ConnectionError(
491 "Deltas channel closed before first message".to_string(),
492 )
493 })?
494 },
495 _ = &mut end_rx => {
496 info!("Received close signal while waiting for first deltas");
497 return Ok(());
498 }
499 };
500 self.filter_deltas(&mut first_msg);
501
502 let block = first_msg.get_block().clone();
504 info!(height = &block.number, "Deltas received. Retrieving snapshot");
505 let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
506 let snapshot = self
507 .get_snapshots::<Vec<&String>>(
508 BlockHeader::from_block(&block, false),
509 None,
510 )
511 .await?
512 .merge(StateSyncMessage {
513 header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
514 snapshots: Default::default(),
515 deltas: Some(first_msg),
516 removed_components: Default::default(),
517 });
518
519 let n_components = self.component_tracker.components.len();
520 let n_snapshots = snapshot.snapshots.states.len();
521 info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
522
523 block_tx.send(snapshot).await?;
524 self.last_synced_block = Some(header.clone());
525 loop {
526 select! {
527 deltas_opt = msg_rx.recv() => {
528 if let Some(mut deltas) = deltas_opt {
529 let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
530 debug!(block_number=?header.number, "Received delta message");
531
532 let (snapshots, removed_components) = {
533 let (to_add, to_remove) = self.component_tracker.filter_updated_components(&deltas);
536
537 let requiring_snapshot: Vec<_> = to_add
539 .iter()
540 .filter(|id| {
541 !self.component_tracker
542 .components
543 .contains_key(id.as_str())
544 })
545 .collect();
546 debug!(components=?requiring_snapshot, "SnapshotRequest");
547 self.component_tracker
548 .start_tracking(requiring_snapshot.as_slice())
549 .await?;
550 let snapshots = self
551 .get_snapshots(header.clone(), Some(requiring_snapshot))
552 .await?
553 .snapshots;
554
555 let removed_components = if !to_remove.is_empty() {
556 self.component_tracker.stop_tracking(&to_remove)
557 } else {
558 Default::default()
559 };
560
561 (snapshots, removed_components)
562 };
563
564 self.component_tracker.process_entrypoints(&deltas.dci_update);
566
567 self.filter_deltas(&mut deltas);
569 let n_changes = deltas.n_changes();
570
571 let next = StateSyncMessage {
573 header: header.clone(),
574 snapshots,
575 deltas: Some(deltas),
576 removed_components,
577 };
578 block_tx.send(next).await?;
579 self.last_synced_block = Some(header.clone());
580
581 debug!(block_number=?header.number, n_changes, "Finished processing delta message");
582 } else {
583 return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
584 }
585 },
586 _ = &mut end_rx => {
587 info!("Received close signal during state_sync");
588 return Ok(());
589 }
590 }
591 }
592 }.await;
593
594 warn!(last_synced_block = ?&self.last_synced_block, "Deltas processing ended, resetting last synced block.");
596 self.last_synced_block = None;
597 let _ = self
599 .deltas_client
600 .unsubscribe(subscription_id)
601 .await
602 .map_err(|err| {
603 warn!(err=?err, "Unsubscribing from deltas on cleanup failed!");
604 });
605
606 match result {
609 Ok(()) => Ok(()), Err(e) => {
611 Err((e, Some(end_rx)))
615 }
616 }
617 }
618
619 fn filter_deltas(&self, deltas: &mut BlockChanges) {
620 deltas.filter_by_component(|id| {
621 self.component_tracker
622 .components
623 .contains_key(id)
624 });
625 deltas.filter_by_contract(|id| {
626 self.component_tracker
627 .contracts
628 .contains(id)
629 });
630 }
631}
632
633#[async_trait]
634impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
635where
636 R: RPCClient + Clone + Send + Sync + 'static,
637 D: DeltasClient + Clone + Send + Sync + 'static,
638{
639 async fn initialize(&mut self) -> SyncResult<()> {
640 info!("Retrieving relevant protocol components");
641 self.component_tracker
642 .initialise_components()
643 .await?;
644 info!(
645 n_components = self.component_tracker.components.len(),
646 n_contracts = self.component_tracker.contracts.len(),
647 "Finished retrieving components",
648 );
649
650 Ok(())
651 }
652
653 async fn start(
654 mut self,
655 ) -> SyncResult<(SynchronizerTaskHandle, Receiver<StateSyncMessage<BlockHeader>>)> {
656 let (mut tx, rx) = channel(15);
657 let (end_tx, end_rx) = oneshot::channel::<()>();
658
659 let jh = tokio::spawn(async move {
660 let mut retry_count = 0;
661 let mut current_end_rx = end_rx;
662
663 while retry_count < self.max_retries {
664 info!(extractor_id=%&self.extractor_id, retry_count, "(Re)starting synchronization loop");
665
666 let res = self
667 .state_sync(&mut tx, current_end_rx)
668 .await;
669 match res {
670 Ok(()) => {
671 info!(
672 extractor_id=%&self.extractor_id,
673 retry_count,
674 "State synchronization exited cleanly"
675 );
676 return Ok(());
677 }
678 Err((e, maybe_end_rx)) => {
679 error!(
680 extractor_id=%&self.extractor_id,
681 retry_count,
682 error=%e,
683 "State synchronization errored!"
684 );
685
686 if let Some(recovered_end_rx) = maybe_end_rx {
688 current_end_rx = recovered_end_rx;
689
690 if let SynchronizerError::ConnectionClosed = e {
691 return Err(e);
693 }
694 } else {
695 info!(extractor_id=%&self.extractor_id, "Received close signal, exiting");
697 return Ok(());
698 }
699 }
700 }
701 retry_count += 1;
702 }
703 warn!(extractor_id=%&self.extractor_id, retry_count, "Max retries exceeded");
704 Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
705 });
706
707 let handle = SynchronizerTaskHandle::new(jh, end_tx);
708 Ok((handle, rx))
709 }
710}
711
712#[cfg(test)]
713mod test {
714 use std::{collections::HashSet, sync::Arc};
733
734 use test_log::test;
735 use tycho_common::dto::{
736 Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
737 PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
738 ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
739 ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
740 TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
741 TracedEntryPointRequestResponse, TracingParams,
742 };
743 use uuid::Uuid;
744
745 use super::*;
746 use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
747
748 struct ArcRPCClient<T>(Arc<T>);
750
751 impl<T> Clone for ArcRPCClient<T> {
753 fn clone(&self) -> Self {
754 ArcRPCClient(self.0.clone())
755 }
756 }
757
758 #[async_trait]
759 impl<T> RPCClient for ArcRPCClient<T>
760 where
761 T: RPCClient + Sync + Send + 'static,
762 {
763 async fn get_tokens(
764 &self,
765 request: &TokensRequestBody,
766 ) -> Result<TokensRequestResponse, RPCError> {
767 self.0.get_tokens(request).await
768 }
769
770 async fn get_contract_state(
771 &self,
772 request: &StateRequestBody,
773 ) -> Result<StateRequestResponse, RPCError> {
774 self.0.get_contract_state(request).await
775 }
776
777 async fn get_protocol_components(
778 &self,
779 request: &ProtocolComponentsRequestBody,
780 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
781 self.0
782 .get_protocol_components(request)
783 .await
784 }
785
786 async fn get_protocol_states(
787 &self,
788 request: &ProtocolStateRequestBody,
789 ) -> Result<ProtocolStateRequestResponse, RPCError> {
790 self.0
791 .get_protocol_states(request)
792 .await
793 }
794
795 async fn get_protocol_systems(
796 &self,
797 request: &ProtocolSystemsRequestBody,
798 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
799 self.0
800 .get_protocol_systems(request)
801 .await
802 }
803
804 async fn get_component_tvl(
805 &self,
806 request: &ComponentTvlRequestBody,
807 ) -> Result<ComponentTvlRequestResponse, RPCError> {
808 self.0.get_component_tvl(request).await
809 }
810
811 async fn get_traced_entry_points(
812 &self,
813 request: &TracedEntryPointRequestBody,
814 ) -> Result<TracedEntryPointRequestResponse, RPCError> {
815 self.0
816 .get_traced_entry_points(request)
817 .await
818 }
819 }
820
821 struct ArcDeltasClient<T>(Arc<T>);
823
824 impl<T> Clone for ArcDeltasClient<T> {
826 fn clone(&self) -> Self {
827 ArcDeltasClient(self.0.clone())
828 }
829 }
830
831 #[async_trait]
832 impl<T> DeltasClient for ArcDeltasClient<T>
833 where
834 T: DeltasClient + Sync + Send + 'static,
835 {
836 async fn subscribe(
837 &self,
838 extractor_id: ExtractorIdentity,
839 options: SubscriptionOptions,
840 ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
841 self.0
842 .subscribe(extractor_id, options)
843 .await
844 }
845
846 async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
847 self.0
848 .unsubscribe(subscription_id)
849 .await
850 }
851
852 async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
853 self.0.connect().await
854 }
855
856 async fn close(&self) -> Result<(), DeltasError> {
857 self.0.close().await
858 }
859 }
860
861 fn with_mocked_clients(
862 native: bool,
863 include_tvl: bool,
864 rpc_client: Option<MockRPCClient>,
865 deltas_client: Option<MockDeltasClient>,
866 ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
867 {
868 let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
869 let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
870
871 ProtocolStateSynchronizer::new(
872 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
873 native,
874 ComponentFilter::with_tvl_range(50.0, 50.0),
875 1,
876 true,
877 include_tvl,
878 rpc_client,
879 deltas_client,
880 10_u64,
881 )
882 }
883
884 fn state_snapshot_native() -> ProtocolStateRequestResponse {
885 ProtocolStateRequestResponse {
886 states: vec![ResponseProtocolState {
887 component_id: "Component1".to_string(),
888 ..Default::default()
889 }],
890 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
891 }
892 }
893
894 fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
895 let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
896
897 ComponentTvlRequestResponse {
898 tvl,
899 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
900 }
901 }
902
903 #[test(tokio::test)]
904 async fn test_get_snapshots_native() {
905 let header = BlockHeader::default();
906 let mut rpc = MockRPCClient::new();
907 rpc.expect_get_protocol_states()
908 .returning(|_| Ok(state_snapshot_native()));
909 rpc.expect_get_traced_entry_points()
910 .returning(|_| {
911 Ok(TracedEntryPointRequestResponse {
912 traced_entry_points: HashMap::new(),
913 pagination: PaginationResponse::new(0, 20, 0),
914 })
915 });
916 let mut state_sync = with_mocked_clients(true, false, Some(rpc), None);
917 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
918 state_sync
919 .component_tracker
920 .components
921 .insert("Component1".to_string(), component.clone());
922 let components_arg = ["Component1".to_string()];
923 let exp = StateSyncMessage {
924 header: header.clone(),
925 snapshots: Snapshot {
926 states: state_snapshot_native()
927 .states
928 .into_iter()
929 .map(|state| {
930 (
931 state.component_id.clone(),
932 ComponentWithState {
933 state,
934 component: component.clone(),
935 entrypoints: vec![],
936 component_tvl: None,
937 },
938 )
939 })
940 .collect(),
941 vm_storage: HashMap::new(),
942 },
943 deltas: None,
944 removed_components: Default::default(),
945 };
946
947 let snap = state_sync
948 .get_snapshots(header, Some(&components_arg))
949 .await
950 .expect("Retrieving snapshot failed");
951
952 assert_eq!(snap, exp);
953 }
954
955 #[test(tokio::test)]
956 async fn test_get_snapshots_native_with_tvl() {
957 let header = BlockHeader::default();
958 let mut rpc = MockRPCClient::new();
959 rpc.expect_get_protocol_states()
960 .returning(|_| Ok(state_snapshot_native()));
961 rpc.expect_get_component_tvl()
962 .returning(|_| Ok(component_tvl_snapshot()));
963 rpc.expect_get_traced_entry_points()
964 .returning(|_| {
965 Ok(TracedEntryPointRequestResponse {
966 traced_entry_points: HashMap::new(),
967 pagination: PaginationResponse::new(0, 20, 0),
968 })
969 });
970 let mut state_sync = with_mocked_clients(true, true, Some(rpc), None);
971 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
972 state_sync
973 .component_tracker
974 .components
975 .insert("Component1".to_string(), component.clone());
976 let components_arg = ["Component1".to_string()];
977 let exp = StateSyncMessage {
978 header: header.clone(),
979 snapshots: Snapshot {
980 states: state_snapshot_native()
981 .states
982 .into_iter()
983 .map(|state| {
984 (
985 state.component_id.clone(),
986 ComponentWithState {
987 state,
988 component: component.clone(),
989 component_tvl: Some(100.0),
990 entrypoints: vec![],
991 },
992 )
993 })
994 .collect(),
995 vm_storage: HashMap::new(),
996 },
997 deltas: None,
998 removed_components: Default::default(),
999 };
1000
1001 let snap = state_sync
1002 .get_snapshots(header, Some(&components_arg))
1003 .await
1004 .expect("Retrieving snapshot failed");
1005
1006 assert_eq!(snap, exp);
1007 }
1008
1009 fn state_snapshot_vm() -> StateRequestResponse {
1010 StateRequestResponse {
1011 accounts: vec![
1012 ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
1013 ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
1014 ],
1015 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1016 }
1017 }
1018
1019 fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
1020 TracedEntryPointRequestResponse {
1021 traced_entry_points: HashMap::from([(
1022 "Component1".to_string(),
1023 vec![(
1024 EntryPointWithTracingParams {
1025 entry_point: EntryPoint {
1026 external_id: "entrypoint_a".to_string(),
1027 target: Bytes::from("0x0badc0ffee"),
1028 signature: "sig()".to_string(),
1029 },
1030 params: TracingParams::RPCTracer(RPCTracerParams {
1031 caller: Some(Bytes::from("0x0badc0ffee")),
1032 calldata: Bytes::from("0x0badc0ffee"),
1033 }),
1034 },
1035 TracingResult {
1036 retriggers: HashSet::from([(
1037 Bytes::from("0x0badc0ffee"),
1038 Bytes::from("0x0badc0ffee"),
1039 )]),
1040 accessed_slots: HashMap::from([(
1041 Bytes::from("0x0badc0ffee"),
1042 HashSet::from([Bytes::from("0xbadbeef0")]),
1043 )]),
1044 },
1045 )],
1046 )]),
1047 pagination: PaginationResponse::new(0, 20, 0),
1048 }
1049 }
1050
1051 #[test(tokio::test)]
1052 async fn test_get_snapshots_vm() {
1053 let header = BlockHeader::default();
1054 let mut rpc = MockRPCClient::new();
1055 rpc.expect_get_protocol_states()
1056 .returning(|_| Ok(state_snapshot_native()));
1057 rpc.expect_get_contract_state()
1058 .returning(|_| Ok(state_snapshot_vm()));
1059 rpc.expect_get_traced_entry_points()
1060 .returning(|_| Ok(traced_entry_point_response()));
1061 let mut state_sync = with_mocked_clients(false, false, Some(rpc), None);
1062 let component = ProtocolComponent {
1063 id: "Component1".to_string(),
1064 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1065 ..Default::default()
1066 };
1067 state_sync
1068 .component_tracker
1069 .components
1070 .insert("Component1".to_string(), component.clone());
1071 let components_arg = ["Component1".to_string()];
1072 let exp = StateSyncMessage {
1073 header: header.clone(),
1074 snapshots: Snapshot {
1075 states: [(
1076 component.id.clone(),
1077 ComponentWithState {
1078 state: ResponseProtocolState {
1079 component_id: "Component1".to_string(),
1080 ..Default::default()
1081 },
1082 component: component.clone(),
1083 component_tvl: None,
1084 entrypoints: vec![(
1085 EntryPointWithTracingParams {
1086 entry_point: EntryPoint {
1087 external_id: "entrypoint_a".to_string(),
1088 target: Bytes::from("0x0badc0ffee"),
1089 signature: "sig()".to_string(),
1090 },
1091 params: TracingParams::RPCTracer(RPCTracerParams {
1092 caller: Some(Bytes::from("0x0badc0ffee")),
1093 calldata: Bytes::from("0x0badc0ffee"),
1094 }),
1095 },
1096 TracingResult {
1097 retriggers: HashSet::from([(
1098 Bytes::from("0x0badc0ffee"),
1099 Bytes::from("0x0badc0ffee"),
1100 )]),
1101 accessed_slots: HashMap::from([(
1102 Bytes::from("0x0badc0ffee"),
1103 HashSet::from([Bytes::from("0xbadbeef0")]),
1104 )]),
1105 },
1106 )],
1107 },
1108 )]
1109 .into_iter()
1110 .collect(),
1111 vm_storage: state_snapshot_vm()
1112 .accounts
1113 .into_iter()
1114 .map(|state| (state.address.clone(), state))
1115 .collect(),
1116 },
1117 deltas: None,
1118 removed_components: Default::default(),
1119 };
1120
1121 let snap = state_sync
1122 .get_snapshots(header, Some(&components_arg))
1123 .await
1124 .expect("Retrieving snapshot failed");
1125
1126 assert_eq!(snap, exp);
1127 }
1128
1129 #[test(tokio::test)]
1130 async fn test_get_snapshots_vm_with_tvl() {
1131 let header = BlockHeader::default();
1132 let mut rpc = MockRPCClient::new();
1133 rpc.expect_get_protocol_states()
1134 .returning(|_| Ok(state_snapshot_native()));
1135 rpc.expect_get_contract_state()
1136 .returning(|_| Ok(state_snapshot_vm()));
1137 rpc.expect_get_component_tvl()
1138 .returning(|_| Ok(component_tvl_snapshot()));
1139 rpc.expect_get_traced_entry_points()
1140 .returning(|_| {
1141 Ok(TracedEntryPointRequestResponse {
1142 traced_entry_points: HashMap::new(),
1143 pagination: PaginationResponse::new(0, 20, 0),
1144 })
1145 });
1146 let mut state_sync = with_mocked_clients(false, true, Some(rpc), None);
1147 let component = ProtocolComponent {
1148 id: "Component1".to_string(),
1149 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1150 ..Default::default()
1151 };
1152 state_sync
1153 .component_tracker
1154 .components
1155 .insert("Component1".to_string(), component.clone());
1156 let components_arg = ["Component1".to_string()];
1157 let exp = StateSyncMessage {
1158 header: header.clone(),
1159 snapshots: Snapshot {
1160 states: [(
1161 component.id.clone(),
1162 ComponentWithState {
1163 state: ResponseProtocolState {
1164 component_id: "Component1".to_string(),
1165 ..Default::default()
1166 },
1167 component: component.clone(),
1168 component_tvl: Some(100.0),
1169 entrypoints: vec![],
1170 },
1171 )]
1172 .into_iter()
1173 .collect(),
1174 vm_storage: state_snapshot_vm()
1175 .accounts
1176 .into_iter()
1177 .map(|state| (state.address.clone(), state))
1178 .collect(),
1179 },
1180 deltas: None,
1181 removed_components: Default::default(),
1182 };
1183
1184 let snap = state_sync
1185 .get_snapshots(header, Some(&components_arg))
1186 .await
1187 .expect("Retrieving snapshot failed");
1188
1189 assert_eq!(snap, exp);
1190 }
1191
1192 fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1193 let mut rpc_client = MockRPCClient::new();
1194 rpc_client
1197 .expect_get_protocol_components()
1198 .with(mockall::predicate::function(
1199 move |request_params: &ProtocolComponentsRequestBody| {
1200 if let Some(ids) = request_params.component_ids.as_ref() {
1201 ids.contains(&"Component3".to_string())
1202 } else {
1203 false
1204 }
1205 },
1206 ))
1207 .returning(|_| {
1208 Ok(ProtocolComponentRequestResponse {
1210 protocol_components: vec![
1211 ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1213 ],
1214 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1215 })
1216 });
1217 rpc_client
1218 .expect_get_protocol_states()
1219 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1220 let expected_id = "Component3".to_string();
1221 if let Some(ids) = request_params.protocol_ids.as_ref() {
1222 ids.contains(&expected_id)
1223 } else {
1224 false
1225 }
1226 }))
1227 .returning(|_| {
1228 Ok(ProtocolStateRequestResponse {
1230 states: vec![ResponseProtocolState {
1231 component_id: "Component3".to_string(),
1232 ..Default::default()
1233 }],
1234 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1235 })
1236 });
1237
1238 rpc_client
1240 .expect_get_protocol_components()
1241 .returning(|_| {
1242 Ok(ProtocolComponentRequestResponse {
1244 protocol_components: vec![
1245 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1247 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1249 ],
1251 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1252 })
1253 });
1254 rpc_client
1255 .expect_get_protocol_states()
1256 .returning(|_| {
1257 Ok(ProtocolStateRequestResponse {
1259 states: vec![
1260 ResponseProtocolState {
1261 component_id: "Component1".to_string(),
1262 ..Default::default()
1263 },
1264 ResponseProtocolState {
1265 component_id: "Component2".to_string(),
1266 ..Default::default()
1267 },
1268 ],
1269 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1270 })
1271 });
1272 rpc_client
1273 .expect_get_component_tvl()
1274 .returning(|_| {
1275 Ok(ComponentTvlRequestResponse {
1276 tvl: [
1277 ("Component1".to_string(), 100.0),
1278 ("Component2".to_string(), 0.0),
1279 ("Component3".to_string(), 1000.0),
1280 ]
1281 .into_iter()
1282 .collect(),
1283 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1284 })
1285 });
1286 rpc_client
1287 .expect_get_traced_entry_points()
1288 .returning(|_| {
1289 Ok(TracedEntryPointRequestResponse {
1290 traced_entry_points: HashMap::new(),
1291 pagination: PaginationResponse::new(0, 20, 0),
1292 })
1293 });
1294
1295 let mut deltas_client = MockDeltasClient::new();
1297 let (tx, rx) = channel(1);
1298 deltas_client
1299 .expect_subscribe()
1300 .return_once(move |_, _| {
1301 Ok((Uuid::default(), rx))
1303 });
1304
1305 deltas_client
1307 .expect_unsubscribe()
1308 .return_once(|_| Ok(()));
1309
1310 (rpc_client, deltas_client, tx)
1311 }
1312
1313 #[test(tokio::test)]
1320 async fn test_state_sync() {
1321 let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1322 let deltas = [
1323 BlockChanges {
1324 extractor: "uniswap-v2".to_string(),
1325 chain: Chain::Ethereum,
1326 block: Block {
1327 number: 1,
1328 hash: Bytes::from("0x01"),
1329 parent_hash: Bytes::from("0x00"),
1330 chain: Chain::Ethereum,
1331 ts: Default::default(),
1332 },
1333 revert: false,
1334 dci_update: DCIUpdate {
1335 new_entrypoints: HashMap::from([(
1336 "Component1".to_string(),
1337 HashSet::from([EntryPoint {
1338 external_id: "entrypoint_a".to_string(),
1339 target: Bytes::from("0x0badc0ffee"),
1340 signature: "sig()".to_string(),
1341 }]),
1342 )]),
1343 new_entrypoint_params: HashMap::from([(
1344 "entrypoint_a".to_string(),
1345 HashSet::from([(
1346 TracingParams::RPCTracer(RPCTracerParams {
1347 caller: Some(Bytes::from("0x0badc0ffee")),
1348 calldata: Bytes::from("0x0badc0ffee"),
1349 }),
1350 Some("Component1".to_string()),
1351 )]),
1352 )]),
1353 trace_results: HashMap::from([(
1354 "entrypoint_a".to_string(),
1355 TracingResult {
1356 retriggers: HashSet::from([(
1357 Bytes::from("0x0badc0ffee"),
1358 Bytes::from("0x0badc0ffee"),
1359 )]),
1360 accessed_slots: HashMap::from([(
1361 Bytes::from("0x0badc0ffee"),
1362 HashSet::from([Bytes::from("0xbadbeef0")]),
1363 )]),
1364 },
1365 )]),
1366 },
1367 ..Default::default()
1368 },
1369 BlockChanges {
1370 extractor: "uniswap-v2".to_string(),
1371 chain: Chain::Ethereum,
1372 block: Block {
1373 number: 2,
1374 hash: Bytes::from("0x02"),
1375 parent_hash: Bytes::from("0x01"),
1376 chain: Chain::Ethereum,
1377 ts: Default::default(),
1378 },
1379 revert: false,
1380 component_tvl: [
1381 ("Component1".to_string(), 100.0),
1382 ("Component2".to_string(), 0.0),
1383 ("Component3".to_string(), 1000.0),
1384 ]
1385 .into_iter()
1386 .collect(),
1387 ..Default::default()
1388 },
1389 ];
1390 let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1391 state_sync
1392 .initialize()
1393 .await
1394 .expect("Init failed");
1395
1396 let (handle, mut rx) = state_sync
1398 .start()
1399 .await
1400 .expect("Failed to start state synchronizer");
1401 let (jh, close_tx) = handle.split();
1402 tx.send(deltas[0].clone())
1403 .await
1404 .expect("deltas channel msg 0 closed!");
1405 let first_msg = timeout(Duration::from_millis(100), rx.recv())
1406 .await
1407 .expect("waiting for first state msg timed out!")
1408 .expect("state sync block sender closed!");
1409 tx.send(deltas[1].clone())
1410 .await
1411 .expect("deltas channel msg 1 closed!");
1412 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1413 .await
1414 .expect("waiting for second state msg timed out!")
1415 .expect("state sync block sender closed!");
1416 let _ = close_tx.send(());
1417 let exit = jh
1418 .await
1419 .expect("state sync task panicked!");
1420
1421 let exp1 = StateSyncMessage {
1423 header: BlockHeader {
1424 number: 1,
1425 hash: Bytes::from("0x01"),
1426 parent_hash: Bytes::from("0x00"),
1427 revert: false,
1428 ..Default::default()
1429 },
1430 snapshots: Snapshot {
1431 states: [
1432 (
1433 "Component1".to_string(),
1434 ComponentWithState {
1435 state: ResponseProtocolState {
1436 component_id: "Component1".to_string(),
1437 ..Default::default()
1438 },
1439 component: ProtocolComponent {
1440 id: "Component1".to_string(),
1441 ..Default::default()
1442 },
1443 component_tvl: Some(100.0),
1444 entrypoints: vec![],
1445 },
1446 ),
1447 (
1448 "Component2".to_string(),
1449 ComponentWithState {
1450 state: ResponseProtocolState {
1451 component_id: "Component2".to_string(),
1452 ..Default::default()
1453 },
1454 component: ProtocolComponent {
1455 id: "Component2".to_string(),
1456 ..Default::default()
1457 },
1458 component_tvl: Some(0.0),
1459 entrypoints: vec![],
1460 },
1461 ),
1462 ]
1463 .into_iter()
1464 .collect(),
1465 vm_storage: HashMap::new(),
1466 },
1467 deltas: Some(deltas[0].clone()),
1468 removed_components: Default::default(),
1469 };
1470
1471 let exp2 = StateSyncMessage {
1472 header: BlockHeader {
1473 number: 2,
1474 hash: Bytes::from("0x02"),
1475 parent_hash: Bytes::from("0x01"),
1476 revert: false,
1477 ..Default::default()
1478 },
1479 snapshots: Snapshot {
1480 states: [
1481 (
1483 "Component3".to_string(),
1484 ComponentWithState {
1485 state: ResponseProtocolState {
1486 component_id: "Component3".to_string(),
1487 ..Default::default()
1488 },
1489 component: ProtocolComponent {
1490 id: "Component3".to_string(),
1491 ..Default::default()
1492 },
1493 component_tvl: Some(1000.0),
1494 entrypoints: vec![],
1495 },
1496 ),
1497 ]
1498 .into_iter()
1499 .collect(),
1500 vm_storage: HashMap::new(),
1501 },
1502 deltas: Some(BlockChanges {
1505 extractor: "uniswap-v2".to_string(),
1506 chain: Chain::Ethereum,
1507 block: Block {
1508 number: 2,
1509 hash: Bytes::from("0x02"),
1510 parent_hash: Bytes::from("0x01"),
1511 chain: Chain::Ethereum,
1512 ts: Default::default(),
1513 },
1514 revert: false,
1515 component_tvl: [
1516 ("Component1".to_string(), 100.0),
1518 ("Component3".to_string(), 1000.0),
1519 ]
1520 .into_iter()
1521 .collect(),
1522 ..Default::default()
1523 }),
1524 removed_components: [(
1526 "Component2".to_string(),
1527 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1528 )]
1529 .into_iter()
1530 .collect(),
1531 };
1532 assert_eq!(first_msg, exp1);
1533 assert_eq!(second_msg, exp2);
1534 assert!(exit.is_ok());
1535 }
1536
1537 #[test(tokio::test)]
1538 async fn test_state_sync_with_tvl_range() {
1539 let remove_tvl_threshold = 5.0;
1541 let add_tvl_threshold = 7.0;
1542
1543 let mut rpc_client = MockRPCClient::new();
1544 let mut deltas_client = MockDeltasClient::new();
1545
1546 rpc_client
1547 .expect_get_protocol_components()
1548 .with(mockall::predicate::function(
1549 move |request_params: &ProtocolComponentsRequestBody| {
1550 if let Some(ids) = request_params.component_ids.as_ref() {
1551 ids.contains(&"Component3".to_string())
1552 } else {
1553 false
1554 }
1555 },
1556 ))
1557 .returning(|_| {
1558 Ok(ProtocolComponentRequestResponse {
1559 protocol_components: vec![ProtocolComponent {
1560 id: "Component3".to_string(),
1561 ..Default::default()
1562 }],
1563 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1564 })
1565 });
1566 rpc_client
1567 .expect_get_protocol_states()
1568 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1569 let expected_id = "Component3".to_string();
1570 if let Some(ids) = request_params.protocol_ids.as_ref() {
1571 ids.contains(&expected_id)
1572 } else {
1573 false
1574 }
1575 }))
1576 .returning(|_| {
1577 Ok(ProtocolStateRequestResponse {
1578 states: vec![ResponseProtocolState {
1579 component_id: "Component3".to_string(),
1580 ..Default::default()
1581 }],
1582 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1583 })
1584 });
1585
1586 rpc_client
1588 .expect_get_protocol_components()
1589 .returning(|_| {
1590 Ok(ProtocolComponentRequestResponse {
1591 protocol_components: vec![
1592 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1593 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1594 ],
1595 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1596 })
1597 });
1598 rpc_client
1599 .expect_get_protocol_states()
1600 .returning(|_| {
1601 Ok(ProtocolStateRequestResponse {
1602 states: vec![
1603 ResponseProtocolState {
1604 component_id: "Component1".to_string(),
1605 ..Default::default()
1606 },
1607 ResponseProtocolState {
1608 component_id: "Component2".to_string(),
1609 ..Default::default()
1610 },
1611 ],
1612 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1613 })
1614 });
1615 rpc_client
1616 .expect_get_traced_entry_points()
1617 .returning(|_| {
1618 Ok(TracedEntryPointRequestResponse {
1619 traced_entry_points: HashMap::new(),
1620 pagination: PaginationResponse::new(0, 20, 0),
1621 })
1622 });
1623
1624 rpc_client
1625 .expect_get_component_tvl()
1626 .returning(|_| {
1627 Ok(ComponentTvlRequestResponse {
1628 tvl: [
1629 ("Component1".to_string(), 6.0),
1630 ("Component2".to_string(), 2.0),
1631 ("Component3".to_string(), 10.0),
1632 ]
1633 .into_iter()
1634 .collect(),
1635 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1636 })
1637 });
1638
1639 rpc_client
1640 .expect_get_component_tvl()
1641 .returning(|_| {
1642 Ok(ComponentTvlRequestResponse {
1643 tvl: [
1644 ("Component1".to_string(), 6.0),
1645 ("Component2".to_string(), 2.0),
1646 ("Component3".to_string(), 10.0),
1647 ]
1648 .into_iter()
1649 .collect(),
1650 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1651 })
1652 });
1653
1654 let (tx, rx) = channel(1);
1655 deltas_client
1656 .expect_subscribe()
1657 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1658
1659 deltas_client
1661 .expect_unsubscribe()
1662 .return_once(|_| Ok(()));
1663
1664 let mut state_sync = ProtocolStateSynchronizer::new(
1665 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1666 true,
1667 ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1668 1,
1669 true,
1670 true,
1671 ArcRPCClient(Arc::new(rpc_client)),
1672 ArcDeltasClient(Arc::new(deltas_client)),
1673 10_u64,
1674 );
1675 state_sync
1676 .initialize()
1677 .await
1678 .expect("Init failed");
1679
1680 let deltas = [
1682 BlockChanges {
1683 extractor: "uniswap-v2".to_string(),
1684 chain: Chain::Ethereum,
1685 block: Block {
1686 number: 1,
1687 hash: Bytes::from("0x01"),
1688 parent_hash: Bytes::from("0x00"),
1689 chain: Chain::Ethereum,
1690 ts: Default::default(),
1691 },
1692 revert: false,
1693 ..Default::default()
1694 },
1695 BlockChanges {
1696 extractor: "uniswap-v2".to_string(),
1697 chain: Chain::Ethereum,
1698 block: Block {
1699 number: 2,
1700 hash: Bytes::from("0x02"),
1701 parent_hash: Bytes::from("0x01"),
1702 chain: Chain::Ethereum,
1703 ts: Default::default(),
1704 },
1705 revert: false,
1706 component_tvl: [
1707 ("Component1".to_string(), 6.0), ("Component2".to_string(), 2.0), ("Component3".to_string(), 10.0), ]
1711 .into_iter()
1712 .collect(),
1713 ..Default::default()
1714 },
1715 ];
1716
1717 let (handle, mut rx) = state_sync
1718 .start()
1719 .await
1720 .expect("Failed to start state synchronizer");
1721 let (jh, close_tx) = handle.split();
1722
1723 tx.send(deltas[0].clone())
1725 .await
1726 .expect("deltas channel msg 0 closed!");
1727
1728 let _ = timeout(Duration::from_millis(100), rx.recv())
1730 .await
1731 .expect("waiting for first state msg timed out!")
1732 .expect("state sync block sender closed!");
1733
1734 tx.send(deltas[1].clone())
1736 .await
1737 .expect("deltas channel msg 1 closed!");
1738 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1739 .await
1740 .expect("waiting for second state msg timed out!")
1741 .expect("state sync block sender closed!");
1742
1743 let _ = close_tx.send(());
1744 let exit = jh
1745 .await
1746 .expect("state sync task panicked!");
1747
1748 let expected_second_msg = StateSyncMessage {
1749 header: BlockHeader {
1750 number: 2,
1751 hash: Bytes::from("0x02"),
1752 parent_hash: Bytes::from("0x01"),
1753 revert: false,
1754 ..Default::default()
1755 },
1756 snapshots: Snapshot {
1757 states: [(
1758 "Component3".to_string(),
1759 ComponentWithState {
1760 state: ResponseProtocolState {
1761 component_id: "Component3".to_string(),
1762 ..Default::default()
1763 },
1764 component: ProtocolComponent {
1765 id: "Component3".to_string(),
1766 ..Default::default()
1767 },
1768 component_tvl: Some(10.0),
1769 entrypoints: vec![], },
1771 )]
1772 .into_iter()
1773 .collect(),
1774 vm_storage: HashMap::new(),
1775 },
1776 deltas: Some(BlockChanges {
1777 extractor: "uniswap-v2".to_string(),
1778 chain: Chain::Ethereum,
1779 block: Block {
1780 number: 2,
1781 hash: Bytes::from("0x02"),
1782 parent_hash: Bytes::from("0x01"),
1783 chain: Chain::Ethereum,
1784 ts: Default::default(),
1785 },
1786 revert: false,
1787 component_tvl: [
1788 ("Component1".to_string(), 6.0), ("Component3".to_string(), 10.0), ]
1791 .into_iter()
1792 .collect(),
1793 ..Default::default()
1794 }),
1795 removed_components: [(
1796 "Component2".to_string(),
1797 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1798 )]
1799 .into_iter()
1800 .collect(),
1801 };
1802
1803 assert_eq!(second_msg, expected_second_msg);
1804 assert!(exit.is_ok());
1805 }
1806
1807 #[test(tokio::test)]
1808 async fn test_public_close_api_functionality() {
1809 let mut rpc_client = MockRPCClient::new();
1816 let mut deltas_client = MockDeltasClient::new();
1817
1818 rpc_client
1820 .expect_get_protocol_components()
1821 .returning(|_| {
1822 Ok(ProtocolComponentRequestResponse {
1823 protocol_components: vec![],
1824 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1825 })
1826 });
1827
1828 let (_tx, rx) = channel(1);
1830 deltas_client
1831 .expect_subscribe()
1832 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1833
1834 deltas_client
1836 .expect_unsubscribe()
1837 .return_once(|_| Ok(()));
1838
1839 let mut state_sync = ProtocolStateSynchronizer::new(
1840 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1841 true,
1842 ComponentFilter::with_tvl_range(0.0, 0.0),
1843 5, true,
1845 false,
1846 ArcRPCClient(Arc::new(rpc_client)),
1847 ArcDeltasClient(Arc::new(deltas_client)),
1848 10000_u64, );
1850
1851 state_sync
1852 .initialize()
1853 .await
1854 .expect("Init should succeed");
1855
1856 let (handle, _rx) = state_sync
1858 .start()
1859 .await
1860 .expect("Failed to start state synchronizer");
1861 let (jh, close_tx) = handle.split();
1862
1863 tokio::time::sleep(Duration::from_millis(100)).await;
1865
1866 close_tx
1868 .send(())
1869 .expect("Should be able to send close signal");
1870 let task_result = jh.await.expect("Task should not panic");
1872 assert!(task_result.is_ok(), "Task should exit cleanly after close: {task_result:?}");
1873 }
1874
1875 #[test(tokio::test)]
1876 async fn test_cleanup_runs_when_state_sync_processing_errors() {
1877 let mut rpc_client = MockRPCClient::new();
1882 let mut deltas_client = MockDeltasClient::new();
1883
1884 rpc_client
1886 .expect_get_protocol_components()
1887 .returning(|_| {
1888 Ok(ProtocolComponentRequestResponse {
1889 protocol_components: vec![],
1890 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
1891 })
1892 });
1893
1894 rpc_client
1896 .expect_get_protocol_states()
1897 .returning(|_| {
1898 Err(RPCError::HttpClient("Test error during snapshot retrieval".to_string()))
1899 });
1900
1901 let (tx, rx) = channel(10);
1903 deltas_client
1904 .expect_subscribe()
1905 .return_once(move |_, _| {
1906 let delta = BlockChanges {
1908 extractor: "test".to_string(),
1909 chain: Chain::Ethereum,
1910 block: Block {
1911 hash: Bytes::from("0x0123"),
1912 number: 1,
1913 parent_hash: Bytes::from("0x0000"),
1914 chain: Chain::Ethereum,
1915 ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
1916 },
1917 revert: false,
1918 new_protocol_components: [(
1920 "new_component".to_string(),
1921 ProtocolComponent {
1922 id: "new_component".to_string(),
1923 protocol_system: "test_protocol".to_string(),
1924 protocol_type_name: "test".to_string(),
1925 chain: Chain::Ethereum,
1926 tokens: vec![Bytes::from("0x0badc0ffee")],
1927 contract_ids: vec![Bytes::from("0x0badc0ffee")],
1928 static_attributes: Default::default(),
1929 creation_tx: Default::default(),
1930 created_at: Default::default(),
1931 change: Default::default(),
1932 },
1933 )]
1934 .into_iter()
1935 .collect(),
1936 component_tvl: [("new_component".to_string(), 100.0)]
1937 .into_iter()
1938 .collect(),
1939 ..Default::default()
1940 };
1941
1942 tokio::spawn(async move {
1943 let _ = tx.send(delta).await;
1944 });
1946
1947 Ok((Uuid::default(), rx))
1948 });
1949
1950 deltas_client
1952 .expect_unsubscribe()
1953 .return_once(|_| Ok(()));
1954
1955 let mut state_sync = ProtocolStateSynchronizer::new(
1956 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
1957 true,
1958 ComponentFilter::with_tvl_range(0.0, 1000.0), 1,
1960 true,
1961 false,
1962 ArcRPCClient(Arc::new(rpc_client)),
1963 ArcDeltasClient(Arc::new(deltas_client)),
1964 5000_u64,
1965 );
1966
1967 state_sync
1968 .initialize()
1969 .await
1970 .expect("Init should succeed");
1971
1972 state_sync.last_synced_block = Some(BlockHeader {
1974 hash: Bytes::from("0x0badc0ffee"),
1975 number: 42,
1976 parent_hash: Bytes::from("0xbadbeef0"),
1977 revert: false,
1978 timestamp: 123456789,
1979 });
1980
1981 let (mut block_tx, _block_rx) = channel(10);
1983
1984 let (_end_tx, end_rx) = oneshot::channel::<()>();
1986 let result = state_sync
1987 .state_sync(&mut block_tx, end_rx)
1988 .await;
1989 assert!(result.is_err(), "state_sync should have errored during processing");
1991
1992 }
1995
1996 #[test(tokio::test)]
1997 async fn test_close_signal_while_waiting_for_first_deltas() {
1998 let mut rpc_client = MockRPCClient::new();
2002 let mut deltas_client = MockDeltasClient::new();
2003
2004 rpc_client
2005 .expect_get_protocol_components()
2006 .returning(|_| {
2007 Ok(ProtocolComponentRequestResponse {
2008 protocol_components: vec![],
2009 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2010 })
2011 });
2012
2013 let (_tx, rx) = channel(1);
2014 deltas_client
2015 .expect_subscribe()
2016 .return_once(move |_, _| Ok((Uuid::default(), rx)));
2017
2018 deltas_client
2019 .expect_unsubscribe()
2020 .return_once(|_| Ok(()));
2021
2022 let mut state_sync = ProtocolStateSynchronizer::new(
2023 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2024 true,
2025 ComponentFilter::with_tvl_range(0.0, 0.0),
2026 1,
2027 true,
2028 false,
2029 ArcRPCClient(Arc::new(rpc_client)),
2030 ArcDeltasClient(Arc::new(deltas_client)),
2031 10000_u64,
2032 );
2033
2034 state_sync
2035 .initialize()
2036 .await
2037 .expect("Init should succeed");
2038
2039 let (mut block_tx, _block_rx) = channel(10);
2040 let (end_tx, end_rx) = oneshot::channel::<()>();
2041
2042 let state_sync_handle = tokio::spawn(async move {
2044 state_sync
2045 .state_sync(&mut block_tx, end_rx)
2046 .await
2047 });
2048
2049 tokio::time::sleep(Duration::from_millis(100)).await;
2051
2052 let _ = end_tx.send(());
2054
2055 let result = state_sync_handle
2057 .await
2058 .expect("Task should not panic");
2059 assert!(result.is_ok(), "state_sync should exit cleanly when closed: {result:?}");
2060
2061 println!("SUCCESS: Close signal handled correctly while waiting for first deltas");
2062 }
2063
2064 #[test(tokio::test)]
2065 async fn test_close_signal_during_main_processing_loop() {
2066 let mut rpc_client = MockRPCClient::new();
2072 let mut deltas_client = MockDeltasClient::new();
2073
2074 rpc_client
2076 .expect_get_protocol_components()
2077 .returning(|_| {
2078 Ok(ProtocolComponentRequestResponse {
2079 protocol_components: vec![],
2080 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2081 })
2082 });
2083
2084 rpc_client
2086 .expect_get_protocol_states()
2087 .returning(|_| {
2088 Ok(ProtocolStateRequestResponse {
2089 states: vec![],
2090 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2091 })
2092 });
2093
2094 rpc_client
2095 .expect_get_component_tvl()
2096 .returning(|_| {
2097 Ok(ComponentTvlRequestResponse {
2098 tvl: HashMap::new(),
2099 pagination: PaginationResponse { page: 0, page_size: 20, total: 0 },
2100 })
2101 });
2102
2103 rpc_client
2104 .expect_get_traced_entry_points()
2105 .returning(|_| {
2106 Ok(TracedEntryPointRequestResponse {
2107 traced_entry_points: HashMap::new(),
2108 pagination: PaginationResponse::new(0, 20, 0),
2109 })
2110 });
2111
2112 let (tx, rx) = channel(10);
2114 deltas_client
2115 .expect_subscribe()
2116 .return_once(move |_, _| {
2117 let first_delta = BlockChanges {
2119 extractor: "test".to_string(),
2120 chain: Chain::Ethereum,
2121 block: Block {
2122 hash: Bytes::from("0x0123"),
2123 number: 1,
2124 parent_hash: Bytes::from("0x0000"),
2125 chain: Chain::Ethereum,
2126 ts: chrono::NaiveDateTime::from_timestamp_opt(1234567890, 0).unwrap(),
2127 },
2128 revert: false,
2129 ..Default::default()
2130 };
2131
2132 tokio::spawn(async move {
2133 let _ = tx.send(first_delta).await;
2134 tokio::time::sleep(Duration::from_secs(30)).await;
2137 });
2138
2139 Ok((Uuid::default(), rx))
2140 });
2141
2142 deltas_client
2143 .expect_unsubscribe()
2144 .return_once(|_| Ok(()));
2145
2146 let mut state_sync = ProtocolStateSynchronizer::new(
2147 ExtractorIdentity::new(Chain::Ethereum, "test-protocol"),
2148 true,
2149 ComponentFilter::with_tvl_range(0.0, 1000.0),
2150 1,
2151 true,
2152 false,
2153 ArcRPCClient(Arc::new(rpc_client)),
2154 ArcDeltasClient(Arc::new(deltas_client)),
2155 10000_u64,
2156 );
2157
2158 state_sync
2159 .initialize()
2160 .await
2161 .expect("Init should succeed");
2162
2163 let (mut block_tx, mut block_rx) = channel(10);
2164 let (end_tx, end_rx) = oneshot::channel::<()>();
2165
2166 let state_sync_handle = tokio::spawn(async move {
2168 state_sync
2169 .state_sync(&mut block_tx, end_rx)
2170 .await
2171 });
2172
2173 let first_snapshot = block_rx
2175 .recv()
2176 .await
2177 .expect("Should receive first snapshot");
2178 assert!(
2179 !first_snapshot
2180 .snapshots
2181 .states
2182 .is_empty() ||
2183 first_snapshot.deltas.is_some()
2184 );
2185 let _ = end_tx.send(());
2187
2188 let result = state_sync_handle
2190 .await
2191 .expect("Task should not panic");
2192 assert!(
2193 result.is_ok(),
2194 "state_sync should exit cleanly when closed after first message: {result:?}"
2195 );
2196 println!("SUCCESS: Close signal handled correctly during main processing loop");
2197 }
2198}