scouter_dataframe/parquet/bifrost/
ipc.rs1use arrow::error::ArrowError;
2use arrow::ipc::reader::StreamReader;
3use arrow::ipc::writer::StreamWriter;
4use arrow_array::RecordBatch;
5use std::io::Cursor;
6
7pub fn ipc_bytes_to_batches(data: &[u8]) -> Result<Vec<RecordBatch>, ArrowError> {
12 if data.is_empty() {
13 return Ok(Vec::new());
14 }
15 let cursor = Cursor::new(data);
16 let reader = StreamReader::try_new(cursor, None)?;
17 reader.collect()
18}
19
20pub fn batches_to_ipc_bytes(batches: &[RecordBatch]) -> Result<Vec<u8>, ArrowError> {
24 if batches.is_empty() {
25 return Ok(Vec::new());
26 }
27 let schema = batches[0].schema();
28 let mut buf = Vec::new();
29 let mut writer = StreamWriter::try_new(&mut buf, &schema)?;
30 for batch in batches {
31 writer.write(batch)?;
32 }
33 writer.finish()?;
34 Ok(buf)
35}
36
37#[cfg(test)]
38mod tests {
39 use super::*;
40 use arrow::datatypes::{DataType, Field, Schema};
41 use arrow_array::{Float64Array, Int64Array, StringArray};
42 use std::sync::Arc;
43
44 fn test_schema() -> Schema {
45 Schema::new(vec![
46 Field::new("id", DataType::Int64, false),
47 Field::new("name", DataType::Utf8, false),
48 Field::new("score", DataType::Float64, true),
49 ])
50 }
51
52 fn test_batch() -> RecordBatch {
53 let schema = Arc::new(test_schema());
54 RecordBatch::try_new(
55 schema,
56 vec![
57 Arc::new(Int64Array::from(vec![1, 2, 3])),
58 Arc::new(StringArray::from(vec!["alice", "bob", "charlie"])),
59 Arc::new(Float64Array::from(vec![Some(0.9), None, Some(0.7)])),
60 ],
61 )
62 .unwrap()
63 }
64
65 #[test]
66 fn test_round_trip() {
67 let batch = test_batch();
68 let bytes = batches_to_ipc_bytes(std::slice::from_ref(&batch)).unwrap();
69 let decoded = ipc_bytes_to_batches(&bytes).unwrap();
70
71 assert_eq!(1, decoded.len());
72 assert_eq!(batch.num_rows(), decoded[0].num_rows());
73 assert_eq!(batch.num_columns(), decoded[0].num_columns());
74 assert_eq!(batch.schema(), decoded[0].schema());
75
76 let orig_ids: &Int64Array = batch.column(0).as_any().downcast_ref().unwrap();
78 let decoded_ids: &Int64Array = decoded[0].column(0).as_any().downcast_ref().unwrap();
79 assert_eq!(orig_ids.values(), decoded_ids.values());
80 }
81
82 #[test]
83 fn test_multiple_batches_round_trip() {
84 let batch = test_batch();
85 let bytes = batches_to_ipc_bytes(&[batch.clone(), batch.clone()]).unwrap();
86 let decoded = ipc_bytes_to_batches(&bytes).unwrap();
87 assert_eq!(2, decoded.len());
88 assert_eq!(batch.num_rows(), decoded[0].num_rows());
89 assert_eq!(batch.num_rows(), decoded[1].num_rows());
90 }
91
92 #[test]
93 fn test_empty_batches_round_trip() {
94 let bytes = batches_to_ipc_bytes(&[]).unwrap();
95 assert!(bytes.is_empty());
96 let decoded = ipc_bytes_to_batches(&bytes).unwrap();
98 assert!(decoded.is_empty());
99 }
100
101 #[test]
102 fn test_malformed_bytes() {
103 let result = ipc_bytes_to_batches(b"not valid ipc data");
104 assert!(result.is_err());
105 }
106}