1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tokio::{
11 select,
12 sync::{
13 mpsc::{channel, error::SendError, Receiver, Sender},
14 oneshot, Mutex,
15 },
16 task::JoinHandle,
17 time::timeout,
18};
19use tracing::{debug, error, info, instrument, trace, warn};
20use tycho_common::{
21 dto::{
22 BlockChanges, BlockParam, ComponentTvlRequestBody, ExtractorIdentity, ProtocolComponent,
23 ResponseAccount, ResponseProtocolState, VersionParam,
24 },
25 Bytes,
26};
27
28use crate::{
29 deltas::{DeltasClient, SubscriptionOptions},
30 feed::{
31 component_tracker::{ComponentFilter, ComponentTracker},
32 Header,
33 },
34 rpc::{RPCClient, RPCError},
35 DeltasError,
36};
37
38#[derive(Error, Debug)]
39pub enum SynchronizerError {
40 #[error("RPC error: {0}")]
42 RPCError(#[from] RPCError),
43
44 #[error("Failed to send channel message: {0}")]
46 ChannelError(String),
47
48 #[error("Timeout error: {0}")]
50 Timeout(String),
51
52 #[error("Failed to close synchronizer: {0}")]
54 CloseError(String),
55
56 #[error("Connection error: {0}")]
58 ConnectionError(String),
59
60 #[error("Connection closed")]
62 ConnectionClosed,
63}
64
65pub type SyncResult<T> = Result<T, SynchronizerError>;
66
67impl From<SendError<StateSyncMessage>> for SynchronizerError {
68 fn from(err: SendError<StateSyncMessage>) -> Self {
69 SynchronizerError::ChannelError(err.to_string())
70 }
71}
72
73impl From<DeltasError> for SynchronizerError {
74 fn from(err: DeltasError) -> Self {
75 match err {
76 DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
77 _ => SynchronizerError::ConnectionError(err.to_string()),
78 }
79 }
80}
81
82#[derive(Clone)]
83pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
84 extractor_id: ExtractorIdentity,
85 retrieve_balances: bool,
86 rpc_client: R,
87 deltas_client: D,
88 max_retries: u64,
89 include_snapshots: bool,
90 component_tracker: Arc<Mutex<ComponentTracker<R>>>,
91 shared: Arc<Mutex<SharedState>>,
92 end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
93 timeout: u64,
94 include_tvl: bool,
95}
96
97#[derive(Debug, Default)]
98struct SharedState {
99 last_synced_block: Option<Header>,
100}
101
102#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
103pub struct ComponentWithState {
104 pub state: ResponseProtocolState,
105 pub component: ProtocolComponent,
106 pub component_tvl: Option<f64>,
107}
108
109#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
110pub struct Snapshot {
111 pub states: HashMap<String, ComponentWithState>,
112 pub vm_storage: HashMap<Bytes, ResponseAccount>,
113}
114
115impl Snapshot {
116 fn extend(&mut self, other: Snapshot) {
117 self.states.extend(other.states);
118 self.vm_storage.extend(other.vm_storage);
119 }
120
121 pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
122 &self.states
123 }
124
125 pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
126 &self.vm_storage
127 }
128}
129
130#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
131pub struct StateSyncMessage {
132 pub header: Header,
134 pub snapshots: Snapshot,
136 pub deltas: Option<BlockChanges>,
140 pub removed_components: HashMap<String, ProtocolComponent>,
142}
143
144impl StateSyncMessage {
145 pub fn merge(mut self, other: Self) -> Self {
146 self.removed_components
148 .retain(|k, _| !other.snapshots.states.contains_key(k));
149 self.snapshots
150 .states
151 .retain(|k, _| !other.removed_components.contains_key(k));
152
153 self.snapshots.extend(other.snapshots);
154 let deltas = match (self.deltas, other.deltas) {
155 (Some(l), Some(r)) => Some(l.merge(r)),
156 (None, Some(r)) => Some(r),
157 (Some(l), None) => Some(l),
158 (None, None) => None,
159 };
160 self.removed_components
161 .extend(other.removed_components);
162 Self {
163 header: other.header,
164 snapshots: self.snapshots,
165 deltas,
166 removed_components: self.removed_components,
167 }
168 }
169}
170
171#[async_trait]
180pub trait StateSynchronizer: Send + Sync + 'static {
181 async fn initialize(&self) -> SyncResult<()>;
182 async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)>;
184 async fn close(&mut self) -> SyncResult<()>;
186}
187
188impl<R, D> ProtocolStateSynchronizer<R, D>
189where
190 R: RPCClient + Clone + Send + Sync + 'static,
193 D: DeltasClient + Clone + Send + Sync + 'static,
194{
195 #[allow(clippy::too_many_arguments)]
197 pub fn new(
198 extractor_id: ExtractorIdentity,
199 retrieve_balances: bool,
200 component_filter: ComponentFilter,
201 max_retries: u64,
202 include_snapshots: bool,
203 include_tvl: bool,
204 rpc_client: R,
205 deltas_client: D,
206 timeout: u64,
207 ) -> Self {
208 Self {
209 extractor_id: extractor_id.clone(),
210 retrieve_balances,
211 rpc_client: rpc_client.clone(),
212 include_snapshots,
213 deltas_client,
214 component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
215 extractor_id.chain,
216 extractor_id.name.as_str(),
217 component_filter,
218 rpc_client,
219 ))),
220 max_retries,
221 shared: Arc::new(Mutex::new(SharedState::default())),
222 end_tx: Arc::new(Mutex::new(None)),
223 timeout,
224 include_tvl,
225 }
226 }
227
228 #[allow(deprecated)]
236 async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
237 &self,
238 header: Header,
239 tracked_components: &ComponentTracker<R>,
240 ids: Option<I>,
241 ) -> SyncResult<StateSyncMessage> {
242 if !self.include_snapshots {
243 return Ok(StateSyncMessage { header, ..Default::default() });
244 }
245 let version = VersionParam::new(
246 None,
247 Some(BlockParam {
248 chain: Some(self.extractor_id.chain),
249 hash: None,
250 number: Some(header.number as i64),
251 }),
252 );
253
254 let request_ids = ids
256 .map(|it| {
257 it.into_iter()
258 .cloned()
259 .collect::<Vec<_>>()
260 })
261 .unwrap_or_else(|| tracked_components.get_tracked_component_ids());
262
263 let component_ids = request_ids
264 .iter()
265 .collect::<HashSet<_>>();
266
267 if component_ids.is_empty() {
268 return Ok(StateSyncMessage { header, ..Default::default() });
269 }
270
271 let component_tvl = if self.include_tvl {
272 let body =
273 ComponentTvlRequestBody::id_filtered(request_ids.clone(), self.extractor_id.chain);
274 self.rpc_client
275 .get_component_tvl_paginated(&body, 100, 4)
276 .await?
277 .tvl
278 } else {
279 HashMap::new()
280 };
281
282 let mut protocol_states = self
283 .rpc_client
284 .get_protocol_states_paginated(
285 self.extractor_id.chain,
286 &request_ids,
287 &self.extractor_id.name,
288 self.retrieve_balances,
289 &version,
290 100,
291 4,
292 )
293 .await?
294 .states
295 .into_iter()
296 .map(|state| (state.component_id.clone(), state))
297 .collect::<HashMap<_, _>>();
298
299 trace!(states=?&protocol_states, "Retrieved ProtocolStates");
300 let states = tracked_components
301 .components
302 .values()
303 .filter_map(|component| {
304 if let Some(state) = protocol_states.remove(&component.id) {
305 Some((
306 component.id.clone(),
307 ComponentWithState {
308 state,
309 component: component.clone(),
310 component_tvl: component_tvl
311 .get(&component.id)
312 .cloned(),
313 },
314 ))
315 } else if component_ids.contains(&&component.id) {
316 let component_id = &component.id;
318 error!(?component_id, "Missing state for native component!");
319 None
320 } else {
321 None
322 }
323 })
324 .collect();
325
326 let contract_ids = tracked_components.get_contracts_by_component(component_ids.clone());
327 let vm_storage = if !contract_ids.is_empty() {
328 let ids: Vec<Bytes> = contract_ids
329 .clone()
330 .into_iter()
331 .collect();
332 let contract_states = self
333 .rpc_client
334 .get_contract_state_paginated(
335 self.extractor_id.chain,
336 ids.as_slice(),
337 &self.extractor_id.name,
338 &version,
339 100,
340 4,
341 )
342 .await?
343 .accounts
344 .into_iter()
345 .map(|acc| (acc.address.clone(), acc))
346 .collect::<HashMap<_, _>>();
347
348 trace!(states=?&contract_states, "Retrieved ContractState");
349
350 let contract_address_to_components = tracked_components
351 .components
352 .iter()
353 .filter_map(|(id, comp)| {
354 if component_ids.contains(&id) {
355 Some(
356 comp.contract_ids
357 .iter()
358 .map(|address| (address.clone(), comp.id.clone())),
359 )
360 } else {
361 None
362 }
363 })
364 .flatten()
365 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
366 acc.entry(addr).or_default().push(c_id);
367 acc
368 });
369
370 contract_ids
371 .iter()
372 .filter_map(|address| {
373 if let Some(state) = contract_states.get(address) {
374 Some((address.clone(), state.clone()))
375 } else if let Some(ids) = contract_address_to_components.get(address) {
376 error!(
378 ?address,
379 ?ids,
380 "Component with lacking contract storage encountered!"
381 );
382 None
383 } else {
384 None
385 }
386 })
387 .collect()
388 } else {
389 HashMap::new()
390 };
391
392 Ok(StateSyncMessage {
393 header,
394 snapshots: Snapshot { states, vm_storage },
395 deltas: None,
396 removed_components: HashMap::new(),
397 })
398 }
399
400 #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
402 async fn state_sync(self, block_tx: &mut Sender<StateSyncMessage>) -> SyncResult<()> {
403 let mut tracker = self.component_tracker.lock().await;
405
406 let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
407 let (_, mut msg_rx) = self
408 .deltas_client
409 .subscribe(self.extractor_id.clone(), subscription_options)
410 .await?;
411
412 info!("Waiting for deltas...");
413 let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
415 .await
416 .map_err(|_| {
417 SynchronizerError::Timeout(format!(
418 "First deltas took longer than {t}s to arrive",
419 t = self.timeout
420 ))
421 })?
422 .ok_or_else(|| {
423 SynchronizerError::ConnectionError(
424 "Deltas channel closed before first message".to_string(),
425 )
426 })?;
427 self.filter_deltas(&mut first_msg, &tracker);
428
429 let block = first_msg.get_block().clone();
431 info!(height = &block.number, "Deltas received. Retrieving snapshot");
432 let header = Header::from_block(first_msg.get_block(), first_msg.is_revert());
433 let snapshot = self
434 .get_snapshots::<Vec<&String>>(Header::from_block(&block, false), &tracker, None)
435 .await?
436 .merge(StateSyncMessage {
437 header: Header::from_block(first_msg.get_block(), first_msg.is_revert()),
438 snapshots: Default::default(),
439 deltas: Some(first_msg),
440 removed_components: Default::default(),
441 });
442
443 let n_components = tracker.components.len();
444 let n_snapshots = snapshot.snapshots.states.len();
445 info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
446
447 {
448 let mut shared = self.shared.lock().await;
449 block_tx.send(snapshot).await?;
450 shared.last_synced_block = Some(header.clone());
451 }
452
453 loop {
454 if let Some(mut deltas) = msg_rx.recv().await {
455 let header = Header::from_block(deltas.get_block(), deltas.is_revert());
456 debug!(block_number=?header.number, "Received delta message");
457 let (snapshots, removed_components) = {
458 let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
461
462 let requiring_snapshot: Vec<_> = to_add
464 .iter()
465 .filter(|id| {
466 !tracker
467 .components
468 .contains_key(id.as_str())
469 })
470 .collect();
471 debug!(components=?requiring_snapshot, "SnapshotRequest");
472 tracker
473 .start_tracking(requiring_snapshot.as_slice())
474 .await?;
475 let snapshots = self
476 .get_snapshots(header.clone(), &tracker, Some(requiring_snapshot))
477 .await?
478 .snapshots;
479
480 let removed_components = if !to_remove.is_empty() {
481 tracker.stop_tracking(&to_remove)
482 } else {
483 Default::default()
484 };
485 (snapshots, removed_components)
486 };
487
488 self.filter_deltas(&mut deltas, &tracker);
490 let n_changes = deltas.n_changes();
491
492 let next = StateSyncMessage {
493 header: header.clone(),
494 snapshots,
495 deltas: Some(deltas),
496 removed_components,
497 };
498 block_tx.send(next).await?;
499 {
500 let mut shared = self.shared.lock().await;
501 shared.last_synced_block = Some(header.clone());
502 }
503
504 debug!(block_number=?header.number, n_changes, "Finished processing delta message");
505 } else {
506 let mut shared = self.shared.lock().await;
507 warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
508 shared.last_synced_block = None;
509
510 return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
511 }
512 }
513 }
514
515 fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
516 second_msg.filter_by_component(|id| tracker.components.contains_key(id));
517 second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
518 }
519}
520
521#[async_trait]
522impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
523where
524 R: RPCClient + Clone + Send + Sync + 'static,
525 D: DeltasClient + Clone + Send + Sync + 'static,
526{
527 async fn initialize(&self) -> SyncResult<()> {
528 let mut tracker = self.component_tracker.lock().await;
529 info!("Retrieving relevant protocol components");
530 tracker.initialise_components().await?;
531 info!(
532 n_components = tracker.components.len(),
533 n_contracts = tracker.contracts.len(),
534 "Finished retrieving components",
535 );
536
537 Ok(())
538 }
539
540 async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)> {
541 let (mut tx, rx) = channel(15);
542
543 let this = self.clone();
544 let jh = tokio::spawn(async move {
545 let mut retry_count = 0;
546 while retry_count < this.max_retries {
547 info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
548 let (end_tx, end_rx) = oneshot::channel::<()>();
549 {
550 let mut end_tx_guard = this.end_tx.lock().await;
551 *end_tx_guard = Some(end_tx);
552 }
553
554 select! {
555 res = this.clone().state_sync(&mut tx) => {
556 match res {
557 Err(e) => {
558 error!(
559 extractor_id=%&this.extractor_id,
560 retry_count,
561 error=%e,
562 "State synchronization errored!"
563 );
564 if let SynchronizerError::ConnectionClosed = e {
565 return Err(e);
567 }
568 }
569 _ => {
570 warn!(
571 extractor_id=%&this.extractor_id,
572 retry_count,
573 "State synchronization exited with Ok(())"
574 );
575 }
576 }
577 },
578 _ = end_rx => {
579 info!(
580 extractor_id=%&this.extractor_id,
581 retry_count,
582 "StateSynchronizer received close signal. Stopping"
583 );
584 return Ok(())
585 }
586 }
587 retry_count += 1;
588 }
589 Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
590 });
591
592 Ok((jh, rx))
593 }
594
595 async fn close(&mut self) -> SyncResult<()> {
596 let mut end_tx = self.end_tx.lock().await;
597 if let Some(tx) = end_tx.take() {
598 let _ = tx.send(());
599 Ok(())
600 } else {
601 Err(SynchronizerError::CloseError("Synchronizer not started".to_string()))
602 }
603 }
604}
605
606#[cfg(test)]
607mod test {
608 use test_log::test;
609 use tycho_common::dto::{
610 Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, PaginationResponse,
611 ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
612 ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
613 StateRequestBody, StateRequestResponse, TokensRequestBody, TokensRequestResponse,
614 };
615 use uuid::Uuid;
616
617 use super::*;
618 use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
619
620 struct ArcRPCClient<T>(Arc<T>);
622
623 impl<T> Clone for ArcRPCClient<T> {
625 fn clone(&self) -> Self {
626 ArcRPCClient(self.0.clone())
627 }
628 }
629
630 #[async_trait]
631 impl<T> RPCClient for ArcRPCClient<T>
632 where
633 T: RPCClient + Sync + Send + 'static,
634 {
635 async fn get_tokens(
636 &self,
637 request: &TokensRequestBody,
638 ) -> Result<TokensRequestResponse, RPCError> {
639 self.0.get_tokens(request).await
640 }
641
642 async fn get_contract_state(
643 &self,
644 request: &StateRequestBody,
645 ) -> Result<StateRequestResponse, RPCError> {
646 self.0.get_contract_state(request).await
647 }
648
649 async fn get_protocol_components(
650 &self,
651 request: &ProtocolComponentsRequestBody,
652 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
653 self.0
654 .get_protocol_components(request)
655 .await
656 }
657
658 async fn get_protocol_states(
659 &self,
660 request: &ProtocolStateRequestBody,
661 ) -> Result<ProtocolStateRequestResponse, RPCError> {
662 self.0
663 .get_protocol_states(request)
664 .await
665 }
666
667 async fn get_protocol_systems(
668 &self,
669 request: &ProtocolSystemsRequestBody,
670 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
671 self.0
672 .get_protocol_systems(request)
673 .await
674 }
675
676 async fn get_component_tvl(
677 &self,
678 request: &ComponentTvlRequestBody,
679 ) -> Result<ComponentTvlRequestResponse, RPCError> {
680 self.0.get_component_tvl(request).await
681 }
682 }
683
684 struct ArcDeltasClient<T>(Arc<T>);
686
687 impl<T> Clone for ArcDeltasClient<T> {
689 fn clone(&self) -> Self {
690 ArcDeltasClient(self.0.clone())
691 }
692 }
693
694 #[async_trait]
695 impl<T> DeltasClient for ArcDeltasClient<T>
696 where
697 T: DeltasClient + Sync + Send + 'static,
698 {
699 async fn subscribe(
700 &self,
701 extractor_id: ExtractorIdentity,
702 options: SubscriptionOptions,
703 ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
704 self.0
705 .subscribe(extractor_id, options)
706 .await
707 }
708
709 async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
710 self.0
711 .unsubscribe(subscription_id)
712 .await
713 }
714
715 async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
716 self.0.connect().await
717 }
718
719 async fn close(&self) -> Result<(), DeltasError> {
720 self.0.close().await
721 }
722 }
723
724 fn with_mocked_clients(
725 native: bool,
726 include_tvl: bool,
727 rpc_client: Option<MockRPCClient>,
728 deltas_client: Option<MockDeltasClient>,
729 ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
730 {
731 let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
732 let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
733
734 ProtocolStateSynchronizer::new(
735 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
736 native,
737 ComponentFilter::with_tvl_range(50.0, 50.0),
738 1,
739 true,
740 include_tvl,
741 rpc_client,
742 deltas_client,
743 10_u64,
744 )
745 }
746
747 fn state_snapshot_native() -> ProtocolStateRequestResponse {
748 ProtocolStateRequestResponse {
749 states: vec![ResponseProtocolState {
750 component_id: "Component1".to_string(),
751 ..Default::default()
752 }],
753 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
754 }
755 }
756
757 fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
758 let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
759
760 ComponentTvlRequestResponse {
761 tvl,
762 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
763 }
764 }
765
766 #[test(tokio::test)]
767 async fn test_get_snapshots_native() {
768 let header = Header::default();
769 let mut rpc = MockRPCClient::new();
770 rpc.expect_get_protocol_states()
771 .returning(|_| Ok(state_snapshot_native()));
772 let state_sync = with_mocked_clients(true, false, Some(rpc), None);
773 let mut tracker = ComponentTracker::new(
774 Chain::Ethereum,
775 "uniswap-v2",
776 ComponentFilter::with_tvl_range(0.0, 0.0),
777 state_sync.rpc_client.clone(),
778 );
779 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
780 tracker
781 .components
782 .insert("Component1".to_string(), component.clone());
783 let components_arg = ["Component1".to_string()];
784 let exp = StateSyncMessage {
785 header: header.clone(),
786 snapshots: Snapshot {
787 states: state_snapshot_native()
788 .states
789 .into_iter()
790 .map(|state| {
791 (
792 state.component_id.clone(),
793 ComponentWithState {
794 state,
795 component: component.clone(),
796 component_tvl: None,
797 },
798 )
799 })
800 .collect(),
801 vm_storage: HashMap::new(),
802 },
803 deltas: None,
804 removed_components: Default::default(),
805 };
806
807 let snap = state_sync
808 .get_snapshots(header, &tracker, Some(&components_arg))
809 .await
810 .expect("Retrieving snapshot failed");
811
812 assert_eq!(snap, exp);
813 }
814
815 #[test(tokio::test)]
816 async fn test_get_snapshots_native_with_tvl() {
817 let header = Header::default();
818 let mut rpc = MockRPCClient::new();
819 rpc.expect_get_protocol_states()
820 .returning(|_| Ok(state_snapshot_native()));
821 rpc.expect_get_component_tvl()
822 .returning(|_| Ok(component_tvl_snapshot()));
823 let state_sync = with_mocked_clients(true, true, Some(rpc), None);
824 let mut tracker = ComponentTracker::new(
825 Chain::Ethereum,
826 "uniswap-v2",
827 ComponentFilter::with_tvl_range(0.0, 0.0),
828 state_sync.rpc_client.clone(),
829 );
830 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
831 tracker
832 .components
833 .insert("Component1".to_string(), component.clone());
834 let components_arg = ["Component1".to_string()];
835 let exp = StateSyncMessage {
836 header: header.clone(),
837 snapshots: Snapshot {
838 states: state_snapshot_native()
839 .states
840 .into_iter()
841 .map(|state| {
842 (
843 state.component_id.clone(),
844 ComponentWithState {
845 state,
846 component: component.clone(),
847 component_tvl: Some(100.0),
848 },
849 )
850 })
851 .collect(),
852 vm_storage: HashMap::new(),
853 },
854 deltas: None,
855 removed_components: Default::default(),
856 };
857
858 let snap = state_sync
859 .get_snapshots(header, &tracker, Some(&components_arg))
860 .await
861 .expect("Retrieving snapshot failed");
862
863 assert_eq!(snap, exp);
864 }
865
866 fn state_snapshot_vm() -> StateRequestResponse {
867 StateRequestResponse {
868 accounts: vec![
869 ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
870 ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
871 ],
872 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
873 }
874 }
875
876 #[test(tokio::test)]
877 async fn test_get_snapshots_vm() {
878 let header = Header::default();
879 let mut rpc = MockRPCClient::new();
880 rpc.expect_get_protocol_states()
881 .returning(|_| Ok(state_snapshot_native()));
882 rpc.expect_get_contract_state()
883 .returning(|_| Ok(state_snapshot_vm()));
884 let state_sync = with_mocked_clients(false, false, Some(rpc), None);
885 let mut tracker = ComponentTracker::new(
886 Chain::Ethereum,
887 "uniswap-v2",
888 ComponentFilter::with_tvl_range(0.0, 0.0),
889 state_sync.rpc_client.clone(),
890 );
891 let component = ProtocolComponent {
892 id: "Component1".to_string(),
893 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
894 ..Default::default()
895 };
896 tracker
897 .components
898 .insert("Component1".to_string(), component.clone());
899 let components_arg = ["Component1".to_string()];
900 let exp = StateSyncMessage {
901 header: header.clone(),
902 snapshots: Snapshot {
903 states: [(
904 component.id.clone(),
905 ComponentWithState {
906 state: ResponseProtocolState {
907 component_id: "Component1".to_string(),
908 ..Default::default()
909 },
910 component: component.clone(),
911 component_tvl: None,
912 },
913 )]
914 .into_iter()
915 .collect(),
916 vm_storage: state_snapshot_vm()
917 .accounts
918 .into_iter()
919 .map(|state| (state.address.clone(), state))
920 .collect(),
921 },
922 deltas: None,
923 removed_components: Default::default(),
924 };
925
926 let snap = state_sync
927 .get_snapshots(header, &tracker, Some(&components_arg))
928 .await
929 .expect("Retrieving snapshot failed");
930
931 assert_eq!(snap, exp);
932 }
933
934 #[test(tokio::test)]
935 async fn test_get_snapshots_vm_with_tvl() {
936 let header = Header::default();
937 let mut rpc = MockRPCClient::new();
938 rpc.expect_get_protocol_states()
939 .returning(|_| Ok(state_snapshot_native()));
940 rpc.expect_get_contract_state()
941 .returning(|_| Ok(state_snapshot_vm()));
942 rpc.expect_get_component_tvl()
943 .returning(|_| Ok(component_tvl_snapshot()));
944 let state_sync = with_mocked_clients(false, true, Some(rpc), None);
945 let mut tracker = ComponentTracker::new(
946 Chain::Ethereum,
947 "uniswap-v2",
948 ComponentFilter::with_tvl_range(0.0, 0.0),
949 state_sync.rpc_client.clone(),
950 );
951 let component = ProtocolComponent {
952 id: "Component1".to_string(),
953 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
954 ..Default::default()
955 };
956 tracker
957 .components
958 .insert("Component1".to_string(), component.clone());
959 let components_arg = ["Component1".to_string()];
960 let exp = StateSyncMessage {
961 header: header.clone(),
962 snapshots: Snapshot {
963 states: [(
964 component.id.clone(),
965 ComponentWithState {
966 state: ResponseProtocolState {
967 component_id: "Component1".to_string(),
968 ..Default::default()
969 },
970 component: component.clone(),
971 component_tvl: Some(100.0),
972 },
973 )]
974 .into_iter()
975 .collect(),
976 vm_storage: state_snapshot_vm()
977 .accounts
978 .into_iter()
979 .map(|state| (state.address.clone(), state))
980 .collect(),
981 },
982 deltas: None,
983 removed_components: Default::default(),
984 };
985
986 let snap = state_sync
987 .get_snapshots(header, &tracker, Some(&components_arg))
988 .await
989 .expect("Retrieving snapshot failed");
990
991 assert_eq!(snap, exp);
992 }
993
994 fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
995 let mut rpc_client = MockRPCClient::new();
996 rpc_client
999 .expect_get_protocol_components()
1000 .with(mockall::predicate::function(
1001 move |request_params: &ProtocolComponentsRequestBody| {
1002 if let Some(ids) = request_params.component_ids.as_ref() {
1003 ids.contains(&"Component3".to_string())
1004 } else {
1005 false
1006 }
1007 },
1008 ))
1009 .returning(|_| {
1010 Ok(ProtocolComponentRequestResponse {
1012 protocol_components: vec![
1013 ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1015 ],
1016 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1017 })
1018 });
1019 rpc_client
1020 .expect_get_protocol_states()
1021 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1022 let expected_id = "Component3".to_string();
1023 if let Some(ids) = request_params.protocol_ids.as_ref() {
1024 ids.contains(&expected_id)
1025 } else {
1026 false
1027 }
1028 }))
1029 .returning(|_| {
1030 Ok(ProtocolStateRequestResponse {
1032 states: vec![ResponseProtocolState {
1033 component_id: "Component3".to_string(),
1034 ..Default::default()
1035 }],
1036 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1037 })
1038 });
1039
1040 rpc_client
1042 .expect_get_protocol_components()
1043 .returning(|_| {
1044 Ok(ProtocolComponentRequestResponse {
1046 protocol_components: vec![
1047 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1049 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1051 ],
1053 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1054 })
1055 });
1056 rpc_client
1057 .expect_get_protocol_states()
1058 .returning(|_| {
1059 Ok(ProtocolStateRequestResponse {
1061 states: vec![
1062 ResponseProtocolState {
1063 component_id: "Component1".to_string(),
1064 ..Default::default()
1065 },
1066 ResponseProtocolState {
1067 component_id: "Component2".to_string(),
1068 ..Default::default()
1069 },
1070 ],
1071 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1072 })
1073 });
1074 rpc_client
1075 .expect_get_component_tvl()
1076 .returning(|_| {
1077 Ok(ComponentTvlRequestResponse {
1078 tvl: [
1079 ("Component1".to_string(), 100.0),
1080 ("Component2".to_string(), 0.0),
1081 ("Component3".to_string(), 1000.0),
1082 ]
1083 .into_iter()
1084 .collect(),
1085 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1086 })
1087 });
1088 let mut deltas_client = MockDeltasClient::new();
1090 let (tx, rx) = channel(1);
1091 deltas_client
1092 .expect_subscribe()
1093 .return_once(move |_, _| {
1094 Ok((Uuid::default(), rx))
1096 });
1097 (rpc_client, deltas_client, tx)
1098 }
1099
1100 #[test(tokio::test)]
1107 async fn test_state_sync() {
1108 let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1109 let deltas = [
1110 BlockChanges {
1111 extractor: "uniswap-v2".to_string(),
1112 chain: Chain::Ethereum,
1113 block: Block {
1114 number: 1,
1115 hash: Bytes::from("0x01"),
1116 parent_hash: Bytes::from("0x00"),
1117 chain: Chain::Ethereum,
1118 ts: Default::default(),
1119 },
1120 revert: false,
1121 ..Default::default()
1122 },
1123 BlockChanges {
1124 extractor: "uniswap-v2".to_string(),
1125 chain: Chain::Ethereum,
1126 block: Block {
1127 number: 2,
1128 hash: Bytes::from("0x02"),
1129 parent_hash: Bytes::from("0x01"),
1130 chain: Chain::Ethereum,
1131 ts: Default::default(),
1132 },
1133 revert: false,
1134 component_tvl: [
1135 ("Component1".to_string(), 100.0),
1136 ("Component2".to_string(), 0.0),
1137 ("Component3".to_string(), 1000.0),
1138 ]
1139 .into_iter()
1140 .collect(),
1141 ..Default::default()
1142 },
1143 ];
1144 let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1145 state_sync
1146 .initialize()
1147 .await
1148 .expect("Init failed");
1149
1150 let (jh, mut rx) = state_sync
1152 .start()
1153 .await
1154 .expect("Failed to start state synchronizer");
1155 tx.send(deltas[0].clone())
1156 .await
1157 .expect("deltas channel msg 0 closed!");
1158 let first_msg = timeout(Duration::from_millis(100), rx.recv())
1159 .await
1160 .expect("waiting for first state msg timed out!")
1161 .expect("state sync block sender closed!");
1162 tx.send(deltas[1].clone())
1163 .await
1164 .expect("deltas channel msg 1 closed!");
1165 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1166 .await
1167 .expect("waiting for second state msg timed out!")
1168 .expect("state sync block sender closed!");
1169 let _ = state_sync.close().await;
1170 let exit = jh
1171 .await
1172 .expect("state sync task panicked!");
1173
1174 let exp1 = StateSyncMessage {
1176 header: Header {
1177 number: 1,
1178 hash: Bytes::from("0x01"),
1179 parent_hash: Bytes::from("0x00"),
1180 revert: false,
1181 },
1182 snapshots: Snapshot {
1183 states: [
1184 (
1185 "Component1".to_string(),
1186 ComponentWithState {
1187 state: ResponseProtocolState {
1188 component_id: "Component1".to_string(),
1189 ..Default::default()
1190 },
1191 component: ProtocolComponent {
1192 id: "Component1".to_string(),
1193 ..Default::default()
1194 },
1195 component_tvl: Some(100.0),
1196 },
1197 ),
1198 (
1199 "Component2".to_string(),
1200 ComponentWithState {
1201 state: ResponseProtocolState {
1202 component_id: "Component2".to_string(),
1203 ..Default::default()
1204 },
1205 component: ProtocolComponent {
1206 id: "Component2".to_string(),
1207 ..Default::default()
1208 },
1209 component_tvl: Some(0.0),
1210 },
1211 ),
1212 ]
1213 .into_iter()
1214 .collect(),
1215 vm_storage: HashMap::new(),
1216 },
1217 deltas: Some(deltas[0].clone()),
1218 removed_components: Default::default(),
1219 };
1220
1221 let exp2 = StateSyncMessage {
1222 header: Header {
1223 number: 2,
1224 hash: Bytes::from("0x02"),
1225 parent_hash: Bytes::from("0x01"),
1226 revert: false,
1227 },
1228 snapshots: Snapshot {
1229 states: [
1230 (
1232 "Component3".to_string(),
1233 ComponentWithState {
1234 state: ResponseProtocolState {
1235 component_id: "Component3".to_string(),
1236 ..Default::default()
1237 },
1238 component: ProtocolComponent {
1239 id: "Component3".to_string(),
1240 ..Default::default()
1241 },
1242 component_tvl: Some(1000.0),
1243 },
1244 ),
1245 ]
1246 .into_iter()
1247 .collect(),
1248 vm_storage: HashMap::new(),
1249 },
1250 deltas: Some(BlockChanges {
1253 extractor: "uniswap-v2".to_string(),
1254 chain: Chain::Ethereum,
1255 block: Block {
1256 number: 2,
1257 hash: Bytes::from("0x02"),
1258 parent_hash: Bytes::from("0x01"),
1259 chain: Chain::Ethereum,
1260 ts: Default::default(),
1261 },
1262 revert: false,
1263 component_tvl: [
1264 ("Component1".to_string(), 100.0),
1266 ("Component3".to_string(), 1000.0),
1267 ]
1268 .into_iter()
1269 .collect(),
1270 ..Default::default()
1271 }),
1272 removed_components: [(
1274 "Component2".to_string(),
1275 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1276 )]
1277 .into_iter()
1278 .collect(),
1279 };
1280 assert_eq!(first_msg, exp1);
1281 assert_eq!(second_msg, exp2);
1282 assert!(exit.is_ok());
1283 }
1284
1285 #[test(tokio::test)]
1286 async fn test_state_sync_with_tvl_range() {
1287 let remove_tvl_threshold = 5.0;
1289 let add_tvl_threshold = 7.0;
1290
1291 let mut rpc_client = MockRPCClient::new();
1292 let mut deltas_client = MockDeltasClient::new();
1293
1294 rpc_client
1295 .expect_get_protocol_components()
1296 .with(mockall::predicate::function(
1297 move |request_params: &ProtocolComponentsRequestBody| {
1298 if let Some(ids) = request_params.component_ids.as_ref() {
1299 ids.contains(&"Component3".to_string())
1300 } else {
1301 false
1302 }
1303 },
1304 ))
1305 .returning(|_| {
1306 Ok(ProtocolComponentRequestResponse {
1307 protocol_components: vec![ProtocolComponent {
1308 id: "Component3".to_string(),
1309 ..Default::default()
1310 }],
1311 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1312 })
1313 });
1314
1315 rpc_client
1316 .expect_get_protocol_states()
1317 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1318 let expected_id = "Component3".to_string();
1319 if let Some(ids) = request_params.protocol_ids.as_ref() {
1320 ids.contains(&expected_id)
1321 } else {
1322 false
1323 }
1324 }))
1325 .returning(|_| {
1326 Ok(ProtocolStateRequestResponse {
1327 states: vec![ResponseProtocolState {
1328 component_id: "Component3".to_string(),
1329 ..Default::default()
1330 }],
1331 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1332 })
1333 });
1334
1335 rpc_client
1337 .expect_get_protocol_components()
1338 .returning(|_| {
1339 Ok(ProtocolComponentRequestResponse {
1340 protocol_components: vec![
1341 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1342 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1343 ],
1344 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1345 })
1346 });
1347
1348 rpc_client
1349 .expect_get_protocol_states()
1350 .returning(|_| {
1351 Ok(ProtocolStateRequestResponse {
1352 states: vec![
1353 ResponseProtocolState {
1354 component_id: "Component1".to_string(),
1355 ..Default::default()
1356 },
1357 ResponseProtocolState {
1358 component_id: "Component2".to_string(),
1359 ..Default::default()
1360 },
1361 ],
1362 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1363 })
1364 });
1365
1366 rpc_client
1367 .expect_get_component_tvl()
1368 .returning(|_| {
1369 Ok(ComponentTvlRequestResponse {
1370 tvl: [
1371 ("Component1".to_string(), 6.0),
1372 ("Component2".to_string(), 2.0),
1373 ("Component3".to_string(), 10.0),
1374 ]
1375 .into_iter()
1376 .collect(),
1377 pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1378 })
1379 });
1380
1381 let (tx, rx) = channel(1);
1382 deltas_client
1383 .expect_subscribe()
1384 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1385
1386 let mut state_sync = ProtocolStateSynchronizer::new(
1387 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1388 true,
1389 ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1390 1,
1391 true,
1392 true,
1393 ArcRPCClient(Arc::new(rpc_client)),
1394 ArcDeltasClient(Arc::new(deltas_client)),
1395 10_u64,
1396 );
1397 state_sync
1398 .initialize()
1399 .await
1400 .expect("Init failed");
1401
1402 let deltas = [
1404 BlockChanges {
1405 extractor: "uniswap-v2".to_string(),
1406 chain: Chain::Ethereum,
1407 block: Block {
1408 number: 1,
1409 hash: Bytes::from("0x01"),
1410 parent_hash: Bytes::from("0x00"),
1411 chain: Chain::Ethereum,
1412 ts: Default::default(),
1413 },
1414 revert: false,
1415 ..Default::default()
1416 },
1417 BlockChanges {
1418 extractor: "uniswap-v2".to_string(),
1419 chain: Chain::Ethereum,
1420 block: Block {
1421 number: 2,
1422 hash: Bytes::from("0x02"),
1423 parent_hash: Bytes::from("0x01"),
1424 chain: Chain::Ethereum,
1425 ts: Default::default(),
1426 },
1427 revert: false,
1428 component_tvl: [
1429 ("Component1".to_string(), 6.0), ("Component2".to_string(), 2.0), ("Component3".to_string(), 10.0), ]
1433 .into_iter()
1434 .collect(),
1435 ..Default::default()
1436 },
1437 ];
1438
1439 let (jh, mut rx) = state_sync
1440 .start()
1441 .await
1442 .expect("Failed to start state synchronizer");
1443
1444 tx.send(deltas[0].clone())
1446 .await
1447 .expect("deltas channel msg 0 closed!");
1448
1449 let _ = timeout(Duration::from_millis(100), rx.recv())
1451 .await
1452 .expect("waiting for first state msg timed out!")
1453 .expect("state sync block sender closed!");
1454
1455 tx.send(deltas[1].clone())
1457 .await
1458 .expect("deltas channel msg 1 closed!");
1459 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1460 .await
1461 .expect("waiting for second state msg timed out!")
1462 .expect("state sync block sender closed!");
1463
1464 let _ = state_sync.close().await;
1465 let exit = jh
1466 .await
1467 .expect("state sync task panicked!");
1468
1469 let expected_second_msg = StateSyncMessage {
1470 header: Header {
1471 number: 2,
1472 hash: Bytes::from("0x02"),
1473 parent_hash: Bytes::from("0x01"),
1474 revert: false,
1475 },
1476 snapshots: Snapshot {
1477 states: [(
1478 "Component3".to_string(),
1479 ComponentWithState {
1480 state: ResponseProtocolState {
1481 component_id: "Component3".to_string(),
1482 ..Default::default()
1483 },
1484 component: ProtocolComponent {
1485 id: "Component3".to_string(),
1486 ..Default::default()
1487 },
1488 component_tvl: Some(10.0),
1489 },
1490 )]
1491 .into_iter()
1492 .collect(),
1493 vm_storage: HashMap::new(),
1494 },
1495 deltas: Some(BlockChanges {
1496 extractor: "uniswap-v2".to_string(),
1497 chain: Chain::Ethereum,
1498 block: Block {
1499 number: 2,
1500 hash: Bytes::from("0x02"),
1501 parent_hash: Bytes::from("0x01"),
1502 chain: Chain::Ethereum,
1503 ts: Default::default(),
1504 },
1505 revert: false,
1506 component_tvl: [
1507 ("Component1".to_string(), 6.0), ("Component3".to_string(), 10.0), ]
1510 .into_iter()
1511 .collect(),
1512 ..Default::default()
1513 }),
1514 removed_components: [(
1515 "Component2".to_string(),
1516 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1517 )]
1518 .into_iter()
1519 .collect(),
1520 };
1521
1522 assert_eq!(second_msg, expected_second_msg);
1523 assert!(exit.is_ok());
1524 }
1525}