1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4 time::Duration,
5};
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use tokio::{
10 select,
11 sync::{
12 mpsc::{channel, Receiver, Sender},
13 oneshot, Mutex,
14 },
15 task::JoinHandle,
16 time::timeout,
17};
18use tracing::{debug, error, info, instrument, trace, warn};
19use tycho_common::{
20 dto::{
21 BlockChanges, BlockParam, ExtractorIdentity, ProtocolComponent, ResponseAccount,
22 ResponseProtocolState, VersionParam,
23 },
24 Bytes,
25};
26
27use crate::{
28 deltas::{DeltasClient, SubscriptionOptions},
29 feed::{
30 component_tracker::{ComponentFilter, ComponentTracker},
31 Header,
32 },
33 rpc::RPCClient,
34};
35
36pub type SyncResult<T> = anyhow::Result<T>;
37
38#[derive(Clone)]
39pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
40 extractor_id: ExtractorIdentity,
41 retrieve_balances: bool,
42 rpc_client: R,
43 deltas_client: D,
44 max_retries: u64,
45 include_snapshots: bool,
46 component_tracker: Arc<Mutex<ComponentTracker<R>>>,
47 shared: Arc<Mutex<SharedState>>,
48 end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
49 timeout: u64,
50}
51
52#[derive(Debug, Default)]
53struct SharedState {
54 last_synced_block: Option<Header>,
55}
56
57#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
58pub struct ComponentWithState {
59 pub state: ResponseProtocolState,
60 pub component: ProtocolComponent,
61}
62
63#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
64pub struct Snapshot {
65 pub states: HashMap<String, ComponentWithState>,
66 pub vm_storage: HashMap<Bytes, ResponseAccount>,
67}
68
69impl Snapshot {
70 fn extend(&mut self, other: Snapshot) {
71 self.states.extend(other.states);
72 self.vm_storage.extend(other.vm_storage);
73 }
74
75 pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
76 &self.states
77 }
78
79 pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
80 &self.vm_storage
81 }
82}
83
84#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
85pub struct StateSyncMessage {
86 pub header: Header,
88 pub snapshots: Snapshot,
90 pub deltas: Option<BlockChanges>,
94 pub removed_components: HashMap<String, ProtocolComponent>,
96}
97
98impl StateSyncMessage {
99 pub fn merge(mut self, other: Self) -> Self {
100 self.removed_components
102 .retain(|k, _| !other.snapshots.states.contains_key(k));
103 self.snapshots
104 .states
105 .retain(|k, _| !other.removed_components.contains_key(k));
106
107 self.snapshots.extend(other.snapshots);
108 let deltas = match (self.deltas, other.deltas) {
109 (Some(l), Some(r)) => Some(l.merge(r)),
110 (None, Some(r)) => Some(r),
111 (Some(l), None) => Some(l),
112 (None, None) => None,
113 };
114 self.removed_components
115 .extend(other.removed_components);
116 Self {
117 header: other.header,
118 snapshots: self.snapshots,
119 deltas,
120 removed_components: self.removed_components,
121 }
122 }
123}
124
125#[async_trait]
134pub trait StateSynchronizer: Send + Sync + 'static {
135 async fn initialize(&self) -> SyncResult<()>;
136 async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)>;
138 async fn close(&mut self) -> SyncResult<()>;
140}
141
142impl<R, D> ProtocolStateSynchronizer<R, D>
143where
144 R: RPCClient + Clone + Send + Sync + 'static,
147 D: DeltasClient + Clone + Send + Sync + 'static,
148{
149 #[allow(clippy::too_many_arguments)]
151 pub fn new(
152 extractor_id: ExtractorIdentity,
153 retrieve_balances: bool,
154 component_filter: ComponentFilter,
155 max_retries: u64,
156 include_snapshots: bool,
157 rpc_client: R,
158 deltas_client: D,
159 timeout: u64,
160 ) -> Self {
161 Self {
162 extractor_id: extractor_id.clone(),
163 retrieve_balances,
164 rpc_client: rpc_client.clone(),
165 include_snapshots,
166 deltas_client,
167 component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
168 extractor_id.chain,
169 extractor_id.name.as_str(),
170 component_filter,
171 rpc_client,
172 ))),
173 max_retries,
174 shared: Arc::new(Mutex::new(SharedState::default())),
175 end_tx: Arc::new(Mutex::new(None)),
176 timeout,
177 }
178 }
179
180 #[allow(deprecated)]
188 async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
189 &self,
190 header: Header,
191 tracked_components: &ComponentTracker<R>,
192 ids: Option<I>,
193 ) -> SyncResult<StateSyncMessage> {
194 if !self.include_snapshots {
195 return Ok(StateSyncMessage { header, ..Default::default() });
196 }
197 let version = VersionParam::new(
198 None,
199 Some(BlockParam {
200 chain: Some(self.extractor_id.chain),
201 hash: None,
202 number: Some(header.number as i64),
203 }),
204 );
205
206 let request_ids = ids
208 .map(|it| {
209 it.into_iter()
210 .cloned()
211 .collect::<Vec<_>>()
212 })
213 .unwrap_or_else(|| tracked_components.get_tracked_component_ids());
214
215 let component_ids = request_ids
216 .iter()
217 .collect::<HashSet<_>>();
218
219 if component_ids.is_empty() {
220 return Ok(StateSyncMessage { header, ..Default::default() });
221 }
222
223 let mut protocol_states = self
224 .rpc_client
225 .get_protocol_states_paginated(
226 self.extractor_id.chain,
227 &request_ids,
228 &self.extractor_id.name,
229 self.retrieve_balances,
230 &version,
231 100,
232 4,
233 )
234 .await?
235 .states
236 .into_iter()
237 .map(|state| (state.component_id.clone(), state))
238 .collect::<HashMap<_, _>>();
239
240 trace!(states=?&protocol_states, "Retrieved ProtocolStates");
241 let states = tracked_components
242 .components
243 .values()
244 .filter_map(|component| {
245 if let Some(state) = protocol_states.remove(&component.id) {
246 Some((
247 component.id.clone(),
248 ComponentWithState { state, component: component.clone() },
249 ))
250 } else if component_ids.contains(&&component.id) {
251 let component_id = &component.id;
253 error!(?component_id, "Missing state for native component!");
254 None
255 } else {
256 None
257 }
258 })
259 .collect();
260
261 let contract_ids = tracked_components.get_contracts_by_component(component_ids.clone());
262 let vm_storage = if !contract_ids.is_empty() {
263 let ids: Vec<Bytes> = contract_ids
264 .clone()
265 .into_iter()
266 .collect();
267 let contract_states = self
268 .rpc_client
269 .get_contract_state_paginated(
270 self.extractor_id.chain,
271 ids.as_slice(),
272 &self.extractor_id.name,
273 &version,
274 100,
275 4,
276 )
277 .await?
278 .accounts
279 .into_iter()
280 .map(|acc| (acc.address.clone(), acc))
281 .collect::<HashMap<_, _>>();
282
283 trace!(states=?&contract_states, "Retrieved ContractState");
284
285 let contract_address_to_components = tracked_components
286 .components
287 .iter()
288 .filter_map(|(id, comp)| {
289 if component_ids.contains(&id) {
290 Some(
291 comp.contract_ids
292 .iter()
293 .map(|address| (address.clone(), comp.id.clone())),
294 )
295 } else {
296 None
297 }
298 })
299 .flatten()
300 .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
301 acc.entry(addr).or_default().push(c_id);
302 acc
303 });
304
305 contract_ids
306 .iter()
307 .filter_map(|address| {
308 if let Some(state) = contract_states.get(address) {
309 Some((address.clone(), state.clone()))
310 } else if let Some(ids) = contract_address_to_components.get(address) {
311 error!(
313 ?address,
314 ?ids,
315 "Component with lacking contract storage encountered!"
316 );
317 None
318 } else {
319 None
320 }
321 })
322 .collect()
323 } else {
324 HashMap::new()
325 };
326
327 Ok(StateSyncMessage {
328 header,
329 snapshots: Snapshot { states, vm_storage },
330 deltas: None,
331 removed_components: HashMap::new(),
332 })
333 }
334
335 #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
337 async fn state_sync(self, block_tx: &mut Sender<StateSyncMessage>) -> SyncResult<()> {
338 let mut tracker = self.component_tracker.lock().await;
340
341 let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
342 let (_, mut msg_rx) = self
343 .deltas_client
344 .subscribe(self.extractor_id.clone(), subscription_options)
345 .await?;
346
347 info!("Waiting for deltas...");
348 let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
350 .await?
351 .ok_or_else(|| anyhow::format_err!("Subscription ended too soon"))?;
352 self.filter_deltas(&mut first_msg, &tracker);
353
354 let block = first_msg.get_block().clone();
356 info!(height = &block.number, "Deltas received. Retrieving snapshot");
357 let header = Header::from_block(first_msg.get_block(), first_msg.is_revert());
358 let snapshot = self
359 .get_snapshots::<Vec<&String>>(Header::from_block(&block, false), &tracker, None)
360 .await
361 .map_err(|rpc_err| anyhow::format_err!("failed to get initial snapshot: {}", rpc_err))?
362 .merge(StateSyncMessage {
363 header: Header::from_block(first_msg.get_block(), first_msg.is_revert()),
364 snapshots: Default::default(),
365 deltas: Some(first_msg),
366 removed_components: Default::default(),
367 });
368
369 let n_components = tracker.components.len();
370 let n_snapshots = snapshot.snapshots.states.len();
371 info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
372
373 {
374 let mut shared = self.shared.lock().await;
375 block_tx.send(snapshot).await?;
376 shared.last_synced_block = Some(header.clone());
377 }
378
379 loop {
380 if let Some(mut deltas) = msg_rx.recv().await {
381 let header = Header::from_block(deltas.get_block(), deltas.is_revert());
382 debug!(block_number=?header.number, "Received delta message");
383 let (snapshots, removed_components) = {
384 let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
387
388 let requiring_snapshot: Vec<_> = to_add
390 .iter()
391 .filter(|id| {
392 !tracker
393 .components
394 .contains_key(id.as_str())
395 })
396 .collect();
397 debug!(components=?requiring_snapshot, "SnapshotRequest");
398 tracker
399 .start_tracking(requiring_snapshot.as_slice())
400 .await?;
401 let snapshots = self
402 .get_snapshots(header.clone(), &tracker, Some(requiring_snapshot))
403 .await?
404 .snapshots;
405
406 let removed_components = if !to_remove.is_empty() {
407 tracker.stop_tracking(&to_remove)
408 } else {
409 Default::default()
410 };
411 (snapshots, removed_components)
412 };
413
414 self.filter_deltas(&mut deltas, &tracker);
416 let n_changes = deltas.n_changes();
417
418 let next = StateSyncMessage {
419 header: header.clone(),
420 snapshots,
421 deltas: Some(deltas),
422 removed_components,
423 };
424 block_tx.send(next).await?;
425 {
426 let mut shared = self.shared.lock().await;
427 shared.last_synced_block = Some(header.clone());
428 }
429
430 debug!(block_number=?header.number, n_changes, "Finished processing delta message");
431 } else {
432 let mut shared = self.shared.lock().await;
433 warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
434 shared.last_synced_block = None;
435
436 return Err(anyhow::format_err!("Deltas channel closed!"));
437 }
438 }
439 }
440
441 fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
442 second_msg.filter_by_component(|id| tracker.components.contains_key(id));
443 second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
444 }
445}
446
447#[async_trait]
448impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
449where
450 R: RPCClient + Clone + Send + Sync + 'static,
451 D: DeltasClient + Clone + Send + Sync + 'static,
452{
453 async fn initialize(&self) -> SyncResult<()> {
454 let mut tracker = self.component_tracker.lock().await;
455 info!("Retrieving relevant protocol components");
456 tracker.initialise_components().await?;
457 info!(
458 n_components = tracker.components.len(),
459 n_contracts = tracker.contracts.len(),
460 "Finished retrieving components",
461 );
462
463 Ok(())
464 }
465 async fn start(&self) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage>)> {
466 let (mut tx, rx) = channel(15);
467
468 let this = self.clone();
469 let jh = tokio::spawn(async move {
470 let mut retry_count = 0;
471 while retry_count < this.max_retries {
472 info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
473 let (end_tx, end_rx) = oneshot::channel::<()>();
474 {
475 let mut end_tx_guard = this.end_tx.lock().await;
476 *end_tx_guard = Some(end_tx);
477 }
478
479 select! {
480 res = this.clone().state_sync(&mut tx) => {
481 match res
482 {
483 Err(e) => {
484 error!(
485 extractor_id=%&this.extractor_id,
486 retry_count,
487 error=%e,
488 "State synchronization errored!"
489 );
490 }
491 _ => {
492 warn!(
493 extractor_id=%&this.extractor_id,
494 retry_count,
495 "State sync exited with Ok(())"
496 );
497 }
498 }
499 },
500 _ = end_rx => {
501 info!(
502 extractor_id=%&this.extractor_id,
503 retry_count,
504 "StateSynchronizer received close signal. Stopping"
505 );
506 return Ok(())
507 }
508 }
509 retry_count += 1;
510 }
511 Err(anyhow::format_err!("Max retries exceeded giving up"))
512 });
513
514 Ok((jh, rx))
515 }
516
517 async fn close(&mut self) -> SyncResult<()> {
518 let mut end_tx = self.end_tx.lock().await;
519 if let Some(tx) = end_tx.take() {
520 let _ = tx.send(());
521 Ok(())
522 } else {
523 Err(anyhow::format_err!("Not started"))
524 }
525 }
526}
527
528#[cfg(test)]
529mod test {
530 use test_log::test;
531 use tycho_common::dto::{
532 Block, Chain, PaginationResponse, ProtocolComponentRequestResponse,
533 ProtocolComponentsRequestBody, ProtocolStateRequestBody, ProtocolStateRequestResponse,
534 ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse, StateRequestBody,
535 StateRequestResponse, TokensRequestBody, TokensRequestResponse,
536 };
537 use uuid::Uuid;
538
539 use super::*;
540 use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
541
542 struct ArcRPCClient<T>(Arc<T>);
544
545 impl<T> Clone for ArcRPCClient<T> {
547 fn clone(&self) -> Self {
548 ArcRPCClient(self.0.clone())
549 }
550 }
551
552 #[async_trait]
553 impl<T> RPCClient for ArcRPCClient<T>
554 where
555 T: RPCClient + Sync + Send + 'static,
556 {
557 async fn get_tokens(
558 &self,
559 request: &TokensRequestBody,
560 ) -> Result<TokensRequestResponse, RPCError> {
561 self.0.get_tokens(request).await
562 }
563
564 async fn get_contract_state(
565 &self,
566 request: &StateRequestBody,
567 ) -> Result<StateRequestResponse, RPCError> {
568 self.0.get_contract_state(request).await
569 }
570
571 async fn get_protocol_components(
572 &self,
573 request: &ProtocolComponentsRequestBody,
574 ) -> Result<ProtocolComponentRequestResponse, RPCError> {
575 self.0
576 .get_protocol_components(request)
577 .await
578 }
579
580 async fn get_protocol_states(
581 &self,
582 request: &ProtocolStateRequestBody,
583 ) -> Result<ProtocolStateRequestResponse, RPCError> {
584 self.0
585 .get_protocol_states(request)
586 .await
587 }
588
589 async fn get_protocol_systems(
590 &self,
591 request: &ProtocolSystemsRequestBody,
592 ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
593 self.0
594 .get_protocol_systems(request)
595 .await
596 }
597 }
598
599 struct ArcDeltasClient<T>(Arc<T>);
601
602 impl<T> Clone for ArcDeltasClient<T> {
604 fn clone(&self) -> Self {
605 ArcDeltasClient(self.0.clone())
606 }
607 }
608
609 #[async_trait]
610 impl<T> DeltasClient for ArcDeltasClient<T>
611 where
612 T: DeltasClient + Sync + Send + 'static,
613 {
614 async fn subscribe(
615 &self,
616 extractor_id: ExtractorIdentity,
617 options: SubscriptionOptions,
618 ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
619 self.0
620 .subscribe(extractor_id, options)
621 .await
622 }
623
624 async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
625 self.0
626 .unsubscribe(subscription_id)
627 .await
628 }
629
630 async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
631 self.0.connect().await
632 }
633
634 async fn close(&self) -> Result<(), DeltasError> {
635 self.0.close().await
636 }
637 }
638
639 fn with_mocked_clients(
640 native: bool,
641 rpc_client: Option<MockRPCClient>,
642 deltas_client: Option<MockDeltasClient>,
643 ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
644 {
645 let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
646 let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
647
648 ProtocolStateSynchronizer::new(
649 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
650 native,
651 ComponentFilter::with_tvl_range(50.0, 50.0),
652 1,
653 true,
654 rpc_client,
655 deltas_client,
656 10_u64,
657 )
658 }
659
660 fn state_snapshot_native() -> ProtocolStateRequestResponse {
661 ProtocolStateRequestResponse {
662 states: vec![ResponseProtocolState {
663 component_id: "Component1".to_string(),
664 ..Default::default()
665 }],
666 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
667 }
668 }
669
670 #[test(tokio::test)]
671 async fn test_get_snapshots_native() {
672 let header = Header::default();
673 let mut rpc = MockRPCClient::new();
674 rpc.expect_get_protocol_states()
675 .returning(|_| Ok(state_snapshot_native()));
676 let state_sync = with_mocked_clients(true, Some(rpc), None);
677 let mut tracker = ComponentTracker::new(
678 Chain::Ethereum,
679 "uniswap-v2",
680 ComponentFilter::with_tvl_range(0.0, 0.0),
681 state_sync.rpc_client.clone(),
682 );
683 let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
684 tracker
685 .components
686 .insert("Component1".to_string(), component.clone());
687 let components_arg = ["Component1".to_string()];
688 let exp = StateSyncMessage {
689 header: header.clone(),
690 snapshots: Snapshot {
691 states: state_snapshot_native()
692 .states
693 .into_iter()
694 .map(|state| {
695 (
696 state.component_id.clone(),
697 ComponentWithState { state, component: component.clone() },
698 )
699 })
700 .collect(),
701 vm_storage: HashMap::new(),
702 },
703 deltas: None,
704 removed_components: Default::default(),
705 };
706
707 let snap = state_sync
708 .get_snapshots(header, &tracker, Some(&components_arg))
709 .await
710 .expect("Retrieving snapshot failed");
711
712 assert_eq!(snap, exp);
713 }
714
715 fn state_snapshot_vm() -> StateRequestResponse {
716 StateRequestResponse {
717 accounts: vec![
718 ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
719 ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
720 ],
721 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
722 }
723 }
724
725 #[test(tokio::test)]
726 async fn test_get_snapshots_vm() {
727 let header = Header::default();
728 let mut rpc = MockRPCClient::new();
729 rpc.expect_get_protocol_states()
730 .returning(|_| Ok(state_snapshot_native()));
731 rpc.expect_get_contract_state()
732 .returning(|_| Ok(state_snapshot_vm()));
733 let state_sync = with_mocked_clients(false, Some(rpc), None);
734 let mut tracker = ComponentTracker::new(
735 Chain::Ethereum,
736 "uniswap-v2",
737 ComponentFilter::with_tvl_range(0.0, 0.0),
738 state_sync.rpc_client.clone(),
739 );
740 let component = ProtocolComponent {
741 id: "Component1".to_string(),
742 contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
743 ..Default::default()
744 };
745 tracker
746 .components
747 .insert("Component1".to_string(), component.clone());
748 let components_arg = ["Component1".to_string()];
749 let exp = StateSyncMessage {
750 header: header.clone(),
751 snapshots: Snapshot {
752 states: [(
753 component.id.clone(),
754 ComponentWithState {
755 state: ResponseProtocolState {
756 component_id: "Component1".to_string(),
757 ..Default::default()
758 },
759 component: component.clone(),
760 },
761 )]
762 .into_iter()
763 .collect(),
764 vm_storage: state_snapshot_vm()
765 .accounts
766 .into_iter()
767 .map(|state| (state.address.clone(), state))
768 .collect(),
769 },
770 deltas: None,
771 removed_components: Default::default(),
772 };
773
774 let snap = state_sync
775 .get_snapshots(header, &tracker, Some(&components_arg))
776 .await
777 .expect("Retrieving snapshot failed");
778
779 assert_eq!(snap, exp);
780 }
781
782 fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
783 let mut rpc_client = MockRPCClient::new();
784 rpc_client
787 .expect_get_protocol_components()
788 .with(mockall::predicate::function(
789 move |request_params: &ProtocolComponentsRequestBody| {
790 if let Some(ids) = request_params.component_ids.as_ref() {
791 ids.contains(&"Component3".to_string())
792 } else {
793 false
794 }
795 },
796 ))
797 .returning(|_| {
798 Ok(ProtocolComponentRequestResponse {
800 protocol_components: vec![
801 ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
803 ],
804 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
805 })
806 });
807 rpc_client
808 .expect_get_protocol_states()
809 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
810 let expected_id = "Component3".to_string();
811 if let Some(ids) = request_params.protocol_ids.as_ref() {
812 ids.contains(&expected_id)
813 } else {
814 false
815 }
816 }))
817 .returning(|_| {
818 Ok(ProtocolStateRequestResponse {
820 states: vec![ResponseProtocolState {
821 component_id: "Component3".to_string(),
822 ..Default::default()
823 }],
824 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
825 })
826 });
827
828 rpc_client
830 .expect_get_protocol_components()
831 .returning(|_| {
832 Ok(ProtocolComponentRequestResponse {
834 protocol_components: vec![
835 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
837 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
839 ],
841 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
842 })
843 });
844 rpc_client
845 .expect_get_protocol_states()
846 .returning(|_| {
847 Ok(ProtocolStateRequestResponse {
849 states: vec![
850 ResponseProtocolState {
851 component_id: "Component1".to_string(),
852 ..Default::default()
853 },
854 ResponseProtocolState {
855 component_id: "Component2".to_string(),
856 ..Default::default()
857 },
858 ],
859 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
860 })
861 });
862 let mut deltas_client = MockDeltasClient::new();
864 let (tx, rx) = channel(1);
865 deltas_client
866 .expect_subscribe()
867 .return_once(move |_, _| {
868 Ok((Uuid::default(), rx))
870 });
871 (rpc_client, deltas_client, tx)
872 }
873
874 #[test(tokio::test)]
881 async fn test_state_sync() {
882 let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
883 let deltas = [
884 BlockChanges {
885 extractor: "uniswap-v2".to_string(),
886 chain: Chain::Ethereum,
887 block: Block {
888 number: 1,
889 hash: Bytes::from("0x01"),
890 parent_hash: Bytes::from("0x00"),
891 chain: Chain::Ethereum,
892 ts: Default::default(),
893 },
894 revert: false,
895 ..Default::default()
896 },
897 BlockChanges {
898 extractor: "uniswap-v2".to_string(),
899 chain: Chain::Ethereum,
900 block: Block {
901 number: 2,
902 hash: Bytes::from("0x02"),
903 parent_hash: Bytes::from("0x01"),
904 chain: Chain::Ethereum,
905 ts: Default::default(),
906 },
907 revert: false,
908 component_tvl: [
909 ("Component1".to_string(), 100.0),
910 ("Component2".to_string(), 0.0),
911 ("Component3".to_string(), 1000.0),
912 ]
913 .into_iter()
914 .collect(),
915 ..Default::default()
916 },
917 ];
918 let mut state_sync = with_mocked_clients(true, Some(rpc_client), Some(deltas_client));
919 state_sync
920 .initialize()
921 .await
922 .expect("Init failed");
923
924 let (jh, mut rx) = state_sync
926 .start()
927 .await
928 .expect("Failed to start state synchronizer");
929 tx.send(deltas[0].clone())
930 .await
931 .expect("deltas channel msg 0 closed!");
932 let first_msg = timeout(Duration::from_millis(100), rx.recv())
933 .await
934 .expect("waiting for first state msg timed out!")
935 .expect("state sync block sender closed!");
936 tx.send(deltas[1].clone())
937 .await
938 .expect("deltas channel msg 1 closed!");
939 let second_msg = timeout(Duration::from_millis(100), rx.recv())
940 .await
941 .expect("waiting for second state msg timed out!")
942 .expect("state sync block sender closed!");
943 let _ = state_sync.close().await;
944 let exit = jh
945 .await
946 .expect("state sync task panicked!");
947
948 let exp1 = StateSyncMessage {
950 header: Header {
951 number: 1,
952 hash: Bytes::from("0x01"),
953 parent_hash: Bytes::from("0x00"),
954 revert: false,
955 },
956 snapshots: Snapshot {
957 states: [
958 (
959 "Component1".to_string(),
960 ComponentWithState {
961 state: ResponseProtocolState {
962 component_id: "Component1".to_string(),
963 ..Default::default()
964 },
965 component: ProtocolComponent {
966 id: "Component1".to_string(),
967 ..Default::default()
968 },
969 },
970 ),
971 (
972 "Component2".to_string(),
973 ComponentWithState {
974 state: ResponseProtocolState {
975 component_id: "Component2".to_string(),
976 ..Default::default()
977 },
978 component: ProtocolComponent {
979 id: "Component2".to_string(),
980 ..Default::default()
981 },
982 },
983 ),
984 ]
985 .into_iter()
986 .collect(),
987 vm_storage: HashMap::new(),
988 },
989 deltas: Some(deltas[0].clone()),
990 removed_components: Default::default(),
991 };
992
993 let exp2 = StateSyncMessage {
994 header: Header {
995 number: 2,
996 hash: Bytes::from("0x02"),
997 parent_hash: Bytes::from("0x01"),
998 revert: false,
999 },
1000 snapshots: Snapshot {
1001 states: [
1002 (
1004 "Component3".to_string(),
1005 ComponentWithState {
1006 state: ResponseProtocolState {
1007 component_id: "Component3".to_string(),
1008 ..Default::default()
1009 },
1010 component: ProtocolComponent {
1011 id: "Component3".to_string(),
1012 ..Default::default()
1013 },
1014 },
1015 ),
1016 ]
1017 .into_iter()
1018 .collect(),
1019 vm_storage: HashMap::new(),
1020 },
1021 deltas: Some(BlockChanges {
1024 extractor: "uniswap-v2".to_string(),
1025 chain: Chain::Ethereum,
1026 block: Block {
1027 number: 2,
1028 hash: Bytes::from("0x02"),
1029 parent_hash: Bytes::from("0x01"),
1030 chain: Chain::Ethereum,
1031 ts: Default::default(),
1032 },
1033 revert: false,
1034 component_tvl: [
1035 ("Component1".to_string(), 100.0),
1037 ("Component3".to_string(), 1000.0),
1038 ]
1039 .into_iter()
1040 .collect(),
1041 ..Default::default()
1042 }),
1043 removed_components: [(
1045 "Component2".to_string(),
1046 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1047 )]
1048 .into_iter()
1049 .collect(),
1050 };
1051 assert_eq!(first_msg, exp1);
1052 assert_eq!(second_msg, exp2);
1053 assert!(exit.is_ok());
1054 }
1055
1056 #[test(tokio::test)]
1057 async fn test_state_sync_with_tvl_range() {
1058 let remove_tvl_threshold = 5.0;
1060 let add_tvl_threshold = 7.0;
1061
1062 let mut rpc_client = MockRPCClient::new();
1063 let mut deltas_client = MockDeltasClient::new();
1064
1065 rpc_client
1066 .expect_get_protocol_components()
1067 .with(mockall::predicate::function(
1068 move |request_params: &ProtocolComponentsRequestBody| {
1069 if let Some(ids) = request_params.component_ids.as_ref() {
1070 ids.contains(&"Component3".to_string())
1071 } else {
1072 false
1073 }
1074 },
1075 ))
1076 .returning(|_| {
1077 Ok(ProtocolComponentRequestResponse {
1078 protocol_components: vec![ProtocolComponent {
1079 id: "Component3".to_string(),
1080 ..Default::default()
1081 }],
1082 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1083 })
1084 });
1085
1086 rpc_client
1087 .expect_get_protocol_states()
1088 .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1089 let expected_id = "Component3".to_string();
1090 if let Some(ids) = request_params.protocol_ids.as_ref() {
1091 ids.contains(&expected_id)
1092 } else {
1093 false
1094 }
1095 }))
1096 .returning(|_| {
1097 Ok(ProtocolStateRequestResponse {
1098 states: vec![ResponseProtocolState {
1099 component_id: "Component3".to_string(),
1100 ..Default::default()
1101 }],
1102 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1103 })
1104 });
1105
1106 rpc_client
1108 .expect_get_protocol_components()
1109 .returning(|_| {
1110 Ok(ProtocolComponentRequestResponse {
1111 protocol_components: vec![
1112 ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1113 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1114 ],
1115 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1116 })
1117 });
1118
1119 rpc_client
1120 .expect_get_protocol_states()
1121 .returning(|_| {
1122 Ok(ProtocolStateRequestResponse {
1123 states: vec![
1124 ResponseProtocolState {
1125 component_id: "Component1".to_string(),
1126 ..Default::default()
1127 },
1128 ResponseProtocolState {
1129 component_id: "Component2".to_string(),
1130 ..Default::default()
1131 },
1132 ],
1133 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1134 })
1135 });
1136
1137 let (tx, rx) = channel(1);
1138 deltas_client
1139 .expect_subscribe()
1140 .return_once(move |_, _| Ok((Uuid::default(), rx)));
1141
1142 let mut state_sync = ProtocolStateSynchronizer::new(
1143 ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1144 true,
1145 ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1146 1,
1147 true,
1148 ArcRPCClient(Arc::new(rpc_client)),
1149 ArcDeltasClient(Arc::new(deltas_client)),
1150 10_u64,
1151 );
1152 state_sync
1153 .initialize()
1154 .await
1155 .expect("Init failed");
1156
1157 let deltas = [
1159 BlockChanges {
1160 extractor: "uniswap-v2".to_string(),
1161 chain: Chain::Ethereum,
1162 block: Block {
1163 number: 1,
1164 hash: Bytes::from("0x01"),
1165 parent_hash: Bytes::from("0x00"),
1166 chain: Chain::Ethereum,
1167 ts: Default::default(),
1168 },
1169 revert: false,
1170 ..Default::default()
1171 },
1172 BlockChanges {
1173 extractor: "uniswap-v2".to_string(),
1174 chain: Chain::Ethereum,
1175 block: Block {
1176 number: 2,
1177 hash: Bytes::from("0x02"),
1178 parent_hash: Bytes::from("0x01"),
1179 chain: Chain::Ethereum,
1180 ts: Default::default(),
1181 },
1182 revert: false,
1183 component_tvl: [
1184 ("Component1".to_string(), 6.0), ("Component2".to_string(), 2.0), ("Component3".to_string(), 10.0), ]
1188 .into_iter()
1189 .collect(),
1190 ..Default::default()
1191 },
1192 ];
1193
1194 let (jh, mut rx) = state_sync
1195 .start()
1196 .await
1197 .expect("Failed to start state synchronizer");
1198
1199 tx.send(deltas[0].clone())
1201 .await
1202 .expect("deltas channel msg 0 closed!");
1203
1204 let _ = timeout(Duration::from_millis(100), rx.recv())
1206 .await
1207 .expect("waiting for first state msg timed out!")
1208 .expect("state sync block sender closed!");
1209
1210 tx.send(deltas[1].clone())
1212 .await
1213 .expect("deltas channel msg 1 closed!");
1214 let second_msg = timeout(Duration::from_millis(100), rx.recv())
1215 .await
1216 .expect("waiting for second state msg timed out!")
1217 .expect("state sync block sender closed!");
1218
1219 let _ = state_sync.close().await;
1220 let exit = jh
1221 .await
1222 .expect("state sync task panicked!");
1223
1224 let expected_second_msg = StateSyncMessage {
1225 header: Header {
1226 number: 2,
1227 hash: Bytes::from("0x02"),
1228 parent_hash: Bytes::from("0x01"),
1229 revert: false,
1230 },
1231 snapshots: Snapshot {
1232 states: [(
1233 "Component3".to_string(),
1234 ComponentWithState {
1235 state: ResponseProtocolState {
1236 component_id: "Component3".to_string(),
1237 ..Default::default()
1238 },
1239 component: ProtocolComponent {
1240 id: "Component3".to_string(),
1241 ..Default::default()
1242 },
1243 },
1244 )]
1245 .into_iter()
1246 .collect(),
1247 vm_storage: HashMap::new(),
1248 },
1249 deltas: Some(BlockChanges {
1250 extractor: "uniswap-v2".to_string(),
1251 chain: Chain::Ethereum,
1252 block: Block {
1253 number: 2,
1254 hash: Bytes::from("0x02"),
1255 parent_hash: Bytes::from("0x01"),
1256 chain: Chain::Ethereum,
1257 ts: Default::default(),
1258 },
1259 revert: false,
1260 component_tvl: [
1261 ("Component1".to_string(), 6.0), ("Component3".to_string(), 10.0), ]
1264 .into_iter()
1265 .collect(),
1266 ..Default::default()
1267 }),
1268 removed_components: [(
1269 "Component2".to_string(),
1270 ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1271 )]
1272 .into_iter()
1273 .collect(),
1274 };
1275
1276 assert_eq!(second_msg, expected_second_msg);
1277 assert!(exit.is_ok());
1278 }
1279}