1use std::collections::HashMap;
2use std::sync::Arc;
3
4use bytes::Bytes;
5use xenith_core::{ChainId, Result, XenithError};
6
7use crate::provider::ChainProvider;
8
9#[derive(Debug, Clone)]
11pub struct DivergenceReport {
12 pub slot: [u8; 32],
14 pub readings: Vec<(ChainId, [u8; 32])>,
16 pub is_diverged: bool,
18 pub diverged_chains: Vec<ChainId>,
21}
22
23pub struct MultiChainReader {
39 pub providers: HashMap<ChainId, Arc<dyn ChainProvider>>,
40}
41
42impl MultiChainReader {
43 pub fn new(providers: HashMap<ChainId, Arc<dyn ChainProvider>>) -> Self {
44 Self { providers }
45 }
46
47 pub async fn read_parallel(
54 &self,
55 chains: Vec<ChainId>,
56 address: [u8; 20],
57 slot: [u8; 32],
58 ) -> Result<Vec<(ChainId, [u8; 32])>> {
59 let handles: Vec<_> = chains
61 .into_iter()
62 .map(|chain| {
63 let provider = self.providers.get(&chain).cloned();
64 let handle = tokio::spawn(async move {
65 let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
66 let val = p.read_storage(address, slot).await.map_err(|e| {
67 XenithError::Transport {
68 chain,
69 message: e.to_string(),
70 }
71 })?;
72 Ok::<(ChainId, [u8; 32]), XenithError>((chain, val))
73 });
74 (chain, handle)
75 })
76 .collect();
77
78 let mut successes = Vec::with_capacity(handles.len());
79 for (chain, handle) in handles {
80 match handle.await {
81 Ok(Ok(reading)) => successes.push(reading),
82 Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} read failed: {e}"),
83 Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
84 }
85 }
86 Ok(successes)
87 }
88
89 pub async fn call_parallel(
93 &self,
94 chains: Vec<ChainId>,
95 address: [u8; 20],
96 calldata: Bytes,
97 ) -> Result<Vec<(ChainId, Bytes)>> {
98 let handles: Vec<_> = chains
99 .into_iter()
100 .map(|chain| {
101 let provider = self.providers.get(&chain).cloned();
102 let data = calldata.clone();
103 let handle = tokio::spawn(async move {
104 let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
105 let result =
106 p.call(address, data)
107 .await
108 .map_err(|e| XenithError::Transport {
109 chain,
110 message: e.to_string(),
111 })?;
112 Ok::<(ChainId, Bytes), XenithError>((chain, result))
113 });
114 (chain, handle)
115 })
116 .collect();
117
118 let mut successes = Vec::with_capacity(handles.len());
119 for (chain, handle) in handles {
120 match handle.await {
121 Ok(Ok(result)) => successes.push(result),
122 Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} call failed: {e}"),
123 Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
124 }
125 }
126 Ok(successes)
127 }
128
129 pub async fn check_divergence(
131 &self,
132 chains: Vec<ChainId>,
133 address: [u8; 20],
134 slot: [u8; 32],
135 ) -> Result<DivergenceReport> {
136 let readings = self.read_parallel(chains, address, slot).await?;
137
138 let is_diverged = readings.windows(2).any(|w| w[0].1 != w[1].1);
139
140 let diverged_chains = if is_diverged {
141 let mut counts: HashMap<[u8; 32], usize> = HashMap::new();
143 for (_, v) in &readings {
144 *counts.entry(*v).or_insert(0) += 1;
145 }
146 let max_count = counts.values().copied().max().unwrap_or(0);
147 let n_at_max = counts.values().filter(|&&c| c == max_count).count();
149
150 if n_at_max == 1 {
151 let majority = *counts.iter().find(|(_, &c)| c == max_count).unwrap().0;
153 readings
154 .iter()
155 .filter(|(_, v)| *v != majority)
156 .map(|(c, _)| *c)
157 .collect()
158 } else {
159 readings.iter().map(|(c, _)| *c).collect()
161 }
162 } else {
163 vec![]
164 };
165
166 Ok(DivergenceReport {
167 slot,
168 readings,
169 is_diverged,
170 diverged_chains,
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::provider::MockProvider;
179 use std::collections::HashMap;
180
181 fn make_reader(chain_slots: &[(u64, [u8; 32], [u8; 32])]) -> MultiChainReader {
182 let mut by_chain: HashMap<u64, HashMap<[u8; 32], [u8; 32]>> = HashMap::new();
184 for &(chain, slot, val) in chain_slots {
185 by_chain.entry(chain).or_default().insert(slot, val);
186 }
187 let providers = by_chain
188 .into_iter()
189 .map(|(c, slots)| {
190 (
191 ChainId(c),
192 Arc::new(MockProvider::new(slots)) as Arc<dyn ChainProvider>,
193 )
194 })
195 .collect();
196 MultiChainReader::new(providers)
197 }
198
199 const SLOT: [u8; 32] = [0xABu8; 32];
200 const ADDR: [u8; 20] = [0u8; 20];
201
202 #[tokio::test]
203 async fn no_divergence_when_chains_agree() {
204 let value = [0x42u8; 32];
205 let reader = make_reader(&[(1, SLOT, value), (42161, SLOT, value)]);
206
207 let report = reader
208 .check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
209 .await
210 .unwrap();
211
212 assert!(!report.is_diverged);
213 assert!(report.diverged_chains.is_empty());
214 assert_eq!(report.readings.len(), 2);
215 assert!(report.readings.iter().all(|(_, v)| *v == value));
216 }
217
218 #[tokio::test]
219 async fn divergence_detected_when_chains_differ() {
220 let reader = make_reader(&[(1, SLOT, [0x01u8; 32]), (42161, SLOT, [0x02u8; 32])]);
221
222 let report = reader
223 .check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
224 .await
225 .unwrap();
226
227 assert!(report.is_diverged);
228 assert_eq!(report.diverged_chains.len(), 2);
230 assert_eq!(report.readings.len(), 2);
231 }
232
233 #[tokio::test]
234 async fn three_chains_one_diverges() {
235 let reader = make_reader(&[
236 (1, SLOT, [0x01u8; 32]),
237 (10, SLOT, [0x01u8; 32]), (42161, SLOT, [0xFFu8; 32]), ]);
240
241 let report = reader
242 .check_divergence(vec![ChainId(1), ChainId(10), ChainId(42161)], ADDR, SLOT)
243 .await
244 .unwrap();
245
246 assert!(report.is_diverged);
247 assert_eq!(report.diverged_chains, vec![ChainId(42161)]);
248 }
249
250 #[tokio::test]
251 async fn unregistered_chain_is_skipped_with_warning() {
252 let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
253
254 let readings = reader
256 .read_parallel(vec![ChainId(1), ChainId(9999)], ADDR, SLOT)
257 .await
258 .unwrap();
259
260 assert_eq!(readings.len(), 1);
261 assert_eq!(readings[0].0, ChainId(1));
262 }
263
264 #[tokio::test]
265 async fn call_parallel_returns_responses() {
266 let reader = make_reader(&[(1, SLOT, [0u8; 32]), (42161, SLOT, [0u8; 32])]);
267 let results = reader
268 .call_parallel(vec![ChainId(1), ChainId(42161)], ADDR, Bytes::new())
269 .await
270 .unwrap();
271 assert_eq!(results.len(), 2);
272 }
273
274 #[tokio::test]
275 async fn single_chain_never_diverges() {
276 let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
277 let report = reader
278 .check_divergence(vec![ChainId(1)], ADDR, SLOT)
279 .await
280 .unwrap();
281 assert!(!report.is_diverged);
282 }
283}