1use arrow::array::RecordBatch;
17use arrow::ipc::reader::StreamReader;
18use arrow::ipc::writer::StreamWriter;
19use arrow_schema::SchemaRef;
20
21use crate::error::IpcError;
22
23pub const SECRET_HANDLE_EXTENSION: &str = "uni-db.secret-handle";
33
34const ARROW_EXTENSION_KEY: &str = "ARROW:extension:name";
36
37fn reject_secret_handles(batch: &RecordBatch) -> Result<(), IpcError> {
47 fn walk(field: &arrow_schema::Field) -> Result<(), IpcError> {
48 use arrow_schema::DataType;
49 if field
50 .metadata()
51 .get(ARROW_EXTENSION_KEY)
52 .map(String::as_str)
53 == Some(SECRET_HANDLE_EXTENSION)
54 {
55 return Err(IpcError::SecretLeakAttempt {
56 column: field.name().clone(),
57 });
58 }
59 match field.data_type() {
60 DataType::Struct(fields) => fields.iter().try_for_each(|f| walk(f.as_ref())),
61 DataType::List(item) | DataType::LargeList(item) | DataType::FixedSizeList(item, _) => {
62 walk(item.as_ref())
63 }
64 DataType::Map(field, _) => walk(field.as_ref()),
65 _ => Ok(()),
66 }
67 }
68 batch
69 .schema()
70 .fields()
71 .iter()
72 .try_for_each(|f| walk(f.as_ref()))
73}
74
75fn reject_all(batches: &[RecordBatch]) -> Result<(), IpcError> {
79 batches.iter().try_for_each(reject_secret_handles)
80}
81
82pub fn encode_batch(batch: &RecordBatch) -> Result<Vec<u8>, IpcError> {
92 reject_secret_handles(batch)?;
93 let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(batch));
94 write_stream(&mut buf, batch.schema(), std::slice::from_ref(batch))?;
95 Ok(buf)
96}
97
98pub fn encode_batches(batches: &[RecordBatch]) -> Result<Vec<u8>, IpcError> {
109 let first = batches.first().ok_or(IpcError::EmptyBatchInput)?;
110 reject_all(batches)?;
111 let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(first).saturating_mul(batches.len()));
112 write_stream(&mut buf, first.schema(), batches)?;
113 Ok(buf)
114}
115
116fn write_stream(
118 buf: &mut Vec<u8>,
119 schema: SchemaRef,
120 batches: &[RecordBatch],
121) -> Result<(), IpcError> {
122 let mut w = StreamWriter::try_new(buf, schema.as_ref())
123 .map_err(|e| IpcError::Arrow(format!("writer setup: {e}")))?;
124 for b in batches {
125 w.write(b)
126 .map_err(|e| IpcError::Arrow(format!("write batch: {e}")))?;
127 }
128 w.finish()
129 .map_err(|e| IpcError::Arrow(format!("finish: {e}")))?;
130 Ok(())
131}
132
133pub fn decode_batch(bytes: &[u8]) -> Result<Option<RecordBatch>, IpcError> {
150 let batches = read_stream(bytes, "read batch")?;
151 reject_all(&batches)?;
155 match batches.len() {
156 0 => Ok(None),
157 1 => Ok(batches.into_iter().next()),
158 n => Err(IpcError::Arrow(format!(
159 "decode_batch expects a single-batch stream, got {n} batches"
160 ))),
161 }
162}
163
164pub fn decode_batches(bytes: &[u8]) -> Result<Vec<RecordBatch>, IpcError> {
170 let batches = read_stream(bytes, "read batches")?;
171 reject_all(&batches)?;
175 Ok(batches)
176}
177
178fn read_stream(bytes: &[u8], read_label: &str) -> Result<Vec<RecordBatch>, IpcError> {
182 let reader = StreamReader::try_new(bytes, None)
183 .map_err(|e| IpcError::Arrow(format!("reader setup: {e}")))?;
184 reader
185 .collect::<Result<Vec<_>, _>>()
186 .map_err(|e| IpcError::Arrow(format!("{read_label}: {e}")))
187}
188
189fn estimate_size(batch: &RecordBatch) -> usize {
190 let rows = batch.num_rows();
192 let cols = batch.num_columns();
193 rows.saturating_mul(cols).saturating_mul(16) + 4096
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::sync::Arc;
200
201 use arrow::array::{
202 Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeBinaryArray,
203 ListArray, StringArray, StructArray, TimestampMillisecondArray,
204 };
205 use arrow::buffer::OffsetBuffer;
206 use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit};
207
208 fn schema_for(name: &str, dt: DataType) -> SchemaRef {
209 Arc::new(Schema::new(vec![Field::new(name, dt, true)]))
210 }
211
212 fn one_col_batch(name: &str, col: Arc<dyn arrow::array::Array>) -> RecordBatch {
213 let dt = col.data_type().clone();
214 let schema = schema_for(name, dt);
215 RecordBatch::try_new(schema, vec![col]).unwrap()
216 }
217
218 #[test]
219 fn round_trip_int64() {
220 let arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2, 3]));
221 let batch = one_col_batch("x", arr);
222 let encoded = encode_batch(&batch).unwrap();
223 let decoded = decode_batch(&encoded).unwrap().unwrap();
224 assert_eq!(decoded.num_rows(), 3);
225 }
226
227 #[test]
228 fn round_trip_int32_float32_float64() {
229 let schema = Arc::new(Schema::new(vec![
230 Field::new("i32", DataType::Int32, true),
231 Field::new("f32", DataType::Float32, true),
232 Field::new("f64", DataType::Float64, true),
233 ]));
234 let i: Arc<dyn arrow::array::Array> = Arc::new(Int32Array::from(vec![1, 2]));
235 let f32a: Arc<dyn arrow::array::Array> = Arc::new(Float32Array::from(vec![1.5_f32, 2.5]));
236 let f64a: Arc<dyn arrow::array::Array> = Arc::new(Float64Array::from(vec![10.5_f64, 20.5]));
237 let batch = RecordBatch::try_new(schema, vec![i, f32a, f64a]).unwrap();
238 let encoded = encode_batch(&batch).unwrap();
239 let decoded = decode_batch(&encoded).unwrap().unwrap();
240 assert_eq!(decoded.num_rows(), 2);
241 let f64_out = decoded
242 .column(2)
243 .as_any()
244 .downcast_ref::<Float64Array>()
245 .unwrap();
246 assert!((f64_out.value(1) - 20.5).abs() < f64::EPSILON);
247 }
248
249 #[test]
250 fn round_trip_utf8_strings_including_unicode() {
251 let arr: Arc<dyn arrow::array::Array> =
252 Arc::new(StringArray::from(vec!["hello", "naïve", "🌳", ""]));
253 let batch = one_col_batch("s", arr);
254 let encoded = encode_batch(&batch).unwrap();
255 let decoded = decode_batch(&encoded).unwrap().unwrap();
256 let col = decoded
257 .column(0)
258 .as_any()
259 .downcast_ref::<StringArray>()
260 .unwrap();
261 assert_eq!(col.value(2), "🌳");
262 assert_eq!(col.value(3), "");
263 }
264
265 #[test]
266 fn round_trip_booleans_with_nulls() {
267 let arr: Arc<dyn arrow::array::Array> =
268 Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));
269 let batch = one_col_batch("b", arr);
270 let encoded = encode_batch(&batch).unwrap();
271 let decoded = decode_batch(&encoded).unwrap().unwrap();
272 let col = decoded
273 .column(0)
274 .as_any()
275 .downcast_ref::<BooleanArray>()
276 .unwrap();
277 assert!(col.is_null(1));
278 assert!(col.value(0));
279 assert!(!col.value(2));
280 }
281
282 #[test]
283 fn round_trip_timestamp_ms() {
284 let arr: Arc<dyn arrow::array::Array> = Arc::new(
285 TimestampMillisecondArray::from(vec![1_700_000_000_000_i64, 1_800_000_000_000])
286 .with_timezone_opt::<&str>(None),
287 );
288 let batch = one_col_batch("ts", arr);
289 let encoded = encode_batch(&batch).unwrap();
290 let decoded = decode_batch(&encoded).unwrap().unwrap();
291 assert!(matches!(
292 decoded.schema().field(0).data_type(),
293 DataType::Timestamp(TimeUnit::Millisecond, _)
294 ));
295 }
296
297 #[test]
298 fn round_trip_large_binary_for_cypher_values() {
299 let arr: Arc<dyn arrow::array::Array> = Arc::new(LargeBinaryArray::from(vec![
300 &[1_u8, 2, 3][..],
301 &[4, 5, 6, 7],
302 ]));
303 let batch = one_col_batch("v", arr);
304 let encoded = encode_batch(&batch).unwrap();
305 let decoded = decode_batch(&encoded).unwrap().unwrap();
306 let col = decoded
307 .column(0)
308 .as_any()
309 .downcast_ref::<LargeBinaryArray>()
310 .unwrap();
311 assert_eq!(col.value(0), &[1, 2, 3]);
312 assert_eq!(col.value(1), &[4, 5, 6, 7]);
313 }
314
315 #[test]
316 fn round_trip_list_of_int64() {
317 let values: Arc<dyn arrow::array::Array> =
318 Arc::new(Int64Array::from(vec![1_i64, 2, 3, 4, 5, 6]));
319 let offsets = OffsetBuffer::new(vec![0_i32, 2, 5, 6].into());
320 let field = Arc::new(Field::new("item", DataType::Int64, true));
321 let list = ListArray::new(field, offsets, values, None);
322 let arr: Arc<dyn arrow::array::Array> = Arc::new(list);
323 let batch = one_col_batch("xs", arr);
324 let encoded = encode_batch(&batch).unwrap();
325 let decoded = decode_batch(&encoded).unwrap().unwrap();
326 let col = decoded
327 .column(0)
328 .as_any()
329 .downcast_ref::<ListArray>()
330 .unwrap();
331 assert_eq!(col.len(), 3);
332 assert_eq!(col.value_length(1), 3);
333 }
334
335 #[test]
336 fn round_trip_struct_array() {
337 let id: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![10, 20]));
338 let label: Arc<dyn arrow::array::Array> = Arc::new(StringArray::from(vec!["a", "b"]));
339 let fields = Fields::from(vec![
340 Field::new("id", DataType::Int64, false),
341 Field::new("label", DataType::Utf8, false),
342 ]);
343 let s = StructArray::new(fields, vec![id, label], None);
344 let arr: Arc<dyn arrow::array::Array> = Arc::new(s);
345 let batch = one_col_batch("rec", arr);
346 let encoded = encode_batch(&batch).unwrap();
347 let decoded = decode_batch(&encoded).unwrap().unwrap();
348 assert_eq!(decoded.num_rows(), 2);
349 assert!(matches!(
350 decoded.schema().field(0).data_type(),
351 DataType::Struct(_)
352 ));
353 }
354
355 #[test]
356 fn decode_empty_stream_returns_none() {
357 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
358 let mut buf: Vec<u8> = Vec::new();
359 {
360 let mut w = StreamWriter::try_new(&mut buf, schema.as_ref()).unwrap();
361 w.finish().unwrap();
362 }
363 assert!(decode_batch(&buf).unwrap().is_none());
364 }
365
366 #[test]
367 fn decode_garbage_bytes_is_arrow_ipc_error() {
368 let err = decode_batch(b"not arrow ipc").unwrap_err();
369 assert!(matches!(err, IpcError::Arrow(_)));
370 }
371
372 #[test]
373 fn encode_batches_emits_multiple_in_one_stream() {
374 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
375 let a: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1_i64, 2]));
376 let b: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![3_i64, 4, 5]));
377 let ba = RecordBatch::try_new(schema.clone(), vec![a]).unwrap();
378 let bb = RecordBatch::try_new(schema, vec![b]).unwrap();
379 let encoded = encode_batches(&[ba, bb]).unwrap();
380 let all = decode_batches(&encoded).unwrap();
381 assert_eq!(all.len(), 2);
382 assert_eq!(all[0].num_rows(), 2);
383 assert_eq!(all[1].num_rows(), 3);
384 }
385
386 #[test]
387 fn encode_batches_rejects_empty_input() {
388 let err = encode_batches(&[]).unwrap_err();
389 assert!(matches!(err, IpcError::EmptyBatchInput));
390 }
391
392 fn secret_tagged_field(name: &str) -> Field {
395 Field::new(name, DataType::FixedSizeBinary(8), false).with_metadata(
396 std::collections::HashMap::from([(
397 "ARROW:extension:name".to_owned(),
398 SECRET_HANDLE_EXTENSION.to_owned(),
399 )]),
400 )
401 }
402
403 #[test]
407 fn encode_batch_rejects_secret_handle_column() {
408 use arrow::array::FixedSizeBinaryArray;
409 let schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
410 let arr =
411 FixedSizeBinaryArray::try_from_iter([[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()))
412 .unwrap();
413 let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
414 match encode_batch(&batch) {
415 Ok(_) => panic!("encode_batch must reject secret-handle columns"),
416 Err(IpcError::SecretLeakAttempt { column }) => {
417 assert_eq!(column, "api_key_handle");
418 }
419 Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
420 }
421 }
422
423 #[test]
427 fn decode_batches_rejects_secret_handle_column() {
428 use arrow::array::FixedSizeBinaryArray;
429 let plain_field = Field::new("api_key_handle", DataType::FixedSizeBinary(8), false);
430 let schema = Arc::new(Schema::new(vec![plain_field]));
431 let arr =
432 FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
433 let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
434 let encoded = encode_batch(&batch).unwrap();
435 let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
439 let arr2 =
440 FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
441 let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr2)]).unwrap();
442 let mut buf: Vec<u8> = Vec::new();
445 {
446 let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
447 w.write(&tagged).unwrap();
448 w.finish().unwrap();
449 }
450 match decode_batches(&buf) {
452 Ok(_) => panic!("decode_batches must reject secret-handle columns"),
453 Err(IpcError::SecretLeakAttempt { column }) => {
454 assert_eq!(column, "api_key_handle");
455 }
456 Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
457 }
458 assert!(!encoded.is_empty());
460 }
461
462 #[test]
466 fn decode_batch_rejects_secret_handle_column() {
467 use arrow::array::FixedSizeBinaryArray;
468 let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
469 let arr =
470 FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
471 let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr)]).unwrap();
472 let mut buf: Vec<u8> = Vec::new();
475 {
476 let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
477 w.write(&tagged).unwrap();
478 w.finish().unwrap();
479 }
480 match decode_batch(&buf) {
481 Ok(_) => panic!("decode_batch must reject secret-handle columns"),
482 Err(IpcError::SecretLeakAttempt { column }) => {
483 assert_eq!(column, "api_key_handle");
484 }
485 Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
486 }
487 }
488
489 #[test]
492 fn encode_batch_rejects_secret_handle_inside_struct() {
493 use arrow::array::Int64Array;
494 let plain = Field::new("id", DataType::Int64, false);
495 let secret = secret_tagged_field("handle");
496 let struct_field = Field::new(
497 "rec",
498 DataType::Struct(Fields::from(vec![plain, secret])),
499 false,
500 );
501 let schema = Arc::new(Schema::new(vec![struct_field]));
502 let id_arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2]));
503 let secret_arr: Arc<dyn arrow::array::Array> = Arc::new(
504 arrow::array::FixedSizeBinaryArray::try_from_iter(
505 [[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()),
506 )
507 .unwrap(),
508 );
509 let s = StructArray::new(
510 Fields::from(vec![
511 Field::new("id", DataType::Int64, false),
512 Field::new("handle", DataType::FixedSizeBinary(8), false).with_metadata(
513 std::collections::HashMap::from([(
514 "ARROW:extension:name".to_owned(),
515 SECRET_HANDLE_EXTENSION.to_owned(),
516 )]),
517 ),
518 ]),
519 vec![id_arr, secret_arr],
520 None,
521 );
522 let batch = RecordBatch::try_new(schema, vec![Arc::new(s)]).unwrap();
523 match encode_batch(&batch) {
524 Ok(_) => panic!("encode_batch must reject nested secret-handle"),
525 Err(IpcError::SecretLeakAttempt { column }) => {
526 assert_eq!(column, "handle");
527 }
528 Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
529 }
530 }
531}