1use arrow::array::RecordBatch;
2use arrow::ipc::reader::StreamReader;
3use arrow::ipc::writer::StreamWriter;
4
5use crate::error::WpArrowError;
6
7pub struct IpcFrame {
9 pub tag: String,
10 pub batch: RecordBatch,
11}
12
13pub fn encode_ipc(tag: &str, batch: &RecordBatch) -> Result<Vec<u8>, WpArrowError> {
15 let tag_bytes = tag.as_bytes();
16 let tag_len = tag_bytes.len() as u32;
17
18 let mut buf = Vec::new();
19 buf.extend_from_slice(&tag_len.to_be_bytes());
20 buf.extend_from_slice(tag_bytes);
21
22 let mut writer = StreamWriter::try_new(&mut buf, batch.schema().as_ref())
23 .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
24 writer
25 .write(batch)
26 .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
27 writer
28 .finish()
29 .map_err(|e| WpArrowError::IpcEncodeError(e.to_string()))?;
30
31 Ok(buf)
32}
33
34pub fn decode_ipc(data: &[u8]) -> Result<IpcFrame, WpArrowError> {
36 if data.len() < 4 {
37 return Err(WpArrowError::IpcDecodeError(format!(
38 "frame too short: {} bytes, minimum 4",
39 data.len()
40 )));
41 }
42
43 let tag_len = u32::from_be_bytes(data[0..4].try_into().unwrap()) as usize;
44 let tag_end = 4 + tag_len;
45 if data.len() < tag_end {
46 return Err(WpArrowError::IpcDecodeError(format!(
47 "frame truncated: tag_len={tag_len} but only {} bytes remain after header",
48 data.len() - 4
49 )));
50 }
51
52 let tag = String::from_utf8(data[4..tag_end].to_vec())
53 .map_err(|e| WpArrowError::IpcDecodeError(format!("invalid UTF-8 in tag: {e}")))?;
54
55 let ipc_payload = &data[tag_end..];
56 let mut reader = StreamReader::try_new(ipc_payload, None)
57 .map_err(|e| WpArrowError::IpcDecodeError(e.to_string()))?;
58
59 let batch = reader
60 .next()
61 .ok_or_else(|| WpArrowError::IpcDecodeError("no RecordBatch in IPC payload".to_string()))?
62 .map_err(|e| WpArrowError::IpcDecodeError(e.to_string()))?;
63
64 Ok(IpcFrame { tag, batch })
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use arrow::array::{Int32Array, StringArray};
71 use arrow::datatypes::{DataType, Field, Schema};
72 use std::sync::Arc;
73
74 fn make_batch(num_rows: usize) -> RecordBatch {
75 let schema = Arc::new(Schema::new(vec![
76 Field::new("id", DataType::Int32, false),
77 Field::new("name", DataType::Utf8, true),
78 ]));
79 let ids: Vec<i32> = (0..num_rows as i32).collect();
80 let names: Vec<Option<&str>> = (0..num_rows)
81 .map(|i| if i % 2 == 0 { Some("even") } else { None })
82 .collect();
83 RecordBatch::try_new(
84 schema,
85 vec![
86 Arc::new(Int32Array::from(ids)),
87 Arc::new(StringArray::from(names)),
88 ],
89 )
90 .unwrap()
91 }
92
93 #[test]
94 fn ipc_roundtrip_basic() {
95 let batch = make_batch(5);
96 let encoded = encode_ipc("test-tag", &batch).unwrap();
97 let frame = decode_ipc(&encoded).unwrap();
98 assert_eq!(frame.tag, "test-tag");
99 assert_eq!(frame.batch.num_rows(), batch.num_rows());
100 assert_eq!(frame.batch.num_columns(), batch.num_columns());
101 assert_eq!(frame.batch.schema(), batch.schema());
102 assert_eq!(frame.batch, batch);
103 }
104
105 #[test]
106 fn ipc_roundtrip_empty_batch() {
107 let batch = make_batch(0);
108 let encoded = encode_ipc("empty", &batch).unwrap();
109 let frame = decode_ipc(&encoded).unwrap();
110 assert_eq!(frame.batch.num_rows(), 0);
111 assert_eq!(frame.batch, batch);
112 }
113
114 #[test]
115 fn ipc_roundtrip_large_batch() {
116 let batch = make_batch(1000);
117 let encoded = encode_ipc("large", &batch).unwrap();
118 let frame = decode_ipc(&encoded).unwrap();
119 assert_eq!(frame.batch.num_rows(), 1000);
120 assert_eq!(frame.batch, batch);
121 }
122
123 #[test]
124 fn ipc_tag_preserved() {
125 let batch = make_batch(1);
126 let encoded = encode_ipc("my-tag", &batch).unwrap();
127 let frame = decode_ipc(&encoded).unwrap();
128 assert_eq!(frame.tag, "my-tag");
129 }
130
131 #[test]
132 fn ipc_utf8_tag() {
133 let batch = make_batch(1);
134 let tag = "数据标签-🚀";
135 let encoded = encode_ipc(tag, &batch).unwrap();
136 let frame = decode_ipc(&encoded).unwrap();
137 assert_eq!(frame.tag, tag);
138 }
139
140 #[test]
141 fn ipc_empty_tag() {
142 let batch = make_batch(1);
143 let encoded = encode_ipc("", &batch).unwrap();
144 let frame = decode_ipc(&encoded).unwrap();
145 assert_eq!(frame.tag, "");
146 }
147
148 #[test]
149 fn decode_ipc_too_short() {
150 let result = decode_ipc(&[0x00; 2]);
151 assert!(matches!(result, Err(WpArrowError::IpcDecodeError(_))));
152 }
153
154 #[test]
155 fn decode_ipc_truncated_tag() {
156 let mut data = Vec::new();
157 data.extend_from_slice(&100u32.to_be_bytes()); let result = decode_ipc(&data);
159 assert!(matches!(result, Err(WpArrowError::IpcDecodeError(_))));
160 }
161}