zlayer_consensus/network/
http_client.rs1use std::collections::HashMap;
8use std::future::Future;
9use std::io::Cursor;
10use std::sync::Arc;
11use std::time::Duration;
12
13use openraft::error::{
14 Fatal, InstallSnapshotError, RPCError, RaftError, ReplicationClosed, StreamingError,
15 Unreachable,
16};
17use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
18use openraft::raft::{
19 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
20 SnapshotResponse, VoteRequest, VoteResponse,
21};
22use openraft::{BasicNode, OptionalSend, RaftTypeConfig, Snapshot, SnapshotMeta, Vote};
23use reqwest::Client;
24use tokio::sync::RwLock;
25use tracing::debug;
26
27use crate::types::NodeId;
28
29#[derive(Clone)]
39pub struct RaftHttpClient {
40 rpc_client: Client,
42 snapshot_client: Client,
44 auth_token: Option<String>,
46}
47
48impl RaftHttpClient {
49 #[must_use]
51 pub fn new(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
52 Self::with_auth(rpc_timeout, snapshot_timeout, None)
53 }
54
55 #[must_use]
60 pub fn with_auth(
61 rpc_timeout: Duration,
62 snapshot_timeout: Duration,
63 auth_token: Option<String>,
64 ) -> Self {
65 let rpc_client = Client::builder()
66 .timeout(rpc_timeout)
67 .pool_max_idle_per_host(10)
68 .pool_idle_timeout(Duration::from_secs(90))
69 .build()
70 .expect("Failed to build RPC HTTP client");
71
72 let snapshot_client = Client::builder()
73 .timeout(snapshot_timeout)
74 .pool_max_idle_per_host(5)
75 .pool_idle_timeout(Duration::from_secs(90))
76 .build()
77 .expect("Failed to build snapshot HTTP client");
78
79 Self {
80 rpc_client,
81 snapshot_client,
82 auth_token,
83 }
84 }
85
86 async fn postcard_post<Req, Resp>(
88 client: &Client,
89 url: &str,
90 request: &Req,
91 auth_token: Option<&str>,
92 ) -> Result<Resp, String>
93 where
94 Req: serde::Serialize,
95 Resp: serde::de::DeserializeOwned,
96 {
97 let body =
98 postcard2::to_vec(request).map_err(|e| format!("postcard2 serialize error: {e}"))?;
99
100 let mut builder = client
101 .post(url)
102 .header("Content-Type", "application/octet-stream");
103
104 if let Some(token) = auth_token {
105 builder = builder.header("Authorization", format!("Bearer {token}"));
106 }
107
108 let response = builder.body(body).send().await.map_err(|e| {
109 if e.is_timeout() {
110 format!("timeout: {e}")
111 } else if e.is_connect() {
112 format!("unreachable: {e}")
113 } else {
114 format!("http error: {e}")
115 }
116 })?;
117
118 if !response.status().is_success() {
119 let status = response.status();
120 let text = response.text().await.unwrap_or_default();
121 return Err(format!("HTTP {status}: {text}"));
122 }
123
124 let bytes = response
125 .bytes()
126 .await
127 .map_err(|e| format!("read body error: {e}"))?;
128 postcard2::from_bytes(&bytes).map_err(|e| format!("postcard2 deserialize error: {e}"))
129 }
130}
131
132impl Default for RaftHttpClient {
133 fn default() -> Self {
134 Self::new(Duration::from_secs(5), Duration::from_secs(60))
135 }
136}
137
138pub struct HttpNetwork<C: RaftTypeConfig<NodeId = NodeId>> {
146 peers: Arc<RwLock<HashMap<NodeId, String>>>,
148 client: Arc<RaftHttpClient>,
150 _phantom: std::marker::PhantomData<C>,
151}
152
153impl<C: RaftTypeConfig<NodeId = NodeId>> HttpNetwork<C> {
154 #[must_use]
156 pub fn new() -> Self {
157 Self::with_client(RaftHttpClient::default())
158 }
159
160 #[must_use]
162 pub fn with_client(client: RaftHttpClient) -> Self {
163 Self {
164 peers: Arc::new(RwLock::new(HashMap::new())),
165 client: Arc::new(client),
166 _phantom: std::marker::PhantomData,
167 }
168 }
169
170 #[must_use]
172 pub fn with_timeouts(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
173 Self::with_client(RaftHttpClient::new(rpc_timeout, snapshot_timeout))
174 }
175
176 #[must_use]
178 pub fn with_timeouts_and_auth(
179 rpc_timeout: Duration,
180 snapshot_timeout: Duration,
181 auth_token: Option<String>,
182 ) -> Self {
183 Self::with_client(RaftHttpClient::with_auth(
184 rpc_timeout,
185 snapshot_timeout,
186 auth_token,
187 ))
188 }
189
190 pub async fn add_peer(&self, node_id: NodeId, address: String) {
192 self.peers.write().await.insert(node_id, address);
193 }
194
195 pub async fn remove_peer(&self, node_id: NodeId) {
197 self.peers.write().await.remove(&node_id);
198 }
199
200 pub async fn peers(&self) -> HashMap<NodeId, String> {
202 self.peers.read().await.clone()
203 }
204}
205
206impl<C: RaftTypeConfig<NodeId = NodeId>> Default for HttpNetwork<C> {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for HttpNetwork<C> {
213 fn clone(&self) -> Self {
214 Self {
215 peers: Arc::clone(&self.peers),
216 client: Arc::clone(&self.client),
217 _phantom: std::marker::PhantomData,
218 }
219 }
220}
221
222impl<C> RaftNetworkFactory<C> for HttpNetwork<C>
223where
224 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
225 C::D: serde::Serialize + serde::de::DeserializeOwned,
226 C::R: serde::Serialize + serde::de::DeserializeOwned,
227 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
228{
229 type Network = HttpConnection<C>;
230
231 async fn new_client(&mut self, _target: NodeId, node: &BasicNode) -> Self::Network {
232 HttpConnection {
233 target_addr: node.addr.clone(),
234 client: Arc::clone(&self.client),
235 auth_token: self.client.auth_token.clone(),
236 _phantom: std::marker::PhantomData,
237 }
238 }
239}
240
241pub struct HttpConnection<C: RaftTypeConfig<NodeId = NodeId>> {
243 target_addr: String,
244 client: Arc<RaftHttpClient>,
245 auth_token: Option<String>,
247 _phantom: std::marker::PhantomData<C>,
248}
249
250fn normalize_addr(addr: &str) -> String {
252 if addr.starts_with("http://") || addr.starts_with("https://") {
253 addr.to_string()
254 } else {
255 format!("http://{addr}")
256 }
257}
258
259fn to_unreachable<E: std::error::Error>(msg: String) -> RPCError<NodeId, BasicNode, E> {
261 RPCError::Unreachable(Unreachable::new(&std::io::Error::other(msg)))
262}
263
264impl<C> RaftNetwork<C> for HttpConnection<C>
265where
266 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
267 C::D: serde::Serialize + serde::de::DeserializeOwned,
268 C::R: serde::Serialize + serde::de::DeserializeOwned,
269 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
270{
271 async fn append_entries(
272 &mut self,
273 rpc: AppendEntriesRequest<C>,
274 _option: RPCOption,
275 ) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
276 let url = format!("{}/raft/append", normalize_addr(&self.target_addr));
277 debug!(target_addr = %self.target_addr, "Sending append_entries RPC");
278
279 RaftHttpClient::postcard_post(
280 &self.client.rpc_client,
281 &url,
282 &rpc,
283 self.auth_token.as_deref(),
284 )
285 .await
286 .map_err(to_unreachable)
287 }
288
289 async fn install_snapshot(
290 &mut self,
291 rpc: InstallSnapshotRequest<C>,
292 _option: RPCOption,
293 ) -> Result<
294 InstallSnapshotResponse<NodeId>,
295 RPCError<NodeId, BasicNode, RaftError<NodeId, InstallSnapshotError>>,
296 > {
297 let url = format!("{}/raft/snapshot", normalize_addr(&self.target_addr));
298 debug!(target_addr = %self.target_addr, "Sending install_snapshot RPC");
299
300 RaftHttpClient::postcard_post(
301 &self.client.snapshot_client,
302 &url,
303 &rpc,
304 self.auth_token.as_deref(),
305 )
306 .await
307 .map_err(to_unreachable)
308 }
309
310 async fn vote(
311 &mut self,
312 rpc: VoteRequest<NodeId>,
313 _option: RPCOption,
314 ) -> Result<VoteResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
315 let url = format!("{}/raft/vote", normalize_addr(&self.target_addr));
316 debug!(target_addr = %self.target_addr, "Sending vote RPC");
317
318 RaftHttpClient::postcard_post(
319 &self.client.rpc_client,
320 &url,
321 &rpc,
322 self.auth_token.as_deref(),
323 )
324 .await
325 .map_err(to_unreachable)
326 }
327
328 async fn full_snapshot(
329 &mut self,
330 vote: Vote<NodeId>,
331 snapshot: Snapshot<C>,
332 _cancel: impl Future<Output = ReplicationClosed> + OptionalSend + 'static,
333 _option: RPCOption,
334 ) -> Result<SnapshotResponse<NodeId>, StreamingError<C, Fatal<NodeId>>> {
335 #[derive(serde::Serialize)]
336 struct FullSnapshotReq {
337 vote: Vote<NodeId>,
338 meta: SnapshotMeta<NodeId, BasicNode>,
339 snapshot_data: Vec<u8>,
340 }
341
342 let url = format!("{}/raft/full-snapshot", normalize_addr(&self.target_addr));
343 debug!(target_addr = %self.target_addr, "Sending full_snapshot RPC");
344
345 let snapshot_data = snapshot.snapshot.into_inner();
346
347 let req = FullSnapshotReq {
348 vote,
349 meta: snapshot.meta,
350 snapshot_data,
351 };
352
353 RaftHttpClient::postcard_post::<FullSnapshotReq, SnapshotResponse<NodeId>>(
354 &self.client.snapshot_client,
355 &url,
356 &req,
357 self.auth_token.as_deref(),
358 )
359 .await
360 .map_err(|e| StreamingError::Unreachable(Unreachable::new(&std::io::Error::other(e))))
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_normalize_addr() {
370 assert_eq!(normalize_addr("10.0.0.1:9000"), "http://10.0.0.1:9000");
371 assert_eq!(
372 normalize_addr("http://10.0.0.1:9000"),
373 "http://10.0.0.1:9000"
374 );
375 assert_eq!(
376 normalize_addr("https://10.0.0.1:9000"),
377 "https://10.0.0.1:9000"
378 );
379 }
380
381 #[test]
382 fn test_client_creation() {
383 let _client = RaftHttpClient::new(Duration::from_secs(3), Duration::from_secs(30));
384 }
385}