1use std::future::Future;
8use std::io::Cursor;
9use std::sync::Arc;
10use std::time::Duration;
11
12use openraft::error::{
13 Fatal, InstallSnapshotError, RPCError, RaftError, ReplicationClosed, StreamingError,
14 Unreachable,
15};
16use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
17use openraft::raft::{
18 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
19 SnapshotResponse, VoteRequest, VoteResponse,
20};
21use openraft::{BasicNode, OptionalSend, RaftTypeConfig, Snapshot, SnapshotMeta, Vote};
22use reqwest::Client;
23use tracing::debug;
24
25use crate::types::NodeId;
26
27pub const RAFT_PROTOCOL_VERSION: &str = "1";
29
30#[derive(Clone)]
40pub struct RaftHttpClient {
41 rpc_client: Client,
43 snapshot_client: Client,
45 auth_header: Option<reqwest::header::HeaderValue>,
47}
48
49impl RaftHttpClient {
50 #[must_use]
52 pub fn new(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
53 Self::with_auth(rpc_timeout, snapshot_timeout, None)
54 }
55
56 #[must_use]
61 pub fn with_auth(
62 rpc_timeout: Duration,
63 snapshot_timeout: Duration,
64 auth_token: Option<String>,
65 ) -> Self {
66 let rpc_client = Client::builder()
67 .timeout(rpc_timeout)
68 .pool_max_idle_per_host(2)
69 .pool_idle_timeout(Duration::from_secs(90))
70 .http2_prior_knowledge()
71 .http2_keep_alive_interval(std::time::Duration::from_secs(10))
72 .http2_keep_alive_while_idle(true)
73 .build()
74 .expect("Failed to build RPC HTTP client");
75
76 let snapshot_client = Client::builder()
77 .timeout(snapshot_timeout)
78 .pool_max_idle_per_host(2)
79 .pool_idle_timeout(Duration::from_secs(90))
80 .http2_prior_knowledge()
81 .http2_keep_alive_interval(std::time::Duration::from_secs(10))
82 .http2_keep_alive_while_idle(true)
83 .build()
84 .expect("Failed to build snapshot HTTP client");
85
86 let auth_header = auth_token.map(|token| {
87 let mut header = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
88 .expect("valid bearer token");
89 header.set_sensitive(true);
90 header
91 });
92
93 Self {
94 rpc_client,
95 snapshot_client,
96 auth_header,
97 }
98 }
99
100 async fn postcard_post<Req, Resp>(
102 client: &Client,
103 url: &str,
104 request: &Req,
105 auth_header: Option<&reqwest::header::HeaderValue>,
106 ) -> Result<Resp, String>
107 where
108 Req: serde::Serialize,
109 Resp: serde::de::DeserializeOwned,
110 {
111 let body =
112 postcard2::to_vec(request).map_err(|e| format!("postcard2 serialize error: {e}"))?;
113
114 let mut builder = client
115 .post(url)
116 .header("Content-Type", "application/octet-stream");
117
118 builder = builder.header("X-ZLayer-Raft-Protocol", RAFT_PROTOCOL_VERSION);
119
120 if let Some(header) = auth_header {
121 builder = builder.header(reqwest::header::AUTHORIZATION, header.clone());
122 }
123
124 let response = builder.body(body).send().await.map_err(|e| {
125 if e.is_timeout() {
126 format!("timeout: {e}")
127 } else if e.is_connect() {
128 format!("unreachable: {e}")
129 } else {
130 format!("http error: {e}")
131 }
132 })?;
133
134 if !response.status().is_success() {
135 let status = response.status();
136 if status == reqwest::StatusCode::UPGRADE_REQUIRED {
137 let server_version = response
138 .headers()
139 .get("X-ZLayer-Raft-Protocol-Supported")
140 .and_then(|v| v.to_str().ok())
141 .unwrap_or("<unknown>")
142 .to_string();
143 return Err(format!(
144 "protocol version mismatch: server supports {server_version}"
145 ));
146 }
147 let text = response.text().await.unwrap_or_default();
148 return Err(format!("HTTP {status}: {text}"));
149 }
150
151 let bytes = response
152 .bytes()
153 .await
154 .map_err(|e| format!("read body error: {e}"))?;
155 postcard2::from_bytes(&bytes).map_err(|e| format!("postcard2 deserialize error: {e}"))
156 }
157}
158
159impl Default for RaftHttpClient {
160 fn default() -> Self {
161 Self::new(Duration::from_secs(5), Duration::from_secs(60))
162 }
163}
164
165fn build_urls(addr: &str) -> [String; 4] {
171 let base = normalize_addr(addr);
172 [
173 format!("{base}/raft/append"),
174 format!("{base}/raft/vote"),
175 format!("{base}/raft/snapshot"),
176 format!("{base}/raft/full-snapshot"),
177 ]
178}
179
180pub struct HttpNetwork<C: RaftTypeConfig<NodeId = NodeId>> {
184 client: Arc<RaftHttpClient>,
186 _phantom: std::marker::PhantomData<C>,
187}
188
189impl<C: RaftTypeConfig<NodeId = NodeId>> HttpNetwork<C> {
190 #[must_use]
192 pub fn new() -> Self {
193 Self::with_client(RaftHttpClient::default())
194 }
195
196 #[must_use]
198 pub fn with_client(client: RaftHttpClient) -> Self {
199 Self {
200 client: Arc::new(client),
201 _phantom: std::marker::PhantomData,
202 }
203 }
204
205 #[must_use]
207 pub fn with_timeouts(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
208 Self::with_client(RaftHttpClient::new(rpc_timeout, snapshot_timeout))
209 }
210
211 #[must_use]
213 pub fn with_timeouts_and_auth(
214 rpc_timeout: Duration,
215 snapshot_timeout: Duration,
216 auth_token: Option<String>,
217 ) -> Self {
218 Self::with_client(RaftHttpClient::with_auth(
219 rpc_timeout,
220 snapshot_timeout,
221 auth_token,
222 ))
223 }
224}
225
226impl<C: RaftTypeConfig<NodeId = NodeId>> Default for HttpNetwork<C> {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for HttpNetwork<C> {
233 fn clone(&self) -> Self {
234 Self {
235 client: Arc::clone(&self.client),
236 _phantom: std::marker::PhantomData,
237 }
238 }
239}
240
241impl<C> RaftNetworkFactory<C> for HttpNetwork<C>
242where
243 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
244 C::D: serde::Serialize + serde::de::DeserializeOwned,
245 C::R: serde::Serialize + serde::de::DeserializeOwned,
246 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
247{
248 type Network = HttpConnection<C>;
249
250 async fn new_client(&mut self, _target: NodeId, node: &BasicNode) -> Self::Network {
251 let [append, vote, snapshot, full_snapshot] = build_urls(&node.addr);
252 HttpConnection {
253 target_addr: Arc::<str>::from(node.addr.as_str()),
254 client: Arc::clone(&self.client),
255 auth_header: self.client.auth_header.clone(),
256 append_url: Arc::<str>::from(append.as_str()),
257 vote_url: Arc::<str>::from(vote.as_str()),
258 snapshot_url: Arc::<str>::from(snapshot.as_str()),
259 full_snapshot_url: Arc::<str>::from(full_snapshot.as_str()),
260 _phantom: std::marker::PhantomData,
261 }
262 }
263}
264
265pub struct HttpConnection<C: RaftTypeConfig<NodeId = NodeId>> {
267 target_addr: Arc<str>,
268 client: Arc<RaftHttpClient>,
269 auth_header: Option<reqwest::header::HeaderValue>,
271 append_url: Arc<str>,
272 vote_url: Arc<str>,
273 snapshot_url: Arc<str>,
274 full_snapshot_url: Arc<str>,
275 _phantom: std::marker::PhantomData<C>,
276}
277
278fn normalize_addr(addr: &str) -> String {
280 if addr.starts_with("http://") || addr.starts_with("https://") {
281 addr.to_string()
282 } else {
283 format!("http://{addr}")
284 }
285}
286
287fn to_unreachable<E: std::error::Error>(msg: String) -> RPCError<NodeId, BasicNode, E> {
289 RPCError::Unreachable(Unreachable::new(&std::io::Error::other(msg)))
290}
291
292impl<C> RaftNetwork<C> for HttpConnection<C>
293where
294 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
295 C::D: serde::Serialize + serde::de::DeserializeOwned,
296 C::R: serde::Serialize + serde::de::DeserializeOwned,
297 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
298{
299 async fn append_entries(
300 &mut self,
301 rpc: AppendEntriesRequest<C>,
302 _option: RPCOption,
303 ) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
304 debug!(target_addr = %self.target_addr, "Sending append_entries RPC");
305
306 RaftHttpClient::postcard_post(
307 &self.client.rpc_client,
308 &self.append_url,
309 &rpc,
310 self.auth_header.as_ref(),
311 )
312 .await
313 .map_err(to_unreachable)
314 }
315
316 async fn install_snapshot(
317 &mut self,
318 rpc: InstallSnapshotRequest<C>,
319 _option: RPCOption,
320 ) -> Result<
321 InstallSnapshotResponse<NodeId>,
322 RPCError<NodeId, BasicNode, RaftError<NodeId, InstallSnapshotError>>,
323 > {
324 debug!(target_addr = %self.target_addr, "Sending install_snapshot RPC");
325
326 RaftHttpClient::postcard_post(
327 &self.client.snapshot_client,
328 &self.snapshot_url,
329 &rpc,
330 self.auth_header.as_ref(),
331 )
332 .await
333 .map_err(to_unreachable)
334 }
335
336 async fn vote(
337 &mut self,
338 rpc: VoteRequest<NodeId>,
339 _option: RPCOption,
340 ) -> Result<VoteResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
341 debug!(target_addr = %self.target_addr, "Sending vote RPC");
342
343 RaftHttpClient::postcard_post(
344 &self.client.rpc_client,
345 &self.vote_url,
346 &rpc,
347 self.auth_header.as_ref(),
348 )
349 .await
350 .map_err(to_unreachable)
351 }
352
353 async fn full_snapshot(
354 &mut self,
355 vote: Vote<NodeId>,
356 snapshot: Snapshot<C>,
357 _cancel: impl Future<Output = ReplicationClosed> + OptionalSend + 'static,
358 _option: RPCOption,
359 ) -> Result<SnapshotResponse<NodeId>, StreamingError<C, Fatal<NodeId>>> {
360 #[derive(serde::Serialize)]
361 struct FullSnapshotReq {
362 vote: Vote<NodeId>,
363 meta: SnapshotMeta<NodeId, BasicNode>,
364 snapshot_data: Vec<u8>,
365 }
366
367 debug!(target_addr = %self.target_addr, "Sending full_snapshot RPC");
368
369 let snapshot_data = snapshot.snapshot.into_inner();
370
371 let req = FullSnapshotReq {
372 vote,
373 meta: snapshot.meta,
374 snapshot_data,
375 };
376
377 RaftHttpClient::postcard_post::<FullSnapshotReq, SnapshotResponse<NodeId>>(
378 &self.client.snapshot_client,
379 &self.full_snapshot_url,
380 &req,
381 self.auth_header.as_ref(),
382 )
383 .await
384 .map_err(|e| StreamingError::Unreachable(Unreachable::new(&std::io::Error::other(e))))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_normalize_addr() {
394 assert_eq!(normalize_addr("10.0.0.1:9000"), "http://10.0.0.1:9000");
395 assert_eq!(
396 normalize_addr("http://10.0.0.1:9000"),
397 "http://10.0.0.1:9000"
398 );
399 assert_eq!(
400 normalize_addr("https://10.0.0.1:9000"),
401 "https://10.0.0.1:9000"
402 );
403 }
404
405 #[test]
406 fn test_client_creation() {
407 let _client = RaftHttpClient::new(Duration::from_secs(3), Duration::from_secs(30));
408 }
409
410 #[test]
411 fn test_protocol_version_constant() {
412 assert_eq!(RAFT_PROTOCOL_VERSION, "1");
413 }
414
415 #[test]
416 fn test_connection_urls_precomputed() {
417 let [append, vote, snapshot, full_snapshot] = build_urls("10.0.0.1:9000");
418 assert_eq!(append, "http://10.0.0.1:9000/raft/append");
419 assert_eq!(vote, "http://10.0.0.1:9000/raft/vote");
420 assert_eq!(snapshot, "http://10.0.0.1:9000/raft/snapshot");
421 assert_eq!(full_snapshot, "http://10.0.0.1:9000/raft/full-snapshot");
422 }
423}