Skip to main content

zlayer_consensus/network/
http_service.rs

1//! Axum HTTP service for receiving Raft RPCs.
2//!
3//! Provides an Axum router with endpoints for all Raft RPC operations.
4//! Uses **postcard2** serialization for request/response bodies.
5//!
6//! When an `auth_token` is provided, every request must include an
7//! `Authorization: Bearer <token>` header matching the expected value.
8//! Requests without a valid token receive HTTP 401.
9//!
10//! ## Endpoints
11//!
12//! | Method | Path | Description |
13//! |--------|------|-------------|
14//! | POST | `/raft/vote` | RequestVote RPC |
15//! | POST | `/raft/append` | AppendEntries RPC |
16//! | POST | `/raft/snapshot` | InstallSnapshot RPC |
17//! | POST | `/raft/full-snapshot` | Full snapshot transfer |
18
19use std::io::Cursor;
20use std::sync::Arc;
21
22use axum::body::Bytes;
23use axum::extract::State;
24use axum::http::{Request, StatusCode};
25use axum::middleware::{self, Next};
26use axum::response::{IntoResponse, Response};
27use axum::routing::post;
28use axum::Router;
29use openraft::raft::{AppendEntriesRequest, InstallSnapshotRequest, VoteRequest};
30use openraft::storage::Snapshot;
31use openraft::{BasicNode, Raft, RaftTypeConfig, SnapshotMeta, Vote};
32use tracing::{debug, error, warn};
33
34use crate::types::NodeId;
35
36/// Application state shared with Axum handlers.
37struct RaftState<C: RaftTypeConfig> {
38    raft: Raft<C>,
39}
40
41/// Create an Axum router for Raft RPC endpoints.
42///
43/// If `auth_token` is `Some`, a middleware layer is added that validates
44/// the `Authorization: Bearer <token>` header on every request.  Requests
45/// that do not carry a matching token are rejected with HTTP 401.
46///
47/// The router uses postcard2 for serialization. Mount it at any prefix:
48///
49/// ```ignore
50/// let raft_router = raft_service_router(raft_instance, Some("secret".into()));
51/// let app = Router::new().nest("/", raft_router);
52/// ```
53pub fn raft_service_router<C>(raft: Raft<C>, auth_token: Option<String>) -> Router
54where
55    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
56    C::D: serde::Serialize + serde::de::DeserializeOwned,
57    C::R: serde::Serialize + serde::de::DeserializeOwned,
58    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
59{
60    let state = Arc::new(RaftState { raft });
61
62    let router = Router::new()
63        .route("/raft/vote", post(handle_vote::<C>))
64        .route("/raft/append", post(handle_append::<C>))
65        .route("/raft/snapshot", post(handle_snapshot::<C>))
66        .route("/raft/full-snapshot", post(handle_full_snapshot::<C>))
67        .with_state(state);
68
69    if let Some(token) = auth_token {
70        let expected = Arc::new(token);
71        router.layer(middleware::from_fn(move |req, next| {
72            let expected = Arc::clone(&expected);
73            bearer_auth_middleware(expected, req, next)
74        }))
75    } else {
76        router
77    }
78}
79
80/// Middleware that validates `Authorization: Bearer <token>`.
81async fn bearer_auth_middleware(
82    expected_token: Arc<String>,
83    req: Request<axum::body::Body>,
84    next: Next,
85) -> Response {
86    let auth_header = req
87        .headers()
88        .get(axum::http::header::AUTHORIZATION)
89        .and_then(|v| v.to_str().ok());
90
91    match auth_header {
92        Some(value) if value.starts_with("Bearer ") => {
93            let provided = &value["Bearer ".len()..];
94            if provided == expected_token.as_str() {
95                next.run(req).await
96            } else {
97                warn!("Raft RPC rejected: invalid bearer token");
98                StatusCode::UNAUTHORIZED.into_response()
99            }
100        }
101        _ => {
102            warn!("Raft RPC rejected: missing or malformed Authorization header");
103            StatusCode::UNAUTHORIZED.into_response()
104        }
105    }
106}
107
108// ---------------------------------------------------------------------------
109// Handlers
110// ---------------------------------------------------------------------------
111
112async fn handle_vote<C>(State(state): State<Arc<RaftState<C>>>, body: Bytes) -> impl IntoResponse
113where
114    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
115    C::D: serde::Serialize + serde::de::DeserializeOwned,
116    C::R: serde::Serialize + serde::de::DeserializeOwned,
117    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
118{
119    let req: VoteRequest<NodeId> = match postcard2::from_bytes(&body) {
120        Ok(r) => r,
121        Err(e) => {
122            error!("Failed to deserialize vote request: {e}");
123            return (StatusCode::BAD_REQUEST, Vec::new());
124        }
125    };
126
127    debug!("Received vote RPC");
128
129    match state.raft.vote(req).await {
130        Ok(resp) => match postcard2::to_vec(&resp) {
131            Ok(bytes) => (StatusCode::OK, bytes),
132            Err(e) => {
133                error!("Failed to serialize vote response: {e}");
134                (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
135            }
136        },
137        Err(e) => {
138            error!("Vote RPC failed: {e}");
139            (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
140        }
141    }
142}
143
144async fn handle_append<C>(State(state): State<Arc<RaftState<C>>>, body: Bytes) -> impl IntoResponse
145where
146    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
147    C::D: serde::Serialize + serde::de::DeserializeOwned,
148    C::R: serde::Serialize + serde::de::DeserializeOwned,
149    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
150{
151    let req: AppendEntriesRequest<C> = match postcard2::from_bytes(&body) {
152        Ok(r) => r,
153        Err(e) => {
154            error!("Failed to deserialize append request: {e}");
155            return (StatusCode::BAD_REQUEST, Vec::new());
156        }
157    };
158
159    debug!("Received append_entries RPC");
160
161    match state.raft.append_entries(req).await {
162        Ok(resp) => match postcard2::to_vec(&resp) {
163            Ok(bytes) => (StatusCode::OK, bytes),
164            Err(e) => {
165                error!("Failed to serialize append response: {e}");
166                (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
167            }
168        },
169        Err(e) => {
170            error!("AppendEntries RPC failed: {e}");
171            (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
172        }
173    }
174}
175
176async fn handle_snapshot<C>(
177    State(state): State<Arc<RaftState<C>>>,
178    body: Bytes,
179) -> impl IntoResponse
180where
181    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
182    C::D: serde::Serialize + serde::de::DeserializeOwned,
183    C::R: serde::Serialize + serde::de::DeserializeOwned,
184    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
185{
186    let req: InstallSnapshotRequest<C> = match postcard2::from_bytes(&body) {
187        Ok(r) => r,
188        Err(e) => {
189            error!("Failed to deserialize snapshot request: {e}");
190            return (StatusCode::BAD_REQUEST, Vec::new());
191        }
192    };
193
194    debug!("Received install_snapshot RPC");
195
196    match state.raft.install_snapshot(req).await {
197        Ok(resp) => match postcard2::to_vec(&resp) {
198            Ok(bytes) => (StatusCode::OK, bytes),
199            Err(e) => {
200                error!("Failed to serialize snapshot response: {e}");
201                (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
202            }
203        },
204        Err(e) => {
205            error!("InstallSnapshot RPC failed: {e}");
206            (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
207        }
208    }
209}
210
211#[derive(serde::Deserialize)]
212struct FullSnapshotReq {
213    vote: Vote<NodeId>,
214    meta: SnapshotMeta<NodeId, BasicNode>,
215    snapshot_data: Vec<u8>,
216}
217
218async fn handle_full_snapshot<C>(
219    State(state): State<Arc<RaftState<C>>>,
220    body: Bytes,
221) -> impl IntoResponse
222where
223    C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
224    C::D: serde::Serialize + serde::de::DeserializeOwned,
225    C::R: serde::Serialize + serde::de::DeserializeOwned,
226    C::Entry: serde::Serialize + serde::de::DeserializeOwned,
227{
228    let req: FullSnapshotReq = match postcard2::from_bytes(&body) {
229        Ok(r) => r,
230        Err(e) => {
231            error!("Failed to deserialize full snapshot request: {e}");
232            return (StatusCode::BAD_REQUEST, Vec::new());
233        }
234    };
235
236    debug!("Received full_snapshot RPC");
237
238    let snapshot = Snapshot {
239        meta: req.meta,
240        snapshot: Box::new(Cursor::new(req.snapshot_data)),
241    };
242
243    match state.raft.install_full_snapshot(req.vote, snapshot).await {
244        Ok(resp) => match postcard2::to_vec(&resp) {
245            Ok(bytes) => (StatusCode::OK, bytes),
246            Err(e) => {
247                error!("Failed to serialize full snapshot response: {e}");
248                (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
249            }
250        },
251        Err(e) => {
252            error!("install_full_snapshot failed: {e}");
253            (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
254        }
255    }
256}