1use std::collections::{HashMap, HashSet};
2
3use tracing::{debug, instrument, warn};
4use tycho_common::{
5 dto::{BlockChanges, Chain, ProtocolComponent, ProtocolComponentsRequestBody},
6 Bytes,
7};
8
9use crate::{rpc::RPCClient, RPCError};
10
11#[derive(Clone, Debug)]
12pub(crate) enum ComponentFilterVariant {
13 Ids(Vec<String>),
14 MinimumTVLRange((f64, f64)),
20}
21
22#[derive(Clone, Debug)]
23pub struct ComponentFilter {
24 variant: ComponentFilterVariant,
25}
26
27impl ComponentFilter {
28 #[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
37 pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
38 ComponentFilter { variant: ComponentFilterVariant::MinimumTVLRange((min_tvl, min_tvl)) }
39 }
40
41 pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
56 ComponentFilter {
57 variant: ComponentFilterVariant::MinimumTVLRange((
58 remove_tvl_threshold,
59 add_tvl_threshold,
60 )),
61 }
62 }
63
64 #[allow(non_snake_case)] pub fn Ids(ids: Vec<String>) -> ComponentFilter {
73 ComponentFilter { variant: ComponentFilterVariant::Ids(ids) }
74 }
75}
76
77pub struct ComponentTracker<R: RPCClient> {
79 chain: Chain,
80 protocol_system: String,
81 filter: ComponentFilter,
82 pub components: HashMap<String, ProtocolComponent>,
85 pub contracts: HashSet<Bytes>,
88 rpc_client: R,
90}
91
92impl<R> ComponentTracker<R>
93where
94 R: RPCClient,
95{
96 pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
97 Self {
98 chain,
99 protocol_system: protocol_system.to_string(),
100 filter,
101 components: Default::default(),
102 contracts: Default::default(),
103 rpc_client: rpc,
104 }
105 }
106 pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
108 let body = match &self.filter.variant {
109 ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
110 &self.protocol_system,
111 ids.clone(),
112 self.chain,
113 ),
114 ComponentFilterVariant::MinimumTVLRange((_, upper_tvl_threshold)) => {
115 ProtocolComponentsRequestBody::system_filtered(
116 &self.protocol_system,
117 Some(*upper_tvl_threshold),
118 self.chain,
119 )
120 }
121 };
122
123 self.components = self
124 .rpc_client
125 .get_protocol_components_paginated(&body, 500, 4)
126 .await?
127 .protocol_components
128 .into_iter()
129 .map(|pc| (pc.id.clone(), pc))
130 .collect::<HashMap<_, _>>();
131 self.update_contracts();
132 Ok(())
133 }
134
135 fn update_contracts(&mut self) {
136 self.contracts.extend(
137 self.components
138 .values()
139 .flat_map(|comp| comp.contract_ids.iter().cloned()),
140 );
141 }
142
143 #[instrument(skip(self, new_components))]
145 pub async fn start_tracking(&mut self, new_components: &[&String]) -> Result<(), RPCError> {
146 if new_components.is_empty() {
147 return Ok(());
148 }
149 let request = ProtocolComponentsRequestBody::id_filtered(
150 &self.protocol_system,
151 new_components
152 .iter()
153 .map(|pc_id| pc_id.to_string())
154 .collect(),
155 self.chain,
156 );
157
158 self.components.extend(
159 self.rpc_client
160 .get_protocol_components(&request)
161 .await?
162 .protocol_components
163 .into_iter()
164 .map(|pc| (pc.id.clone(), pc)),
165 );
166 self.update_contracts();
167 debug!(n_components = new_components.len(), "StartedTracking");
168 Ok(())
169 }
170
171 #[instrument(skip(self, to_remove))]
173 pub fn stop_tracking<'a, I: IntoIterator<Item = &'a String> + std::fmt::Debug>(
174 &mut self,
175 to_remove: I,
176 ) -> HashMap<String, ProtocolComponent> {
177 let mut n_components = 0;
178 let res = to_remove
179 .into_iter()
180 .filter_map(|k| {
181 let comp = self.components.remove(k);
182 if let Some(component) = &comp {
183 n_components += 1;
184 for contract in component.contract_ids.iter() {
185 self.contracts.remove(contract);
186 }
187 }
188 comp.map(|c| (k.clone(), c))
189 })
190 .collect();
191 debug!(n_components, "StoppedTracking");
192 res
193 }
194
195 pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
196 &self,
197 ids: I,
198 ) -> HashSet<Bytes> {
199 ids.into_iter()
200 .flat_map(|cid| {
201 let comp = self
202 .components
203 .get(cid)
204 .unwrap_or_else(|| panic!("requested component that is not present: {cid}"));
205 comp.contract_ids.iter().cloned()
206 })
207 .collect()
208 }
209
210 pub fn get_tracked_component_ids(&self) -> Vec<String> {
211 self.components
212 .keys()
213 .cloned()
214 .collect()
215 }
216
217 pub fn filter_updated_components(&self, deltas: &BlockChanges) -> (Vec<String>, Vec<String>) {
220 match &self.filter.variant {
221 ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
222 ComponentFilterVariant::MinimumTVLRange((remove_tvl, add_tvl)) => deltas
223 .component_tvl
224 .iter()
225 .filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
226 .map(|(id, _)| id.clone())
227 .partition(|id| deltas.component_tvl[id] > *add_tvl),
228 }
229 }
230}
231
232#[cfg(test)]
233mod test {
234 use tycho_common::dto::{PaginationResponse, ProtocolComponentRequestResponse};
235
236 use super::*;
237 use crate::rpc::MockRPCClient;
238
239 fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
240 let rpc = MockRPCClient::new();
241 ComponentTracker::new(
242 Chain::Ethereum,
243 "uniswap-v2",
244 ComponentFilter::with_tvl_range(0.0, 0.0),
245 rpc,
246 )
247 }
248
249 fn components_response() -> (Vec<Bytes>, ProtocolComponent) {
250 let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
251 let component = ProtocolComponent {
252 id: "Component1".to_string(),
253 contract_ids: contract_ids.clone(),
254 ..Default::default()
255 };
256 (contract_ids, component)
257 }
258
259 #[tokio::test]
260 async fn test_initialise_components() {
261 let mut tracker = with_mocked_rpc();
262 let (contract_ids, component) = components_response();
263 let exp_component = component.clone();
264 tracker
265 .rpc_client
266 .expect_get_protocol_components_paginated()
267 .returning(move |_, _, _| {
268 Ok(ProtocolComponentRequestResponse {
269 protocol_components: vec![component.clone()],
270 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
271 })
272 });
273
274 tracker
275 .initialise_components()
276 .await
277 .expect("Retrieving components failed");
278
279 assert_eq!(
280 tracker
281 .components
282 .get("Component1")
283 .expect("Component1 not tracked"),
284 &exp_component
285 );
286 assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
287 }
288
289 #[tokio::test]
290 async fn test_start_tracking() {
291 let mut tracker = with_mocked_rpc();
292 let (contract_ids, component) = components_response();
293 let exp_contracts = contract_ids.into_iter().collect();
294 let component_id = component.id.clone();
295 let components_arg = [&component_id];
296 tracker
297 .rpc_client
298 .expect_get_protocol_components()
299 .returning(move |_| {
300 Ok(ProtocolComponentRequestResponse {
301 protocol_components: vec![component.clone()],
302 pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
303 })
304 });
305
306 tracker
307 .start_tracking(&components_arg)
308 .await
309 .expect("Tracking components failed");
310
311 assert_eq!(&tracker.contracts, &exp_contracts);
312 assert!(tracker
313 .components
314 .contains_key("Component1"));
315 }
316
317 #[test]
318 fn test_stop_tracking() {
319 let mut tracker = with_mocked_rpc();
320 let (contract_ids, component) = components_response();
321 tracker
322 .components
323 .insert("Component1".to_string(), component.clone());
324 tracker.contracts.extend(contract_ids);
325 let components_arg = ["Component1".to_string(), "Component2".to_string()];
326 let exp = [("Component1".to_string(), component)]
327 .into_iter()
328 .collect();
329
330 let res = tracker.stop_tracking(&components_arg);
331
332 assert_eq!(res, exp);
333 assert!(tracker.contracts.is_empty());
334 }
335
336 #[test]
337 fn test_get_contracts_by_component() {
338 let mut tracker = with_mocked_rpc();
339 let (exp_contracts, component) = components_response();
340 tracker
341 .components
342 .insert("Component1".to_string(), component);
343 let components_arg = ["Component1".to_string()];
344
345 let res = tracker.get_contracts_by_component(&components_arg);
346
347 assert_eq!(res, exp_contracts.into_iter().collect());
348 }
349
350 #[test]
351 fn test_get_tracked_component_ids() {
352 let mut tracker = with_mocked_rpc();
353 let (_, component) = components_response();
354 tracker
355 .components
356 .insert("Component1".to_string(), component);
357 let exp = vec!["Component1".to_string()];
358
359 let res = tracker.get_tracked_component_ids();
360
361 assert_eq!(res, exp);
362 }
363}