typed_arrow_dyn/
validate.rs

1//! Validate nullability invariants in nested Arrow arrays using the schema.
2
3use std::{collections::HashMap, sync::Arc};
4
5use arrow_array::{
6    Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, MapArray, StructArray,
7    UnionArray,
8};
9use arrow_buffer::{ArrowNativeType, OffsetBuffer};
10use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, UnionFields};
11
12use crate::{DynError, dyn_builder::array_key};
13
14/// Extract start and end offsets for a row from an offset buffer.
15fn offset_range<T: ArrowNativeType>(
16    offsets: &OffsetBuffer<T>,
17    row: usize,
18    col_name: &str,
19) -> Result<(usize, usize), DynError>
20where
21    usize: TryFrom<T>,
22{
23    let start_raw = offsets.get(row).ok_or_else(|| DynError::Builder {
24        message: format!("offset index {row} out of range for {col_name}"),
25    })?;
26    let end_raw = offsets.get(row + 1).ok_or_else(|| DynError::Builder {
27        message: format!("offset index {} out of range for {col_name}", row + 1),
28    })?;
29    let start = usize::try_from(*start_raw).map_err(|_| DynError::Builder {
30        message: format!("negative offset at index {row} for {col_name}"),
31    })?;
32    let end = usize::try_from(*end_raw).map_err(|_| DynError::Builder {
33        message: format!("negative offset at index {} for {col_name}", row + 1),
34    })?;
35    Ok((start, end))
36}
37
38/// Validate that arrays satisfy nullability constraints declared by `schema`.
39/// Returns the first violation encountered with a descriptive path.
40///
41/// # Errors
42/// Returns a `DynError::Nullability` describing the first violation encountered.
43pub fn validate_nullability(
44    schema: &Schema,
45    arrays: &[ArrayRef],
46    union_null_rows: &HashMap<usize, Vec<usize>>,
47) -> Result<(), DynError> {
48    for (col, (field, array)) in schema.fields().iter().zip(arrays.iter()).enumerate() {
49        // Top-level field nullability
50        if !field.is_nullable()
51            && array.null_count() > 0
52            && let Some(idx) = first_null_index(array.as_ref())
53        {
54            return Err(DynError::Nullability {
55                col,
56                path: field.name().to_string(),
57                index: idx,
58                message: "non-nullable field contains null".to_string(),
59            });
60        }
61
62        // Nested
63        validate_nested(
64            field.name(),
65            field.data_type(),
66            array,
67            col,
68            None,
69            field.is_nullable(),
70            union_null_rows,
71        )?;
72    }
73    Ok(())
74}
75
76fn validate_nested(
77    col_name: &str,
78    dt: &DataType,
79    array: &ArrayRef,
80    col: usize,
81    // An optional mask: when present, only indices with `true` are considered.
82    parent_valid_mask: Option<Vec<bool>>,
83    nullable: bool,
84    union_null_rows: &HashMap<usize, Vec<usize>>,
85) -> Result<(), DynError> {
86    match dt {
87        DataType::Struct(children) => validate_struct(
88            col_name,
89            children,
90            array,
91            col,
92            parent_valid_mask,
93            union_null_rows,
94        ),
95        DataType::List(item) => validate_list(
96            col_name,
97            item,
98            array,
99            col,
100            parent_valid_mask,
101            union_null_rows,
102        ),
103        DataType::LargeList(item) => validate_large_list(
104            col_name,
105            item,
106            array,
107            col,
108            parent_valid_mask,
109            union_null_rows,
110        ),
111        DataType::FixedSizeList(item, _len) => validate_fixed_list(
112            col_name,
113            item,
114            array,
115            col,
116            parent_valid_mask,
117            union_null_rows,
118        ),
119        DataType::Union(children, _) => validate_union(
120            col_name,
121            children,
122            array,
123            col,
124            parent_valid_mask,
125            nullable,
126            union_null_rows,
127        ),
128        DataType::Map(entry_field, _) => validate_map(
129            col_name,
130            entry_field,
131            array,
132            col,
133            parent_valid_mask,
134            union_null_rows,
135        ),
136        // Other data types have no nested children.
137        _ => Ok(()),
138    }
139}
140
141fn validate_union(
142    col_name: &str,
143    fields: &UnionFields,
144    array: &ArrayRef,
145    col: usize,
146    parent_mask: Option<Vec<bool>>,
147    nullable: bool,
148    union_null_rows: &HashMap<usize, Vec<usize>>,
149) -> Result<(), DynError> {
150    let union = array
151        .as_any()
152        .downcast_ref::<UnionArray>()
153        .ok_or_else(|| DynError::Builder {
154            message: format!("expected UnionArray for {col_name}"),
155        })?;
156
157    let parent_valid = parent_mask.unwrap_or_else(|| validity_mask(union));
158    let null_rows = union_null_rows
159        .get(&array_key(array))
160        .cloned()
161        .unwrap_or_default();
162    let null_row_mask = if null_rows.is_empty() {
163        None
164    } else {
165        let mut mask = vec![false; union.len()];
166        for &row in &null_rows {
167            if row >= mask.len() {
168                return Err(DynError::Builder {
169                    message: format!("union null row index {row} out of bounds"),
170                });
171            }
172            mask[row] = true;
173        }
174        Some(mask)
175    };
176    let is_union_null_row = |row: usize| {
177        null_row_mask
178            .as_ref()
179            .and_then(|mask| mask.get(row))
180            .copied()
181            .unwrap_or(false)
182    };
183
184    if !nullable
185        && let Some(&row) = null_rows
186            .iter()
187            .find(|&&row| parent_valid.get(row).copied().unwrap_or(false))
188    {
189        return Err(DynError::Nullability {
190            col,
191            path: col_name.to_string(),
192            index: row,
193            message: "non-nullable field contains null".to_string(),
194        });
195    }
196
197    let variants: Vec<(i8, FieldRef)> = fields
198        .iter()
199        .map(|(tag, field)| (tag, field.clone()))
200        .collect();
201
202    let mut tag_to_index = vec![None; 256];
203    for (idx, (tag, _)) in variants.iter().enumerate() {
204        tag_to_index[tag_slot(*tag)] = Some(idx);
205    }
206
207    let mut rows_per_variant: Vec<Vec<(usize, usize)>> =
208        variants.iter().map(|_| Vec::new()).collect();
209
210    for (row, &is_valid) in parent_valid.iter().enumerate() {
211        if !is_valid {
212            continue;
213        }
214        let tag = union.type_id(row);
215        let Some(idx) = tag_to_index[tag_slot(tag)] else {
216            return Err(DynError::Builder {
217                message: format!("union value uses unknown type id {tag}"),
218            });
219        };
220        let offset = union.value_offset(row);
221        rows_per_variant[idx].push((row, offset));
222    }
223
224    for (idx, rows) in rows_per_variant.iter().enumerate() {
225        if rows.is_empty() {
226            continue;
227        }
228        let (tag, field) = &variants[idx];
229        let child = union.child(*tag).clone();
230        let path = format!("{}.{}", col_name, field.name());
231        let child_len = child.len();
232        let mut child_mask = vec![false; child_len];
233
234        for &(row_index, child_index) in rows {
235            if child_index >= child_len {
236                return Err(DynError::Builder {
237                    message: format!(
238                        "union child index {} out of bounds for variant '{}'",
239                        child_index,
240                        field.name()
241                    ),
242                });
243            }
244
245            let union_row_is_null = is_union_null_row(row_index);
246
247            if !field.is_nullable() && !union_row_is_null && child.is_null(child_index) {
248                return Err(DynError::Nullability {
249                    col,
250                    path: path.clone(),
251                    index: row_index,
252                    message: "non-nullable union variant contains null".to_string(),
253                });
254            }
255
256            if !union_row_is_null {
257                child_mask[child_index] = true;
258            }
259        }
260
261        validate_nested(
262            &path,
263            field.data_type(),
264            &child,
265            col,
266            Some(child_mask),
267            field.is_nullable(),
268            union_null_rows,
269        )?;
270    }
271
272    Ok(())
273}
274
275fn validate_struct(
276    col_name: &str,
277    fields: &Fields,
278    array: &ArrayRef,
279    col: usize,
280    parent_mask: Option<Vec<bool>>,
281    union_null_rows: &HashMap<usize, Vec<usize>>,
282) -> Result<(), DynError> {
283    let s = array
284        .as_any()
285        .downcast_ref::<StructArray>()
286        .ok_or_else(|| DynError::Builder {
287            message: format!("expected StructArray for {col_name}"),
288        })?;
289
290    // Compute mask of valid parent rows: respect parent validity if provided, else
291    // derive from the struct's own validity.
292    let arr: &dyn Array = s;
293    let mask = parent_mask.unwrap_or_else(|| validity_mask(arr));
294
295    for (child_field, child_array) in fields.iter().zip(s.columns().iter()) {
296        // Enforce child field nullability only where parent struct is valid.
297        if !child_field.is_nullable() {
298            let child = child_array.as_ref();
299            for (i, &pvalid) in mask.iter().enumerate() {
300                if pvalid && child.is_null(i) {
301                    return Err(DynError::Nullability {
302                        col,
303                        path: format!("{}.{}", col_name, child_field.name()),
304                        index: i,
305                        message: "non-nullable struct field contains null".to_string(),
306                    });
307                }
308            }
309        }
310
311        // Recurse into nested children. For struct children, combine the current mask
312        // with the child's validity to handle nested nullable structs correctly.
313        // e.g., if parent.child is None, child's fields should not be validated.
314        let child_mask = if matches!(child_field.data_type(), DataType::Struct(_)) {
315            let child_arr: &dyn Array = child_array.as_ref();
316            let child_valid = validity_mask(child_arr);
317            // Combine: row is valid only if both parent and child struct are valid
318            Some(
319                mask.iter()
320                    .zip(child_valid.iter())
321                    .map(|(&p, &c)| p && c)
322                    .collect(),
323            )
324        } else {
325            Some(mask.clone())
326        };
327        validate_nested(
328            &format!("{}.{}", col_name, child_field.name()),
329            child_field.data_type(),
330            child_array,
331            col,
332            child_mask,
333            child_field.is_nullable(),
334            union_null_rows,
335        )?;
336    }
337    Ok(())
338}
339
340fn validate_list(
341    col_name: &str,
342    item: &Arc<Field>,
343    array: &ArrayRef,
344    col: usize,
345    parent_mask: Option<Vec<bool>>,
346    union_null_rows: &HashMap<usize, Vec<usize>>,
347) -> Result<(), DynError> {
348    let l = array
349        .as_any()
350        .downcast_ref::<ListArray>()
351        .ok_or_else(|| DynError::Builder {
352            message: format!("expected ListArray for {col_name}"),
353        })?;
354
355    let arr: &dyn Array = l;
356    let parent_valid = parent_mask.unwrap_or_else(|| validity_mask(arr));
357    let offsets: &OffsetBuffer<i32> = l.offsets();
358    let child = l.values().clone();
359
360    if !item.is_nullable() {
361        for (row, &pvalid) in parent_valid.iter().enumerate() {
362            if !pvalid {
363                continue;
364            }
365            let (start, end) = offset_range(offsets, row, col_name)?;
366            for idx in start..end {
367                if child.is_null(idx) {
368                    return Err(DynError::Nullability {
369                        col,
370                        path: format!("{col_name}[]"),
371                        index: idx,
372                        message: "non-nullable list item contains null".to_string(),
373                    });
374                }
375            }
376        }
377    }
378
379    // Recurse into child type. Construct mask of child indices belonging to
380    // valid parent rows.
381    let mut child_mask = vec![false; child.len()];
382    for (row, &pvalid) in parent_valid.iter().enumerate() {
383        if !pvalid {
384            continue;
385        }
386        let (start, end) = offset_range(offsets, row, col_name)?;
387        for item in child_mask.iter_mut().take(end).skip(start) {
388            *item = true;
389        }
390    }
391
392    validate_nested(
393        &format!("{col_name}[]"),
394        item.data_type(),
395        &child,
396        col,
397        Some(child_mask),
398        item.is_nullable(),
399        union_null_rows,
400    )
401}
402
403fn validate_large_list(
404    col_name: &str,
405    item: &Arc<Field>,
406    array: &ArrayRef,
407    col: usize,
408    parent_mask: Option<Vec<bool>>,
409    union_null_rows: &HashMap<usize, Vec<usize>>,
410) -> Result<(), DynError> {
411    let l = array
412        .as_any()
413        .downcast_ref::<LargeListArray>()
414        .ok_or_else(|| DynError::Builder {
415            message: format!("expected LargeListArray for {col_name}"),
416        })?;
417    let arr: &dyn Array = l;
418    let parent_valid = parent_mask.unwrap_or_else(|| validity_mask(arr));
419    let offsets = l.offsets();
420    let child = l.values().clone();
421
422    if !item.is_nullable() {
423        for (row, &pvalid) in parent_valid.iter().enumerate() {
424            if !pvalid {
425                continue;
426            }
427            let (start, end) = offset_range(offsets, row, col_name)?;
428            for idx in start..end {
429                if child.is_null(idx) {
430                    return Err(DynError::Nullability {
431                        col,
432                        path: format!("{col_name}[]"),
433                        index: idx,
434                        message: "non-nullable large-list item contains null".to_string(),
435                    });
436                }
437            }
438        }
439    }
440
441    let mut child_mask = vec![false; child.len()];
442    for (row, &pvalid) in parent_valid.iter().enumerate() {
443        if !pvalid {
444            continue;
445        }
446        let (start, end) = offset_range(offsets, row, col_name)?;
447        for item in child_mask.iter_mut().take(end).skip(start) {
448            *item = true;
449        }
450    }
451
452    validate_nested(
453        &format!("{col_name}[]"),
454        item.data_type(),
455        &child,
456        col,
457        Some(child_mask),
458        item.is_nullable(),
459        union_null_rows,
460    )
461}
462
463fn validate_fixed_list(
464    col_name: &str,
465    item: &Arc<Field>,
466    array: &ArrayRef,
467    col: usize,
468    parent_mask: Option<Vec<bool>>,
469    union_null_rows: &HashMap<usize, Vec<usize>>,
470) -> Result<(), DynError> {
471    let l = array
472        .as_any()
473        .downcast_ref::<FixedSizeListArray>()
474        .ok_or_else(|| DynError::Builder {
475            message: format!("expected FixedSizeListArray for {col_name}"),
476        })?;
477    let arr: &dyn Array = l;
478    let parent_valid = parent_mask.unwrap_or_else(|| validity_mask(arr));
479    let child = l.values().clone();
480    let width = usize::try_from(l.value_length()).map_err(|_| DynError::Builder {
481        message: format!("negative fixed-size list width for {col_name}"),
482    })?;
483
484    if !item.is_nullable() {
485        for (row, &pvalid) in parent_valid.iter().enumerate() {
486            if !pvalid {
487                continue;
488            }
489            let start = row * width;
490            let end = start + width;
491            for idx in start..end {
492                if child.is_null(idx) {
493                    return Err(DynError::Nullability {
494                        col,
495                        path: format!("{col_name}[{row}]"),
496                        index: idx,
497                        message: "non-nullable fixed-size list item contains null".to_string(),
498                    });
499                }
500            }
501        }
502    }
503
504    let mut child_mask = vec![false; child.len()];
505    for (row, &pvalid) in parent_valid.iter().enumerate() {
506        if !pvalid {
507            continue;
508        }
509        let start = row * width;
510        let end = start + width;
511        for item in child_mask.iter_mut().take(end).skip(start) {
512            *item = true;
513        }
514    }
515
516    validate_nested(
517        &format!("{col_name}[]"),
518        item.data_type(),
519        &child,
520        col,
521        Some(child_mask),
522        item.is_nullable(),
523        union_null_rows,
524    )
525}
526
527fn validate_map(
528    col_name: &str,
529    entry_field: &Arc<Field>,
530    array: &ArrayRef,
531    col: usize,
532    parent_mask: Option<Vec<bool>>,
533    union_null_rows: &HashMap<usize, Vec<usize>>,
534) -> Result<(), DynError> {
535    let map = array
536        .as_any()
537        .downcast_ref::<MapArray>()
538        .ok_or_else(|| DynError::Builder {
539            message: format!("expected MapArray for {col_name}"),
540        })?;
541
542    let arr: &dyn Array = map;
543    let parent_valid = parent_mask.unwrap_or_else(|| validity_mask(arr));
544    let offsets = map.offsets();
545    let keys = map.keys().clone();
546    let values = map.values().clone();
547
548    let DataType::Struct(children) = entry_field.data_type() else {
549        return Err(DynError::Builder {
550            message: "map entry field is not a struct".to_string(),
551        });
552    };
553    if children.len() != 2 {
554        return Err(DynError::Builder {
555            message: format!(
556                "map entry struct must have 2 fields, found {}",
557                children.len()
558            ),
559        });
560    }
561    let key_field = &children[0];
562    let value_field = &children[1];
563
564    for (row, &pvalid) in parent_valid.iter().enumerate() {
565        if !pvalid {
566            continue;
567        }
568        let (start, end) = offset_range(offsets, row, col_name)?;
569        for idx in start..end {
570            if keys.as_ref().is_null(idx) {
571                return Err(DynError::Nullability {
572                    col,
573                    path: format!("{col_name}.keys"),
574                    index: idx,
575                    message: "map keys cannot contain nulls".to_string(),
576                });
577            }
578            if !value_field.is_nullable() && values.as_ref().is_null(idx) {
579                return Err(DynError::Nullability {
580                    col,
581                    path: format!("{col_name}.values"),
582                    index: idx,
583                    message: "map values marked non-nullable contain null".to_string(),
584                });
585            }
586        }
587    }
588
589    let mut key_mask = vec![false; keys.len()];
590    let mut value_mask = vec![false; values.len()];
591    for (row, &pvalid) in parent_valid.iter().enumerate() {
592        if !pvalid {
593            continue;
594        }
595        let (start, end) = offset_range(offsets, row, col_name)?;
596        for idx in start..end {
597            key_mask[idx] = true;
598            if values.as_ref().is_valid(idx) {
599                value_mask[idx] = true;
600            }
601        }
602    }
603
604    validate_nested(
605        &format!("{col_name}.keys"),
606        key_field.data_type(),
607        &keys,
608        col,
609        Some(key_mask),
610        key_field.is_nullable(),
611        union_null_rows,
612    )?;
613    validate_nested(
614        &format!("{col_name}.values"),
615        value_field.data_type(),
616        &values,
617        col,
618        Some(value_mask),
619        value_field.is_nullable(),
620        union_null_rows,
621    )?;
622    Ok(())
623}
624
625fn validity_mask(array: &dyn Array) -> Vec<bool> {
626    (0..array.len()).map(|i| array.is_valid(i)).collect()
627}
628
629fn first_null_index(array: &dyn Array) -> Option<usize> {
630    (0..array.len()).find(|&i| array.is_null(i))
631}
632
633fn tag_slot(tag: i8) -> usize {
634    (i16::from(tag) + 128) as usize
635}