1use async_trait::async_trait;
30use serde::de::DeserializeOwned;
31use std::collections::VecDeque;
32use subsoil::runtime::traits::{Block as BlockT, Header as HeaderT};
33
34pub use crate::api::{
35 author::AuthorApiClient as AuthorApi, chain::ChainApiClient as ChainApi,
36 child_state::ChildStateApiClient as ChildStateApi, dev::DevApiClient as DevApi,
37 offchain::OffchainApiClient as OffchainApi, state::StateApiClient as StateApi,
38 system::SystemApiClient as SystemApi,
39};
40pub use jsonrpsee::{
41 core::{
42 client::{ClientT, Error, Subscription, SubscriptionClientT},
43 params::BatchRequestBuilder,
44 RpcResult,
45 },
46 rpc_params,
47 ws_client::{WsClient, WsClientBuilder},
48};
49
50pub async fn ws_client(uri: impl AsRef<str>) -> Result<WsClient, String> {
52 WsClientBuilder::default()
53 .max_request_size(u32::MAX)
54 .max_response_size(u32::MAX)
55 .request_timeout(std::time::Duration::from_secs(60 * 10))
56 .connection_timeout(std::time::Duration::from_secs(60))
57 .max_buffer_capacity_per_subscription(1024)
58 .build(uri)
59 .await
60 .map_err(|e| format!("`WsClientBuilder` failed to build: {:?}", e))
61}
62
63#[async_trait]
65pub trait HeaderProvider<Block: BlockT>
66where
67 Block::Header: HeaderT,
68{
69 async fn get_header(&self, hash: Block::Hash) -> Block::Header;
71}
72
73#[async_trait]
74impl<Block: BlockT> HeaderProvider<Block> for WsClient
75where
76 Block::Header: DeserializeOwned,
77{
78 async fn get_header(&self, hash: Block::Hash) -> Block::Header {
79 ChainApi::<(), Block::Hash, Block::Header, ()>::header(self, Some(hash))
80 .await
81 .unwrap()
82 .unwrap()
83 }
84}
85
86#[async_trait]
88pub trait HeaderSubscription<Block: BlockT>
89where
90 Block::Header: HeaderT,
91{
92 async fn next_header(&mut self) -> Option<Block::Header>;
97}
98
99#[async_trait]
100impl<Block: BlockT> HeaderSubscription<Block> for Subscription<Block::Header>
101where
102 Block::Header: DeserializeOwned,
103{
104 async fn next_header(&mut self) -> Option<Block::Header> {
105 match self.next().await {
106 Some(Ok(header)) => Some(header),
107 None => {
108 log::warn!("subscription closed");
109 None
110 },
111 Some(Err(why)) => {
112 log::warn!("subscription returned error: {:?}. Probably decoding has failed.", why);
113 None
114 },
115 }
116 }
117}
118
119pub struct FinalizedHeaders<
124 'a,
125 Block: BlockT,
126 HP: HeaderProvider<Block>,
127 HS: HeaderSubscription<Block>,
128> {
129 header_provider: &'a HP,
130 subscription: HS,
131 fetched_headers: VecDeque<Block::Header>,
132 last_returned: Option<<Block::Header as HeaderT>::Hash>,
133}
134
135impl<'a, Block: BlockT, HP: HeaderProvider<Block>, HS: HeaderSubscription<Block>>
136 FinalizedHeaders<'a, Block, HP, HS>
137where
138 <Block as BlockT>::Header: DeserializeOwned,
139{
140 pub fn new(header_provider: &'a HP, subscription: HS) -> Self {
141 Self {
142 header_provider,
143 subscription,
144 fetched_headers: VecDeque::new(),
145 last_returned: None,
146 }
147 }
148
149 async fn fetch(&mut self) -> usize {
154 let last_finalized = match self.subscription.next_header().await {
155 Some(header) => header,
156 None => return 0,
157 };
158
159 self.fetched_headers.push_front(last_finalized.clone());
160
161 let mut last_finalized_parent = *last_finalized.parent_hash();
162 let last_returned = self.last_returned.unwrap_or(last_finalized_parent);
163
164 while last_finalized_parent != last_returned {
165 let parent_header = self.header_provider.get_header(last_finalized_parent).await;
166 self.fetched_headers.push_front(parent_header.clone());
167 last_finalized_parent = *parent_header.parent_hash();
168 }
169
170 self.fetched_headers.len()
171 }
172
173 pub async fn next(&mut self) -> Option<Block::Header> {
175 if self.fetched_headers.is_empty() {
176 self.fetch().await;
177 }
178
179 if let Some(header) = self.fetched_headers.pop_front() {
180 self.last_returned = Some(header.hash());
181 Some(header)
182 } else {
183 None
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use std::sync::Arc;
192 use subsoil::runtime::testing::{Block as TBlock, Header, MockCallU64, TestXt, H256};
193 use tokio::sync::Mutex;
194
195 type UncheckedXt = TestXt<MockCallU64, ()>;
196 type Block = TBlock<UncheckedXt>;
197 type BlockNumber = u64;
198 type Hash = H256;
199
200 struct MockHeaderProvider(pub Arc<Mutex<VecDeque<BlockNumber>>>);
201
202 fn headers() -> Vec<Header> {
203 let mut headers = vec![Header::new_from_number(0)];
204 for n in 1..11 {
205 headers.push(Header {
206 parent_hash: headers.last().unwrap().hash(),
207 ..Header::new_from_number(n)
208 })
209 }
210 headers
211 }
212
213 #[async_trait]
214 impl HeaderProvider<Block> for MockHeaderProvider {
215 async fn get_header(&self, _hash: Hash) -> Header {
216 let height = self.0.lock().await.pop_front().unwrap();
217 headers()[height as usize].clone()
218 }
219 }
220
221 struct MockHeaderSubscription(pub VecDeque<BlockNumber>);
222
223 #[async_trait]
224 impl HeaderSubscription<Block> for MockHeaderSubscription {
225 async fn next_header(&mut self) -> Option<Header> {
226 self.0.pop_front().map(|h| headers()[h as usize].clone())
227 }
228 }
229
230 #[tokio::test]
231 async fn finalized_headers_works_when_every_block_comes_from_subscription() {
232 let heights = vec![4, 5, 6, 7];
233
234 let provider = MockHeaderProvider(Default::default());
235 let subscription = MockHeaderSubscription(heights.clone().into());
236 let mut headers = FinalizedHeaders::new(&provider, subscription);
237
238 for h in heights {
239 assert_eq!(h, headers.next().await.unwrap().number);
240 }
241 assert_eq!(None, headers.next().await);
242 }
243
244 #[tokio::test]
245 async fn finalized_headers_come_from_subscription_and_provider_if_in_need() {
246 let all_heights = 3..11;
247 let heights_in_subscription = vec![3, 4, 6, 10];
248 let heights_not_in_subscription = vec![5, 9, 8, 7];
250
251 let provider = MockHeaderProvider(Arc::new(Mutex::new(heights_not_in_subscription.into())));
252 let subscription = MockHeaderSubscription(heights_in_subscription.into());
253 let mut headers = FinalizedHeaders::new(&provider, subscription);
254
255 for h in all_heights {
256 assert_eq!(h, headers.next().await.unwrap().number);
257 }
258 assert_eq!(None, headers.next().await);
259 }
260}