1use std::{
2 fmt,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use bytes::{Buf, BufMut, BytesMut};
9use futures::{Stream, future};
10use futures_util::ready;
11use http::StatusCode;
12use http_body::Body;
13use pilota::pb::Message;
14use tracing::{debug, trace};
15
16use super::{BUFFER_SIZE, DefaultDecoder, PREFIX_LEN};
17use crate::{
18 Status,
19 body::BoxBody,
20 codec::{
21 Decoder,
22 compression::{CompressionEncoding, decompress},
23 },
24 metadata::MetadataMap,
25 status::Code,
26};
27
28pub struct RecvStream<T> {
32 body: BoxBody,
33 decoder: DefaultDecoder<T>,
34 trailers: Option<MetadataMap>,
35 buf: BytesMut,
36 state: State,
37 kind: Kind,
38 compression_encoding: Option<CompressionEncoding>,
39 decompress_buf: BytesMut,
40}
41
42impl<T> Unpin for RecvStream<T> {}
43
44#[derive(Debug, Clone)]
45enum State {
46 Header,
47 Body(Option<CompressionEncoding>, usize),
48 Error,
49}
50
51#[derive(Debug, PartialEq, Eq)]
52pub enum Kind {
53 Request,
54 Response(StatusCode),
55}
56
57impl<T> RecvStream<T> {
58 pub fn new(
59 body: BoxBody,
60 kind: Kind,
61 compression_encoding: Option<CompressionEncoding>,
62 ) -> Self {
63 RecvStream {
64 body,
65 decoder: DefaultDecoder(PhantomData),
66 trailers: None,
67 buf: BytesMut::with_capacity(BUFFER_SIZE),
68 state: State::Header,
69 kind,
70 compression_encoding,
71 decompress_buf: BytesMut::new(),
72 }
73 }
74}
75
76impl<T: Message + Default> RecvStream<T> {
77 async fn message(&mut self) -> Result<Option<T>, Status> {
79 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
80 Some(Ok(m)) => Ok(Some(m)),
81 Some(Err(e)) => Err(e),
82 None => Ok(None),
83 }
84 }
85
86 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
88 if let Some(trailers) = self.trailers.take() {
89 return Ok(Some(trailers));
90 }
91
92 while self.message().await?.is_some() {}
95
96 if let Some(trailers) = self.trailers.take() {
97 return Ok(Some(trailers));
98 }
99
100 let maybe_trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_frame(cx)).await;
101
102 match maybe_trailer {
103 Some(Ok(frame)) => match frame.into_trailers() {
104 Ok(headers) => Ok(Some(MetadataMap::from_headers(headers))),
105 Err(_frame) => {
106 debug!("[VOLO] unexpected data from stream");
108 Err(Status::new(
109 Code::Internal,
110 "Unexpected data from stream.".to_string(),
111 ))
112 }
113 },
114 Some(Err(err)) => Err(Status::from_error(Box::new(err))),
115 None => Ok(None),
116 }
117 }
118
119 #[allow(clippy::result_large_err)]
120 fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
121 if let State::Header = self.state {
122 if self.buf.remaining() < PREFIX_LEN {
124 return Ok(None);
125 }
126 trace!("[VOLO-GRPC] streaming received buf: {:?}", self.buf);
127
128 let compression_encoding = match self.buf.get_u8() {
129 0 => None,
130 1 => {
131 if self.compression_encoding.is_some() {
132 self.compression_encoding
133 } else {
134 return Err(Status::new(
135 Code::Internal,
136 "protocol error: received message with compressed-flag but no \
137 grpc-encoding was specified"
138 .to_string(),
139 ));
140 }
141 }
142 flag => {
143 let message = format!(
144 "protocol error: received message with invalid compression flag: {flag} \
145 (valid flags are 0 and 1), while sending request"
146 );
147 return Err(Status::new(Code::Internal, message));
149 }
150 };
151 let len = self.buf.get_u32() as usize;
152 self.buf.reserve(len);
153
154 self.state = State::Body(compression_encoding, len);
155 }
156
157 if let State::Body(compression_encoding, len) = &self.state {
158 if self.buf.remaining() < *len || self.buf.len() < *len {
160 return Ok(None);
161 }
162 trace!("[VOLO-GRPC] streaming reading body: {:?}", self.buf);
163 let mut buf = self.buf.split_to(*len);
164 let decode_result = if let Some(encoding) = compression_encoding {
165 self.decompress_buf.clear();
166 if let Err(err) = decompress(*encoding, &mut buf, &mut self.decompress_buf) {
167 let message = if let Kind::Response(status) = self.kind {
168 format!(
169 "Error decompressing: {err}, while receiving response with status: \
170 {status}"
171 )
172 } else {
173 format!("Error decompressing: {err}, while sending request")
174 };
175 return Err(Status::new(Code::Internal, message));
176 }
177 DefaultDecoder::<T>::decode(&mut self.decoder, self.decompress_buf.split().freeze())
178 } else {
179 DefaultDecoder::<T>::decode(&mut self.decoder, buf.freeze())
180 };
181
182 return match decode_result {
183 Ok(Some(msg)) => {
184 self.state = State::Header;
185 Ok(Some(msg))
186 }
187 Ok(None) => Ok(None),
188 Err(e) => Err(e),
189 };
190 }
191
192 Ok(None)
193 }
194}
195
196impl<T: Message + Default> Stream for RecvStream<T> {
197 type Item = Result<T, Status>;
198
199 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
200 let trailer_frame = loop {
201 if let State::Error = &self.state {
202 return Poll::Ready(None);
203 }
204 if let Some(item) = self.decode_chunk()? {
205 return Poll::Ready(Some(Ok(item)));
206 }
207
208 match ready!(Pin::new(&mut self.body).poll_frame(cx)) {
209 Some(Ok(frame)) => match frame.into_data() {
210 Ok(data) => self.buf.put(data),
211 Err(trailer) => {
212 break Some(trailer);
213 }
214 },
215 Some(Err(e)) => {
216 let err: crate::BoxError = e.into();
217 let status = Status::from_error(err);
218 if self.kind == Kind::Request && status.code() == Code::Cancelled {
219 return Poll::Ready(None);
220 }
221 debug!("[VOLO] decoder inner stream error: {:?}", status);
222 let _ = std::mem::replace(&mut self.state, State::Error);
223 return Poll::Ready(Some(Err(status)));
224 }
225 None => {
226 if self.buf.has_remaining() {
227 debug!("[VOLO] unexpected EOF decoding stream");
228 return Poll::Ready(Some(Err(Status::new(
229 Code::Internal,
230 "Unexpected EOF decoding stream.".to_string(),
231 ))));
232 } else {
233 break None;
234 }
235 }
236 }
237 };
238
239 if let Kind::Response(status) = self.kind {
240 let trailer = match trailer_frame.map(|frame| frame.into_trailers()) {
241 Some(Ok(trailer)) => Some(trailer),
242 Some(Err(_frame)) => {
243 debug!("[VOLO] unexpected data from stream");
245 return Poll::Ready(Some(Err(Status::new(
246 Code::Internal,
247 "Unexpected data from stream.".to_string(),
248 ))));
249 }
250 None => None,
251 };
252
253 if let Err(e) = Status::infer_grpc_status(trailer.as_ref(), status) {
254 return if let Some(e) = e {
255 Some(Err(e)).into()
256 } else {
257 Poll::Ready(None)
258 };
259 } else {
260 self.trailers = trailer.map(MetadataMap::from_headers);
261 }
262 }
263
264 Poll::Ready(None)
265 }
266}
267
268impl<T> fmt::Debug for RecvStream<T> {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 f.debug_struct("RecvStream").finish()
271 }
272}