Skip to main content

zlayer_consensus/network/
http_client.rs

1//! HTTP client for Raft network communication using postcard2 serialization.
2//!
3//! Implements `RaftNetworkFactory` and `RaftNetwork` traits from openraft,
4//! using reqwest with connection pooling and split timeouts (short for
5//! vote/append, long for snapshots).
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::io::Cursor;
10use std::sync::Arc;
11use std::time::Duration;
12
13use openraft::error::{
14    Fatal, InstallSnapshotError, RPCError, RaftError, ReplicationClosed, StreamingError,
15    Unreachable,
16};
17use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
18use openraft::raft::{
19    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
20    SnapshotResponse, VoteRequest, VoteResponse,
21};
22use openraft::{BasicNode, OptionalSend, RaftTypeConfig, Snapshot, SnapshotMeta, Vote};
23use reqwest::Client;
24use tokio::sync::RwLock;
25use tracing::debug;
26
27use crate::types::NodeId;
28
29// ---------------------------------------------------------------------------
30// HTTP client
31// ---------------------------------------------------------------------------
32
33/// HTTP client for Raft RPCs using **postcard2** serialization.
34///
35/// Maintains separate timeout configurations for regular RPCs and
36/// snapshot transfers.  Optionally attaches a bearer token to every
37/// outgoing request for authentication against the Raft service.
38#[derive(Clone)]
39pub struct RaftHttpClient {
40    /// Client for regular RPCs (vote, `append_entries`).
41    rpc_client: Client,
42    /// Client for snapshot transfers (longer timeout).
43    snapshot_client: Client,
44    /// Optional bearer token for Raft RPC authentication.
45    auth_token: Option<String>,
46}
47
48impl RaftHttpClient {
49    /// Create a new client with the specified timeouts and no auth token.
50    #[must_use]
51    pub fn new(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
52        Self::with_auth(rpc_timeout, snapshot_timeout, None)
53    }
54
55    /// Create a new client with the specified timeouts and an optional auth token.
56    ///
57    /// # Panics
58    /// Panics if the HTTP client builders fail to build (should not happen in practice).
59    #[must_use]
60    pub fn with_auth(
61        rpc_timeout: Duration,
62        snapshot_timeout: Duration,
63        auth_token: Option<String>,
64    ) -> Self {
65        let rpc_client = Client::builder()
66            .timeout(rpc_timeout)
67            .pool_max_idle_per_host(10)
68            .pool_idle_timeout(Duration::from_secs(90))
69            .build()
70            .expect("Failed to build RPC HTTP client");
71
72        let snapshot_client = Client::builder()
73            .timeout(snapshot_timeout)
74            .pool_max_idle_per_host(5)
75            .pool_idle_timeout(Duration::from_secs(90))
76            .build()
77            .expect("Failed to build snapshot HTTP client");
78
79        Self {
80            rpc_client,
81            snapshot_client,
82            auth_token,
83        }
84    }
85
86    /// Send a postcard2-encoded POST request and decode the response.
87    async fn postcard_post<Req, Resp>(
88        client: &Client,
89        url: &str,
90        request: &Req,
91        auth_token: Option<&str>,
92    ) -> Result<Resp, String>
93    where
94        Req: serde::Serialize,
95        Resp: serde::de::DeserializeOwned,
96    {
97        let body =
98            postcard2::to_vec(request).map_err(|e| format!("postcard2 serialize error: {e}"))?;
99
100        let mut builder = client
101            .post(url)
102            .header("Content-Type", "application/octet-stream");
103
104        if let Some(token) = auth_token {
105            builder = builder.header("Authorization", format!("Bearer {token}"));
106        }
107
108        let response = builder.body(body).send().await.map_err(|e| {
109            if e.is_timeout() {
110                format!("timeout: {e}")
111            } else if e.is_connect() {
112                format!("unreachable: {e}")
113            } else {
114                format!("http error: {e}")
115            }
116        })?;
117
118        if !response.status().is_success() {
119            let status = response.status();
120            let text = response.text().await.unwrap_or_default();
121            return Err(format!("HTTP {status}: {text}"));
122        }
123
124        let bytes = response
125            .bytes()
126            .await
127            .map_err(|e| format!("read body error: {e}"))?;
128        postcard2::from_bytes(&bytes).map_err(|e| format!("postcard2 deserialize error: {e}"))
129    }
130}
131
132impl Default for RaftHttpClient {
133    fn default() -> Self {
134        Self::new(Duration::from_secs(5), Duration::from_secs(60))
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Network factory + connection
140// ---------------------------------------------------------------------------
141
142/// Network factory that creates HTTP connections to Raft peers.
143///
144/// Generic over the `RaftTypeConfig` so any application can use it.
145pub struct HttpNetwork<C: RaftTypeConfig<NodeId = NodeId>> {
146    /// Known peers (for informational purposes / peer management).
147    peers: Arc<RwLock<HashMap<NodeId, String>>>,
148    /// Shared HTTP client.
149    client: Arc<RaftHttpClient>,
150    _phantom: std::marker::PhantomData<C>,
151}
152
153impl<C: RaftTypeConfig<NodeId = NodeId>> HttpNetwork<C> {
154    /// Create a new network layer with default timeouts (5s RPC, 60s snapshot).
155    #[must_use]
156    pub fn new() -> Self {
157        Self::with_client(RaftHttpClient::default())
158    }
159
160    /// Create a new network layer with a custom client.
161    #[must_use]
162    pub fn with_client(client: RaftHttpClient) -> Self {
163        Self {
164            peers: Arc::new(RwLock::new(HashMap::new())),
165            client: Arc::new(client),
166            _phantom: std::marker::PhantomData,
167        }
168    }
169
170    /// Create a new network layer with custom timeouts.
171    #[must_use]
172    pub fn with_timeouts(rpc_timeout: Duration, snapshot_timeout: Duration) -> Self {
173        Self::with_client(RaftHttpClient::new(rpc_timeout, snapshot_timeout))
174    }
175
176    /// Create a new network layer with custom timeouts and an optional auth token.
177    #[must_use]
178    pub fn with_timeouts_and_auth(
179        rpc_timeout: Duration,
180        snapshot_timeout: Duration,
181        auth_token: Option<String>,
182    ) -> Self {
183        Self::with_client(RaftHttpClient::with_auth(
184            rpc_timeout,
185            snapshot_timeout,
186            auth_token,
187        ))
188    }
189
190    /// Add a peer address.
191    pub async fn add_peer(&self, node_id: NodeId, address: String) {
192        self.peers.write().await.insert(node_id, address);
193    }
194
195    /// Remove a peer.
196    pub async fn remove_peer(&self, node_id: NodeId) {
197        self.peers.write().await.remove(&node_id);
198    }
199
200    /// Get all known peers.
201    pub async fn peers(&self) -> HashMap<NodeId, String> {
202        self.peers.read().await.clone()
203    }
204}
205
206impl<C: RaftTypeConfig<NodeId = NodeId>> Default for HttpNetwork<C> {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for HttpNetwork<C> {
213    fn clone(&self) -> Self {
214        Self {
215            peers: Arc::clone(&self.peers),
216            client: Arc::clone(&self.client),
217            _phantom: std::marker::PhantomData,
218        }
219    }
220}
221
222impl<C> RaftNetworkFactory<C> for HttpNetwork<C>
223where
224    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
225    C::D: serde::Serialize + serde::de::DeserializeOwned,
226    C::R: serde::Serialize + serde::de::DeserializeOwned,
227    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
228{
229    type Network = HttpConnection<C>;
230
231    async fn new_client(&mut self, _target: NodeId, node: &BasicNode) -> Self::Network {
232        HttpConnection {
233            target_addr: node.addr.clone(),
234            client: Arc::clone(&self.client),
235            auth_token: self.client.auth_token.clone(),
236            _phantom: std::marker::PhantomData,
237        }
238    }
239}
240
241/// A single connection to a Raft peer.
242pub struct HttpConnection<C: RaftTypeConfig<NodeId = NodeId>> {
243    target_addr: String,
244    client: Arc<RaftHttpClient>,
245    /// Optional bearer token for authenticating outgoing RPCs.
246    auth_token: Option<String>,
247    _phantom: std::marker::PhantomData<C>,
248}
249
250/// Normalize an address to ensure it has an HTTP scheme.
251fn normalize_addr(addr: &str) -> String {
252    if addr.starts_with("http://") || addr.starts_with("https://") {
253        addr.to_string()
254    } else {
255        format!("http://{addr}")
256    }
257}
258
259/// Convert a string error to an `RPCError::Unreachable`.
260fn to_unreachable<E: std::error::Error>(msg: String) -> RPCError<NodeId, BasicNode, E> {
261    RPCError::Unreachable(Unreachable::new(&std::io::Error::other(msg)))
262}
263
264impl<C> RaftNetwork<C> for HttpConnection<C>
265where
266    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
267    C::D: serde::Serialize + serde::de::DeserializeOwned,
268    C::R: serde::Serialize + serde::de::DeserializeOwned,
269    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
270{
271    async fn append_entries(
272        &mut self,
273        rpc: AppendEntriesRequest<C>,
274        _option: RPCOption,
275    ) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
276        let url = format!("{}/raft/append", normalize_addr(&self.target_addr));
277        debug!(target_addr = %self.target_addr, "Sending append_entries RPC");
278
279        RaftHttpClient::postcard_post(
280            &self.client.rpc_client,
281            &url,
282            &rpc,
283            self.auth_token.as_deref(),
284        )
285        .await
286        .map_err(to_unreachable)
287    }
288
289    async fn install_snapshot(
290        &mut self,
291        rpc: InstallSnapshotRequest<C>,
292        _option: RPCOption,
293    ) -> Result<
294        InstallSnapshotResponse<NodeId>,
295        RPCError<NodeId, BasicNode, RaftError<NodeId, InstallSnapshotError>>,
296    > {
297        let url = format!("{}/raft/snapshot", normalize_addr(&self.target_addr));
298        debug!(target_addr = %self.target_addr, "Sending install_snapshot RPC");
299
300        RaftHttpClient::postcard_post(
301            &self.client.snapshot_client,
302            &url,
303            &rpc,
304            self.auth_token.as_deref(),
305        )
306        .await
307        .map_err(to_unreachable)
308    }
309
310    async fn vote(
311        &mut self,
312        rpc: VoteRequest<NodeId>,
313        _option: RPCOption,
314    ) -> Result<VoteResponse<NodeId>, RPCError<NodeId, BasicNode, RaftError<NodeId>>> {
315        let url = format!("{}/raft/vote", normalize_addr(&self.target_addr));
316        debug!(target_addr = %self.target_addr, "Sending vote RPC");
317
318        RaftHttpClient::postcard_post(
319            &self.client.rpc_client,
320            &url,
321            &rpc,
322            self.auth_token.as_deref(),
323        )
324        .await
325        .map_err(to_unreachable)
326    }
327
328    async fn full_snapshot(
329        &mut self,
330        vote: Vote<NodeId>,
331        snapshot: Snapshot<C>,
332        _cancel: impl Future<Output = ReplicationClosed> + OptionalSend + 'static,
333        _option: RPCOption,
334    ) -> Result<SnapshotResponse<NodeId>, StreamingError<C, Fatal<NodeId>>> {
335        #[derive(serde::Serialize)]
336        struct FullSnapshotReq {
337            vote: Vote<NodeId>,
338            meta: SnapshotMeta<NodeId, BasicNode>,
339            snapshot_data: Vec<u8>,
340        }
341
342        let url = format!("{}/raft/full-snapshot", normalize_addr(&self.target_addr));
343        debug!(target_addr = %self.target_addr, "Sending full_snapshot RPC");
344
345        let snapshot_data = snapshot.snapshot.into_inner();
346
347        let req = FullSnapshotReq {
348            vote,
349            meta: snapshot.meta,
350            snapshot_data,
351        };
352
353        RaftHttpClient::postcard_post::<FullSnapshotReq, SnapshotResponse<NodeId>>(
354            &self.client.snapshot_client,
355            &url,
356            &req,
357            self.auth_token.as_deref(),
358        )
359        .await
360        .map_err(|e| StreamingError::Unreachable(Unreachable::new(&std::io::Error::other(e))))
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_normalize_addr() {
370        assert_eq!(normalize_addr("10.0.0.1:9000"), "http://10.0.0.1:9000");
371        assert_eq!(
372            normalize_addr("http://10.0.0.1:9000"),
373            "http://10.0.0.1:9000"
374        );
375        assert_eq!(
376            normalize_addr("https://10.0.0.1:9000"),
377            "https://10.0.0.1:9000"
378        );
379    }
380
381    #[test]
382    fn test_client_creation() {
383        let _client = RaftHttpClient::new(Duration::from_secs(3), Duration::from_secs(30));
384    }
385}