1use std::io;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use anyhow::{anyhow, bail, Error, Result};
13use bytes::{Buf, Bytes, BytesMut};
14use futures::{Stream, StreamExt};
15use tokio::sync::mpsc;
16use tokio::task::JoinHandle;
17use tokio_util::codec::{Decoder, FramedRead};
18use tokio_util::io::StreamReader;
19
20use crate::msgpack_codec::{subscription_frame_from_value, ErrorResponseBody, SubscriptionFrame};
21use crate::ops::DataType;
22use crate::protocol::DEFAULT_MAX_MESSAGE_SIZE;
23use crate::transaction::TxKey;
24
25type ByteStream = Pin<Box<dyn Stream<Item = io::Result<Bytes>> + Send>>;
26type FrameStream = FramedRead<StreamReader<ByteStream, Bytes>, MsgpackFrameDecoder>;
27
28const SUBSCRIPTION_QUEUE_CAPACITY: usize = 128;
29
30#[derive(Debug, Clone, PartialEq)]
32pub struct Delta {
33 pub tx_key: TxKey,
35 pub rows: Vec<(Vec<DataType>, i64)>,
37}
38
39pub struct Subscription {
44 tx_key: TxKey,
45 deltas: mpsc::Receiver<Result<Delta>>,
46 reader: JoinHandle<()>,
47}
48
49fn error_frame_to_error(err: ErrorResponseBody) -> Error {
50 anyhow!("subscription error (code {}): {}", err.code, err.message)
51}
52
53async fn read_deltas(mut frames: FrameStream, sender: mpsc::Sender<Result<Delta>>) {
54 while let Some(frame) = frames.next().await {
55 let (item, terminal) = match frame {
56 Ok(SubscriptionFrame::Delta { tx_key, rows }) => (Ok(Delta { tx_key, rows }), false),
57 Ok(SubscriptionFrame::Error(err)) => (Err(error_frame_to_error(err)), true),
58 Ok(SubscriptionFrame::Open { .. }) => {
59 (Err(anyhow!("unexpected open frame mid-stream")), true)
60 }
61 Err(err) => (Err(err), true),
62 };
63
64 if sender.send(item).await.is_err() || terminal {
65 break;
66 }
67 }
68}
69
70impl Subscription {
71 pub fn tx_key(&self) -> TxKey {
73 self.tx_key
74 }
75
76 pub(crate) async fn connect(resp: reqwest::Response) -> Result<Self> {
79 let byte_stream = resp
80 .bytes_stream()
81 .map(|chunk| chunk.map_err(io::Error::other));
82 Self::from_byte_stream(byte_stream).await
83 }
84
85 async fn from_byte_stream<S>(stream: S) -> Result<Self>
86 where
87 S: Stream<Item = io::Result<Bytes>> + Send + 'static,
88 {
89 let reader = StreamReader::new(Box::pin(stream) as ByteStream);
90 let mut frames = FramedRead::new(reader, MsgpackFrameDecoder::default());
91 let tx_key = match frames.next().await {
92 Some(Ok(SubscriptionFrame::Open { tx_key, .. })) => tx_key,
93 Some(Ok(SubscriptionFrame::Error(err))) => return Err(error_frame_to_error(err)),
94 Some(Ok(other)) => bail!("expected open frame, got {other:?}"),
95 Some(Err(err)) => return Err(err),
96 None => bail!("subscription stream closed before the open frame"),
97 };
98 let (sender, deltas) = mpsc::channel(SUBSCRIPTION_QUEUE_CAPACITY);
99 let reader = tokio::spawn(read_deltas(frames, sender));
100 Ok(Subscription {
101 tx_key,
102 deltas,
103 reader,
104 })
105 }
106}
107
108impl Stream for Subscription {
109 type Item = Result<Delta>;
110
111 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 let this = self.get_mut();
113 this.deltas.poll_recv(cx)
114 }
115}
116
117impl Drop for Subscription {
118 fn drop(&mut self) {
119 self.reader.abort();
120 }
121}
122
123pub(crate) struct MsgpackFrameDecoder {
127 max_frame_size: usize,
128}
129
130impl Default for MsgpackFrameDecoder {
131 fn default() -> Self {
132 Self {
133 max_frame_size: DEFAULT_MAX_MESSAGE_SIZE as usize,
134 }
135 }
136}
137
138fn needs_more_data(err: &rmpv::decode::Error) -> bool {
141 match err {
142 rmpv::decode::Error::InvalidMarkerRead(e) | rmpv::decode::Error::InvalidDataRead(e) => {
143 e.kind() == io::ErrorKind::UnexpectedEof
144 }
145 rmpv::decode::Error::DepthLimitExceeded => false,
146 }
147}
148
149impl Decoder for MsgpackFrameDecoder {
150 type Item = SubscriptionFrame;
151 type Error = Error;
152
153 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
154 if src.is_empty() {
155 return Ok(None);
156 }
157 let mut cursor: &[u8] = &src[..];
158 let remaining_before = cursor.len();
159 match rmpv::decode::read_value(&mut cursor) {
160 Ok(value) => {
161 let consumed = remaining_before - cursor.len();
162 let frame = subscription_frame_from_value(value)?;
163 src.advance(consumed);
164 Ok(Some(frame))
165 }
166 Err(err) if needs_more_data(&err) => {
167 if src.len() > self.max_frame_size {
168 bail!(
169 "subscription frame exceeds maximum size of {} bytes",
170 self.max_frame_size
171 );
172 }
173 Ok(None)
174 }
175 Err(err) => Err(anyhow!("msgpack frame decode error: {err}")),
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::msgpack_codec::encode_subscription_frame;
184 use crate::protocol::ColumnDescription;
185 use crate::transaction::TxKey;
186 use chrono::{TimeZone, Utc};
187 use std::collections::VecDeque;
188 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
189 use std::sync::{Arc, LazyLock};
190 use tokio::time::{sleep, timeout, Duration};
191
192 static SAMPLE_TX_KEY: LazyLock<TxKey> = LazyLock::new(|| TxKey {
193 tx_id: 3,
194 system_time: Utc.timestamp_opt(1_700_000_000, 0).unwrap(),
195 });
196
197 fn open_bytes() -> Vec<u8> {
198 encode_subscription_frame(&SubscriptionFrame::Open {
199 tx_key: *SAMPLE_TX_KEY,
200 columns: vec![ColumnDescription {
201 name: "n".to_string(),
202 data_type: 255,
203 members: None,
204 }],
205 })
206 .unwrap()
207 }
208
209 fn delta_bytes(name: &str) -> Vec<u8> {
210 encode_subscription_frame(&SubscriptionFrame::Delta {
211 tx_key: *SAMPLE_TX_KEY,
212 rows: vec![(vec![DataType::String(name.to_string())], 1)],
213 })
214 .unwrap()
215 }
216
217 fn error_bytes() -> Vec<u8> {
218 encode_subscription_frame(&SubscriptionFrame::Error(ErrorResponseBody {
219 severity: b'F',
220 code: 4000,
221 message: "boom".to_string(),
222 detail: None,
223 hint: None,
224 }))
225 .unwrap()
226 }
227
228 fn unknown_bytes() -> Vec<u8> {
229 let mut buf = Vec::new();
231 rmp::encode::write_map_len(&mut buf, 1).unwrap();
232 rmp::encode::write_str(&mut buf, "kind").unwrap();
233 rmp::encode::write_str(&mut buf, "heartbeat").unwrap();
234 buf
235 }
236
237 #[test]
238 fn decoder_needs_more_then_completes() {
239 let bytes = open_bytes();
240 let mut decoder = MsgpackFrameDecoder::default();
241 let mut buf = BytesMut::from(&bytes[..bytes.len() - 1]);
242 assert!(decoder.decode(&mut buf).unwrap().is_none(), "truncated");
243 buf.extend_from_slice(&bytes[bytes.len() - 1..]);
244 let frame = decoder.decode(&mut buf).unwrap().expect("complete frame");
245 assert!(matches!(frame, SubscriptionFrame::Open { .. }));
246 assert!(buf.is_empty());
247 }
248
249 #[test]
250 fn decoder_rejects_non_map_frame() {
251 let mut v = Vec::new();
253 rmp::encode::write_uint(&mut v, 5).unwrap();
254 let mut buf = BytesMut::from(&v[..]);
255 assert!(MsgpackFrameDecoder::default().decode(&mut buf).is_err());
256 }
257
258 #[test]
259 fn decoder_rejects_oversize_frame() {
260 let mut buf = BytesMut::from(&[0xd9u8, 100, 0x00, 0x00][..]);
262 let mut decoder = MsgpackFrameDecoder { max_frame_size: 3 };
263 assert!(decoder.decode(&mut buf).is_err());
264 }
265
266 #[tokio::test]
267 async fn subscription_surfaces_unknown_frame_kind_error() {
268 let mut payload = Vec::new();
269 payload.extend(open_bytes());
270 payload.extend(unknown_bytes());
271 let stream =
272 futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
273
274 let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
275 assert_eq!(sub.tx_key(), *SAMPLE_TX_KEY);
276
277 let err = sub.next().await.expect("an item").unwrap_err();
278 assert!(err
279 .to_string()
280 .contains("unknown subscription frame kind: heartbeat"));
281 assert!(sub.next().await.is_none(), "done after error");
282 }
283
284 #[tokio::test]
285 async fn subscription_surfaces_error_frame() {
286 let mut payload = Vec::new();
287 payload.extend(open_bytes());
288 payload.extend(error_bytes());
289 let stream =
290 futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
291
292 let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
293 let err = sub.next().await.expect("an item").unwrap_err();
294 assert!(err.to_string().contains("4000"));
295 assert!(sub.next().await.is_none(), "done after error");
296 }
297
298 #[tokio::test]
299 async fn subscription_preserves_queued_delta_before_error_frame() {
300 let mut payload = Vec::new();
301 payload.extend(open_bytes());
302 payload.extend(delta_bytes("Alice"));
303 payload.extend(error_bytes());
304 let stream =
305 futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
306
307 let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
308 let delta = sub.next().await.expect("first item").unwrap();
309 assert_eq!(
310 delta.rows,
311 vec![(vec![DataType::String("Alice".to_string())], 1)]
312 );
313
314 let err = sub.next().await.expect("second item").unwrap_err();
315 assert!(err.to_string().contains("4000"));
316 assert!(sub.next().await.is_none(), "done after error");
317 }
318
319 struct CountingByteStream {
320 chunks: VecDeque<Bytes>,
321 yielded: Arc<AtomicUsize>,
322 }
323
324 impl Stream for CountingByteStream {
325 type Item = io::Result<Bytes>;
326
327 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328 match self.chunks.pop_front() {
329 Some(chunk) => {
330 self.yielded.fetch_add(1, Ordering::SeqCst);
331 Poll::Ready(Some(Ok(chunk)))
332 }
333 None => Poll::Ready(None),
334 }
335 }
336 }
337
338 #[tokio::test]
339 async fn subscription_reads_ahead_until_delta_queue_is_full() {
340 let yielded = Arc::new(AtomicUsize::new(0));
341 let mut chunks = VecDeque::new();
342 chunks.push_back(Bytes::from(open_bytes()));
343 for idx in 0..SUBSCRIPTION_QUEUE_CAPACITY + 3 {
344 chunks.push_back(Bytes::from(delta_bytes(&format!("Alice {idx}"))));
345 }
346 let stream = CountingByteStream {
347 chunks,
348 yielded: yielded.clone(),
349 };
350
351 let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
352 let blocked_at = SUBSCRIPTION_QUEUE_CAPACITY + 2;
353
354 timeout(Duration::from_secs(1), async {
355 while yielded.load(Ordering::SeqCst) < blocked_at {
356 tokio::task::yield_now().await;
357 }
358 })
359 .await
360 .expect("reader should fill the bounded delta queue");
361
362 sleep(Duration::from_millis(25)).await;
363 assert_eq!(
364 yielded.load(Ordering::SeqCst),
365 blocked_at,
366 "reader should stop pulling once the delta queue is full"
367 );
368
369 let delta = sub.next().await.expect("buffered delta").unwrap();
370 assert_eq!(
371 delta.rows,
372 vec![(vec![DataType::String("Alice 0".to_string())], 1)]
373 );
374 timeout(Duration::from_secs(1), async {
375 while yielded.load(Ordering::SeqCst) < blocked_at + 1 {
376 tokio::task::yield_now().await;
377 }
378 })
379 .await
380 .expect("draining one delta should let the reader pull one more frame");
381 }
382
383 struct DropNotifyStream {
384 chunks: VecDeque<Bytes>,
385 polled_pending: Arc<AtomicBool>,
386 dropped: Arc<AtomicBool>,
387 }
388
389 impl Stream for DropNotifyStream {
390 type Item = io::Result<Bytes>;
391
392 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393 match self.chunks.pop_front() {
394 Some(chunk) => Poll::Ready(Some(Ok(chunk))),
395 None => {
396 self.polled_pending.store(true, Ordering::SeqCst);
397 Poll::Pending
398 }
399 }
400 }
401 }
402
403 impl Drop for DropNotifyStream {
404 fn drop(&mut self) {
405 self.dropped.store(true, Ordering::SeqCst);
406 }
407 }
408
409 #[tokio::test]
410 async fn dropping_subscription_aborts_reader_and_drops_stream() {
411 let polled_pending = Arc::new(AtomicBool::new(false));
412 let dropped = Arc::new(AtomicBool::new(false));
413 let mut chunks = VecDeque::new();
414 chunks.push_back(Bytes::from(open_bytes()));
415 let stream = DropNotifyStream {
416 chunks,
417 polled_pending: polled_pending.clone(),
418 dropped: dropped.clone(),
419 };
420
421 let sub = Subscription::from_byte_stream(stream).await.unwrap();
422 timeout(Duration::from_secs(1), async {
423 while !polled_pending.load(Ordering::SeqCst) {
424 tokio::task::yield_now().await;
425 }
426 })
427 .await
428 .expect("reader should poll the upstream stream");
429
430 drop(sub);
431
432 timeout(Duration::from_secs(1), async {
433 while !dropped.load(Ordering::SeqCst) {
434 tokio::task::yield_now().await;
435 }
436 })
437 .await
438 .expect("dropping subscription should abort the reader task");
439 }
440}