1use alloc::string::String;
22use alloc::vec::Vec;
23
24use zerodds_hpack::{Decoder as HpackDecoder, Encoder as HpackEncoder, HeaderField};
25use zerodds_http2::{
26 Flags, Frame, FrameHeader, FrameType, Settings, StreamId, StreamState, decode_frame,
27 encode_frame,
28};
29
30use crate::frame::{decode_message, encode_message};
31use crate::path::parse_path;
32use crate::status::Status;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct GrpcRequest {
37 pub stream_id: StreamId,
39 pub path: String,
41 pub service: String,
43 pub method: String,
45 pub encoding: Option<String>,
47 pub body: Vec<u8>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct GrpcResponse {
54 pub stream_id: StreamId,
56 pub status: Status,
58 pub message: Option<String>,
60 pub body: Vec<u8>,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66struct StreamSlot {
67 state: StreamState,
68 headers: Vec<HeaderField>,
69 body: Vec<u8>,
70}
71
72#[derive(Debug, Clone, Default, PartialEq, Eq)]
74pub struct GrpcServer {
75 settings: Settings,
76 decoder: HpackDecoder,
77 encoder: HpackEncoder,
78 streams: alloc::collections::BTreeMap<StreamId, StreamSlot>,
79}
80
81impl GrpcServer {
82 #[must_use]
84 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn process_frame(
96 &mut self,
97 input: &[u8],
98 ) -> Result<(Option<GrpcRequest>, usize), &'static str> {
99 let (frame, consumed) =
100 decode_frame(input, self.settings.max_frame_size).map_err(|_| "decode frame failed")?;
101 let request = match frame.header.frame_type {
102 FrameType::Headers => self.handle_headers(&frame)?,
103 FrameType::Data => self.handle_data(&frame)?,
104 FrameType::Settings | FrameType::Ping | FrameType::WindowUpdate => None,
105 FrameType::RstStream => {
106 self.streams.remove(&frame.header.stream_id);
107 None
108 }
109 _ => None,
110 };
111 Ok((request, consumed))
112 }
113
114 fn handle_headers(&mut self, frame: &Frame<'_>) -> Result<Option<GrpcRequest>, &'static str> {
115 let headers = self
116 .decoder
117 .decode(frame.payload)
118 .map_err(|_| "hpack decode")?;
119 let slot = self
120 .streams
121 .entry(frame.header.stream_id)
122 .or_insert(StreamSlot {
123 state: StreamState::Idle,
124 headers: Vec::new(),
125 body: Vec::new(),
126 });
127 slot.headers.extend(headers);
128 if frame.header.flags.has(Flags::END_STREAM) {
129 return Ok(Some(self.finalize_request(frame.header.stream_id)?));
130 }
131 Ok(None)
132 }
133
134 fn handle_data(&mut self, frame: &Frame<'_>) -> Result<Option<GrpcRequest>, &'static str> {
135 let slot = self
136 .streams
137 .get_mut(&frame.header.stream_id)
138 .ok_or("data on unknown stream")?;
139 slot.body.extend_from_slice(frame.payload);
140 if frame.header.flags.has(Flags::END_STREAM) {
141 return Ok(Some(self.finalize_request(frame.header.stream_id)?));
142 }
143 Ok(None)
144 }
145
146 fn finalize_request(&mut self, stream_id: StreamId) -> Result<GrpcRequest, &'static str> {
147 let slot = self.streams.remove(&stream_id).ok_or("unknown stream")?;
148 let path = slot
149 .headers
150 .iter()
151 .find(|h| h.name == ":path")
152 .map(|h| h.value.clone())
153 .ok_or(":path missing")?;
154 let (service, method) = parse_path(&path).map_err(|_| "bad path")?;
155 let encoding = slot
156 .headers
157 .iter()
158 .find(|h| h.name == "grpc-encoding")
159 .map(|h| h.value.clone());
160 Ok(GrpcRequest {
161 stream_id,
162 path: path.clone(),
163 service,
164 method,
165 encoding,
166 body: slot.body,
167 })
168 }
169
170 pub fn encode_response(&mut self, resp: &GrpcResponse) -> Result<Vec<u8>, &'static str> {
176 let mut out = Vec::new();
177 let headers = alloc::vec![
179 HeaderField {
180 name: ":status".into(),
181 value: "200".into(),
182 },
183 HeaderField {
184 name: "content-type".into(),
185 value: "application/grpc".into(),
186 },
187 ];
188 let h_payload = self.encoder.encode(&headers);
189 let h = FrameHeader {
190 length: h_payload.len() as u32,
191 frame_type: FrameType::Headers,
192 flags: Flags(Flags::END_HEADERS),
193 stream_id: resp.stream_id,
194 };
195 let mut buf = alloc::vec![0u8; 9 + h_payload.len()];
196 encode_frame(&h, &h_payload, &mut buf, self.settings.max_frame_size)
197 .map_err(|_| "headers encode")?;
198 out.extend_from_slice(&buf);
199
200 if !resp.body.is_empty() {
202 let lpm = encode_message(&resp.body, false).map_err(|_| "lpm encode")?;
203 let d = FrameHeader {
204 length: lpm.len() as u32,
205 frame_type: FrameType::Data,
206 flags: Flags(0),
207 stream_id: resp.stream_id,
208 };
209 let mut dbuf = alloc::vec![0u8; 9 + lpm.len()];
210 encode_frame(&d, &lpm, &mut dbuf, self.settings.max_frame_size)
211 .map_err(|_| "data encode")?;
212 out.extend_from_slice(&dbuf);
213 }
214
215 let mut trailers = alloc::vec![HeaderField {
217 name: "grpc-status".into(),
218 value: alloc::format!("{}", resp.status.code()),
219 }];
220 if let Some(msg) = &resp.message {
221 trailers.push(HeaderField {
222 name: "grpc-message".into(),
223 value: msg.clone(),
224 });
225 }
226 let t_payload = self.encoder.encode(&trailers);
227 let t = FrameHeader {
228 length: t_payload.len() as u32,
229 frame_type: FrameType::Headers,
230 flags: Flags(Flags::END_HEADERS | Flags::END_STREAM),
231 stream_id: resp.stream_id,
232 };
233 let mut tbuf = alloc::vec![0u8; 9 + t_payload.len()];
234 encode_frame(&t, &t_payload, &mut tbuf, self.settings.max_frame_size)
235 .map_err(|_| "trailer encode")?;
236 out.extend_from_slice(&tbuf);
237
238 Ok(out)
239 }
240
241 pub fn decode_request_body(&self, req: &GrpcRequest) -> Result<Vec<u8>, &'static str> {
247 let (_, msg, _) = decode_message(&req.body).map_err(|_| "lpm decode")?;
248 Ok(msg)
249 }
250}
251
252#[cfg(test)]
253#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn server_starts_with_no_streams() {
259 let s = GrpcServer::new();
260 assert!(s.streams.is_empty());
261 }
262
263 #[test]
264 fn encode_response_includes_status_and_trailers() {
265 let mut s = GrpcServer::new();
266 let resp = GrpcResponse {
267 stream_id: 1,
268 status: Status::Ok,
269 message: None,
270 body: alloc::vec![1, 2, 3],
271 };
272 let bytes = s.encode_response(&resp).unwrap();
273 assert!(bytes.len() > 9 * 3, "should have at least 3 frames");
274 }
275
276 #[test]
277 fn encode_response_with_status_message_includes_it() {
278 let mut s = GrpcServer::new();
279 let resp = GrpcResponse {
280 stream_id: 1,
281 status: Status::Internal,
282 message: Some("boom".into()),
283 body: Vec::new(),
284 };
285 let _bytes = s.encode_response(&resp).unwrap();
286 }
289
290 #[test]
291 fn rst_stream_clears_state() {
292 let mut s = GrpcServer::new();
293 s.streams.insert(
294 1,
295 StreamSlot {
296 state: StreamState::Open,
297 headers: alloc::vec![],
298 body: alloc::vec![],
299 },
300 );
301 let buf = alloc::vec![
303 0x00, 0x00, 0x04, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, ];
309 s.process_frame(&buf).unwrap();
310 assert!(s.streams.is_empty());
311 }
312}