Skip to main content

vgi_rpc/
wire.rs

1//! Low-level IPC stream helpers that preserve per-batch custom metadata.
2//!
3//! The standard `arrow-ipc` `StreamWriter` / `StreamReader` types do not
4//! expose per-message `custom_metadata`, but the vgi_rpc wire protocol
5//! relies on that field to carry `vgi_rpc.method`,
6//! `vgi_rpc.request_version`, log keys, externalisation pointers,
7//! state tokens, etc. This module hand-rolls the framing layer so the
8//! crate can depend on **stock** arrow-rs from crates.io rather than a
9//! patched fork — the published vgi-rpc crate is therefore directly
10//! installable without any `[patch.crates-io]` directives downstream.
11//!
12//! Internally we delegate column encoding / decoding to
13//! [`arrow_ipc::writer::IpcDataGenerator`] and the
14//! [`arrow_ipc::reader::read_record_batch`] / `read_dictionary`
15//! functions, and only intercept the flatbuffer `Message` wrapper to
16//! attach / extract `custom_metadata`. That keeps the code small and
17//! the on-wire bytes byte-for-byte compatible with arrow-rs's
18//! `StreamWriter`.
19//!
20//! ## DoS guard
21//!
22//! [`StreamReader::new`] pre-validates the schema-message length prefix
23//! against [`MAX_IPC_SCHEMA_BYTES`] *before* allocating; a remote
24//! client cannot trigger a multi-gigabyte alloc by sending a crafted
25//! 4-byte payload. Each subsequent message body is also bounded by
26//! [`MAX_IPC_MESSAGE_BYTES`] before we allocate, mitigating the
27//! flatbuffer-`bodyLength` overshoot that the fuzz harness surfaced.
28
29use std::collections::HashMap;
30use std::io::{Read, Write};
31use std::sync::Arc;
32
33use arrow_array::RecordBatch;
34use arrow_buffer::Buffer as ArrowBuffer;
35use arrow_ipc::reader as ipc_reader;
36use arrow_ipc::writer::{write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
37use arrow_ipc::{convert as ipc_convert, root_as_message, MessageHeader};
38use arrow_schema::{Schema, SchemaRef};
39use flatbuffers::FlatBufferBuilder;
40
41use crate::errors::{Result, RpcError};
42
43/// Per-batch metadata pairs. Order is not preserved across
44/// serialisation; that matches Python's `RecordBatch.custom_metadata`
45/// semantics.
46pub type Metadata = HashMap<String, String>;
47
48/// Look up a key in a [`Metadata`] map, returning the value as `&str`.
49#[inline]
50pub fn md_get<'a>(md: &'a Metadata, key: &str) -> Option<&'a str> {
51    md.get(key).map(String::as_str)
52}
53
54/// Maximum permitted size, in bytes, of the schema-message flatbuffer
55/// at the head of an IPC stream. Schemas are typically tens to
56/// hundreds of bytes; 16 MiB is gracious headroom that still rejects
57/// the crafted 4-byte input `[0x1A, 0x2C, 0xF5, 0x2C]` that
58/// `fuzz/wire_stream_reader` discovered would OOM the process by
59/// claiming a ~720 MB schema. Applies to the *schema* message length
60/// prefix on the wire.
61pub const MAX_IPC_SCHEMA_BYTES: usize = 16 * 1024 * 1024;
62
63/// Maximum permitted total size of any per-batch IPC message (header
64/// flatbuffer + body bytes). Default 256 MiB — large enough for any
65/// reasonable Arrow batch, small enough to refuse the
66/// `bodyLength = 0x4000000100000` overshoot the fuzz harness
67/// surfaced.
68pub const MAX_IPC_MESSAGE_BYTES: usize = 256 * 1024 * 1024;
69
70// ---------------------------------------------------------------------------
71// Writer
72// ---------------------------------------------------------------------------
73
74const CONTINUATION_MARKER: [u8; 4] = [0xFF, 0xFF, 0xFF, 0xFF];
75
76/// A streaming IPC writer that supports per-batch custom metadata.
77///
78/// The byte sequence written for a complete stream is:
79///   `SchemaMessage → [DictionaryMessage]* → [RecordBatchMessage]* → EOS(4xFF 0x00)`.
80///
81/// Each call to [`write`](Self::write) emits one record-batch message
82/// (preceded by any newly-needed dictionary messages) with its
83/// `custom_metadata` attached at the IPC Message level.
84pub struct StreamWriter<W: Write> {
85    writer: W,
86    schema: SchemaRef,
87    opts: IpcWriteOptions,
88    data_gen: IpcDataGenerator,
89    dict_tracker: DictionaryTracker,
90    finished: bool,
91}
92
93impl<W: Write> StreamWriter<W> {
94    /// Create a new writer and emit the schema message.
95    pub fn new(mut writer: W, schema: &Schema) -> Result<Self> {
96        let opts = IpcWriteOptions::default();
97        let data_gen = IpcDataGenerator::default();
98        let mut dict_tracker = DictionaryTracker::new(false);
99        let encoded =
100            data_gen.schema_to_bytes_with_dictionary_tracker(schema, &mut dict_tracker, &opts);
101        write_message(&mut writer, encoded, &opts)?;
102        Ok(Self {
103            writer,
104            schema: Arc::new(schema.clone()),
105            opts,
106            data_gen,
107            dict_tracker,
108            finished: false,
109        })
110    }
111
112    /// Write one RecordBatch carrying optional `metadata` as the IPC
113    /// Message-level `custom_metadata` field. Pass `None` to omit the
114    /// field (saves a few bytes per message).
115    pub fn write(&mut self, batch: &RecordBatch, metadata: Option<&Metadata>) -> Result<()> {
116        if self.finished {
117            return Err(RpcError::new("IOError", "writer already finished"));
118        }
119        let mut ctx = Default::default();
120        let (dicts, data) = self
121            .data_gen
122            .encode(batch, &mut self.dict_tracker, &self.opts, &mut ctx)
123            .map_err(RpcError::from)?;
124        for d in dicts {
125            write_message(&mut self.writer, d, &self.opts).map_err(RpcError::from)?;
126        }
127        if let Some(md) = metadata.filter(|m| !m.is_empty()) {
128            let new_msg = repack_record_batch_message_with_metadata(&data.ipc_message, md)?;
129            let encoded = arrow_ipc::writer::EncodedData {
130                ipc_message: new_msg,
131                arrow_data: data.arrow_data,
132            };
133            write_message(&mut self.writer, encoded, &self.opts).map_err(RpcError::from)?;
134        } else {
135            write_message(&mut self.writer, data, &self.opts).map_err(RpcError::from)?;
136        }
137        Ok(())
138    }
139
140    /// Return the schema this writer was opened with.
141    pub fn schema(&self) -> SchemaRef {
142        self.schema.clone()
143    }
144
145    /// Write the EOS continuation marker. Idempotent.
146    pub fn finish(&mut self) -> Result<()> {
147        if self.finished {
148            return Ok(());
149        }
150        self.writer.write_all(&CONTINUATION_MARKER)?;
151        self.writer.write_all(&[0u8; 4])?;
152        self.writer.flush()?;
153        self.finished = true;
154        Ok(())
155    }
156
157    /// Flush the underlying writer.
158    pub fn flush(&mut self) -> Result<()> {
159        self.writer.flush()?;
160        Ok(())
161    }
162
163    pub fn get_mut(&mut self) -> &mut W {
164        &mut self.writer
165    }
166}
167
168impl<W: Write> Drop for StreamWriter<W> {
169    fn drop(&mut self) {
170        let _ = self.finish();
171    }
172}
173
174/// Rebuild a Message flatbuffer with `custom_metadata` added,
175/// preserving the embedded RecordBatch header unchanged.
176fn repack_record_batch_message_with_metadata(
177    msg_bytes: &[u8],
178    metadata: &Metadata,
179) -> Result<Vec<u8>> {
180    use arrow_ipc::{
181        Buffer as FbBuffer, FieldNode, KeyValue, KeyValueArgs, MessageBuilder, RecordBatchBuilder,
182    };
183
184    let msg = root_as_message(msg_bytes)
185        .map_err(|e| RpcError::new("IPC", format!("parsing message: {e}")))?;
186    let version = msg.version();
187    let header_type = msg.header_type();
188    let body_length = msg.bodyLength();
189    if header_type != MessageHeader::RecordBatch {
190        return Err(RpcError::new(
191            "IPC",
192            format!("repack expected RecordBatch header, got {header_type:?}"),
193        ));
194    }
195    let rb = msg
196        .header_as_record_batch()
197        .ok_or_else(|| RpcError::new("IPC", "missing RecordBatch header"))?;
198
199    let mut fbb = FlatBufferBuilder::new();
200
201    let src_nodes = rb
202        .nodes()
203        .ok_or_else(|| RpcError::new("IPC", "RecordBatch missing nodes"))?;
204    let nodes: Vec<FieldNode> = src_nodes.iter().copied().collect();
205    let nodes_vec = fbb.create_vector(&nodes);
206
207    let src_buffers = rb
208        .buffers()
209        .ok_or_else(|| RpcError::new("IPC", "RecordBatch missing buffers"))?;
210    let buffers: Vec<FbBuffer> = src_buffers.iter().copied().collect();
211    let buffers_vec = fbb.create_vector(&buffers);
212
213    let variadic_vec = rb.variadicBufferCounts().map(|v| {
214        let counts: Vec<i64> = v.iter().collect();
215        fbb.create_vector(&counts)
216    });
217
218    let new_rb = {
219        let mut b = RecordBatchBuilder::new(&mut fbb);
220        b.add_length(rb.length());
221        b.add_nodes(nodes_vec);
222        b.add_buffers(buffers_vec);
223        if let Some(v) = variadic_vec {
224            b.add_variadicBufferCounts(v);
225        }
226        // Note: we don't carry compression here; the conformance worker
227        // does not enable IPC batch compression, so this is safe.
228        b.finish()
229    };
230
231    // Build custom_metadata vector. Order matches HashMap iteration —
232    // not stable, but that matches both upstream arrow-ipc and Python
233    // `RecordBatch.custom_metadata` semantics.
234    let kvs: Vec<_> = metadata
235        .iter()
236        .map(|(k, v)| {
237            let k_off = fbb.create_string(k);
238            let v_off = fbb.create_string(v);
239            KeyValue::create(
240                &mut fbb,
241                &KeyValueArgs {
242                    key: Some(k_off),
243                    value: Some(v_off),
244                },
245            )
246        })
247        .collect();
248    let md_vec = fbb.create_vector(&kvs);
249
250    let mut mb = MessageBuilder::new(&mut fbb);
251    mb.add_version(version);
252    mb.add_header_type(header_type);
253    mb.add_header(new_rb.as_union_value());
254    mb.add_bodyLength(body_length);
255    mb.add_custom_metadata(md_vec);
256    let m = mb.finish();
257    fbb.finish(m, None);
258    Ok(fbb.finished_data().to_vec())
259}
260
261// ---------------------------------------------------------------------------
262// Reader
263// ---------------------------------------------------------------------------
264
265/// A streaming IPC reader that surfaces per-message custom metadata.
266///
267/// [`read_next`](Self::read_next) returns `Some((batch, metadata))`
268/// for each RecordBatch message and `None` on end-of-stream.
269/// Dictionary and schema messages are consumed transparently.
270pub struct StreamReader<R: Read> {
271    reader: R,
272    schema: SchemaRef,
273    dictionaries: HashMap<i64, arrow_array::ArrayRef>,
274    finished: bool,
275    /// When `Some`, every read batch is rewrapped with this relaxed
276    /// schema before being returned to the caller (used by the
277    /// conformance worker to accept Python's nullable-flag-lying
278    /// `ArrowSerializableDataclass` outputs).
279    relaxed_schema: Option<SchemaRef>,
280}
281
282impl<R: Read> StreamReader<R> {
283    /// Create a new reader and consume the schema message.
284    ///
285    /// The schema-message length prefix is validated against
286    /// [`MAX_IPC_SCHEMA_BYTES`] *before* allocating, so a remote
287    /// client cannot trigger a multi-gigabyte alloc by sending a
288    /// crafted short payload.
289    pub fn new(mut reader: R) -> Result<Self> {
290        let msg = read_message_bytes(&mut reader, MAX_IPC_SCHEMA_BYTES)?
291            .ok_or_else(|| RpcError::new("IPC", "empty IPC stream (no schema)"))?;
292        let msg_fb = root_as_message(&msg.message_bytes)
293            .map_err(|e| RpcError::new("IPC", format!("parse schema message: {e}")))?;
294        if msg_fb.header_type() != MessageHeader::Schema {
295            return Err(RpcError::new(
296                "IPC",
297                format!("expected Schema, got {:?}", msg_fb.header_type()),
298            ));
299        }
300        let ipc_schema = msg_fb
301            .header_as_schema()
302            .ok_or_else(|| RpcError::new("IPC", "bad schema header"))?;
303        let schema = ipc_convert::fb_to_schema(ipc_schema);
304        Ok(Self {
305            reader,
306            schema: Arc::new(schema),
307            dictionaries: HashMap::new(),
308            finished: false,
309            relaxed_schema: None,
310        })
311    }
312
313    /// Get the schema of the stream (relaxed schema, if relaxation was
314    /// requested).
315    pub fn schema(&self) -> SchemaRef {
316        self.relaxed_schema
317            .clone()
318            .unwrap_or_else(|| self.schema.clone())
319    }
320
321    /// Promote every field in the stream's schema to `nullable = true`,
322    /// recursively (lists, structs, fixed-size lists). Use when a
323    /// producer declares a field non-nullable but legitimately sends
324    /// nulls — e.g. Python's `ArrowSerializableDataclass` for
325    /// `Annotated[T | None, ArrowType(...)]`.
326    pub fn relax_nullability(mut self) -> Self {
327        self.relaxed_schema = Some(Arc::new(relax_schema_nullability(self.schema.as_ref())));
328        self
329    }
330
331    /// Read the next record batch, or `None` on end-of-stream.
332    /// Returns `(batch, metadata)` where `metadata` carries the IPC
333    /// Message-level `custom_metadata` (empty when the producer
334    /// omitted the field).
335    pub fn read_next(&mut self) -> Result<Option<(RecordBatch, Metadata)>> {
336        if self.finished {
337            return Ok(None);
338        }
339        loop {
340            let msg = match read_message_bytes(&mut self.reader, MAX_IPC_MESSAGE_BYTES)? {
341                Some(m) => m,
342                None => {
343                    self.finished = true;
344                    return Ok(None);
345                }
346            };
347            let msg_fb = root_as_message(&msg.message_bytes)
348                .map_err(|e| RpcError::new("IPC", format!("parse message: {e}")))?;
349            let version = msg_fb.version();
350            match msg_fb.header_type() {
351                MessageHeader::DictionaryBatch => {
352                    let dict = msg_fb
353                        .header_as_dictionary_batch()
354                        .ok_or_else(|| RpcError::new("IPC", "bad dictionary header"))?;
355                    let body_buf = ArrowBuffer::from_vec(msg.body);
356                    // Reject buffer descriptors that point outside the
357                    // body *before* handing them to arrow-ipc, which
358                    // would otherwise panic on an out-of-bounds slice.
359                    if let Some(data) = dict.data() {
360                        validate_record_batch_buffers(&data, body_buf.len())?;
361                    }
362                    // arrow-ipc's decoder still has internal invariants
363                    // we don't re-check; `catch_unwind` is the backstop
364                    // that turns any residual panic into a clean error.
365                    decode_guard("dictionary batch", || {
366                        ipc_reader::read_dictionary(
367                            &body_buf,
368                            dict,
369                            self.schema.as_ref(),
370                            &mut self.dictionaries,
371                            &version,
372                        )
373                    })?
374                    .map_err(RpcError::from)?;
375                }
376                MessageHeader::RecordBatch => {
377                    let rb_fb = msg_fb
378                        .header_as_record_batch()
379                        .ok_or_else(|| RpcError::new("IPC", "bad record batch header"))?;
380                    let body_buf = ArrowBuffer::from_vec(msg.body);
381                    validate_record_batch_buffers(&rb_fb, body_buf.len())?;
382                    // When relaxation is in effect, feed the relaxed
383                    // schema directly to `read_record_batch` so its
384                    // validation accepts the legitimate null buffers
385                    // a producer (e.g. Python
386                    // `ArrowSerializableDataclass`) emits for fields
387                    // it declared `nullable=false`.
388                    let decode_schema = self
389                        .relaxed_schema
390                        .clone()
391                        .unwrap_or_else(|| self.schema.clone());
392                    let batch = decode_guard("record batch", || {
393                        ipc_reader::read_record_batch(
394                            &body_buf,
395                            rb_fb,
396                            decode_schema,
397                            &self.dictionaries,
398                            None,
399                            &version,
400                        )
401                    })?
402                    .map_err(RpcError::from)?;
403                    let metadata = parse_custom_metadata(&msg_fb);
404                    return Ok(Some((batch, metadata)));
405                }
406                MessageHeader::Schema => {
407                    return Err(RpcError::new("IPC", "unexpected schema message mid-stream"));
408                }
409                MessageHeader::NONE => continue,
410                other => {
411                    return Err(RpcError::new(
412                        "IPC",
413                        format!("unsupported message type {other:?}"),
414                    ));
415                }
416            }
417        }
418    }
419
420    /// Drain and discard any remaining messages.
421    pub fn drain(&mut self) -> Result<()> {
422        while self.read_next()?.is_some() {}
423        Ok(())
424    }
425
426    pub fn get_mut(&mut self) -> &mut R {
427        &mut self.reader
428    }
429}
430
431fn parse_custom_metadata(msg: &arrow_ipc::Message) -> Metadata {
432    let mut out = Metadata::new();
433    if let Some(md) = msg.custom_metadata() {
434        for kv in md.iter() {
435            let k = kv.key().unwrap_or("").to_string();
436            let v = kv.value().unwrap_or("").to_string();
437            out.insert(k, v);
438        }
439    }
440    out
441}
442
443/// Validate that every `(offset, length)` buffer descriptor in an IPC
444/// record-batch header references a region wholly inside the message
445/// body. arrow-ipc's column decoders index into the body using these
446/// descriptors verbatim and will panic (slice out-of-bounds / arithmetic
447/// overflow) on a crafted frame whose descriptors are inconsistent with
448/// the body it shipped. Catching that here turns a hostile frame into a
449/// clean `RpcError` instead of a thread panic.
450fn validate_record_batch_buffers(rb: &arrow_ipc::RecordBatch, body_len: usize) -> Result<()> {
451    if let Some(buffers) = rb.buffers() {
452        for buf in buffers.iter() {
453            let offset = buf.offset();
454            let length = buf.length();
455            if offset < 0 || length < 0 {
456                return Err(RpcError::new("IPC", "negative IPC buffer descriptor"));
457            }
458            let end = (offset as u64)
459                .checked_add(length as u64)
460                .ok_or_else(|| RpcError::new("IPC", "IPC buffer descriptor overflows"))?;
461            if end > body_len as u64 {
462                return Err(RpcError::new(
463                    "IPC",
464                    "IPC buffer descriptor exceeds message body",
465                ));
466            }
467        }
468    }
469    Ok(())
470}
471
472/// Run an arrow-ipc decode call, converting any panic into a clean
473/// `RpcError`. The descriptor pre-validation above catches the common
474/// crafted-frame cases; this is the defence-in-depth net for any other
475/// internal arrow-ipc invariant a hostile frame might trip.
476fn decode_guard<T>(what: &str, f: impl FnOnce() -> T) -> Result<T> {
477    std::panic::catch_unwind(std::panic::AssertUnwindSafe(f))
478        .map_err(|_| RpcError::new("IPC", format!("panic decoding {what} (malformed frame)")))
479}
480
481struct RawMessage {
482    message_bytes: Vec<u8>,
483    body: Vec<u8>,
484}
485
486fn read_exact(r: &mut impl Read, buf: &mut [u8]) -> Result<bool> {
487    let mut read = 0;
488    while read < buf.len() {
489        match r.read(&mut buf[read..]) {
490            Ok(0) => {
491                if read == 0 {
492                    return Ok(false);
493                }
494                return Err(RpcError::new("IOError", "unexpected EOF in IPC message"));
495            }
496            Ok(n) => read += n,
497            Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
498            Err(e) => return Err(e.into()),
499        }
500    }
501    Ok(true)
502}
503
504/// Read one IPC message off `r`, capping the header and body at
505/// `max_bytes` so a crafted length prefix or flatbuffer
506/// `bodyLength` cannot trigger an unbounded allocation.
507fn read_message_bytes(r: &mut impl Read, max_bytes: usize) -> Result<Option<RawMessage>> {
508    let mut prefix = [0u8; 4];
509    if !read_exact(r, &mut prefix)? {
510        return Ok(None);
511    }
512    let size_bytes = if prefix == CONTINUATION_MARKER {
513        let mut sb = [0u8; 4];
514        if !read_exact(r, &mut sb)? {
515            return Ok(None);
516        }
517        sb
518    } else {
519        prefix
520    };
521    let size = u32::from_le_bytes(size_bytes) as usize;
522    if size == 0 {
523        // EOS
524        return Ok(None);
525    }
526    if size > max_bytes {
527        return Err(RpcError::new(
528            "IPC",
529            format!(
530                "IPC message header length {size} bytes exceeds cap {max_bytes} — \
531                 refusing to allocate before parsing"
532            ),
533        ));
534    }
535    let mut message_bytes = vec![0u8; size];
536    if !read_exact(r, &mut message_bytes)? {
537        return Err(RpcError::new("IOError", "unexpected EOF in message body"));
538    }
539    // Parse just enough to learn the body length, then cap it the same
540    // way before allocating. This blocks the `bodyLength = 1 TB`
541    // attack vector even when the header itself is small.
542    let msg = root_as_message(&message_bytes)
543        .map_err(|e| RpcError::new("IPC", format!("parse message header: {e}")))?;
544    let body_length_signed = msg.bodyLength();
545    if body_length_signed < 0 {
546        return Err(RpcError::new(
547            "IPC",
548            format!("IPC message has negative bodyLength ({body_length_signed})"),
549        ));
550    }
551    let body_length = body_length_signed as usize;
552    if body_length > max_bytes {
553        return Err(RpcError::new(
554            "IPC",
555            format!(
556                "IPC message bodyLength {body_length} bytes exceeds cap {max_bytes} — \
557                 refusing to allocate before parsing"
558            ),
559        ));
560    }
561    let mut body = vec![0u8; body_length];
562    if body_length > 0 && !read_exact(r, &mut body)? {
563        return Err(RpcError::new("IOError", "unexpected EOF in message body"));
564    }
565    Ok(Some(RawMessage {
566        message_bytes,
567        body,
568    }))
569}
570
571// ---------------------------------------------------------------------------
572// Utilities
573// ---------------------------------------------------------------------------
574
575/// Serialize one record batch as a complete IPC stream
576/// (schema + batch + EOS), with optional custom metadata on the batch.
577pub fn write_one_batch(batch: &RecordBatch, metadata: Option<&Metadata>) -> Result<Vec<u8>> {
578    let schema = batch.schema();
579    let mut buf = Vec::new();
580    {
581        let mut w = StreamWriter::new(&mut buf, schema.as_ref())?;
582        w.write(batch, metadata)?;
583        w.finish()?;
584    }
585    Ok(buf)
586}
587
588/// Lowercase hex encoding of a byte slice. Internal helper — use the
589/// `hex` crate from your application code.
590pub(crate) fn bytes_to_hex(bytes: &[u8]) -> String {
591    const HEX: &[u8; 16] = b"0123456789abcdef";
592    let mut out = String::with_capacity(bytes.len() * 2);
593    for b in bytes {
594        out.push(HEX[(b >> 4) as usize] as char);
595        out.push(HEX[(b & 0x0f) as usize] as char);
596    }
597    out
598}
599
600fn relax_field_nullability(f: &arrow_schema::Field) -> arrow_schema::Field {
601    use arrow_schema::DataType;
602    let dt = match f.data_type() {
603        DataType::List(inner) => DataType::List(Arc::new(relax_field_nullability(inner))),
604        DataType::LargeList(inner) => DataType::LargeList(Arc::new(relax_field_nullability(inner))),
605        DataType::FixedSizeList(inner, n) => {
606            DataType::FixedSizeList(Arc::new(relax_field_nullability(inner)), *n)
607        }
608        DataType::Struct(fields) => DataType::Struct(
609            fields
610                .iter()
611                .map(|child| Arc::new(relax_field_nullability(child)))
612                .collect(),
613        ),
614        // Map: leave the entries struct alone (Arrow requires
615        // entries/keys to be non-nullable); leaf nullability inside
616        // the values child is preserved by the original schema.
617        other => other.clone(),
618    };
619    #[allow(deprecated)]
620    let new_field = if let DataType::Dictionary(_, _) = f.data_type() {
621        arrow_schema::Field::new_dict(
622            f.name(),
623            dt,
624            true,
625            f.dict_id().unwrap_or(0),
626            f.dict_is_ordered().unwrap_or(false),
627        )
628    } else {
629        arrow_schema::Field::new(f.name(), dt, true)
630    };
631    new_field.with_metadata(f.metadata().clone())
632}
633
634fn relax_schema_nullability(s: &Schema) -> Schema {
635    let new_fields: Vec<arrow_schema::Field> = s
636        .fields()
637        .iter()
638        .map(|f| relax_field_nullability(f))
639        .collect();
640    Schema::new_with_metadata(new_fields, s.metadata().clone())
641}
642
643/// Build a zero-row `RecordBatch` matching the given schema.
644pub fn empty_batch(schema: &Schema) -> Result<RecordBatch> {
645    use arrow_array::array::new_empty_array;
646    use arrow_array::RecordBatchOptions;
647    let cols: Vec<arrow_array::ArrayRef> = schema
648        .fields()
649        .iter()
650        .map(|f| new_empty_array(f.data_type()))
651        .collect();
652    RecordBatch::try_new_with_options(
653        Arc::new(schema.clone()),
654        cols,
655        &RecordBatchOptions::new().with_row_count(Some(0)),
656    )
657    .map_err(RpcError::from)
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use arrow_array::{Int64Array, StringArray};
664    use arrow_schema::{DataType, Field};
665
666    #[test]
667    fn roundtrip_with_metadata() {
668        let schema = Schema::new(vec![
669            Field::new("idx", DataType::Int64, false),
670            Field::new("name", DataType::Utf8, false),
671        ]);
672        let batch = RecordBatch::try_new(
673            Arc::new(schema.clone()),
674            vec![
675                Arc::new(Int64Array::from(vec![1, 2, 3])) as _,
676                Arc::new(StringArray::from(vec!["a", "b", "c"])) as _,
677            ],
678        )
679        .unwrap();
680
681        let mut buf: Vec<u8> = Vec::new();
682        {
683            let mut w = StreamWriter::new(&mut buf, &schema).unwrap();
684            let mut md = Metadata::new();
685            md.insert("vgi_rpc.method".into(), "echo_string".into());
686            w.write(&batch, Some(&md)).unwrap();
687            w.finish().unwrap();
688        }
689
690        let mut r = StreamReader::new(buf.as_slice()).unwrap();
691        let (rb, md) = r.read_next().unwrap().expect("batch");
692        assert_eq!(rb.num_rows(), 3);
693        assert_eq!(md_get(&md, "vgi_rpc.method"), Some("echo_string"));
694        assert!(r.read_next().unwrap().is_none());
695    }
696
697    #[test]
698    fn zero_row_metadata_only() {
699        let schema = Schema::empty();
700        let batch = empty_batch(&schema).unwrap();
701
702        let mut buf: Vec<u8> = Vec::new();
703        {
704            let mut w = StreamWriter::new(&mut buf, &schema).unwrap();
705            let mut md = Metadata::new();
706            md.insert("vgi_rpc.log_level".into(), "INFO".into());
707            w.write(&batch, Some(&md)).unwrap();
708            w.finish().unwrap();
709        }
710        let mut r = StreamReader::new(buf.as_slice()).unwrap();
711        let (rb, md) = r.read_next().unwrap().expect("batch");
712        assert_eq!(rb.num_rows(), 0);
713        assert_eq!(md_get(&md, "vgi_rpc.log_level"), Some("INFO"));
714    }
715
716    #[test]
717    fn rejects_oversize_schema_length_prefix() {
718        // The 4-byte payload `[0x1A, 0x2C, 0xF5, 0x2C]` parsed LE
719        // claims ~720 MB of schema-message body — must be refused
720        // before any allocation.
721        let bomb: &[u8] = &[0x1A, 0x2C, 0xF5, 0x2C];
722        let err = StreamReader::new(bomb).err().expect("must reject");
723        assert!(
724            err.message.contains("exceeds cap"),
725            "unexpected error: {err:?}"
726        );
727    }
728
729    #[test]
730    fn rejects_oversize_message_bodylength() {
731        // Encode a tiny but well-formed schema then send a record-
732        // batch message whose flatbuffer claims a multi-GB
733        // `bodyLength` — must be refused before allocating the body.
734        use arrow_ipc::{Buffer as FbBuffer, FieldNode, MessageBuilder, RecordBatchBuilder};
735        // Build a real schema first so the schema gate passes.
736        let schema = Schema::new(vec![Field::new("v", DataType::Int64, false)]);
737        let mut buf: Vec<u8> = Vec::new();
738        {
739            let w = StreamWriter::new(&mut buf, &schema).unwrap();
740            // Don't write any batches; we'll append a hand-crafted
741            // malicious message below.
742            // Drop without finish so EOS is not written.
743            std::mem::forget(w);
744        }
745        // Hand-craft a RecordBatch Message flatbuffer with absurd
746        // bodyLength.
747        let mut fbb = FlatBufferBuilder::new();
748        let nodes_vec = fbb.create_vector(&[FieldNode::new(0, 0)]);
749        let buffers_vec = fbb.create_vector(&[FbBuffer::new(0, 0)]);
750        let rb_off = {
751            let mut b = RecordBatchBuilder::new(&mut fbb);
752            b.add_length(0);
753            b.add_nodes(nodes_vec);
754            b.add_buffers(buffers_vec);
755            b.finish()
756        };
757        let msg_off = {
758            let mut mb = MessageBuilder::new(&mut fbb);
759            mb.add_version(arrow_ipc::MetadataVersion::V5);
760            mb.add_header_type(MessageHeader::RecordBatch);
761            mb.add_header(rb_off.as_union_value());
762            mb.add_bodyLength(MAX_IPC_MESSAGE_BYTES as i64 + 1);
763            mb.finish()
764        };
765        fbb.finish(msg_off, None);
766        let msg_bytes = fbb.finished_data();
767        // Frame: continuation + 4-byte LE length + flatbuffer body.
768        buf.extend_from_slice(&CONTINUATION_MARKER);
769        buf.extend_from_slice(&(msg_bytes.len() as u32).to_le_bytes());
770        buf.extend_from_slice(msg_bytes);
771        // No body — but we never get that far; the cap rejects first.
772
773        let mut r = StreamReader::new(buf.as_slice()).unwrap();
774        let err = r.read_next().expect_err("must reject");
775        assert!(
776            err.message.contains("bodyLength") && err.message.contains("exceeds cap"),
777            "unexpected error: {err:?}"
778        );
779    }
780
781    #[test]
782    fn rejects_buffer_descriptor_past_body() {
783        // A record-batch message whose body is 8 bytes but whose buffer
784        // descriptor claims offset 0 / length 1000. arrow-ipc would
785        // index out of bounds and panic; the descriptor pre-check must
786        // reject it as a clean `RpcError` first.
787        use arrow_ipc::{Buffer as FbBuffer, FieldNode, MessageBuilder, RecordBatchBuilder};
788        let schema = Schema::new(vec![Field::new("v", DataType::Int64, false)]);
789        let mut buf: Vec<u8> = Vec::new();
790        {
791            let w = StreamWriter::new(&mut buf, &schema).unwrap();
792            std::mem::forget(w);
793        }
794        let mut fbb = FlatBufferBuilder::new();
795        let nodes_vec = fbb.create_vector(&[FieldNode::new(1, 0)]);
796        // offset 0, length 1000 — far past the 8-byte body below.
797        let buffers_vec = fbb.create_vector(&[FbBuffer::new(0, 1000)]);
798        let rb_off = {
799            let mut b = RecordBatchBuilder::new(&mut fbb);
800            b.add_length(1);
801            b.add_nodes(nodes_vec);
802            b.add_buffers(buffers_vec);
803            b.finish()
804        };
805        let msg_off = {
806            let mut mb = MessageBuilder::new(&mut fbb);
807            mb.add_version(arrow_ipc::MetadataVersion::V5);
808            mb.add_header_type(MessageHeader::RecordBatch);
809            mb.add_header(rb_off.as_union_value());
810            mb.add_bodyLength(8);
811            mb.finish()
812        };
813        fbb.finish(msg_off, None);
814        let msg_bytes = fbb.finished_data().to_vec();
815        buf.extend_from_slice(&CONTINUATION_MARKER);
816        buf.extend_from_slice(&(msg_bytes.len() as u32).to_le_bytes());
817        buf.extend_from_slice(&msg_bytes);
818        buf.extend_from_slice(&[0u8; 8]); // the 8-byte body
819
820        let mut r = StreamReader::new(buf.as_slice()).unwrap();
821        let err = r.read_next().expect_err("must reject");
822        assert!(
823            err.message.contains("buffer descriptor"),
824            "unexpected error: {err:?}"
825        );
826    }
827}