Skip to main content

wp_arrow/
ipc.rs

1use arrow::array::RecordBatch;
2use arrow::ipc::reader::StreamReader;
3use arrow::ipc::writer::StreamWriter;
4
5use crate::error::WpArrowError;
6
7/// Decoded IPC data frame.
8pub struct IpcFrame {
9    pub tag: String,
10    pub batch: RecordBatch,
11}
12
13/// Encode a RecordBatch into an IPC frame: `[4B tag_len BE][tag bytes][Arrow IPC stream]`.
14pub fn encode_ipc(tag: &str, batch: &RecordBatch) -> Result<Vec<u8>, WpArrowError> {
15    let tag_bytes = tag.as_bytes();
16    let tag_len = tag_bytes.len() as u32;
17
18    let mut buf = Vec::new();
19    buf.extend_from_slice(&tag_len.to_be_bytes());
20    buf.extend_from_slice(tag_bytes);
21
22    let mut writer = StreamWriter::try_new(&mut buf, batch.schema().as_ref())
23        .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
24    writer
25        .write(batch)
26        .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
27    writer
28        .finish()
29        .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
30
31    Ok(buf)
32}
33
34/// Decode a complete IPC frame from bytes.
35pub fn decode_ipc(data: &[u8]) -> Result<IpcFrame, WpArrowError> {
36    if data.len() < 4 {
37        return Err(WpArrowError::IpcDecodeError(format!(
38            "frame too short: {} bytes, minimum 4",
39            data.len()
40        )));
41    }
42
43    let tag_len = u32::from_be_bytes(data[0..4].try_into().unwrap()) as usize;
44    let tag_end = 4 + tag_len;
45    if data.len() < tag_end {
46        return Err(WpArrowError::IpcDecodeError(format!(
47            "frame truncated: tag_len={tag_len} but only {} bytes remain after header",
48            data.len() - 4
49        )));
50    }
51
52    let tag = String::from_utf8(data[4..tag_end].to_vec())
53        .map_err(|e| WpArrowError::IpcDecodeError(format!("invalid UTF-8 in tag: {e}")))?;
54
55    let ipc_payload = &data[tag_end..];
56    let mut reader = StreamReader::try_new(ipc_payload, None)
57        .map_err(|e| WpArrowError::IpcDecodeError(e.to_string()))?;
58
59    let batch = reader
60        .next()
61        .ok_or_else(|| WpArrowError::IpcDecodeError("no RecordBatch in IPC payload".to_string()))?
62        .map_err(|e| WpArrowError::IpcDecodeError(e.to_string()))?;
63
64    Ok(IpcFrame { tag, batch })
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use arrow::array::{Int32Array, StringArray};
71    use arrow::datatypes::{DataType, Field, Schema};
72    use std::sync::Arc;
73
74    fn make_batch(num_rows: usize) -> RecordBatch {
75        let schema = Arc::new(Schema::new(vec![
76            Field::new("id", DataType::Int32, false),
77            Field::new("name", DataType::Utf8, true),
78        ]));
79        let ids: Vec<i32> = (0..num_rows as i32).collect();
80        let names: Vec<Option<&str>> = (0..num_rows)
81            .map(|i| if i % 2 == 0 { Some("even") } else { None })
82            .collect();
83        RecordBatch::try_new(
84            schema,
85            vec![
86                Arc::new(Int32Array::from(ids)),
87                Arc::new(StringArray::from(names)),
88            ],
89        )
90        .unwrap()
91    }
92
93    #[test]
94    fn ipc_roundtrip_basic() {
95        let batch = make_batch(5);
96        let encoded = encode_ipc("test-tag", &batch).unwrap();
97        let frame = decode_ipc(&encoded).unwrap();
98        assert_eq!(frame.tag, "test-tag");
99        assert_eq!(frame.batch.num_rows(), batch.num_rows());
100        assert_eq!(frame.batch.num_columns(), batch.num_columns());
101        assert_eq!(frame.batch.schema(), batch.schema());
102        assert_eq!(frame.batch, batch);
103    }
104
105    #[test]
106    fn ipc_roundtrip_empty_batch() {
107        let batch = make_batch(0);
108        let encoded = encode_ipc("empty", &batch).unwrap();
109        let frame = decode_ipc(&encoded).unwrap();
110        assert_eq!(frame.batch.num_rows(), 0);
111        assert_eq!(frame.batch, batch);
112    }
113
114    #[test]
115    fn ipc_roundtrip_large_batch() {
116        let batch = make_batch(1000);
117        let encoded = encode_ipc("large", &batch).unwrap();
118        let frame = decode_ipc(&encoded).unwrap();
119        assert_eq!(frame.batch.num_rows(), 1000);
120        assert_eq!(frame.batch, batch);
121    }
122
123    #[test]
124    fn ipc_tag_preserved() {
125        let batch = make_batch(1);
126        let encoded = encode_ipc("my-tag", &batch).unwrap();
127        let frame = decode_ipc(&encoded).unwrap();
128        assert_eq!(frame.tag, "my-tag");
129    }
130
131    #[test]
132    fn ipc_utf8_tag() {
133        let batch = make_batch(1);
134        let tag = "数据标签-🚀";
135        let encoded = encode_ipc(tag, &batch).unwrap();
136        let frame = decode_ipc(&encoded).unwrap();
137        assert_eq!(frame.tag, tag);
138    }
139
140    #[test]
141    fn ipc_empty_tag() {
142        let batch = make_batch(1);
143        let encoded = encode_ipc("", &batch).unwrap();
144        let frame = decode_ipc(&encoded).unwrap();
145        assert_eq!(frame.tag, "");
146    }
147
148    #[test]
149    fn decode_ipc_too_short() {
150        let result = decode_ipc(&[0x00; 2]);
151        assert!(matches!(result, Err(WpArrowError::IpcDecodeError(_))));
152    }
153
154    #[test]
155    fn decode_ipc_truncated_tag() {
156        let mut data = Vec::new();
157        data.extend_from_slice(&100u32.to_be_bytes()); // tag_len = 100 but no data
158        let result = decode_ipc(&data);
159        assert!(matches!(result, Err(WpArrowError::IpcDecodeError(_))));
160    }
161}