Skip to main content

uni_plugin_wasm_rt/
ipc.rs

1//! Arrow IPC bridge — `RecordBatch` ↔ wire-stream bytes.
2//!
3//! Both the Extism loader (bytes-in/bytes-out via `Plugin::call`) and
4//! the Component Model loader (alloc/copy/free through linear memory)
5//! cross the host↔plugin boundary by shipping Arrow IPC stream bytes.
6//! Standardizing on the wire format means the executor's columnar
7//! contract is identical regardless of which ABI delivered a batch.
8//!
9//! Host call pattern:
10//!
11//! 1. Serialize arguments / state via [`encode_batch`].
12//! 2. Pass the byte slice through the loader-specific call boundary.
13//! 3. Read the returned bytes; deserialize via [`decode_batch`] (or
14//!    [`decode_batches`] for procedure `YIELD` streaming).
15
16use arrow::array::RecordBatch;
17use arrow::ipc::reader::StreamReader;
18use arrow::ipc::writer::StreamWriter;
19use arrow_schema::SchemaRef;
20
21use crate::error::IpcError;
22
23/// FU-2: Arrow extension name tagging a `secret-handle` column.
24///
25/// Columns whose `Field::metadata` contains
26/// `"ARROW:extension:name" = SECRET_HANDLE_EXTENSION` are blocked at
27/// the IPC boundary — see `reject_secret_handles`. The host's
28/// `SecretStore` returns `secret-handle` resources via the
29/// `host.secrets.acquire` WIT import; the IPC membrane ensures those
30/// opaque handles cannot be exfiltrated as raw bytes inside a plugin's
31/// output `RecordBatch`.
32pub const SECRET_HANDLE_EXTENSION: &str = "uni-db.secret-handle";
33
34/// Arrow metadata key for extension-type names.
35const ARROW_EXTENSION_KEY: &str = "ARROW:extension:name";
36
37/// Walk every field of `batch.schema()` and return
38/// [`IpcError::SecretLeakAttempt`] if any field carries the
39/// `uni-db.secret-handle` extension marker.
40///
41/// Called on every encode and decode path ([`encode_batch`],
42/// [`encode_batches`], [`decode_batch`], [`decode_batches`]) via
43/// [`reject_all`] so neither direction can carry a secret-handle column across
44/// the wasm boundary. Nested children (struct fields, list items) are walked
45/// recursively.
46fn reject_secret_handles(batch: &RecordBatch) -> Result<(), IpcError> {
47    fn walk(field: &arrow_schema::Field) -> Result<(), IpcError> {
48        use arrow_schema::DataType;
49        if field
50            .metadata()
51            .get(ARROW_EXTENSION_KEY)
52            .map(String::as_str)
53            == Some(SECRET_HANDLE_EXTENSION)
54        {
55            return Err(IpcError::SecretLeakAttempt {
56                column: field.name().clone(),
57            });
58        }
59        match field.data_type() {
60            DataType::Struct(fields) => fields.iter().try_for_each(|f| walk(f.as_ref())),
61            DataType::List(item) | DataType::LargeList(item) | DataType::FixedSizeList(item, _) => {
62                walk(item.as_ref())
63            }
64            DataType::Map(field, _) => walk(field.as_ref()),
65            _ => Ok(()),
66        }
67    }
68    batch
69        .schema()
70        .fields()
71        .iter()
72        .try_for_each(|f| walk(f.as_ref()))
73}
74
75/// Run [`reject_secret_handles`] over every batch — the FU-2 membrane shared by
76/// all encode/decode paths so a secret-handle column is rejected regardless of
77/// single- vs multi-batch shape.
78fn reject_all(batches: &[RecordBatch]) -> Result<(), IpcError> {
79    batches.iter().try_for_each(reject_secret_handles)
80}
81
82/// Encode a `RecordBatch` as Arrow IPC stream bytes.
83///
84/// Output: schema header + one record batch + end-of-stream marker —
85/// suitable for one-shot transmission across a wasm boundary.
86///
87/// # Errors
88///
89/// Returns [`IpcError::Arrow`] if the writer cannot serialize the
90/// batch (e.g., schema-incompatible types).
91pub fn encode_batch(batch: &RecordBatch) -> Result<Vec<u8>, IpcError> {
92    reject_secret_handles(batch)?;
93    let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(batch));
94    write_stream(&mut buf, batch.schema(), std::slice::from_ref(batch))?;
95    Ok(buf)
96}
97
98/// Encode multiple `RecordBatch`es sharing a schema as one IPC stream.
99///
100/// Useful for procedure plugins that ship a series of yielded rows in
101/// one call. All batches must use the same schema (Arrow IPC stream
102/// invariant).
103///
104/// # Errors
105///
106/// - [`IpcError::EmptyBatchInput`] if `batches` is empty.
107/// - [`IpcError::Arrow`] if the writer rejects the batches.
108pub fn encode_batches(batches: &[RecordBatch]) -> Result<Vec<u8>, IpcError> {
109    let first = batches.first().ok_or(IpcError::EmptyBatchInput)?;
110    reject_all(batches)?;
111    let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(first).saturating_mul(batches.len()));
112    write_stream(&mut buf, first.schema(), batches)?;
113    Ok(buf)
114}
115
116/// Write `batches` (assumed to share `schema`) to `buf` as one IPC stream.
117fn write_stream(
118    buf: &mut Vec<u8>,
119    schema: SchemaRef,
120    batches: &[RecordBatch],
121) -> Result<(), IpcError> {
122    let mut w = StreamWriter::try_new(buf, schema.as_ref())
123        .map_err(|e| IpcError::Arrow(format!("writer setup: {e}")))?;
124    for b in batches {
125        w.write(b)
126            .map_err(|e| IpcError::Arrow(format!("write batch: {e}")))?;
127    }
128    w.finish()
129        .map_err(|e| IpcError::Arrow(format!("finish: {e}")))?;
130    Ok(())
131}
132
133/// Decode the single `RecordBatch` from Arrow IPC stream bytes.
134///
135/// `encode_batch` writes exactly one batch, so any well-formed stream
136/// from this codec carries one batch (or zero, when the plugin produced
137/// no rows). Multiple batches indicate a malformed or malicious sender
138/// and are rejected.
139///
140/// Returns `None` if the stream contained only an end-of-stream marker.
141///
142/// # Errors
143///
144/// Returns [`IpcError::Arrow`] if the bytes are malformed or if the
145/// stream contains more than one batch. The previous form used
146/// `Vec::pop()` and silently returned the *last* batch when more than
147/// one was present, contradicting the "first batch" contract its
148/// documentation promised.
149pub fn decode_batch(bytes: &[u8]) -> Result<Option<RecordBatch>, IpcError> {
150    let batches = read_stream(bytes, "read batch")?;
151    // FU-2: a single-batch stream is still an inbound boundary — reject any
152    // secret-handle column, symmetric with `decode_batches` / `encode_batch`.
153    // (decode_batch is the hot path used by every scalar/aggregate adapter.)
154    reject_all(&batches)?;
155    match batches.len() {
156        0 => Ok(None),
157        1 => Ok(batches.into_iter().next()),
158        n => Err(IpcError::Arrow(format!(
159            "decode_batch expects a single-batch stream, got {n} batches"
160        ))),
161    }
162}
163
164/// Decode every `RecordBatch` from Arrow IPC stream bytes.
165///
166/// # Errors
167///
168/// Returns [`IpcError::Arrow`] if the bytes are malformed.
169pub fn decode_batches(bytes: &[u8]) -> Result<Vec<RecordBatch>, IpcError> {
170    let batches = read_stream(bytes, "read batches")?;
171    // FU-2: reject any incoming batch that carries a secret-handle column.
172    // Symmetric with the encode path so a malicious plugin can't smuggle a
173    // handle back across the boundary either.
174    reject_all(&batches)?;
175    Ok(batches)
176}
177
178/// Build a `StreamReader` over `bytes` and collect all batches.
179/// `read_label` is used only for error messages so each caller's
180/// failure context (`"read batch"` vs `"read batches"`) is preserved.
181fn read_stream(bytes: &[u8], read_label: &str) -> Result<Vec<RecordBatch>, IpcError> {
182    let reader = StreamReader::try_new(bytes, None)
183        .map_err(|e| IpcError::Arrow(format!("reader setup: {e}")))?;
184    reader
185        .collect::<Result<Vec<_>, _>>()
186        .map_err(|e| IpcError::Arrow(format!("{read_label}: {e}")))
187}
188
189fn estimate_size(batch: &RecordBatch) -> usize {
190    // ~16 bytes/cell + 4 KiB schema overhead. Writer grows on demand.
191    let rows = batch.num_rows();
192    let cols = batch.num_columns();
193    rows.saturating_mul(cols).saturating_mul(16) + 4096
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use std::sync::Arc;
200
201    use arrow::array::{
202        Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeBinaryArray,
203        ListArray, StringArray, StructArray, TimestampMillisecondArray,
204    };
205    use arrow::buffer::OffsetBuffer;
206    use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit};
207
208    fn schema_for(name: &str, dt: DataType) -> SchemaRef {
209        Arc::new(Schema::new(vec![Field::new(name, dt, true)]))
210    }
211
212    fn one_col_batch(name: &str, col: Arc<dyn arrow::array::Array>) -> RecordBatch {
213        let dt = col.data_type().clone();
214        let schema = schema_for(name, dt);
215        RecordBatch::try_new(schema, vec![col]).unwrap()
216    }
217
218    #[test]
219    fn round_trip_int64() {
220        let arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2, 3]));
221        let batch = one_col_batch("x", arr);
222        let encoded = encode_batch(&batch).unwrap();
223        let decoded = decode_batch(&encoded).unwrap().unwrap();
224        assert_eq!(decoded.num_rows(), 3);
225    }
226
227    #[test]
228    fn round_trip_int32_float32_float64() {
229        let schema = Arc::new(Schema::new(vec![
230            Field::new("i32", DataType::Int32, true),
231            Field::new("f32", DataType::Float32, true),
232            Field::new("f64", DataType::Float64, true),
233        ]));
234        let i: Arc<dyn arrow::array::Array> = Arc::new(Int32Array::from(vec![1, 2]));
235        let f32a: Arc<dyn arrow::array::Array> = Arc::new(Float32Array::from(vec![1.5_f32, 2.5]));
236        let f64a: Arc<dyn arrow::array::Array> = Arc::new(Float64Array::from(vec![10.5_f64, 20.5]));
237        let batch = RecordBatch::try_new(schema, vec![i, f32a, f64a]).unwrap();
238        let encoded = encode_batch(&batch).unwrap();
239        let decoded = decode_batch(&encoded).unwrap().unwrap();
240        assert_eq!(decoded.num_rows(), 2);
241        let f64_out = decoded
242            .column(2)
243            .as_any()
244            .downcast_ref::<Float64Array>()
245            .unwrap();
246        assert!((f64_out.value(1) - 20.5).abs() < f64::EPSILON);
247    }
248
249    #[test]
250    fn round_trip_utf8_strings_including_unicode() {
251        let arr: Arc<dyn arrow::array::Array> =
252            Arc::new(StringArray::from(vec!["hello", "naïve", "🌳", ""]));
253        let batch = one_col_batch("s", arr);
254        let encoded = encode_batch(&batch).unwrap();
255        let decoded = decode_batch(&encoded).unwrap().unwrap();
256        let col = decoded
257            .column(0)
258            .as_any()
259            .downcast_ref::<StringArray>()
260            .unwrap();
261        assert_eq!(col.value(2), "🌳");
262        assert_eq!(col.value(3), "");
263    }
264
265    #[test]
266    fn round_trip_booleans_with_nulls() {
267        let arr: Arc<dyn arrow::array::Array> =
268            Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));
269        let batch = one_col_batch("b", arr);
270        let encoded = encode_batch(&batch).unwrap();
271        let decoded = decode_batch(&encoded).unwrap().unwrap();
272        let col = decoded
273            .column(0)
274            .as_any()
275            .downcast_ref::<BooleanArray>()
276            .unwrap();
277        assert!(col.is_null(1));
278        assert!(col.value(0));
279        assert!(!col.value(2));
280    }
281
282    #[test]
283    fn round_trip_timestamp_ms() {
284        let arr: Arc<dyn arrow::array::Array> = Arc::new(
285            TimestampMillisecondArray::from(vec![1_700_000_000_000_i64, 1_800_000_000_000])
286                .with_timezone_opt::<&str>(None),
287        );
288        let batch = one_col_batch("ts", arr);
289        let encoded = encode_batch(&batch).unwrap();
290        let decoded = decode_batch(&encoded).unwrap().unwrap();
291        assert!(matches!(
292            decoded.schema().field(0).data_type(),
293            DataType::Timestamp(TimeUnit::Millisecond, _)
294        ));
295    }
296
297    #[test]
298    fn round_trip_large_binary_for_cypher_values() {
299        let arr: Arc<dyn arrow::array::Array> = Arc::new(LargeBinaryArray::from(vec![
300            &[1_u8, 2, 3][..],
301            &[4, 5, 6, 7],
302        ]));
303        let batch = one_col_batch("v", arr);
304        let encoded = encode_batch(&batch).unwrap();
305        let decoded = decode_batch(&encoded).unwrap().unwrap();
306        let col = decoded
307            .column(0)
308            .as_any()
309            .downcast_ref::<LargeBinaryArray>()
310            .unwrap();
311        assert_eq!(col.value(0), &[1, 2, 3]);
312        assert_eq!(col.value(1), &[4, 5, 6, 7]);
313    }
314
315    #[test]
316    fn round_trip_list_of_int64() {
317        let values: Arc<dyn arrow::array::Array> =
318            Arc::new(Int64Array::from(vec![1_i64, 2, 3, 4, 5, 6]));
319        let offsets = OffsetBuffer::new(vec![0_i32, 2, 5, 6].into());
320        let field = Arc::new(Field::new("item", DataType::Int64, true));
321        let list = ListArray::new(field, offsets, values, None);
322        let arr: Arc<dyn arrow::array::Array> = Arc::new(list);
323        let batch = one_col_batch("xs", arr);
324        let encoded = encode_batch(&batch).unwrap();
325        let decoded = decode_batch(&encoded).unwrap().unwrap();
326        let col = decoded
327            .column(0)
328            .as_any()
329            .downcast_ref::<ListArray>()
330            .unwrap();
331        assert_eq!(col.len(), 3);
332        assert_eq!(col.value_length(1), 3);
333    }
334
335    #[test]
336    fn round_trip_struct_array() {
337        let id: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![10, 20]));
338        let label: Arc<dyn arrow::array::Array> = Arc::new(StringArray::from(vec!["a", "b"]));
339        let fields = Fields::from(vec![
340            Field::new("id", DataType::Int64, false),
341            Field::new("label", DataType::Utf8, false),
342        ]);
343        let s = StructArray::new(fields, vec![id, label], None);
344        let arr: Arc<dyn arrow::array::Array> = Arc::new(s);
345        let batch = one_col_batch("rec", arr);
346        let encoded = encode_batch(&batch).unwrap();
347        let decoded = decode_batch(&encoded).unwrap().unwrap();
348        assert_eq!(decoded.num_rows(), 2);
349        assert!(matches!(
350            decoded.schema().field(0).data_type(),
351            DataType::Struct(_)
352        ));
353    }
354
355    #[test]
356    fn decode_empty_stream_returns_none() {
357        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
358        let mut buf: Vec<u8> = Vec::new();
359        {
360            let mut w = StreamWriter::try_new(&mut buf, schema.as_ref()).unwrap();
361            w.finish().unwrap();
362        }
363        assert!(decode_batch(&buf).unwrap().is_none());
364    }
365
366    #[test]
367    fn decode_garbage_bytes_is_arrow_ipc_error() {
368        let err = decode_batch(b"not arrow ipc").unwrap_err();
369        assert!(matches!(err, IpcError::Arrow(_)));
370    }
371
372    #[test]
373    fn encode_batches_emits_multiple_in_one_stream() {
374        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
375        let a: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1_i64, 2]));
376        let b: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![3_i64, 4, 5]));
377        let ba = RecordBatch::try_new(schema.clone(), vec![a]).unwrap();
378        let bb = RecordBatch::try_new(schema, vec![b]).unwrap();
379        let encoded = encode_batches(&[ba, bb]).unwrap();
380        let all = decode_batches(&encoded).unwrap();
381        assert_eq!(all.len(), 2);
382        assert_eq!(all[0].num_rows(), 2);
383        assert_eq!(all[1].num_rows(), 3);
384    }
385
386    #[test]
387    fn encode_batches_rejects_empty_input() {
388        let err = encode_batches(&[]).unwrap_err();
389        assert!(matches!(err, IpcError::EmptyBatchInput));
390    }
391
392    // ── FU-2: secret-handle leak rejection ─────────────────────────
393
394    fn secret_tagged_field(name: &str) -> Field {
395        Field::new(name, DataType::FixedSizeBinary(8), false).with_metadata(
396            std::collections::HashMap::from([(
397                "ARROW:extension:name".to_owned(),
398                SECRET_HANDLE_EXTENSION.to_owned(),
399            )]),
400        )
401    }
402
403    /// FU-2 acceptance: `encode_batch` refuses any column tagged with
404    /// the `uni-db.secret-handle` Arrow extension and returns
405    /// `IpcError::SecretLeakAttempt` naming the offending column.
406    #[test]
407    fn encode_batch_rejects_secret_handle_column() {
408        use arrow::array::FixedSizeBinaryArray;
409        let schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
410        let arr =
411            FixedSizeBinaryArray::try_from_iter([[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()))
412                .unwrap();
413        let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
414        match encode_batch(&batch) {
415            Ok(_) => panic!("encode_batch must reject secret-handle columns"),
416            Err(IpcError::SecretLeakAttempt { column }) => {
417                assert_eq!(column, "api_key_handle");
418            }
419            Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
420        }
421    }
422
423    /// FU-2 acceptance: `decode_batches` symmetrically rejects an
424    /// incoming stream that smuggles a secret-handle column back
425    /// across the boundary.
426    #[test]
427    fn decode_batches_rejects_secret_handle_column() {
428        use arrow::array::FixedSizeBinaryArray;
429        let plain_field = Field::new("api_key_handle", DataType::FixedSizeBinary(8), false);
430        let schema = Arc::new(Schema::new(vec![plain_field]));
431        let arr =
432            FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
433        let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
434        let encoded = encode_batch(&batch).unwrap();
435        // Now corrupt the encoded bytes by re-encoding with the
436        // extension marker present. This simulates a hostile plugin
437        // tagging its output column to try to exfiltrate a handle.
438        let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
439        let arr2 =
440            FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
441        let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr2)]).unwrap();
442        // Build the tagged stream directly (bypassing `encode_batch`
443        // which would have rejected it).
444        let mut buf: Vec<u8> = Vec::new();
445        {
446            let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
447            w.write(&tagged).unwrap();
448            w.finish().unwrap();
449        }
450        // The decode side must reject it.
451        match decode_batches(&buf) {
452            Ok(_) => panic!("decode_batches must reject secret-handle columns"),
453            Err(IpcError::SecretLeakAttempt { column }) => {
454                assert_eq!(column, "api_key_handle");
455            }
456            Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
457        }
458        // Sanity-check: encoding the *un-tagged* version works.
459        assert!(!encoded.is_empty());
460    }
461
462    /// FU-2 regression: the single-batch `decode_batch` path (the hot path for
463    /// every scalar/aggregate adapter) must reject a smuggled secret-handle
464    /// column too — not just the multi-batch `decode_batches`.
465    #[test]
466    fn decode_batch_rejects_secret_handle_column() {
467        use arrow::array::FixedSizeBinaryArray;
468        let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
469        let arr =
470            FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
471        let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr)]).unwrap();
472        // Build a single-batch tagged stream directly (bypassing `encode_batch`,
473        // which would have rejected it on the way out).
474        let mut buf: Vec<u8> = Vec::new();
475        {
476            let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
477            w.write(&tagged).unwrap();
478            w.finish().unwrap();
479        }
480        match decode_batch(&buf) {
481            Ok(_) => panic!("decode_batch must reject secret-handle columns"),
482            Err(IpcError::SecretLeakAttempt { column }) => {
483                assert_eq!(column, "api_key_handle");
484            }
485            Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
486        }
487    }
488
489    /// FU-2 acceptance: nested struct/list fields are walked, so a
490    /// plugin can't bury a secret-handle inside a struct column.
491    #[test]
492    fn encode_batch_rejects_secret_handle_inside_struct() {
493        use arrow::array::Int64Array;
494        let plain = Field::new("id", DataType::Int64, false);
495        let secret = secret_tagged_field("handle");
496        let struct_field = Field::new(
497            "rec",
498            DataType::Struct(Fields::from(vec![plain, secret])),
499            false,
500        );
501        let schema = Arc::new(Schema::new(vec![struct_field]));
502        let id_arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2]));
503        let secret_arr: Arc<dyn arrow::array::Array> = Arc::new(
504            arrow::array::FixedSizeBinaryArray::try_from_iter(
505                [[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()),
506            )
507            .unwrap(),
508        );
509        let s = StructArray::new(
510            Fields::from(vec![
511                Field::new("id", DataType::Int64, false),
512                Field::new("handle", DataType::FixedSizeBinary(8), false).with_metadata(
513                    std::collections::HashMap::from([(
514                        "ARROW:extension:name".to_owned(),
515                        SECRET_HANDLE_EXTENSION.to_owned(),
516                    )]),
517                ),
518            ]),
519            vec![id_arr, secret_arr],
520            None,
521        );
522        let batch = RecordBatch::try_new(schema, vec![Arc::new(s)]).unwrap();
523        match encode_batch(&batch) {
524            Ok(_) => panic!("encode_batch must reject nested secret-handle"),
525            Err(IpcError::SecretLeakAttempt { column }) => {
526                assert_eq!(column, "handle");
527            }
528            Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
529        }
530    }
531}