1use std::collections::HashMap;
2
3use num_bigint::BigUint;
4use tycho_client::feed::{synchronizer::ComponentWithState, BlockHeader};
5use tycho_common::{models::token::Token, Bytes};
6
7use crate::{
8 evm::protocol::lido::state::{LidoPoolType, LidoState, StakeLimitState, StakingStatus},
9 protocol::{
10 errors::InvalidSnapshotError,
11 models::{DecoderContext, TryFromWithBlock},
12 },
13};
14
15pub const ETH_ADDRESS: &str = "0x0000000000000000000000000000000000000000";
16
17impl TryFromWithBlock<ComponentWithState, BlockHeader> for LidoState {
18 type Error = InvalidSnapshotError;
19
20 async fn try_from_with_header(
22 snapshot: ComponentWithState,
23 _block: BlockHeader,
24 _account_balances: &HashMap<Bytes, HashMap<Bytes, Bytes>>,
25 _all_tokens: &HashMap<Bytes, Token>,
26 _decoder_context: &DecoderContext,
27 ) -> Result<Self, Self::Error> {
28 let id = snapshot.component.id.as_str();
29
30 let pool_type = match snapshot
31 .component
32 .static_attributes
33 .get("protocol_type_name")
34 .and_then(|bytes| std::str::from_utf8(bytes).ok())
35 .ok_or(InvalidSnapshotError::MissingAttribute(
36 "protocol_type_name is missing".to_owned(),
37 ))? {
38 "stETH" => LidoPoolType::StEth,
39 "wstETH" => LidoPoolType::WStEth,
40 _ => {
41 return Err(InvalidSnapshotError::ValueError(format!(
42 "Unknown protocol type name: {:?}",
43 snapshot.component.protocol_type_name
44 )))
45 }
46 };
47
48 let token_to_track_total_pooled_eth = snapshot
49 .component
50 .static_attributes
51 .get("token_to_track_total_pooled_eth")
52 .ok_or(InvalidSnapshotError::MissingAttribute(
53 "token_to_track_total_pooled_eth is missing".to_owned(),
54 ))?
55 .clone();
56
57 let tokens: [Bytes; 2] =
58 [snapshot.component.tokens[0].clone(), snapshot.component.tokens[1].clone()];
59
60 let total_shares = snapshot
61 .state
62 .attributes
63 .get("total_shares")
64 .ok_or(InvalidSnapshotError::MissingAttribute(
65 "Total shares field is missing".to_owned(),
66 ))?;
67
68 let total_pooled_eth = snapshot
69 .state
70 .balances
71 .get(&token_to_track_total_pooled_eth)
72 .ok_or(InvalidSnapshotError::MissingAttribute(
73 "Total pooled eth field is missing".to_owned(),
74 ))?;
75
76 let (staking_status_parsed, staking_limit) = if pool_type == LidoPoolType::StEth {
77 let staking_status = snapshot
78 .state
79 .attributes
80 .get("staking_status")
81 .ok_or(InvalidSnapshotError::MissingAttribute(
82 "Staking_status field is missing".to_owned(),
83 ))?;
84
85 let staking_status_parsed =
86 if let Ok(status_as_str) = std::str::from_utf8(staking_status) {
87 match status_as_str {
88 "Limited" => StakingStatus::Limited,
89 "Paused" => StakingStatus::Paused,
90 "Unlimited" => StakingStatus::Unlimited,
91 _ => {
92 return Err(InvalidSnapshotError::ValueError(
93 "status_as_str parsed to invalid status".to_owned(),
94 ))
95 }
96 }
97 } else {
98 return Err(InvalidSnapshotError::ValueError(
99 "status_as_str cannot be parsed".to_owned(),
100 ));
101 };
102
103 let staking_limit = snapshot
104 .state
105 .attributes
106 .get("staking_limit")
107 .ok_or(InvalidSnapshotError::MissingAttribute(
108 "Staking_limit field is missing".to_owned(),
109 ))?;
110 (staking_status_parsed, staking_limit)
111 } else {
112 (StakingStatus::Limited, &Bytes::from(vec![0; 32]))
113 };
114
115 let total_wrapped_st_eth = if pool_type == LidoPoolType::StEth {
116 None
117 } else {
118 Some(BigUint::from_bytes_be(
119 snapshot
120 .state
121 .attributes
122 .get("total_wstETH")
123 .ok_or(InvalidSnapshotError::MissingAttribute(
124 "Total pooled eth field is missing".to_owned(),
125 ))?,
126 ))
127 };
128
129 Ok(Self {
130 pool_type,
131 total_shares: BigUint::from_bytes_be(total_shares),
132 total_pooled_eth: BigUint::from_bytes_be(total_pooled_eth),
133 total_wrapped_st_eth,
134 id: id.into(),
135 native_address: ETH_ADDRESS.into(),
136 stake_limits_state: StakeLimitState {
137 staking_status: staking_status_parsed,
138 staking_limit: BigUint::from_bytes_be(staking_limit),
139 },
140 tokens,
141 token_to_track_total_pooled_eth,
142 })
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::{collections::HashMap, str::FromStr};
149
150 use chrono::NaiveDateTime;
151 use num_bigint::BigUint;
152 use num_traits::Zero;
153 use rstest::rstest;
154 use tycho_client::feed::synchronizer::ComponentWithState;
155 use tycho_common::{
156 dto::{Chain, ChangeType, ProtocolComponent, ResponseProtocolState},
157 Bytes,
158 };
159
160 use crate::{
161 evm::protocol::{
162 lido::{
163 decoder::ETH_ADDRESS,
164 state::{LidoPoolType, LidoState, StakeLimitState},
165 },
166 test_utils::try_decode_snapshot_with_defaults,
167 },
168 protocol::errors::InvalidSnapshotError,
169 };
170
171 const ST_ETH_ADDRESS_PROXY: &str = "0xae7ab96520de3a18e5e111b5eaab095312d7fe84";
172 const WST_ETH_ADDRESS: &str = "0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0";
173
174 #[tokio::test]
175 async fn test_lido_steth_try_from() {
176 let mut static_attr = HashMap::new();
177 static_attr.insert(
178 "token_to_track_total_pooled_eth".to_string(),
179 "0x0000000000000000000000000000000000000000"
180 .as_bytes()
181 .to_vec()
182 .into(),
183 );
184 static_attr.insert(
185 "token_to_track_total_pooled_eth".to_string(),
186 Bytes::from_str(ETH_ADDRESS).unwrap(),
187 );
188 static_attr.insert("protocol_type_name".to_string(), "stETH".as_bytes().to_vec().into());
189
190 let pc = ProtocolComponent {
191 id: ST_ETH_ADDRESS_PROXY.to_string(),
192 protocol_system: "protocol_system".to_owned(),
193 protocol_type_name: "protocol_type_name".to_owned(),
194 chain: Chain::Ethereum,
195 tokens: vec![
196 Bytes::from("0x0000000000000000000000000000000000000000"),
197 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
198 ],
199 contract_ids: vec![],
200 static_attributes: static_attr,
201 change: ChangeType::Creation,
202 creation_tx: Bytes::from(vec![0; 32]),
203 created_at: NaiveDateTime::default(),
204 };
205
206 let snapshot = ComponentWithState {
207 state: ResponseProtocolState {
208 component_id: ST_ETH_ADDRESS_PROXY.to_owned(),
209 attributes: HashMap::from([
210 ("total_shares".to_string(), Bytes::from(vec![0; 32])),
211 ("staking_status".to_string(), "Limited".as_bytes().to_vec().into()),
212 ("staking_limit".to_string(), Bytes::from(vec![0; 32])),
213 ]),
214 balances: HashMap::from([(Bytes::from(ETH_ADDRESS), Bytes::from(vec![0; 32]))]),
215 },
216 component: pc,
217 component_tvl: None,
218 entrypoints: Vec::new(),
219 };
220
221 let result = try_decode_snapshot_with_defaults::<LidoState>(snapshot).await;
222
223 assert!(result.is_ok());
224 assert_eq!(
225 result.unwrap(),
226 LidoState {
227 pool_type: LidoPoolType::StEth,
228 total_shares: BigUint::zero(),
229 total_pooled_eth: BigUint::zero(),
230 total_wrapped_st_eth: None,
231 id: ST_ETH_ADDRESS_PROXY.into(),
232 native_address: ETH_ADDRESS.into(),
233 stake_limits_state: StakeLimitState {
234 staking_status: crate::evm::protocol::lido::state::StakingStatus::Limited,
235 staking_limit: BigUint::zero(),
236 },
237 tokens: [
238 Bytes::from("0x0000000000000000000000000000000000000000"),
239 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
240 ],
241 token_to_track_total_pooled_eth: Bytes::from(ETH_ADDRESS)
242 }
243 );
244 }
245
246 #[tokio::test]
247 #[rstest]
248 #[case::missing_total_shares("total_shares")]
249 #[case::missing_staking_status("staking_status")]
250 #[case::missing_staking_limit("staking_limit")]
251 async fn test_lido_try_from_missing_attribute(#[case] missing_attribute: &str) {
252 let mut static_attr = HashMap::new();
253 static_attr.insert(
254 "token_to_track_total_pooled_eth".to_string(),
255 Bytes::from_str(ETH_ADDRESS).unwrap(),
256 );
257
258 let pc = ProtocolComponent {
259 id: ST_ETH_ADDRESS_PROXY.to_string(),
260 protocol_system: "protocol_system".to_owned(),
261 protocol_type_name: "protocol_type_name".to_owned(),
262 chain: Chain::Ethereum,
263 tokens: vec![
264 Bytes::from("0x0000000000000000000000000000000000000000"),
265 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
266 ],
267 contract_ids: vec![],
268 static_attributes: static_attr,
269 change: ChangeType::Creation,
270 creation_tx: Bytes::from(vec![0; 32]),
271 created_at: NaiveDateTime::default(),
272 };
273
274 let mut snapshot = ComponentWithState {
275 state: ResponseProtocolState {
276 component_id: ST_ETH_ADDRESS_PROXY.to_owned(),
277 attributes: HashMap::from([
278 ("total_shares".to_string(), Bytes::from(vec![0; 32])),
279 ("staking_status".to_string(), "Limited".as_bytes().to_vec().into()),
280 ("staking_limit".to_string(), Bytes::from(vec![0; 32])),
281 ]),
282 balances: HashMap::from([(
283 Bytes::from_str(ETH_ADDRESS).unwrap(),
284 Bytes::from(vec![0; 32]),
285 )]),
286 },
287 component: pc,
288 component_tvl: None,
289 entrypoints: Vec::new(),
290 };
291 snapshot
292 .state
293 .attributes
294 .remove(missing_attribute);
295
296 let result = try_decode_snapshot_with_defaults::<LidoState>(snapshot).await;
297
298 assert!(result.is_err());
299 assert!(matches!(result.unwrap_err(), InvalidSnapshotError::MissingAttribute(_)));
300 }
301
302 #[tokio::test]
303 async fn test_lido_wst_eth_try_from() {
304 let mut static_attr = HashMap::new();
305 static_attr.insert(
306 "token_to_track_total_pooled_eth".to_string(),
307 "0xae7ab96520de3a18e5e111b5eaab095312d7fe84"
308 .as_bytes()
309 .to_vec()
310 .into(),
311 );
312 static_attr.insert(
313 "token_to_track_total_pooled_eth".to_string(),
314 Bytes::from_str(ST_ETH_ADDRESS_PROXY).unwrap(),
315 );
316 static_attr.insert("protocol_type_name".to_string(), "wstETH".as_bytes().to_vec().into());
317
318 let pc = ProtocolComponent {
319 id: WST_ETH_ADDRESS.to_string(),
320 protocol_system: "protocol_system".to_owned(),
321 protocol_type_name: "protocol_type_name".to_owned(),
322 chain: Chain::Ethereum,
323 tokens: vec![
324 Bytes::from("0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0"),
325 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
326 ],
327 contract_ids: vec![],
328 static_attributes: static_attr,
329 change: ChangeType::Creation,
330 creation_tx: Bytes::from(vec![0; 32]),
331 created_at: NaiveDateTime::default(),
332 };
333
334 let snapshot = ComponentWithState {
335 state: ResponseProtocolState {
336 component_id: ST_ETH_ADDRESS_PROXY.to_owned(),
337 attributes: HashMap::from([
338 ("total_shares".to_string(), Bytes::from(vec![0; 32])),
339 ("total_wstETH".to_string(), Bytes::from(vec![0; 32])),
340 ]),
341 balances: HashMap::from([(
342 Bytes::from_str(ST_ETH_ADDRESS_PROXY).unwrap(),
343 Bytes::from(vec![0; 32]),
344 )]),
345 },
346 component: pc,
347 component_tvl: None,
348 entrypoints: Vec::new(),
349 };
350
351 let result = try_decode_snapshot_with_defaults::<LidoState>(snapshot).await;
352
353 assert!(result.is_ok());
354 assert_eq!(
355 result.unwrap(),
356 LidoState {
357 pool_type: LidoPoolType::WStEth,
358 total_shares: BigUint::zero(),
359 total_pooled_eth: BigUint::zero(),
360 total_wrapped_st_eth: Some(BigUint::zero()),
361 id: WST_ETH_ADDRESS.into(),
362 native_address: ETH_ADDRESS.into(),
363 stake_limits_state: StakeLimitState {
364 staking_status: crate::evm::protocol::lido::state::StakingStatus::Limited,
365 staking_limit: BigUint::zero(),
366 },
367 tokens: [
368 Bytes::from("0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0"),
369 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
370 ],
371 token_to_track_total_pooled_eth: Bytes::from(
372 "0xae7ab96520de3a18e5e111b5eaab095312d7fe84"
373 )
374 }
375 );
376 }
377
378 #[tokio::test]
379 #[rstest]
380 #[case::missing_total_shares("total_shares")]
381 #[case::missing_total_wst_eth("total_wstETH")]
382 async fn test_lido_wst_try_from_missing_attribute(#[case] missing_attribute: &str) {
383 let pc = ProtocolComponent {
384 id: WST_ETH_ADDRESS.to_string(),
385 protocol_system: "protocol_system".to_owned(),
386 protocol_type_name: "protocol_type_name".to_owned(),
387 chain: Chain::Ethereum,
388 tokens: vec![
389 Bytes::from("0x7f39c581f595b53c5cb19bd0b3f8da6c935e2ca0"),
390 Bytes::from("0xae7ab96520de3a18e5e111b5eaab095312d7fe84"),
391 ],
392 contract_ids: vec![],
393 static_attributes: HashMap::new(),
394 change: ChangeType::Creation,
395 creation_tx: Bytes::from(vec![0; 32]),
396 created_at: NaiveDateTime::default(),
397 };
398
399 let mut snapshot = ComponentWithState {
400 state: ResponseProtocolState {
401 component_id: ST_ETH_ADDRESS_PROXY.to_owned(),
402 attributes: HashMap::from([
403 ("total_shares".to_string(), Bytes::from(vec![0; 32])),
404 ("total_wstETH".to_string(), Bytes::from(vec![0; 32])),
405 ]),
406 balances: HashMap::from([(
407 Bytes::from(ST_ETH_ADDRESS_PROXY),
408 Bytes::from(vec![0; 32]),
409 )]),
410 },
411 component: pc,
412 component_tvl: None,
413 entrypoints: Vec::new(),
414 };
415 snapshot
416 .state
417 .attributes
418 .remove(missing_attribute);
419
420 let result = try_decode_snapshot_with_defaults::<LidoState>(snapshot).await;
421
422 assert!(result.is_err());
423 }
424}