Skip to main content

zlayer_consensus/network/
http_client.rs

1//! HTTP client for Raft network communication using postcard2 serialization.
2//!
3//! Implements `RaftNetworkFactory` and `RaftNetwork` traits from openraft,
4//! using reqwest with connection pooling and split timeouts (short for
5//! vote/append, long for snapshots).
6
7use std::future::Future;
8use std::io::Cursor;
9use std::sync::Arc;
10use std::time::Duration;
11
12use openraft::error::{
13    Fatal, InstallSnapshotError, RPCError, RaftError, ReplicationClosed, StreamingError,
14    Unreachable,
15};
16use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
17use openraft::raft::{
18    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
19    SnapshotResponse, VoteRequest, VoteResponse,
20};
21use openraft::{BasicNode, OptionalSend, RaftTypeConfig, Snapshot, SnapshotMeta, Vote};
22use reqwest::Client;
23use tracing::debug;
24
25use crate::types::NodeId;
26
27/// Wire protocol version emitted on every Raft RPC and validated by the server.
28pub const RAFT_PROTOCOL_VERSION: &str = "1";
29
30// ---------------------------------------------------------------------------
31// HTTP client
32// ---------------------------------------------------------------------------
33
34/// HTTP client for Raft RPCs using **postcard2** serialization.
35///
36/// Maintains separate timeout configurations for regular RPCs and
37/// snapshot transfers.  Optionally attaches a bearer token to every
38/// outgoing request for authentication against the Raft service.
39#[derive(Clone)]
40pub struct RaftHttpClient {
41    /// Client for regular RPCs (vote, `append_entries`).
42    rpc_client: Client,
43    /// Client for snapshot transfers (longer timeout).
44    snapshot_client: Client,
45    /// Precomputed `Authorization: Bearer …` header value.
46    auth_header: Option<reqwest::header::HeaderValue>,
47}
48
49impl RaftHttpClient {
50    /// Create a new client with the specified timeouts and no auth token.
51    #[must_use]
52    pub fn new(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
53        Self::with_auth(rpc_timeout, snapshot_timeout, None)
54    }
55
56    /// Create a new client with the specified timeouts and an optional auth token.
57    ///
58    /// # Panics
59    /// Panics if the HTTP client builders fail to build (should not happen in practice).
60    #[must_use]
61    pub fn with_auth(
62        rpc_timeout: Duration,
63        snapshot_timeout: Duration,
64        auth_token: Option<String>,
65    ) -> Self {
66        let rpc_client = Client::builder()
67            .timeout(rpc_timeout)
68            .pool_max_idle_per_host(2)
69            .pool_idle_timeout(Duration::from_secs(90))
70            .http2_prior_knowledge()
71            .http2_keep_alive_interval(std::time::Duration::from_secs(10))
72            .http2_keep_alive_while_idle(true)
73            .build()
74            .expect("Failed to build RPC HTTP client");
75
76        let snapshot_client = Client::builder()
77            .timeout(snapshot_timeout)
78            .pool_max_idle_per_host(2)
79            .pool_idle_timeout(Duration::from_secs(90))
80            .http2_prior_knowledge()
81            .http2_keep_alive_interval(std::time::Duration::from_secs(10))
82            .http2_keep_alive_while_idle(true)
83            .build()
84            .expect("Failed to build snapshot HTTP client");
85
86        let auth_header = auth_token.map(|token| {
87            let mut header = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
88                .expect("valid bearer token");
89            header.set_sensitive(true);
90            header
91        });
92
93        Self {
94            rpc_client,
95            snapshot_client,
96            auth_header,
97        }
98    }
99
100    /// Send a postcard2-encoded POST request and decode the response.
101    async fn postcard_post<Req, Resp>(
102        client: &Client,
103        url: &str,
104        request: &Req,
105        auth_header: Option<&reqwest::header::HeaderValue>,
106    ) -> Result<Resp, String>
107    where
108        Req: serde::Serialize,
109        Resp: serde::de::DeserializeOwned,
110    {
111        let body =
112            postcard2::to_vec(request).map_err(|e| format!("postcard2 serialize error: {e}"))?;
113
114        let mut builder = client
115            .post(url)
116            .header("Content-Type", "application/octet-stream");
117
118        builder = builder.header("X-ZLayer-Raft-Protocol", RAFT_PROTOCOL_VERSION);
119
120        if let Some(header) = auth_header {
121            builder = builder.header(reqwest::header::AUTHORIZATION, header.clone());
122        }
123
124        let response = builder.body(body).send().await.map_err(|e| {
125            if e.is_timeout() {
126                format!("timeout: {e}")
127            } else if e.is_connect() {
128                format!("unreachable: {e}")
129            } else {
130                format!("http error: {e}")
131            }
132        })?;
133
134        if !response.status().is_success() {
135            let status = response.status();
136            if status == reqwest::StatusCode::UPGRADE_REQUIRED {
137                let server_version = response
138                    .headers()
139                    .get("X-ZLayer-Raft-Protocol-Supported")
140                    .and_then(|v| v.to_str().ok())
141                    .unwrap_or("<unknown>")
142                    .to_string();
143                return Err(format!(
144                    "protocol version mismatch: server supports {server_version}"
145                ));
146            }
147            let text = response.text().await.unwrap_or_default();
148            return Err(format!("HTTP {status}: {text}"));
149        }
150
151        let bytes = response
152            .bytes()
153            .await
154            .map_err(|e| format!("read body error: {e}"))?;
155        postcard2::from_bytes(&bytes).map_err(|e| format!("postcard2 deserialize error: {e}"))
156    }
157}
158
159impl Default for RaftHttpClient {
160    fn default() -> Self {
161        Self::new(Duration::from_secs(5), Duration::from_secs(60))
162    }
163}
164
165// ---------------------------------------------------------------------------
166// Network factory + connection
167// ---------------------------------------------------------------------------
168
169/// Build the four per-connection URLs for a given peer address.
170fn build_urls(addr: &str) -> [String; 4] {
171    let base = normalize_addr(addr);
172    [
173        format!("{base}/raft/append"),
174        format!("{base}/raft/vote"),
175        format!("{base}/raft/snapshot"),
176        format!("{base}/raft/full-snapshot"),
177    ]
178}
179
180/// Network factory that creates HTTP connections to Raft peers.
181///
182/// Generic over the `RaftTypeConfig` so any application can use it.
183pub struct HttpNetwork<C: RaftTypeConfig<NodeId = NodeId>> {
184    /// Shared HTTP client.
185    client: Arc<RaftHttpClient>,
186    _phantom: std::marker::PhantomData<C>,
187}
188
189impl<C: RaftTypeConfig<NodeId = NodeId>> HttpNetwork<C> {
190    /// Create a new network layer with default timeouts (5s RPC, 60s snapshot).
191    #[must_use]
192    pub fn new() -> Self {
193        Self::with_client(RaftHttpClient::default())
194    }
195
196    /// Create a new network layer with a custom client.
197    #[must_use]
198    pub fn with_client(client: RaftHttpClient) -> Self {
199        Self {
200            client: Arc::new(client),
201            _phantom: std::marker::PhantomData,
202        }
203    }
204
205    /// Create a new network layer with custom timeouts.
206    #[must_use]
207    pub fn with_timeouts(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
208        Self::with_client(RaftHttpClient::new(rpc_timeout, snapshot_timeout))
209    }
210
211    /// Create a new network layer with custom timeouts and an optional auth token.
212    #[must_use]
213    pub fn with_timeouts_and_auth(
214        rpc_timeout: Duration,
215        snapshot_timeout: Duration,
216        auth_token: Option<String>,
217    ) -> Self {
218        Self::with_client(RaftHttpClient::with_auth(
219            rpc_timeout,
220            snapshot_timeout,
221            auth_token,
222        ))
223    }
224}
225
226impl<C: RaftTypeConfig<NodeId = NodeId>> Default for HttpNetwork<C> {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for HttpNetwork<C> {
233    fn clone(&self) -> Self {
234        Self {
235            client: Arc::clone(&self.client),
236            _phantom: std::marker::PhantomData,
237        }
238    }
239}
240
241impl<C> RaftNetworkFactory<C> for HttpNetwork<C>
242where
243    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
244    C::D: serde::Serialize + serde::de::DeserializeOwned,
245    C::R: serde::Serialize + serde::de::DeserializeOwned,
246    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
247{
248    type Network = HttpConnection<C>;
249
250    async fn new_client(&mut self, _target: NodeId, node: &BasicNode) -> Self::Network {
251        let [append, vote, snapshot, full_snapshot] = build_urls(&node.addr);
252        HttpConnection {
253            target_addr: Arc::<str>::from(node.addr.as_str()),
254            client: Arc::clone(&self.client),
255            auth_header: self.client.auth_header.clone(),
256            append_url: Arc::<str>::from(append.as_str()),
257            vote_url: Arc::<str>::from(vote.as_str()),
258            snapshot_url: Arc::<str>::from(snapshot.as_str()),
259            full_snapshot_url: Arc::<str>::from(full_snapshot.as_str()),
260            _phantom: std::marker::PhantomData,
261        }
262    }
263}
264
265/// A single connection to a Raft peer.
266pub struct HttpConnection<C: RaftTypeConfig<NodeId = NodeId>> {
267    target_addr: Arc<str>,
268    client: Arc<RaftHttpClient>,
269    /// Precomputed `Authorization` header for outgoing RPCs.
270    auth_header: Option<reqwest::header::HeaderValue>,
271    append_url: Arc<str>,
272    vote_url: Arc<str>,
273    snapshot_url: Arc<str>,
274    full_snapshot_url: Arc<str>,
275    _phantom: std::marker::PhantomData<C>,
276}
277
278/// Normalize an address to ensure it has an HTTP scheme.
279fn normalize_addr(addr: &str) -> String {
280    if addr.starts_with("http://") || addr.starts_with("https://") {
281        addr.to_string()
282    } else {
283        format!("http://{addr}")
284    }
285}
286
287/// Convert a string error to an `RPCError::Unreachable`.
288fn to_unreachable<E: std::error::Error>(msg: String) -> RPCError<NodeId, BasicNode, E> {
289    RPCError::Unreachable(Unreachable::new(&std::io::Error::other(msg)))
290}
291
292impl<C> RaftNetwork<C> for HttpConnection<C>
293where
294    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
295    C::D: serde::Serialize + serde::de::DeserializeOwned,
296    C::R: serde::Serialize + serde::de::DeserializeOwned,
297    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
298{
299    async fn append_entries(
300        &mut self,
301        rpc: AppendEntriesRequest<C>,
302        _option: RPCOption,
303    ) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
304        debug!(target_addr = %self.target_addr, "Sending append_entries RPC");
305
306        RaftHttpClient::postcard_post(
307            &self.client.rpc_client,
308            &self.append_url,
309            &rpc,
310            self.auth_header.as_ref(),
311        )
312        .await
313        .map_err(to_unreachable)
314    }
315
316    async fn install_snapshot(
317        &mut self,
318        rpc: InstallSnapshotRequest<C>,
319        _option: RPCOption,
320    ) -> Result<
321        InstallSnapshotResponse<NodeId>,
322        RPCError<NodeId, BasicNode, RaftError<NodeId, InstallSnapshotError>>,
323    > {
324        debug!(target_addr = %self.target_addr, "Sending install_snapshot RPC");
325
326        RaftHttpClient::postcard_post(
327            &self.client.snapshot_client,
328            &self.snapshot_url,
329            &rpc,
330            self.auth_header.as_ref(),
331        )
332        .await
333        .map_err(to_unreachable)
334    }
335
336    async fn vote(
337        &mut self,
338        rpc: VoteRequest<NodeId>,
339        _option: RPCOption,
340    ) -> Result<VoteResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
341        debug!(target_addr = %self.target_addr, "Sending vote RPC");
342
343        RaftHttpClient::postcard_post(
344            &self.client.rpc_client,
345            &self.vote_url,
346            &rpc,
347            self.auth_header.as_ref(),
348        )
349        .await
350        .map_err(to_unreachable)
351    }
352
353    async fn full_snapshot(
354        &mut self,
355        vote: Vote<NodeId>,
356        snapshot: Snapshot<C>,
357        _cancel: impl Future<Output = ReplicationClosed> + OptionalSend + 'static,
358        _option: RPCOption,
359    ) -> Result<SnapshotResponse<NodeId>, StreamingError<C, Fatal<NodeId>>> {
360        #[derive(serde::Serialize)]
361        struct FullSnapshotReq {
362            vote: Vote<NodeId>,
363            meta: SnapshotMeta<NodeId, BasicNode>,
364            snapshot_data: Vec<u8>,
365        }
366
367        debug!(target_addr = %self.target_addr, "Sending full_snapshot RPC");
368
369        let snapshot_data = snapshot.snapshot.into_inner();
370
371        let req = FullSnapshotReq {
372            vote,
373            meta: snapshot.meta,
374            snapshot_data,
375        };
376
377        RaftHttpClient::postcard_post::<FullSnapshotReq, SnapshotResponse<NodeId>>(
378            &self.client.snapshot_client,
379            &self.full_snapshot_url,
380            &req,
381            self.auth_header.as_ref(),
382        )
383        .await
384        .map_err(|e| StreamingError::Unreachable(Unreachable::new(&std::io::Error::other(e))))
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_normalize_addr() {
394        assert_eq!(normalize_addr("10.0.0.1:9000"), "http://10.0.0.1:9000");
395        assert_eq!(
396            normalize_addr("http://10.0.0.1:9000"),
397            "http://10.0.0.1:9000"
398        );
399        assert_eq!(
400            normalize_addr("https://10.0.0.1:9000"),
401            "https://10.0.0.1:9000"
402        );
403    }
404
405    #[test]
406    fn test_client_creation() {
407        let _client = RaftHttpClient::new(Duration::from_secs(3), Duration::from_secs(30));
408    }
409
410    #[test]
411    fn test_protocol_version_constant() {
412        assert_eq!(RAFT_PROTOCOL_VERSION, "1");
413    }
414
415    #[test]
416    fn test_connection_urls_precomputed() {
417        let [append, vote, snapshot, full_snapshot] = build_urls("10.0.0.1:9000");
418        assert_eq!(append, "http://10.0.0.1:9000/raft/append");
419        assert_eq!(vote, "http://10.0.0.1:9000/raft/vote");
420        assert_eq!(snapshot, "http://10.0.0.1:9000/raft/snapshot");
421        assert_eq!(full_snapshot, "http://10.0.0.1:9000/raft/full-snapshot");
422    }
423}