vortex_vector/struct_/
vector_mut.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Definition and implementation of [`StructVectorMut`].
5
6use std::sync::Arc;
7
8use vortex_dtype::StructFields;
9use vortex_error::{VortexExpect, VortexResult, vortex_ensure};
10use vortex_mask::MaskMut;
11
12use crate::struct_::StructVector;
13use crate::{Vector, VectorMut, VectorMutOps, VectorOps, match_vector_pair};
14
15/// A mutable vector of struct values (values with named fields).
16///
17/// Struct values are stored column-wise in the vector, so values in the same field are stored next
18/// to each other (rather than values in the same struct stored next to each other).
19#[derive(Debug, Clone)]
20pub struct StructVectorMut {
21    /// The (owned) fields of the `StructVectorMut`, each stored column-wise as a [`VectorMut`].
22    pub(super) fields: Box<[VectorMut]>,
23
24    /// The validity mask (where `true` represents an element is **not** null).
25    pub(super) validity: MaskMut,
26
27    /// The length of the vector (which is the same as all field vectors).
28    ///
29    /// This is stored here as a convenience, and also helps in the case that the `StructVector` has
30    /// no fields.
31    pub(super) len: usize,
32}
33
34impl StructVectorMut {
35    /// Creates a new [`StructVectorMut`] with the given fields and validity mask.
36    ///
37    /// # Panics
38    ///
39    /// Panics if:
40    ///
41    /// - Any field vector has a length that does not match the length of other fields.
42    /// - The validity mask length does not match the field length.
43    pub fn new(fields: Box<[VectorMut]>, validity: MaskMut) -> Self {
44        Self::try_new(fields, validity).vortex_expect("Failed to create `StructVectorMut`")
45    }
46
47    /// Tries to create a new [`StructVectorMut`] with the given fields and validity mask.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if:
52    ///
53    /// - Any field vector has a length that does not match the length of other fields.
54    /// - The validity mask length does not match the field length.
55    pub fn try_new(fields: Box<[VectorMut]>, validity: MaskMut) -> VortexResult<Self> {
56        let len = validity.len();
57
58        // Validate that all fields have the correct length.
59        for (i, field) in fields.iter().enumerate() {
60            vortex_ensure!(
61                field.len() == len,
62                "Field {} has length {} but expected length {}",
63                i,
64                field.len(),
65                len
66            );
67        }
68
69        Ok(Self {
70            fields,
71            validity,
72            len,
73        })
74    }
75
76    /// Creates a new [`StructVectorMut`] with the given fields and validity mask without
77    /// validation.
78    ///
79    /// # Safety
80    ///
81    /// The caller must ensure that:
82    ///
83    /// - All field vectors have the same length.
84    /// - The validity mask has a length equal to the field length.
85    pub unsafe fn new_unchecked(fields: Box<[VectorMut]>, validity: MaskMut) -> Self {
86        let len = validity.len();
87
88        if cfg!(debug_assertions) {
89            Self::new(fields, validity)
90        } else {
91            Self {
92                fields,
93                validity,
94                len,
95            }
96        }
97    }
98
99    /// Creates a new [`StructVectorMut`] with the given fields and capacity.
100    pub fn with_capacity(struct_fields: &StructFields, capacity: usize) -> Self {
101        let fields: Vec<VectorMut> = struct_fields
102            .fields()
103            .map(|dtype| VectorMut::with_capacity(&dtype, capacity))
104            .collect();
105
106        let validity = MaskMut::with_capacity(capacity);
107        let len = validity.len();
108
109        Self {
110            fields: fields.into_boxed_slice(),
111            validity,
112            len,
113        }
114    }
115
116    /// Decomposes the struct vector into its constituent parts (fields, validity, and length).
117    pub fn into_parts(self) -> (Box<[VectorMut]>, MaskMut, usize) {
118        (self.fields, self.validity, self.len)
119    }
120
121    /// Returns the fields of the `StructVectorMut`, each stored column-wise as a [`VectorMut`].
122    pub fn fields(&self) -> &[VectorMut] {
123        self.fields.as_ref()
124    }
125
126    /// Returns a mutable handle to the field vectors.
127    ///
128    /// # Safety
129    ///
130    /// Callers must ensure that any modifications to the field vectors do not violate
131    /// the invariants of this type, namely that all field vectors are of the same length
132    /// and equal to the length of the validity.
133    pub unsafe fn fields_mut(&mut self) -> &mut [VectorMut] {
134        self.fields.as_mut()
135    }
136
137    /// Returns a mutable handle to the validity mask of the vector.
138    ///
139    /// # Safety
140    ///
141    /// Callers must ensure that if the length of the mask is modified, the lengths
142    /// of all of the field vectors should be updated accordingly to continue meeting
143    /// the invariants of the type.
144    pub unsafe fn validity_mut(&mut self) -> &mut MaskMut {
145        &mut self.validity
146    }
147
148    /// Finds the minimum capacity of all field vectors.
149    ///
150    /// This is equal to the maximum amount of scalars we can add before we need to reallocate at
151    /// least one of the child field vectors.
152    ///
153    /// If there are no fields, this returns the length of the vector.
154    ///
155    /// Note that this takes time in `O(f)`, where `f` is the number of fields.
156    pub fn minimum_capacity(&self) -> usize {
157        self.fields
158            .iter()
159            .map(|field| field.capacity())
160            .min()
161            .unwrap_or(self.len)
162    }
163}
164
165impl VectorMutOps for StructVectorMut {
166    type Immutable = StructVector;
167
168    fn len(&self) -> usize {
169        self.len
170    }
171
172    fn validity(&self) -> &MaskMut {
173        &self.validity
174    }
175
176    fn capacity(&self) -> usize {
177        self.minimum_capacity()
178    }
179
180    fn reserve(&mut self, additional: usize) {
181        // Reserve the additional capacity in each field vector.
182        for field in &mut self.fields {
183            field.reserve(additional);
184
185            debug_assert_eq!(
186                field.len(),
187                self.len,
188                "Field length must match `StructVectorMut` length"
189            );
190        }
191
192        self.validity.reserve(additional);
193    }
194
195    fn clear(&mut self) {
196        for field in &mut self.fields {
197            field.clear();
198        }
199
200        self.validity.clear();
201        self.len = 0;
202    }
203
204    fn truncate(&mut self, len: usize) {
205        for field in &mut self.fields {
206            field.truncate(len);
207        }
208
209        self.validity.truncate(len);
210        self.len = self.validity.len();
211    }
212
213    fn extend_from_vector(&mut self, other: &StructVector) {
214        assert_eq!(
215            self.fields.len(),
216            other.fields().len(),
217            "Cannot extend StructVectorMut: field count mismatch (self had {} but other had {})",
218            self.fields.len(),
219            other.fields().len()
220        );
221
222        // Extend each field vector.
223        let pairs = self.fields.iter_mut().zip(other.fields().as_ref());
224        for (self_mut_vector, other_vec) in pairs {
225            match_vector_pair!(self_mut_vector, other_vec, |a: VectorMut, b: Vector| {
226                a.extend_from_vector(b)
227            })
228        }
229
230        // Extend the validity mask.
231        self.validity.append_mask(other.validity());
232        self.len += other.len();
233
234        debug_assert_eq!(self.len, self.validity.len());
235    }
236
237    fn append_nulls(&mut self, n: usize) {
238        for field in &mut self.fields {
239            field.append_nulls(n); // Note that the value we push to each doesn't actually matter.
240        }
241
242        self.validity.append_n(false, n);
243        self.len += n;
244        debug_assert_eq!(self.len, self.validity.len());
245    }
246
247    fn freeze(self) -> StructVector {
248        let frozen_fields: Vec<Vector> = self
249            .fields
250            .into_iter()
251            .map(|mut_field| mut_field.freeze())
252            .collect();
253
254        StructVector {
255            fields: Arc::new(frozen_fields.into_boxed_slice()),
256            len: self.len,
257            validity: self.validity.freeze(),
258        }
259    }
260
261    fn split_off(&mut self, at: usize) -> Self {
262        assert!(
263            at <= self.capacity(),
264            "split_off out of bounds: {} > {}",
265            at,
266            self.capacity()
267        );
268
269        let split_fields: Vec<VectorMut> = self
270            .fields
271            .iter_mut()
272            .map(|field| field.split_off(at))
273            .collect();
274
275        let split_validity = self.validity.split_off(at);
276        let split_len = self.len.saturating_sub(at);
277        self.len = at;
278
279        debug_assert_eq!(self.len, self.validity.len());
280
281        Self {
282            fields: split_fields.into_boxed_slice(),
283            len: split_len,
284            validity: split_validity,
285        }
286    }
287
288    fn unsplit(&mut self, other: Self) {
289        assert_eq!(
290            self.fields.len(),
291            other.fields.len(),
292            "Cannot unsplit StructVectorMut: field count mismatch ({} vs {})",
293            self.fields.len(),
294            other.fields.len()
295        );
296
297        if self.is_empty() {
298            *self = other;
299            return;
300        }
301
302        // Unsplit each field vector.
303        let pairs = self.fields.iter_mut().zip(other.fields);
304        for (self_mut_vector, other_mut_vec) in pairs {
305            match_vector_pair!(
306                self_mut_vector,
307                other_mut_vec,
308                |a: VectorMut, b: VectorMut| a.unsplit(b)
309            )
310        }
311
312        self.validity.unsplit(other.validity);
313        self.len += other.len;
314        debug_assert_eq!(self.len, self.validity.len());
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use vortex_dtype::{DType, FieldNames, Nullability, PType, PTypeDowncast, StructFields};
321    use vortex_mask::{Mask, MaskMut};
322
323    use super::*;
324    use crate::VectorMut;
325    use crate::bool::BoolVectorMut;
326    use crate::null::{NullVector, NullVectorMut};
327    use crate::primitive::PVectorMut;
328
329    #[test]
330    fn test_empty_fields() {
331        let mut struct_vec = StructVectorMut::try_new(Box::new([]), MaskMut::new_true(10)).unwrap();
332        let second_half = struct_vec.split_off(6);
333        assert_eq!(struct_vec.len(), 6);
334        assert_eq!(second_half.len(), 4);
335    }
336
337    #[test]
338    fn test_try_into_mut_and_values() {
339        let struct_vec = StructVector {
340            fields: Arc::new(Box::new([
341                NullVector::new(5).into(),
342                BoolVectorMut::from_iter([true, false, true, false, true])
343                    .freeze()
344                    .into(),
345                PVectorMut::<i32>::from_iter([10, 20, 30, 40, 50])
346                    .freeze()
347                    .into(),
348            ])),
349            len: 5,
350            validity: Mask::AllTrue(5),
351        };
352
353        let mut_struct = struct_vec.try_into_mut().unwrap();
354        assert_eq!(mut_struct.len(), 5);
355
356        // Verify values are preserved.
357        if let VectorMut::Bool(bool_vec) = mut_struct.fields[1].clone() {
358            let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
359            assert_eq!(values, vec![true, false, true, false, true]);
360        }
361
362        if let VectorMut::Primitive(prim_vec) = mut_struct.fields[2].clone() {
363            let values: Vec<_> = prim_vec
364                .into_i32()
365                .into_iter()
366                .map(|v| v.unwrap())
367                .collect();
368            assert_eq!(values, vec![10, 20, 30, 40, 50]);
369        }
370    }
371
372    #[test]
373    fn test_try_into_mut_shared_ownership() {
374        // Test that conversion fails when a field has shared ownership.
375        let bool_field: Vector = BoolVectorMut::from_iter([true, false, true])
376            .freeze()
377            .into();
378        let bool_field_clone = bool_field.clone();
379
380        let struct_vec = StructVector {
381            fields: Arc::new(Box::new([
382                NullVector::new(3).into(),
383                bool_field_clone,
384                PVectorMut::<i32>::from_iter([1, 2, 3]).freeze().into(),
385            ])),
386            len: 3,
387            validity: Mask::AllTrue(3),
388        };
389
390        assert!(struct_vec.try_into_mut().is_err());
391        drop(bool_field); // Keep original alive to maintain shared ownership
392    }
393
394    #[test]
395    fn test_split_unsplit_values() {
396        let mut struct_vec = StructVectorMut::try_new(
397            Box::new([
398                NullVectorMut::new(8).into(),
399                BoolVectorMut::from_iter([true, false, true, false, true, false, true, false])
400                    .into(),
401                PVectorMut::<i32>::from_iter([10, 20, 30, 40, 50, 60, 70, 80]).into(),
402            ]),
403            MaskMut::new_true(8),
404        )
405        .unwrap();
406
407        let second_half = struct_vec.split_off(5);
408        assert_eq!(struct_vec.len(), 5);
409        assert_eq!(second_half.len(), 3);
410
411        // Verify values after split.
412        if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
413            let values: Vec<_> = bool_vec.into_iter().take(5).map(|v| v.unwrap()).collect();
414            assert_eq!(values, vec![true, false, true, false, true]);
415        }
416
417        if let VectorMut::Bool(bool_vec) = second_half.fields[1].clone() {
418            let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
419            assert_eq!(values, vec![false, true, false]);
420        }
421
422        // Unsplit and verify.
423        struct_vec.unsplit(second_half);
424        assert_eq!(struct_vec.len(), 8);
425
426        if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
427            let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
428            assert_eq!(
429                values,
430                vec![true, false, true, false, true, false, true, false]
431            );
432        }
433    }
434
435    #[test]
436    fn test_extend_and_append_nulls() {
437        let mut struct_vec = StructVectorMut::try_new(
438            Box::new([
439                NullVector::new(3).try_into_mut().unwrap().into(),
440                BoolVectorMut::from_iter([true, false, true]).into(),
441                PVectorMut::<i32>::from_iter([10, 20, 30]).into(),
442            ]),
443            MaskMut::new_true(3),
444        )
445        .unwrap();
446
447        // Test extend.
448        let to_extend = StructVector {
449            fields: Arc::new(Box::new([
450                NullVector::new(2).into(),
451                BoolVectorMut::from_iter([false, true]).freeze().into(),
452                PVectorMut::<i32>::from_iter([40, 50]).freeze().into(),
453            ])),
454            len: 2,
455            validity: Mask::AllTrue(2),
456        };
457
458        struct_vec.extend_from_vector(&to_extend);
459        assert_eq!(struct_vec.len(), 5);
460
461        // Test append_nulls.
462        struct_vec.append_nulls(2);
463        assert_eq!(struct_vec.len(), 7);
464
465        // Verify final values include nulls.
466        if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
467            let values: Vec<_> = bool_vec.into_iter().collect();
468            assert_eq!(
469                values,
470                vec![
471                    Some(true),
472                    Some(false),
473                    Some(true),
474                    Some(false),
475                    Some(true),
476                    None,
477                    None
478                ]
479            );
480        }
481    }
482
483    #[test]
484    fn test_roundtrip() {
485        let original_bool = vec![Some(true), None, Some(false), Some(true)];
486        let original_int = vec![Some(100i32), None, Some(200), Some(300)];
487
488        let struct_vec = StructVectorMut::try_new(
489            Box::new([
490                NullVector::new(4).try_into_mut().unwrap().into(),
491                BoolVectorMut::from_iter(original_bool.clone()).into(),
492                PVectorMut::<i32>::from_iter(original_int.clone()).into(),
493            ]),
494            MaskMut::new_true(4),
495        )
496        .unwrap();
497
498        // Verify roundtrip preserves nulls.
499        if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
500            let roundtrip: Vec<_> = bool_vec.into_iter().collect();
501            assert_eq!(roundtrip, original_bool);
502        }
503
504        if let VectorMut::Primitive(prim_vec) = struct_vec.fields[2].clone() {
505            let roundtrip: Vec<_> = prim_vec.into_i32().into_iter().collect();
506            assert_eq!(roundtrip, original_int);
507        }
508    }
509
510    #[test]
511    fn test_nested_struct() {
512        let inner1 = StructVectorMut::try_new(
513            Box::new([
514                NullVector::new(4).try_into_mut().unwrap().into(),
515                BoolVectorMut::from_iter([true, false, true, false]).into(),
516            ]),
517            MaskMut::new_true(4),
518        )
519        .unwrap()
520        .into();
521
522        let inner2 = StructVectorMut::try_new(
523            Box::new([PVectorMut::<u32>::from_iter([100, 200, 300, 400]).into()]),
524            MaskMut::new_true(4),
525        )
526        .unwrap()
527        .into();
528
529        let mut outer =
530            StructVectorMut::try_new(Box::new([inner1, inner2]), MaskMut::new_true(4)).unwrap();
531
532        let second = outer.split_off(2);
533        assert_eq!(outer.len(), 2);
534        assert_eq!(second.len(), 2);
535
536        outer.unsplit(second);
537        assert_eq!(outer.len(), 4);
538        assert!(matches!(outer.fields[0], VectorMut::Struct(_)));
539    }
540
541    #[test]
542    fn test_reserve() {
543        // Test that reserve increases capacity for all fields correctly.
544        let mut struct_vec = StructVectorMut::try_new(
545            Box::new([
546                NullVectorMut::new(3).into(),
547                BoolVectorMut::from_iter([true, false, true]).into(),
548                PVectorMut::<i32>::from_iter([10, 20, 30]).into(),
549            ]),
550            MaskMut::new_true(3),
551        )
552        .unwrap();
553
554        let initial_capacity = struct_vec.capacity();
555        assert_eq!(struct_vec.len(), 3);
556
557        // Reserve additional capacity.
558        struct_vec.reserve(50);
559
560        // Capacity should now be at least len + 50.
561        assert!(struct_vec.capacity() >= 3 + 50);
562        assert!(struct_vec.capacity() >= initial_capacity + 50);
563
564        // Verify minimum_capacity returns the smallest field capacity.
565        let min_cap = struct_vec.minimum_capacity();
566        for field in struct_vec.fields() {
567            assert!(field.capacity() >= min_cap);
568        }
569
570        // Test reserve on an empty struct.
571        let mut empty_struct = StructVectorMut::try_new(
572            Box::new([
573                NullVectorMut::new(0).into(),
574                BoolVectorMut::with_capacity(0).into(),
575            ]),
576            MaskMut::new_true(0),
577        )
578        .unwrap();
579
580        empty_struct.reserve(100);
581        assert!(empty_struct.capacity() >= 100);
582    }
583
584    #[test]
585    fn test_freeze_and_new_unchecked() {
586        // Test new_unchecked creates a valid struct, and freeze preserves data correctly.
587        let fields = Box::new([
588            NullVectorMut::new(4).into(),
589            BoolVectorMut::from_iter([Some(true), None, Some(false), Some(true)]).into(),
590            PVectorMut::<i32>::from_iter([Some(100), Some(200), None, Some(400)]).into(),
591        ]);
592
593        let validity = Mask::from_iter([true, false, true, true])
594            .try_into_mut()
595            .unwrap();
596
597        // Use new_unchecked to create the struct.
598        // SAFETY: All fields have length 4 and validity has length 4.
599        let struct_vec = unsafe { StructVectorMut::new_unchecked(fields, validity) };
600
601        assert_eq!(struct_vec.len(), 4);
602        assert_eq!(struct_vec.fields().len(), 3);
603
604        // Freeze the struct and verify data preservation.
605        let frozen = struct_vec.freeze();
606
607        assert_eq!(frozen.len(), 4);
608        assert_eq!(frozen.fields().len(), 3);
609
610        // Verify validity is preserved (only indices 0, 2, 3 are valid at the struct level).
611        assert_eq!(frozen.validity().true_count(), 3);
612
613        // Verify that `try_into_mut` fails when data isn't owned.
614        {
615            let cloned_vector = frozen.fields()[1].clone();
616            cloned_vector.try_into_mut().unwrap_err();
617        }
618
619        // Verify field data is preserved.
620        let mut fields = Arc::try_unwrap(frozen.into_parts().0).unwrap().into_vec();
621
622        if let Vector::Primitive(prim_vec) = fields.pop().unwrap() {
623            let prim_vec_mut = prim_vec.try_into_mut().unwrap();
624            let values: Vec<_> = prim_vec_mut.into_i32().into_iter().collect();
625            assert_eq!(values, vec![Some(100), Some(200), None, Some(400)]);
626        } else {
627            panic!("Expected primitive vector");
628        }
629
630        if let Vector::Bool(bool_vec) = fields.pop().unwrap() {
631            let bool_vec_mut = bool_vec.try_into_mut().unwrap();
632            let values: Vec<_> = bool_vec_mut.into_iter().collect();
633            // Note: struct-level validity doesn't affect field-level data, only the interpretation.
634            assert_eq!(values, vec![Some(true), None, Some(false), Some(true)]);
635        } else {
636            panic!("Expected bool vector");
637        }
638    }
639
640    #[test]
641    fn test_with_capacity_struct() {
642        // Create a struct dtype with multiple field types.
643        let struct_dtype = DType::Struct(
644            StructFields::new(
645                FieldNames::from(["null_field", "bool_field", "int_field"]),
646                vec![
647                    DType::Null,
648                    DType::Bool(Nullability::NonNullable),
649                    DType::Primitive(PType::I32, Nullability::Nullable),
650                ],
651            ),
652            Nullability::Nullable,
653        );
654
655        // Create a VectorMut with capacity using the struct dtype.
656        let vector_mut = VectorMut::with_capacity(&struct_dtype, 100);
657
658        // Verify it's a struct vector.
659        match vector_mut {
660            VectorMut::Struct(mut struct_vec) => {
661                // Check initial state.
662                assert_eq!(struct_vec.len(), 0);
663                assert_eq!(struct_vec.fields.len(), 3);
664
665                // Verify each field has the correct type.
666                assert!(matches!(struct_vec.fields[0], VectorMut::Null(_)));
667                assert!(matches!(struct_vec.fields[1], VectorMut::Bool(_)));
668                assert!(matches!(struct_vec.fields[2], VectorMut::Primitive(_)));
669
670                // Check that capacity was reserved (minimum should be at least 100).
671                assert!(struct_vec.capacity() >= 100);
672
673                // Verify we can actually use the reserved capacity by pushing values.
674                for _ in 0..50 {
675                    struct_vec.append_nulls(1);
676                }
677                assert_eq!(struct_vec.len(), 50);
678
679                // Should not need reallocation since we reserved capacity.
680                assert!(struct_vec.capacity() >= 100);
681            }
682            _ => panic!("Expected VectorMut::Struct"),
683        }
684    }
685}