1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<ComponentId>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73 ComponentFilter {
74 variant: ComponentFilterVariant::Ids(
75 ids.into_iter()
76 .map(|id| id.to_lowercase())
77 .collect(),
78 ),
79 }
80 }
81}
82
83#[derive(Default)]
86struct EntrypointRelations {
87 components: HashSet<ComponentId>,
89 contracts: HashSet<Address>,
91}
92
93pub struct ComponentTracker<R: RPCClient> {
95 chain: Chain,
96 protocol_system: ProtocolSystem,
97 filter: ComponentFilter,
98 pub components: HashMap<ComponentId, ProtocolComponent>,
101 entrypoints: HashMap<String, EntrypointRelations>,
103 pub contracts: HashSet<Address>,
106 rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112 R: RPCClient,
113{
114 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115 Self {
116 chain,
117 protocol_system: protocol_system.to_string(),
118 filter,
119 components: Default::default(),
120 contracts: Default::default(),
121 rpc_client: rpc,
122 entrypoints: Default::default(),
123 }
124 }
125
126 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129 let body = match &self.filter.variant {
130 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131 &self.protocol_system,
132 ids.clone(),
133 self.chain,
134 ),
135 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136 ProtocolComponentsRequestBody::system_filtered(
137 &self.protocol_system,
138 Some(*upper_tvl_threshold),
139 self.chain,
140 )
141 }
142 };
143 self.components = self
144 .rpc_client
145 .get_protocol_components_paginated(&body, 500, 4)
146 .await?
147 .protocol_components
148 .into_iter()
149 .map(|pc| (pc.id.clone(), pc))
150 .collect::<HashMap<_, _>>();
151
152 self.reinitialize_contracts();
153
154 Ok(())
155 }
156
157 fn reinitialize_contracts(&mut self) {
159 self.contracts = self
161 .components
162 .values()
163 .flat_map(|comp| comp.contract_ids.iter().cloned())
164 .collect();
165
166 let tracked_component_ids = self
168 .components
169 .keys()
170 .cloned()
171 .collect::<HashSet<_>>();
172 for entrypoint in self.entrypoints.values() {
173 if !entrypoint
174 .components
175 .is_disjoint(&tracked_component_ids)
176 {
177 self.contracts
178 .extend(entrypoint.contracts.iter().cloned());
179 }
180 }
181 }
182
183 fn update_contracts(&mut self, components: Vec<ComponentId>) {
185 let tracked_component_ids = components
188 .into_iter()
189 .filter(|id| self.components.contains_key(id))
190 .collect::<HashSet<_>>();
191
192 for comp in &tracked_component_ids {
194 let component = self
195 .components
196 .get(comp)
197 .expect("Component should exist as it was filtered above");
198 self.contracts
199 .extend(component.contract_ids.iter().cloned());
200 }
201
202 for entrypoint in self.entrypoints.values() {
204 if !entrypoint
205 .components
206 .is_disjoint(&tracked_component_ids)
207 {
208 self.contracts
209 .extend(entrypoint.contracts.iter().cloned());
210 }
211 }
212 }
213
214 #[instrument(skip(self, new_components))]
216 pub async fn start_tracking(
217 &mut self,
218 new_components: &[&ComponentId],
219 ) -> Result<(), RPCError> {
220 if new_components.is_empty() {
221 return Ok(());
222 }
223
224 let request = ProtocolComponentsRequestBody::id_filtered(
226 &self.protocol_system,
227 new_components
228 .iter()
229 .map(|&id| id.to_string())
230 .collect(),
231 self.chain,
232 );
233 let components = self
234 .rpc_client
235 .get_protocol_components(&request)
236 .await?
237 .protocol_components
238 .into_iter()
239 .map(|pc| (pc.id.clone(), pc))
240 .collect::<HashMap<_, _>>();
241
242 let component_ids: Vec<_> = components.keys().cloned().collect();
244 let component_count = component_ids.len();
245 self.components.extend(components);
246 self.update_contracts(component_ids);
247
248 debug!(n_components = component_count, "StartedTracking");
249 Ok(())
250 }
251
252 #[instrument(skip(self, to_remove))]
254 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
255 &mut self,
256 to_remove: I,
257 ) -> HashMap<ComponentId, ProtocolComponent> {
258 let mut removed_components = HashMap::new();
259
260 for component_id in to_remove {
261 if let Some(component) = self.components.remove(component_id) {
262 removed_components.insert(component_id.clone(), component);
263 }
264 }
265
266 self.reinitialize_contracts();
269
270 debug!(n_components = removed_components.len(), "StoppedTracking");
271 removed_components
272 }
273
274 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
276 for (entrypoint, traces) in &dci_update.trace_results {
278 self.entrypoints
279 .entry(entrypoint.clone())
280 .or_default()
281 .contracts
282 .extend(traces.accessed_slots.keys().cloned());
283 }
284
285 for (component, entrypoints) in &dci_update.new_entrypoints {
287 for entrypoint in entrypoints {
288 let entrypoint_info = self
289 .entrypoints
290 .entry(entrypoint.external_id.clone())
291 .or_default();
292 entrypoint_info
293 .components
294 .insert(component.clone());
295 if self.components.contains_key(component) {
297 self.contracts.extend(
298 entrypoint_info
299 .contracts
300 .iter()
301 .cloned(),
302 );
303 }
304 }
305 }
306 }
307
308 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
319 &self,
320 ids: I,
321 ) -> HashSet<Address> {
322 ids.into_iter()
323 .filter_map(|cid| {
324 if let Some(comp) = self.components.get(cid) {
325 let dci_contracts: HashSet<Address> = self
327 .entrypoints
328 .values()
329 .filter(|ep| ep.components.contains(cid))
330 .flat_map(|ep| ep.contracts.iter().cloned())
331 .collect();
332 Some(
333 comp.contract_ids
334 .clone()
335 .into_iter()
336 .chain(dci_contracts)
337 .collect::<HashSet<_>>(),
338 )
339 } else {
340 warn!(
341 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
342 );
343 None
344 }
345 })
346 .flatten()
347 .collect()
348 }
349
350 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
351 self.components
352 .keys()
353 .cloned()
354 .collect()
355 }
356
357 pub fn filter_updated_components(
360 &self,
361 deltas: &BlockChanges,
362 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
363 match &self.filter.variant {
364 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
365 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
366 .component_tvl
367 .iter()
368 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
369 .map(|(id, _)| id.clone())
370 .partition(|id| deltas.component_tvl[id] > *add_tvl),
371 }
372 }
373}
374
375#[cfg(test)]
376mod test {
377 use tycho_common::{
378 dto::{PaginationResponse, ProtocolComponentRequestResponse},
379 Bytes,
380 };
381
382 use super::*;
383 use crate::rpc::MockRPCClient;
384
385 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
386 let rpc = MockRPCClient::new();
387 ComponentTracker::new(
388 Chain::Ethereum,
389 "uniswap-v2",
390 ComponentFilter::with_tvl_range(0.0, 0.0),
391 rpc,
392 )
393 }
394
395 fn components_response() -> (Vec<Address>, ProtocolComponent) {
396 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
397 let component = ProtocolComponent {
398 id: "Component1".to_string(),
399 contract_ids: contract_ids.clone(),
400 ..Default::default()
401 };
402 (contract_ids, component)
403 }
404
405 #[tokio::test]
406 async fn test_initialise_components() {
407 let mut tracker = with_mocked_rpc();
408 let (contract_ids, component) = components_response();
409 let exp_component = component.clone();
410 tracker
411 .rpc_client
412 .expect_get_protocol_components_paginated()
413 .returning(move |_, _, _| {
414 Ok(ProtocolComponentRequestResponse {
415 protocol_components: vec![component.clone()],
416 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
417 })
418 });
419
420 tracker
421 .initialise_components()
422 .await
423 .expect("Retrieving components failed");
424
425 assert_eq!(
426 tracker
427 .components
428 .get("Component1")
429 .expect("Component1 not tracked"),
430 &exp_component
431 );
432 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
433 }
434
435 #[tokio::test]
436 async fn test_start_tracking() {
437 let mut tracker = with_mocked_rpc();
438 let (contract_ids, component) = components_response();
439 let exp_contracts = contract_ids.into_iter().collect();
440 let component_id = component.id.clone();
441 let components_arg = [&component_id];
442 tracker
443 .rpc_client
444 .expect_get_protocol_components()
445 .returning(move |_| {
446 Ok(ProtocolComponentRequestResponse {
447 protocol_components: vec![component.clone()],
448 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
449 })
450 });
451
452 tracker
453 .start_tracking(&components_arg)
454 .await
455 .expect("Tracking components failed");
456
457 assert_eq!(&tracker.contracts, &exp_contracts);
458 assert!(tracker
459 .components
460 .contains_key("Component1"));
461 }
462
463 #[test]
464 fn test_stop_tracking() {
465 let mut tracker = with_mocked_rpc();
466 let (contract_ids, component) = components_response();
467 tracker
468 .components
469 .insert("Component1".to_string(), component.clone());
470 tracker.contracts.extend(contract_ids);
471 let components_arg = ["Component1".to_string(), "Component2".to_string()];
472 let exp = [("Component1".to_string(), component)]
473 .into_iter()
474 .collect();
475
476 let res = tracker.stop_tracking(&components_arg);
477
478 assert_eq!(res, exp);
479 assert!(tracker.contracts.is_empty());
480 }
481
482 #[test]
483 fn test_get_contracts_by_component() {
484 let mut tracker = with_mocked_rpc();
485 let (exp_contracts, component) = components_response();
486 tracker
487 .components
488 .insert("Component1".to_string(), component);
489 let components_arg = ["Component1".to_string()];
490
491 let res = tracker.get_contracts_by_component(&components_arg);
492
493 assert_eq!(res, exp_contracts.into_iter().collect());
494 }
495
496 #[test]
497 fn test_get_tracked_component_ids() {
498 let mut tracker = with_mocked_rpc();
499 let (_, component) = components_response();
500 tracker
501 .components
502 .insert("Component1".to_string(), component);
503 let exp = vec!["Component1".to_string()];
504
505 let res = tracker.get_tracked_component_ids();
506
507 assert_eq!(res, exp);
508 }
509}