1use bytes::{Buf, Bytes, BytesMut};
64use futures_util::stream;
65use tonic::Status;
66
67use super::streaming::MessageStream;
68
69pub async fn parse_grpc_client_stream(
106 body: axum::body::Body,
107 max_message_size: usize,
108) -> Result<MessageStream, Status> {
109 let body_bytes = axum::body::to_bytes(body, usize::MAX)
111 .await
112 .map_err(|e| Status::internal(format!("Failed to read body: {}", e)))?;
113
114 let buffer = BytesMut::from(&body_bytes[..]);
116
117 let messages = parse_all_frames(buffer, max_message_size)?;
119
120 Ok(Box::pin(stream::iter(messages.into_iter().map(Ok))))
122}
123
124fn parse_all_frames(mut buffer: BytesMut, max_message_size: usize) -> Result<Vec<Bytes>, Status> {
126 let mut messages = Vec::new();
127
128 while !buffer.is_empty() {
129 if buffer.len() < 5 {
131 return Err(Status::internal(
132 "Incomplete gRPC frame header: expected 5 bytes, got less",
133 ));
134 }
135
136 let compression_flag = buffer[0];
138 if compression_flag != 0 {
139 return Err(Status::unimplemented("Message compression not supported"));
140 }
141
142 let length_bytes = &buffer[1..5];
144 let message_length =
145 u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]) as usize;
146
147 if message_length > max_message_size {
149 return Err(Status::resource_exhausted(format!(
150 "Message size {} exceeds maximum allowed size of {}",
151 message_length, max_message_size
152 )));
153 }
154
155 let total_frame_size = 5 + message_length;
157 if buffer.len() < total_frame_size {
158 return Err(Status::internal(
159 "Incomplete gRPC message: expected more bytes than available",
160 ));
161 }
162
163 let message = buffer[5..total_frame_size].to_vec();
165 messages.push(Bytes::from(message));
166
167 buffer.advance(total_frame_size);
169 }
170
171 Ok(messages)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use futures_util::StreamExt;
178
179 #[tokio::test]
180 async fn test_single_frame_parsing() {
181 let frame = vec![
183 0x00, 0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o', ];
187
188 let body = axum::body::Body::from(frame);
189 let mut stream = parse_grpc_client_stream(body, 1024).await.unwrap();
190 let msg = stream.next().await;
191
192 assert!(msg.is_some());
193 assert!(msg.unwrap().is_ok());
194 let result = stream.next().await;
195 assert!(result.is_none());
196 }
197
198 #[tokio::test]
199 async fn test_multiple_frames() {
200 let mut frame = Vec::new();
202
203 frame.push(0x00);
205 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
206 frame.extend_from_slice(b"hello");
207
208 frame.push(0x00);
210 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
211 frame.extend_from_slice(b"world");
212
213 let body = axum::body::Body::from(frame);
214 let mut stream = parse_grpc_client_stream(body, 1024).await.unwrap();
215
216 let msg1 = stream.next().await;
217 assert!(msg1.is_some());
218 assert_eq!(msg1.unwrap().unwrap(), b"hello"[..]);
219
220 let msg2 = stream.next().await;
221 assert!(msg2.is_some());
222 assert_eq!(msg2.unwrap().unwrap(), b"world"[..]);
223
224 let msg3 = stream.next().await;
225 assert!(msg3.is_none());
226 }
227
228 #[tokio::test]
229 async fn test_empty_body() {
230 let body = axum::body::Body::from(Vec::<u8>::new());
231 let mut stream = parse_grpc_client_stream(body, 1024).await.unwrap();
232
233 let result = stream.next().await;
234 assert!(result.is_none());
235 }
236
237 #[tokio::test]
238 async fn test_frame_size_at_limit() {
239 let max_size = 10;
240 let message = b"0123456789"; let mut frame = Vec::new();
243 frame.push(0x00);
244 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x0a]); frame.extend_from_slice(message);
246
247 let body = axum::body::Body::from(frame);
248 let mut stream = parse_grpc_client_stream(body, max_size).await.unwrap();
249
250 let msg = stream.next().await;
251 assert!(msg.is_some());
252 assert_eq!(msg.unwrap().unwrap(), message[..]);
253 }
254
255 #[tokio::test]
256 async fn test_frame_exceeds_limit() {
257 let max_size = 5;
258 let message = b"toolong"; let mut frame = Vec::new();
261 frame.push(0x00);
262 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x07]); frame.extend_from_slice(message);
264
265 let body = axum::body::Body::from(frame);
266 let result = parse_grpc_client_stream(body, max_size).await;
267
268 assert!(result.is_err());
269 if let Err(status) = result {
270 assert_eq!(status.code(), tonic::Code::ResourceExhausted);
271 }
272 }
273
274 #[tokio::test]
275 async fn test_incomplete_frame_header() {
276 let frame = vec![0x00, 0x00, 0x00];
278
279 let body = axum::body::Body::from(frame);
280 let result = parse_grpc_client_stream(body, 1024).await;
281
282 assert!(result.is_err());
283 if let Err(status) = result {
284 assert_eq!(status.code(), tonic::Code::Internal);
285 }
286 }
287
288 #[tokio::test]
289 async fn test_incomplete_frame_body() {
290 let mut frame = Vec::new();
292 frame.push(0x00);
293 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x0a]); frame.extend_from_slice(b"short"); let body = axum::body::Body::from(frame);
297 let result = parse_grpc_client_stream(body, 1024).await;
298
299 assert!(result.is_err());
300 if let Err(status) = result {
301 assert_eq!(status.code(), tonic::Code::Internal);
302 }
303 }
304
305 #[tokio::test]
306 async fn test_compression_flag_set() {
307 let mut frame = Vec::new();
309 frame.push(0x01); frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
311 frame.extend_from_slice(b"hello");
312
313 let body = axum::body::Body::from(frame);
314 let result = parse_grpc_client_stream(body, 1024).await;
315
316 assert!(result.is_err());
317 if let Err(status) = result {
318 assert_eq!(status.code(), tonic::Code::Unimplemented);
319 }
320 }
321
322 #[tokio::test]
323 async fn test_large_message_length() {
324 let message = b"x".repeat(1000);
326 let mut frame = Vec::new();
327 frame.push(0x00);
328 frame.extend_from_slice(&[0x00, 0x00, 0x03, 0xe8]); frame.extend_from_slice(&message);
330
331 let body = axum::body::Body::from(frame);
332 let mut stream = parse_grpc_client_stream(body, 2000).await.unwrap();
333
334 let msg = stream.next().await;
335 assert!(msg.is_some());
336 assert_eq!(msg.unwrap().unwrap().len(), 1000);
337 }
338
339 #[tokio::test]
340 async fn test_zero_length_message() {
341 let mut frame = Vec::new();
343 frame.push(0x00);
344 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); let body = axum::body::Body::from(frame);
347 let mut stream = parse_grpc_client_stream(body, 1024).await.unwrap();
348
349 let msg = stream.next().await;
350 assert!(msg.is_some());
351 assert_eq!(msg.unwrap().unwrap().len(), 0);
352 }
353
354 #[tokio::test]
355 async fn test_multiple_frames_with_mixed_sizes() {
356 let mut frame = Vec::new();
357
358 frame.push(0x00);
360 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x03]);
361 frame.extend_from_slice(b"abc");
362
363 frame.push(0x00);
365 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x07]);
366 frame.extend_from_slice(b"defghij");
367
368 frame.push(0x00);
370 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
371
372 frame.push(0x00);
374 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
375 frame.extend_from_slice(b"x");
376
377 let body = axum::body::Body::from(frame);
378 let mut stream = parse_grpc_client_stream(body, 1024).await.unwrap();
379
380 let msg1 = stream.next().await.unwrap().unwrap();
381 assert_eq!(msg1, b"abc"[..]);
382
383 let msg2 = stream.next().await.unwrap().unwrap();
384 assert_eq!(msg2, b"defghij"[..]);
385
386 let msg3 = stream.next().await.unwrap().unwrap();
387 assert_eq!(msg3.len(), 0);
388
389 let msg4 = stream.next().await.unwrap().unwrap();
390 assert_eq!(msg4, b"x"[..]);
391
392 let msg5 = stream.next().await;
393 assert!(msg5.is_none());
394 }
395
396 #[test]
397 fn test_big_endian_length_parsing() {
398 let buffer = BytesMut::from(
401 &[
402 0x00, 0x00, 0x00, 0x01, 0x00, ][..],
405 );
406
407 let length_bytes = &buffer[1..5];
409 let length = u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]);
410
411 assert_eq!(length, 256);
412 }
413
414 #[test]
415 fn test_big_endian_max_value() {
416 let buffer = BytesMut::from(
418 &[
419 0x00, 0xff, 0xff, 0xff, 0xff, ][..],
421 );
422
423 let length_bytes = &buffer[1..5];
424 let length = u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]);
425
426 assert_eq!(length, u32::MAX);
427 }
428
429 #[tokio::test]
430 async fn test_error_message_includes_size_info() {
431 let max_size = 100;
432 let message = b"x".repeat(150);
433
434 let mut frame = Vec::new();
435 frame.push(0x00);
436 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x96]); frame.extend_from_slice(&message);
438
439 let body = axum::body::Body::from(frame);
440 let result = parse_grpc_client_stream(body, max_size).await;
441
442 assert!(result.is_err());
443 if let Err(status) = result {
444 assert!(status.message().contains("150"));
445 assert!(status.message().contains("100"));
446 }
447 }
448
449 #[tokio::test]
450 async fn test_stream_collects_all_messages() {
451 let mut frame = Vec::new();
453
454 for i in 0..10 {
455 frame.push(0x00);
456 frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
457 frame.push(b'0' + i as u8);
458 }
459
460 let body = axum::body::Body::from(frame);
461 let stream = parse_grpc_client_stream(body, 1024).await.unwrap();
462 let messages: Vec<_> = futures_util::StreamExt::collect(stream).await;
463
464 assert_eq!(messages.len(), 10);
465 for (i, msg) in messages.iter().enumerate() {
466 assert_eq!(msg.as_ref().unwrap()[0], b'0' + i as u8);
467 }
468 }
469}