zlayer_consensus/network/
http_service.rs1use std::io::Cursor;
20use std::sync::Arc;
21
22use axum::body::Bytes;
23use axum::extract::State;
24use axum::http::{Request, StatusCode};
25use axum::middleware::{self, Next};
26use axum::response::{IntoResponse, Response};
27use axum::routing::post;
28use axum::Router;
29use openraft::raft::{AppendEntriesRequest, InstallSnapshotRequest, VoteRequest};
30use openraft::storage::Snapshot;
31use openraft::{BasicNode, Raft, RaftTypeConfig, SnapshotMeta, Vote};
32use tracing::{debug, error, warn};
33
34use crate::types::NodeId;
35
36struct RaftState<C: RaftTypeConfig> {
38 raft: Raft<C>,
39}
40
41pub fn raft_service_router<C>(raft: Raft<C>, auth_token: Option<String>) -> Router
54where
55 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
56 C::D: serde::Serialize + serde::de::DeserializeOwned,
57 C::R: serde::Serialize + serde::de::DeserializeOwned,
58 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
59{
60 let state = Arc::new(RaftState { raft });
61
62 let router = Router::new()
63 .route("/raft/vote", post(handle_vote::<C>))
64 .route("/raft/append", post(handle_append::<C>))
65 .route("/raft/snapshot", post(handle_snapshot::<C>))
66 .route("/raft/full-snapshot", post(handle_full_snapshot::<C>))
67 .with_state(state);
68
69 if let Some(token) = auth_token {
70 let expected = Arc::new(token);
71 router.layer(middleware::from_fn(move |req, next| {
72 let expected = Arc::clone(&expected);
73 bearer_auth_middleware(expected, req, next)
74 }))
75 } else {
76 router
77 }
78}
79
80async fn bearer_auth_middleware(
82 expected_token: Arc<String>,
83 req: Request<axum::body::Body>,
84 next: Next,
85) -> Response {
86 let auth_header = req
87 .headers()
88 .get(axum::http::header::AUTHORIZATION)
89 .and_then(|v| v.to_str().ok());
90
91 match auth_header {
92 Some(value) if value.starts_with("Bearer ") => {
93 let provided = &value["Bearer ".len()..];
94 if provided == expected_token.as_str() {
95 next.run(req).await
96 } else {
97 warn!("Raft RPC rejected: invalid bearer token");
98 StatusCode::UNAUTHORIZED.into_response()
99 }
100 }
101 _ => {
102 warn!("Raft RPC rejected: missing or malformed Authorization header");
103 StatusCode::UNAUTHORIZED.into_response()
104 }
105 }
106}
107
108async fn handle_vote<C>(State(state): State<Arc<RaftState<C>>>, body: Bytes) -> impl IntoResponse
113where
114 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
115 C::D: serde::Serialize + serde::de::DeserializeOwned,
116 C::R: serde::Serialize + serde::de::DeserializeOwned,
117 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
118{
119 let req: VoteRequest<NodeId> = match postcard2::from_bytes(&body) {
120 Ok(r) => r,
121 Err(e) => {
122 error!("Failed to deserialize vote request: {e}");
123 return (StatusCode::BAD_REQUEST, Vec::new());
124 }
125 };
126
127 debug!("Received vote RPC");
128
129 match state.raft.vote(req).await {
130 Ok(resp) => match postcard2::to_vec(&resp) {
131 Ok(bytes) => (StatusCode::OK, bytes),
132 Err(e) => {
133 error!("Failed to serialize vote response: {e}");
134 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
135 }
136 },
137 Err(e) => {
138 error!("Vote RPC failed: {e}");
139 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
140 }
141 }
142}
143
144async fn handle_append<C>(State(state): State<Arc<RaftState<C>>>, body: Bytes) -> impl IntoResponse
145where
146 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
147 C::D: serde::Serialize + serde::de::DeserializeOwned,
148 C::R: serde::Serialize + serde::de::DeserializeOwned,
149 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
150{
151 let req: AppendEntriesRequest<C> = match postcard2::from_bytes(&body) {
152 Ok(r) => r,
153 Err(e) => {
154 error!("Failed to deserialize append request: {e}");
155 return (StatusCode::BAD_REQUEST, Vec::new());
156 }
157 };
158
159 debug!("Received append_entries RPC");
160
161 match state.raft.append_entries(req).await {
162 Ok(resp) => match postcard2::to_vec(&resp) {
163 Ok(bytes) => (StatusCode::OK, bytes),
164 Err(e) => {
165 error!("Failed to serialize append response: {e}");
166 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
167 }
168 },
169 Err(e) => {
170 error!("AppendEntries RPC failed: {e}");
171 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
172 }
173 }
174}
175
176async fn handle_snapshot<C>(
177 State(state): State<Arc<RaftState<C>>>,
178 body: Bytes,
179) -> impl IntoResponse
180where
181 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
182 C::D: serde::Serialize + serde::de::DeserializeOwned,
183 C::R: serde::Serialize + serde::de::DeserializeOwned,
184 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
185{
186 let req: InstallSnapshotRequest<C> = match postcard2::from_bytes(&body) {
187 Ok(r) => r,
188 Err(e) => {
189 error!("Failed to deserialize snapshot request: {e}");
190 return (StatusCode::BAD_REQUEST, Vec::new());
191 }
192 };
193
194 debug!("Received install_snapshot RPC");
195
196 match state.raft.install_snapshot(req).await {
197 Ok(resp) => match postcard2::to_vec(&resp) {
198 Ok(bytes) => (StatusCode::OK, bytes),
199 Err(e) => {
200 error!("Failed to serialize snapshot response: {e}");
201 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
202 }
203 },
204 Err(e) => {
205 error!("InstallSnapshot RPC failed: {e}");
206 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
207 }
208 }
209}
210
211#[derive(serde::Deserialize)]
212struct FullSnapshotReq {
213 vote: Vote<NodeId>,
214 meta: SnapshotMeta<NodeId, BasicNode>,
215 snapshot_data: Vec<u8>,
216}
217
218async fn handle_full_snapshot<C>(
219 State(state): State<Arc<RaftState<C>>>,
220 body: Bytes,
221) -> impl IntoResponse
222where
223 C: RaftTypeConfig<NodeId = NodeId, Node = BasicNode, SnapshotData = Cursor<Vec<u8>>>,
224 C::D: serde::Serialize + serde::de::DeserializeOwned,
225 C::R: serde::Serialize + serde::de::DeserializeOwned,
226 C::Entry: serde::Serialize + serde::de::DeserializeOwned,
227{
228 let req: FullSnapshotReq = match postcard2::from_bytes(&body) {
229 Ok(r) => r,
230 Err(e) => {
231 error!("Failed to deserialize full snapshot request: {e}");
232 return (StatusCode::BAD_REQUEST, Vec::new());
233 }
234 };
235
236 debug!("Received full_snapshot RPC");
237
238 let snapshot = Snapshot {
239 meta: req.meta,
240 snapshot: Box::new(Cursor::new(req.snapshot_data)),
241 };
242
243 match state.raft.install_full_snapshot(req.vote, snapshot).await {
244 Ok(resp) => match postcard2::to_vec(&resp) {
245 Ok(bytes) => (StatusCode::OK, bytes),
246 Err(e) => {
247 error!("Failed to serialize full snapshot response: {e}");
248 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
249 }
250 },
251 Err(e) => {
252 error!("install_full_snapshot failed: {e}");
253 (StatusCode::INTERNAL_SERVER_ERROR, Vec::new())
254 }
255 }
256}