1use arrow::ipc::writer::StreamWriter;
2use arrow::record_batch::RecordBatch;
3use axum::body::Body;
4use axum::http::header;
5use axum::response::Response;
6use bytes::Bytes;
7use datafusion::execution::SendableRecordBatchStream;
8use futures_util::{StreamExt, stream};
9
10use super::handlers::IDENT;
11use crate::Error;
12use crate::error::Result;
13
14pub const DEFAULT_CHUNK_SIZE: usize = 8 * 1024;
15
16pub struct ArrowIpcStreamState {
20 stream: SendableRecordBatchStream,
21 writer: StreamWriter<Vec<u8>>,
22 chunk_size: usize,
23 finished: bool,
24 emitted: bool,
25}
26
27impl ArrowIpcStreamState {
28 pub fn try_new(stream: SendableRecordBatchStream, chunk_size: Option<usize>) -> Result<Self> {
33 let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE);
34 let schema = stream.schema();
35 let writer = StreamWriter::try_new(Vec::with_capacity(chunk_size), &schema)?;
36 Ok(Self { stream, writer, chunk_size, finished: false, emitted: false })
37 }
38
39 pub fn take_chunk(&mut self, force: bool) -> Option<Bytes> {
40 let buffer = self.writer.get_mut();
41 if buffer.is_empty() || (!force && self.emitted && buffer.len() < self.chunk_size) {
42 return None;
43 }
44 let chunk = buffer.split_off(0);
45 self.emitted = true;
46 Some(Bytes::from(chunk))
47 }
48
49 #[must_use]
50 pub fn take_stream(mut self) -> SendableRecordBatchStream {
51 self.finished = true;
52 self.stream
53 }
54
55 #[must_use]
56 pub fn into_parts(mut self) -> (SendableRecordBatchStream, StreamWriter<Vec<u8>>) {
57 self.finished = true;
58 (self.stream, self.writer)
59 }
60
61 pub fn is_finished(&self) -> bool { self.finished }
62
63 pub fn is_emitted(&self) -> bool { self.emitted }
64
65 pub async fn stream_chunks(mut self) -> Result<Option<(Bytes, Self)>> {
73 loop {
74 if let Some(chunk) = self.take_chunk(false) {
75 return Ok::<_, Error>(Some((chunk, self)));
76 }
77
78 if self.finished {
79 return Ok(None);
80 }
81
82 if let Some(batch_result) = self.stream.next().await {
83 let batch: RecordBatch = batch_result?;
84 self.writer.write(&batch).inspect_err(log_err("stream_chunks - Writer.write"))?;
85 self.writer.flush().inspect_err(log_err("stream_chunks - Writer.flush"))?;
86 } else {
87 self.writer.finish().inspect_err(log_err("stream_chunks - Writer.finish"))?;
88 self.finished = true;
89 if let Some(chunk) = self.take_chunk(true) {
90 return Ok(Some((chunk, self)));
91 }
92 return Ok(None);
93 }
94 }
95 }
96}
97
98pub async fn arrow_ipc_response(stream: SendableRecordBatchStream) -> Result<Response> {
106 let state = ArrowIpcStreamState::try_new(stream, None)?;
107 let body =
108 Body::from_stream(stream::try_unfold(
109 state,
110 |state| async move { state.stream_chunks().await },
111 ));
112
113 Ok(Response::builder()
114 .header(header::CONTENT_TYPE, "application/vnd.apache.arrow.stream")
115 .header(header::TRANSFER_ENCODING, "chunked")
116 .body(body)
117 .unwrap())
118}
119
120fn log_err<E>(msg: &'static str) -> impl FnOnce(&E)
121where
122 E: std::error::Error,
123{
124 move |err| {
125 tracing::error!("{IDENT} {msg}: {err:?}");
126 }
127}