1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
6 models::{Address, ComponentId, ProtocolSystem},
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<ComponentId>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
73 ComponentFilter {
74 variant: ComponentFilterVariant::Ids(
75 ids.into_iter()
76 .map(|id| id.to_lowercase())
77 .collect(),
78 ),
79 }
80 }
81}
82
83#[derive(Default)]
86pub struct EntrypointRelations {
87 components: HashSet<ComponentId>,
89 contracts: HashSet<Address>,
91}
92
93pub struct ComponentTracker<R: RPCClient> {
95 chain: Chain,
96 protocol_system: ProtocolSystem,
97 filter: ComponentFilter,
98 pub components: HashMap<ComponentId, ProtocolComponent>,
101 entrypoints: HashMap<String, EntrypointRelations>,
103 pub contracts: HashSet<Address>,
106 rpc_client: R,
108}
109
110impl<R> ComponentTracker<R>
111where
112 R: RPCClient,
113{
114 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
115 Self {
116 chain,
117 protocol_system: protocol_system.to_string(),
118 filter,
119 components: Default::default(),
120 contracts: Default::default(),
121 rpc_client: rpc,
122 entrypoints: Default::default(),
123 }
124 }
125
126 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
129 let body = match &self.filter.variant {
130 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
131 &self.protocol_system,
132 ids.clone(),
133 self.chain,
134 ),
135 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
136 ProtocolComponentsRequestBody::system_filtered(
137 &self.protocol_system,
138 Some(*upper_tvl_threshold),
139 self.chain,
140 )
141 }
142 };
143 self.components = self
144 .rpc_client
145 .get_protocol_components_paginated(&body, 500, 4)
146 .await?
147 .protocol_components
148 .into_iter()
149 .map(|pc| (pc.id.clone(), pc))
150 .collect::<HashMap<_, _>>();
151
152 self.reinitialize_contracts();
153
154 Ok(())
155 }
156
157 fn reinitialize_contracts(&mut self) {
159 self.contracts = self
161 .components
162 .values()
163 .flat_map(|comp| comp.contract_ids.iter().cloned())
164 .collect();
165
166 let tracked_component_ids = self
168 .components
169 .keys()
170 .cloned()
171 .collect::<HashSet<_>>();
172 for entrypoint in self.entrypoints.values() {
173 if !entrypoint
174 .components
175 .is_disjoint(&tracked_component_ids)
176 {
177 self.contracts
178 .extend(entrypoint.contracts.iter().cloned());
179 }
180 }
181 }
182
183 fn update_contracts(&mut self, components: Vec<ComponentId>) {
185 let mut tracked_component_ids = HashSet::new();
187
188 for comp in components {
190 if let Some(component) = self.components.get(&comp) {
191 self.contracts
192 .extend(component.contract_ids.iter().cloned());
193 tracked_component_ids.insert(comp);
194 }
195 }
196
197 for entrypoint in self.entrypoints.values() {
199 if !entrypoint
200 .components
201 .is_disjoint(&tracked_component_ids)
202 {
203 self.contracts
204 .extend(entrypoint.contracts.iter().cloned());
205 }
206 }
207 }
208
209 #[instrument(skip(self, new_components))]
211 pub async fn start_tracking(
212 &mut self,
213 new_components: &[&ComponentId],
214 ) -> Result<(), RPCError> {
215 if new_components.is_empty() {
216 return Ok(());
217 }
218
219 let request = ProtocolComponentsRequestBody::id_filtered(
221 &self.protocol_system,
222 new_components
223 .iter()
224 .map(|&id| id.to_string())
225 .collect(),
226 self.chain,
227 );
228 let components = self
229 .rpc_client
230 .get_protocol_components(&request)
231 .await?
232 .protocol_components
233 .into_iter()
234 .map(|pc| (pc.id.clone(), pc))
235 .collect::<HashMap<_, _>>();
236
237 let component_ids: Vec<_> = components.keys().cloned().collect();
239 let component_count = component_ids.len();
240 self.components.extend(components);
241 self.update_contracts(component_ids);
242
243 debug!(n_components = component_count, "StartedTracking");
244 Ok(())
245 }
246
247 #[instrument(skip(self, to_remove))]
249 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
250 &mut self,
251 to_remove: I,
252 ) -> HashMap<ComponentId, ProtocolComponent> {
253 let mut removed_components = HashMap::new();
254
255 for component_id in to_remove {
256 if let Some(component) = self.components.remove(component_id) {
257 removed_components.insert(component_id.clone(), component);
258 }
259 }
260
261 self.reinitialize_contracts();
264
265 debug!(n_components = removed_components.len(), "StoppedTracking");
266 removed_components
267 }
268
269 pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
271 for (entrypoint, traces) in &dci_update.trace_results {
273 self.entrypoints
274 .entry(entrypoint.clone())
275 .or_default()
276 .contracts
277 .extend(traces.accessed_slots.keys().cloned());
278 }
279
280 for (component, entrypoints) in &dci_update.new_entrypoints {
282 for entrypoint in entrypoints {
283 let entrypoint_info = self
284 .entrypoints
285 .entry(entrypoint.external_id.clone())
286 .or_default();
287 entrypoint_info
288 .components
289 .insert(component.clone());
290 if self.components.contains_key(component) {
292 self.contracts.extend(
293 entrypoint_info
294 .contracts
295 .iter()
296 .cloned(),
297 );
298 }
299 }
300 }
301 }
302
303 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
314 &self,
315 ids: I,
316 ) -> HashSet<Address> {
317 ids.into_iter()
318 .filter_map(|cid| {
319 if let Some(comp) = self.components.get(cid) {
320 let dci_contracts: HashSet<Address> = self
322 .entrypoints
323 .values()
324 .filter(|ep| ep.components.contains(cid))
325 .flat_map(|ep| ep.contracts.iter().cloned())
326 .collect();
327 Some(
328 comp.contract_ids
329 .clone()
330 .into_iter()
331 .chain(dci_contracts)
332 .collect::<HashSet<_>>(),
333 )
334 } else {
335 warn!(
336 "Requested component is not tracked: {cid}. Skipping fetching contracts..."
337 );
338 None
339 }
340 })
341 .flatten()
342 .collect()
343 }
344
345 pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
346 self.components
347 .keys()
348 .cloned()
349 .collect()
350 }
351
352 pub fn filter_updated_components(
355 &self,
356 deltas: &BlockChanges,
357 ) -> (Vec<ComponentId>, Vec<ComponentId>) {
358 match &self.filter.variant {
359 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
360 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
361 .component_tvl
362 .iter()
363 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
364 .map(|(id, _)| id.clone())
365 .partition(|id| deltas.component_tvl[id] > *add_tvl),
366 }
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use tycho_common::{
373 dto::{PaginationResponse, ProtocolComponentRequestResponse},
374 Bytes,
375 };
376
377 use super::*;
378 use crate::rpc::MockRPCClient;
379
380 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
381 let rpc = MockRPCClient::new();
382 ComponentTracker::new(
383 Chain::Ethereum,
384 "uniswap-v2",
385 ComponentFilter::with_tvl_range(0.0, 0.0),
386 rpc,
387 )
388 }
389
390 fn components_response() -> (Vec<Address>, ProtocolComponent) {
391 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
392 let component = ProtocolComponent {
393 id: "Component1".to_string(),
394 contract_ids: contract_ids.clone(),
395 ..Default::default()
396 };
397 (contract_ids, component)
398 }
399
400 #[tokio::test]
401 async fn test_initialise_components() {
402 let mut tracker = with_mocked_rpc();
403 let (contract_ids, component) = components_response();
404 let exp_component = component.clone();
405 tracker
406 .rpc_client
407 .expect_get_protocol_components_paginated()
408 .returning(move |_, _, _| {
409 Ok(ProtocolComponentRequestResponse {
410 protocol_components: vec![component.clone()],
411 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
412 })
413 });
414
415 tracker
416 .initialise_components()
417 .await
418 .expect("Retrieving components failed");
419
420 assert_eq!(
421 tracker
422 .components
423 .get("Component1")
424 .expect("Component1 not tracked"),
425 &exp_component
426 );
427 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
428 }
429
430 #[tokio::test]
431 async fn test_start_tracking() {
432 let mut tracker = with_mocked_rpc();
433 let (contract_ids, component) = components_response();
434 let exp_contracts = contract_ids.into_iter().collect();
435 let component_id = component.id.clone();
436 let components_arg = [&component_id];
437 tracker
438 .rpc_client
439 .expect_get_protocol_components()
440 .returning(move |_| {
441 Ok(ProtocolComponentRequestResponse {
442 protocol_components: vec![component.clone()],
443 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
444 })
445 });
446
447 tracker
448 .start_tracking(&components_arg)
449 .await
450 .expect("Tracking components failed");
451
452 assert_eq!(&tracker.contracts, &exp_contracts);
453 assert!(tracker
454 .components
455 .contains_key("Component1"));
456 }
457
458 #[test]
459 fn test_stop_tracking() {
460 let mut tracker = with_mocked_rpc();
461 let (contract_ids, component) = components_response();
462 tracker
463 .components
464 .insert("Component1".to_string(), component.clone());
465 tracker.contracts.extend(contract_ids);
466 let components_arg = ["Component1".to_string(), "Component2".to_string()];
467 let exp = [("Component1".to_string(), component)]
468 .into_iter()
469 .collect();
470
471 let res = tracker.stop_tracking(&components_arg);
472
473 assert_eq!(res, exp);
474 assert!(tracker.contracts.is_empty());
475 }
476
477 #[test]
478 fn test_get_contracts_by_component() {
479 let mut tracker = with_mocked_rpc();
480 let (exp_contracts, component) = components_response();
481 tracker
482 .components
483 .insert("Component1".to_string(), component);
484 let components_arg = ["Component1".to_string()];
485
486 let res = tracker.get_contracts_by_component(&components_arg);
487
488 assert_eq!(res, exp_contracts.into_iter().collect());
489 }
490
491 #[test]
492 fn test_get_tracked_component_ids() {
493 let mut tracker = with_mocked_rpc();
494 let (_, component) = components_response();
495 tracker
496 .components
497 .insert("Component1".to_string(), component);
498 let exp = vec!["Component1".to_string()];
499
500 let res = tracker.get_tracked_component_ids();
501
502 assert_eq!(res, exp);
503 }
504}