Skip to main content

xenith_read/
reader.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use bytes::Bytes;
5use xenith_core::{ChainId, Result, XenithError};
6
7use crate::provider::ChainProvider;
8
9/// The result of comparing a single storage slot across multiple chains.
10#[derive(Debug, Clone)]
11pub struct DivergenceReport {
12    /// The slot that was inspected.
13    pub slot: [u8; 32],
14    /// Successful readings: one entry per responding chain.
15    pub readings: Vec<(ChainId, [u8; 32])>,
16    /// `true` if at least two chains returned different values.
17    pub is_diverged: bool,
18    /// Chains whose value differs from the first reading.
19    /// Empty when `is_diverged` is `false`.
20    pub diverged_chains: Vec<ChainId>,
21}
22
23/// Executes storage reads and contract calls across many chains in parallel.
24///
25/// # Example
26///
27/// ```rust,no_run
28/// use std::sync::Arc;
29/// use std::collections::HashMap;
30/// use xenith_core::ChainId;
31/// use xenith_read::MultiChainReader;
32///
33/// # async fn example() {
34/// let reader = MultiChainReader::new(HashMap::new());
35/// let results = reader.read_parallel(vec![ChainId::from(1)], [0u8; 20], [0u8; 32]).await;
36/// # }
37/// ```
38pub struct MultiChainReader {
39    pub providers: HashMap<ChainId, Arc<dyn ChainProvider>>,
40}
41
42impl MultiChainReader {
43    pub fn new(providers: HashMap<ChainId, Arc<dyn ChainProvider>>) -> Self {
44        Self { providers }
45    }
46
47    /// Read `slot` from `address` on every chain in `chains`, concurrently.
48    ///
49    /// Tasks are spawned for all chains simultaneously. Failures are logged as
50    /// warnings and omitted from the returned Vec; only successful readings are
51    /// included. The outer `Result` only fails if something prevents spawning
52    /// (which in practice never happens with tokio).
53    pub async fn read_parallel(
54        &self,
55        chains: Vec<ChainId>,
56        address: [u8; 20],
57        slot: [u8; 32],
58    ) -> Result<Vec<(ChainId, [u8; 32])>> {
59        // Spawn all tasks first so they run concurrently, then collect.
60        let handles: Vec<_> = chains
61            .into_iter()
62            .map(|chain| {
63                let provider = self.providers.get(&chain).cloned();
64                let handle = tokio::spawn(async move {
65                    let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
66                    let val = p.read_storage(address, slot).await.map_err(|e| {
67                        XenithError::Transport {
68                            chain,
69                            message: e.to_string(),
70                        }
71                    })?;
72                    Ok::<(ChainId, [u8; 32]), XenithError>((chain, val))
73                });
74                (chain, handle)
75            })
76            .collect();
77
78        let mut successes = Vec::with_capacity(handles.len());
79        for (chain, handle) in handles {
80            match handle.await {
81                Ok(Ok(reading)) => successes.push(reading),
82                Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} read failed: {e}"),
83                Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
84            }
85        }
86        Ok(successes)
87    }
88
89    /// Call `address` with `calldata` on every chain in `chains`, concurrently.
90    ///
91    /// Same failure-tolerance semantics as [`read_parallel`][Self::read_parallel].
92    pub async fn call_parallel(
93        &self,
94        chains: Vec<ChainId>,
95        address: [u8; 20],
96        calldata: Bytes,
97    ) -> Result<Vec<(ChainId, Bytes)>> {
98        let handles: Vec<_> = chains
99            .into_iter()
100            .map(|chain| {
101                let provider = self.providers.get(&chain).cloned();
102                let data = calldata.clone();
103                let handle = tokio::spawn(async move {
104                    let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
105                    let result =
106                        p.call(address, data)
107                            .await
108                            .map_err(|e| XenithError::Transport {
109                                chain,
110                                message: e.to_string(),
111                            })?;
112                    Ok::<(ChainId, Bytes), XenithError>((chain, result))
113                });
114                (chain, handle)
115            })
116            .collect();
117
118        let mut successes = Vec::with_capacity(handles.len());
119        for (chain, handle) in handles {
120            match handle.await {
121                Ok(Ok(result)) => successes.push(result),
122                Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} call failed: {e}"),
123                Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
124            }
125        }
126        Ok(successes)
127    }
128
129    /// Read `slot` across `chains` and report whether their values agree.
130    pub async fn check_divergence(
131        &self,
132        chains: Vec<ChainId>,
133        address: [u8; 20],
134        slot: [u8; 32],
135    ) -> Result<DivergenceReport> {
136        let readings = self.read_parallel(chains, address, slot).await?;
137
138        let is_diverged = readings.windows(2).any(|w| w[0].1 != w[1].1);
139
140        let diverged_chains = if is_diverged {
141            // Count how many chains hold each distinct value.
142            let mut counts: HashMap<[u8; 32], usize> = HashMap::new();
143            for (_, v) in &readings {
144                *counts.entry(*v).or_insert(0) += 1;
145            }
146            let max_count = counts.values().copied().max().unwrap_or(0);
147            // How many distinct values share the top count?
148            let n_at_max = counts.values().filter(|&&c| c == max_count).count();
149
150            if n_at_max == 1 {
151                // Clear majority exists — flag only minority chains.
152                let majority = *counts.iter().find(|(_, &c)| c == max_count).unwrap().0;
153                readings
154                    .iter()
155                    .filter(|(_, v)| *v != majority)
156                    .map(|(c, _)| *c)
157                    .collect()
158            } else {
159                // No consensus (e.g. 50/50 split) — flag every chain.
160                readings.iter().map(|(c, _)| *c).collect()
161            }
162        } else {
163            vec![]
164        };
165
166        Ok(DivergenceReport {
167            slot,
168            readings,
169            is_diverged,
170            diverged_chains,
171        })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::provider::MockProvider;
179    use std::collections::HashMap;
180
181    fn make_reader(chain_slots: &[(u64, [u8; 32], [u8; 32])]) -> MultiChainReader {
182        // chain_slots: (chain_id, slot, value)
183        let mut by_chain: HashMap<u64, HashMap<[u8; 32], [u8; 32]>> = HashMap::new();
184        for &(chain, slot, val) in chain_slots {
185            by_chain.entry(chain).or_default().insert(slot, val);
186        }
187        let providers = by_chain
188            .into_iter()
189            .map(|(c, slots)| {
190                (
191                    ChainId(c),
192                    Arc::new(MockProvider::new(slots)) as Arc<dyn ChainProvider>,
193                )
194            })
195            .collect();
196        MultiChainReader::new(providers)
197    }
198
199    const SLOT: [u8; 32] = [0xABu8; 32];
200    const ADDR: [u8; 20] = [0u8; 20];
201
202    #[tokio::test]
203    async fn no_divergence_when_chains_agree() {
204        let value = [0x42u8; 32];
205        let reader = make_reader(&[(1, SLOT, value), (42161, SLOT, value)]);
206
207        let report = reader
208            .check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
209            .await
210            .unwrap();
211
212        assert!(!report.is_diverged);
213        assert!(report.diverged_chains.is_empty());
214        assert_eq!(report.readings.len(), 2);
215        assert!(report.readings.iter().all(|(_, v)| *v == value));
216    }
217
218    #[tokio::test]
219    async fn divergence_detected_when_chains_differ() {
220        let reader = make_reader(&[(1, SLOT, [0x01u8; 32]), (42161, SLOT, [0x02u8; 32])]);
221
222        let report = reader
223            .check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
224            .await
225            .unwrap();
226
227        assert!(report.is_diverged);
228        // 50/50 split — no majority, so both chains are flagged.
229        assert_eq!(report.diverged_chains.len(), 2);
230        assert_eq!(report.readings.len(), 2);
231    }
232
233    #[tokio::test]
234    async fn three_chains_one_diverges() {
235        let reader = make_reader(&[
236            (1, SLOT, [0x01u8; 32]),
237            (10, SLOT, [0x01u8; 32]),    // agrees with chain 1
238            (42161, SLOT, [0xFFu8; 32]), // outlier
239        ]);
240
241        let report = reader
242            .check_divergence(vec![ChainId(1), ChainId(10), ChainId(42161)], ADDR, SLOT)
243            .await
244            .unwrap();
245
246        assert!(report.is_diverged);
247        assert_eq!(report.diverged_chains, vec![ChainId(42161)]);
248    }
249
250    #[tokio::test]
251    async fn unregistered_chain_is_skipped_with_warning() {
252        let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
253
254        // Chain 9999 has no provider — should be omitted, not panic.
255        let readings = reader
256            .read_parallel(vec![ChainId(1), ChainId(9999)], ADDR, SLOT)
257            .await
258            .unwrap();
259
260        assert_eq!(readings.len(), 1);
261        assert_eq!(readings[0].0, ChainId(1));
262    }
263
264    #[tokio::test]
265    async fn call_parallel_returns_responses() {
266        let reader = make_reader(&[(1, SLOT, [0u8; 32]), (42161, SLOT, [0u8; 32])]);
267        let results = reader
268            .call_parallel(vec![ChainId(1), ChainId(42161)], ADDR, Bytes::new())
269            .await
270            .unwrap();
271        assert_eq!(results.len(), 2);
272    }
273
274    #[tokio::test]
275    async fn single_chain_never_diverges() {
276        let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
277        let report = reader
278            .check_divergence(vec![ChainId(1)], ADDR, SLOT)
279            .await
280            .unwrap();
281        assert!(!report.is_diverged);
282    }
283}