stately_arrow/api/
ipc.rs

1use arrow::ipc::writer::StreamWriter;
2use arrow::record_batch::RecordBatch;
3use axum::body::Body;
4use axum::http::header;
5use axum::response::Response;
6use bytes::Bytes;
7use datafusion::execution::SendableRecordBatchStream;
8use futures_util::{StreamExt, stream};
9
10use super::handlers::IDENT;
11use crate::Error;
12use crate::error::Result;
13
14pub const DEFAULT_CHUNK_SIZE: usize = 8 * 1024;
15
16/// Wrapper type to internally manage the state and buffering of an Arrow IPC stream.
17///
18/// Helps bridge the async/sync gap between arrow's sync primitives and async streaming.
19pub struct ArrowIpcStreamState {
20    stream:     SendableRecordBatchStream,
21    writer:     StreamWriter<Vec<u8>>,
22    chunk_size: usize,
23    finished:   bool,
24    emitted:    bool,
25}
26
27impl ArrowIpcStreamState {
28    /// Create a new `ArrowIpcStreamState`
29    ///
30    /// # Errors
31    /// - Returns an error if the `StreamWriter` cannot be created
32    pub fn try_new(stream: SendableRecordBatchStream, chunk_size: Option<usize>) -> Result<Self> {
33        let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE);
34        let schema = stream.schema();
35        let writer = StreamWriter::try_new(Vec::with_capacity(chunk_size), &schema)?;
36        Ok(Self { stream, writer, chunk_size, finished: false, emitted: false })
37    }
38
39    pub fn take_chunk(&mut self, force: bool) -> Option<Bytes> {
40        let buffer = self.writer.get_mut();
41        if buffer.is_empty() || (!force && self.emitted && buffer.len() < self.chunk_size) {
42            return None;
43        }
44        let chunk = buffer.split_off(0);
45        self.emitted = true;
46        Some(Bytes::from(chunk))
47    }
48
49    #[must_use]
50    pub fn take_stream(mut self) -> SendableRecordBatchStream {
51        self.finished = true;
52        self.stream
53    }
54
55    #[must_use]
56    pub fn into_parts(mut self) -> (SendableRecordBatchStream, StreamWriter<Vec<u8>>) {
57        self.finished = true;
58        (self.stream, self.writer)
59    }
60
61    pub fn is_finished(&self) -> bool { self.finished }
62
63    pub fn is_emitted(&self) -> bool { self.emitted }
64
65    /// Stream chunks from an Arrow IPC stream state.
66    ///
67    /// Useful inside of `stream::try_unfold` to convert an Arrow IPC stream state into a stream of
68    /// chunks.
69    ///
70    /// # Errors
71    /// - Returns an error if flushing fails
72    pub async fn stream_chunks(mut self) -> Result<Option<(Bytes, Self)>> {
73        loop {
74            if let Some(chunk) = self.take_chunk(false) {
75                return Ok::<_, Error>(Some((chunk, self)));
76            }
77
78            if self.finished {
79                return Ok(None);
80            }
81
82            if let Some(batch_result) = self.stream.next().await {
83                let batch: RecordBatch = batch_result?;
84                self.writer.write(&batch).inspect_err(log_err("stream_chunks - Writer.write"))?;
85                self.writer.flush().inspect_err(log_err("stream_chunks - Writer.flush"))?;
86            } else {
87                self.writer.finish().inspect_err(log_err("stream_chunks - Writer.finish"))?;
88                self.finished = true;
89                if let Some(chunk) = self.take_chunk(true) {
90                    return Ok(Some((chunk, self)));
91                }
92                return Ok(None);
93            }
94        }
95    }
96}
97
98/// Create an arrow ipc response from a stream of record batches
99///
100/// # Errors
101/// - Returns an error if the stream cannot be converted to an arrow ipc response
102///
103/// # Panics
104/// - Should not panic. The headers provided are valid.
105pub async fn arrow_ipc_response(stream: SendableRecordBatchStream) -> Result<Response> {
106    let state = ArrowIpcStreamState::try_new(stream, None)?;
107    let body =
108        Body::from_stream(stream::try_unfold(
109            state,
110            |state| async move { state.stream_chunks().await },
111        ));
112
113    Ok(Response::builder()
114        .header(header::CONTENT_TYPE, "application/vnd.apache.arrow.stream")
115        .header(header::TRANSFER_ENCODING, "chunked")
116        .body(body)
117        .unwrap())
118}
119
120fn log_err<E>(msg: &'static str) -> impl FnOnce(&E)
121where
122    E: std::error::Error,
123{
124    move |err| {
125        tracing::error!("{IDENT} {msg}: {err:?}");
126    }
127}