1use std::sync::Arc;
7
8use vortex_dtype::StructFields;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_mask::MaskMut;
13
14use crate::ScalarOps;
15use crate::Vector;
16use crate::VectorMut;
17use crate::VectorMutOps;
18use crate::VectorOps;
19use crate::match_vector_pair;
20use crate::struct_::StructScalar;
21use crate::struct_::StructVector;
22
23#[derive(Debug, Clone)]
28pub struct StructVectorMut {
29 pub(super) fields: Box<[VectorMut]>,
31
32 pub(super) validity: MaskMut,
34
35 pub(super) len: usize,
40}
41
42impl StructVectorMut {
43 pub fn new(fields: Box<[VectorMut]>, validity: MaskMut) -> Self {
52 Self::try_new(fields, validity).vortex_expect("Failed to create `StructVectorMut`")
53 }
54
55 pub fn try_new(fields: Box<[VectorMut]>, validity: MaskMut) -> VortexResult<Self> {
64 let len = validity.len();
65
66 for (i, field) in fields.iter().enumerate() {
68 vortex_ensure!(
69 field.len() == len,
70 "Field {} has length {} but expected length {}",
71 i,
72 field.len(),
73 len
74 );
75 }
76
77 Ok(Self {
78 fields,
79 validity,
80 len,
81 })
82 }
83
84 pub unsafe fn new_unchecked(fields: Box<[VectorMut]>, validity: MaskMut) -> Self {
94 let len = validity.len();
95
96 if cfg!(debug_assertions) {
97 Self::new(fields, validity)
98 } else {
99 Self {
100 fields,
101 validity,
102 len,
103 }
104 }
105 }
106
107 pub fn with_capacity(struct_fields: &StructFields, capacity: usize) -> Self {
109 let fields: Vec<VectorMut> = struct_fields
110 .fields()
111 .map(|dtype| VectorMut::with_capacity(&dtype, capacity))
112 .collect();
113
114 let validity = MaskMut::with_capacity(capacity);
115 let len = validity.len();
116
117 Self {
118 fields: fields.into_boxed_slice(),
119 validity,
120 len,
121 }
122 }
123
124 pub fn into_parts(self) -> (Box<[VectorMut]>, MaskMut, usize) {
126 (self.fields, self.validity, self.len)
127 }
128
129 pub fn fields(&self) -> &[VectorMut] {
131 self.fields.as_ref()
132 }
133
134 pub unsafe fn fields_mut(&mut self) -> &mut [VectorMut] {
142 self.fields.as_mut()
143 }
144
145 pub unsafe fn validity_mut(&mut self) -> &mut MaskMut {
153 &mut self.validity
154 }
155
156 pub fn minimum_capacity(&self) -> usize {
165 self.fields
166 .iter()
167 .map(|field| field.capacity())
168 .min()
169 .unwrap_or(self.len)
170 }
171}
172
173impl VectorMutOps for StructVectorMut {
174 type Immutable = StructVector;
175
176 fn len(&self) -> usize {
177 self.len
178 }
179
180 fn validity(&self) -> &MaskMut {
181 &self.validity
182 }
183
184 fn capacity(&self) -> usize {
185 self.minimum_capacity()
186 }
187
188 fn reserve(&mut self, additional: usize) {
189 for field in &mut self.fields {
191 field.reserve(additional);
192
193 debug_assert_eq!(
194 field.len(),
195 self.len,
196 "Field length must match `StructVectorMut` length"
197 );
198 }
199
200 self.validity.reserve(additional);
201 }
202
203 fn clear(&mut self) {
204 for field in &mut self.fields {
205 field.clear();
206 }
207
208 self.validity.clear();
209 self.len = 0;
210 }
211
212 fn truncate(&mut self, len: usize) {
213 for field in &mut self.fields {
214 field.truncate(len);
215 }
216
217 self.validity.truncate(len);
218 self.len = self.validity.len();
219 }
220
221 fn extend_from_vector(&mut self, other: &StructVector) {
222 assert_eq!(
223 self.fields.len(),
224 other.fields().len(),
225 "Cannot extend StructVectorMut: field count mismatch (self had {} but other had {})",
226 self.fields.len(),
227 other.fields().len()
228 );
229
230 let pairs = self.fields.iter_mut().zip(other.fields().as_ref());
232 for (self_mut_vector, other_vec) in pairs {
233 match_vector_pair!(self_mut_vector, other_vec, |a: VectorMut, b: Vector| {
234 a.extend_from_vector(b)
235 })
236 }
237
238 self.validity.append_mask(other.validity());
240 self.len += other.len();
241
242 debug_assert_eq!(self.len, self.validity.len());
243 }
244
245 fn append_nulls(&mut self, n: usize) {
246 for field in &mut self.fields {
247 field.append_zeros(n);
248 }
249
250 self.validity.append_n(false, n);
251 self.len += n;
252 debug_assert_eq!(self.len, self.validity.len());
253 }
254
255 fn append_zeros(&mut self, n: usize) {
256 for field in &mut self.fields {
257 field.append_zeros(n);
258 }
259
260 self.validity.append_n(true, n);
261 self.len += n;
262 debug_assert_eq!(self.len, self.validity.len());
263 }
264
265 fn append_scalars(&mut self, scalar: &StructScalar, n: usize) {
266 if scalar.is_valid() {
267 for (v, s) in self.fields.iter_mut().zip(scalar.value().fields.iter()) {
268 v.append_scalars(&s.scalar_at(0), n)
269 }
270 self.validity.append_n(true, n)
271 } else {
272 for field in &mut self.fields {
273 field.append_zeros(n);
274 }
275 self.validity.append_n(false, n)
276 }
277 self.len += n;
278 }
279
280 fn freeze(self) -> StructVector {
281 let frozen_fields: Vec<Vector> = self
282 .fields
283 .into_iter()
284 .map(|mut_field| mut_field.freeze())
285 .collect();
286
287 StructVector {
288 fields: Arc::new(frozen_fields.into_boxed_slice()),
289 len: self.len,
290 validity: self.validity.freeze(),
291 }
292 }
293
294 fn split_off(&mut self, at: usize) -> Self {
295 assert!(
296 at <= self.capacity(),
297 "split_off out of bounds: {} > {}",
298 at,
299 self.capacity()
300 );
301
302 let split_fields: Vec<VectorMut> = self
303 .fields
304 .iter_mut()
305 .map(|field| field.split_off(at))
306 .collect();
307
308 let split_validity = self.validity.split_off(at);
309 let split_len = self.len.saturating_sub(at);
310 self.len = at;
311
312 debug_assert_eq!(self.len, self.validity.len());
313
314 Self {
315 fields: split_fields.into_boxed_slice(),
316 len: split_len,
317 validity: split_validity,
318 }
319 }
320
321 fn unsplit(&mut self, other: Self) {
322 assert_eq!(
323 self.fields.len(),
324 other.fields.len(),
325 "Cannot unsplit StructVectorMut: field count mismatch ({} vs {})",
326 self.fields.len(),
327 other.fields.len()
328 );
329
330 if self.is_empty() {
331 *self = other;
332 return;
333 }
334
335 let pairs = self.fields.iter_mut().zip(other.fields);
337 for (self_mut_vector, other_mut_vec) in pairs {
338 match_vector_pair!(
339 self_mut_vector,
340 other_mut_vec,
341 |a: VectorMut, b: VectorMut| a.unsplit(b)
342 )
343 }
344
345 self.validity.unsplit(other.validity);
346 self.len += other.len;
347 debug_assert_eq!(self.len, self.validity.len());
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use vortex_dtype::DType;
354 use vortex_dtype::FieldNames;
355 use vortex_dtype::Nullability;
356 use vortex_dtype::PType;
357 use vortex_dtype::PTypeDowncast;
358 use vortex_dtype::StructFields;
359 use vortex_mask::Mask;
360 use vortex_mask::MaskMut;
361
362 use super::*;
363 use crate::VectorMut;
364 use crate::bool::BoolVectorMut;
365 use crate::null::NullVector;
366 use crate::null::NullVectorMut;
367 use crate::primitive::PVectorMut;
368
369 #[test]
370 fn test_empty_fields() {
371 let mut struct_vec = StructVectorMut::try_new(Box::new([]), MaskMut::new_true(10)).unwrap();
372 let second_half = struct_vec.split_off(6);
373 assert_eq!(struct_vec.len(), 6);
374 assert_eq!(second_half.len(), 4);
375 }
376
377 #[test]
378 fn test_try_into_mut_and_values() {
379 let struct_vec = StructVector {
380 fields: Arc::new(Box::new([
381 NullVector::new(5).into(),
382 BoolVectorMut::from_iter([true, false, true, false, true])
383 .freeze()
384 .into(),
385 PVectorMut::<i32>::from_iter([10, 20, 30, 40, 50])
386 .freeze()
387 .into(),
388 ])),
389 len: 5,
390 validity: Mask::AllTrue(5),
391 };
392
393 let mut_struct = struct_vec.try_into_mut().unwrap();
394 assert_eq!(mut_struct.len(), 5);
395
396 if let VectorMut::Bool(bool_vec) = mut_struct.fields[1].clone() {
398 let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
399 assert_eq!(values, vec![true, false, true, false, true]);
400 }
401
402 if let VectorMut::Primitive(prim_vec) = mut_struct.fields[2].clone() {
403 let values: Vec<_> = prim_vec
404 .into_i32()
405 .into_iter()
406 .map(|v| v.unwrap())
407 .collect();
408 assert_eq!(values, vec![10, 20, 30, 40, 50]);
409 }
410 }
411
412 #[test]
413 fn test_try_into_mut_shared_ownership() {
414 let bool_field: Vector = BoolVectorMut::from_iter([true, false, true])
416 .freeze()
417 .into();
418 let bool_field_clone = bool_field.clone();
419
420 let struct_vec = StructVector {
421 fields: Arc::new(Box::new([
422 NullVector::new(3).into(),
423 bool_field_clone,
424 PVectorMut::<i32>::from_iter([1, 2, 3]).freeze().into(),
425 ])),
426 len: 3,
427 validity: Mask::AllTrue(3),
428 };
429
430 assert!(struct_vec.try_into_mut().is_err());
431 drop(bool_field); }
433
434 #[test]
435 fn test_split_unsplit_values() {
436 let mut struct_vec = StructVectorMut::try_new(
437 Box::new([
438 NullVectorMut::new(8).into(),
439 BoolVectorMut::from_iter([true, false, true, false, true, false, true, false])
440 .into(),
441 PVectorMut::<i32>::from_iter([10, 20, 30, 40, 50, 60, 70, 80]).into(),
442 ]),
443 MaskMut::new_true(8),
444 )
445 .unwrap();
446
447 let second_half = struct_vec.split_off(5);
448 assert_eq!(struct_vec.len(), 5);
449 assert_eq!(second_half.len(), 3);
450
451 if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
453 let values: Vec<_> = bool_vec.into_iter().take(5).map(|v| v.unwrap()).collect();
454 assert_eq!(values, vec![true, false, true, false, true]);
455 }
456
457 if let VectorMut::Bool(bool_vec) = second_half.fields[1].clone() {
458 let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
459 assert_eq!(values, vec![false, true, false]);
460 }
461
462 struct_vec.unsplit(second_half);
464 assert_eq!(struct_vec.len(), 8);
465
466 if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
467 let values: Vec<_> = bool_vec.into_iter().map(|v| v.unwrap()).collect();
468 assert_eq!(
469 values,
470 vec![true, false, true, false, true, false, true, false]
471 );
472 }
473 }
474
475 #[test]
476 fn test_extend_and_append_nulls() {
477 let mut struct_vec = StructVectorMut::try_new(
478 Box::new([
479 NullVector::new(3).try_into_mut().unwrap().into(),
480 BoolVectorMut::from_iter([true, false, true]).into(),
481 PVectorMut::<i32>::from_iter([10, 20, 30]).into(),
482 ]),
483 MaskMut::new_true(3),
484 )
485 .unwrap();
486
487 let to_extend = StructVector {
489 fields: Arc::new(Box::new([
490 NullVector::new(2).into(),
491 BoolVectorMut::from_iter([false, true]).freeze().into(),
492 PVectorMut::<i32>::from_iter([40, 50]).freeze().into(),
493 ])),
494 len: 2,
495 validity: Mask::AllTrue(2),
496 };
497
498 struct_vec.extend_from_vector(&to_extend);
499 assert_eq!(struct_vec.len(), 5);
500
501 struct_vec.append_nulls(2);
503 assert_eq!(struct_vec.len(), 7);
504
505 if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
507 let values: Vec<_> = bool_vec.into_iter().collect();
508 assert_eq!(
509 values,
510 vec![
511 Some(true),
512 Some(false),
513 Some(true),
514 Some(false),
515 Some(true),
516 Some(false),
517 Some(false)
518 ]
519 );
520 }
521 }
522
523 #[test]
524 fn test_roundtrip() {
525 let original_bool = vec![Some(true), None, Some(false), Some(true)];
526 let original_int = vec![Some(100i32), None, Some(200), Some(300)];
527
528 let struct_vec = StructVectorMut::try_new(
529 Box::new([
530 NullVector::new(4).try_into_mut().unwrap().into(),
531 BoolVectorMut::from_iter(original_bool.clone()).into(),
532 PVectorMut::<i32>::from_iter(original_int.clone()).into(),
533 ]),
534 MaskMut::new_true(4),
535 )
536 .unwrap();
537
538 if let VectorMut::Bool(bool_vec) = struct_vec.fields[1].clone() {
540 let roundtrip: Vec<_> = bool_vec.into_iter().collect();
541 assert_eq!(roundtrip, original_bool);
542 }
543
544 if let VectorMut::Primitive(prim_vec) = struct_vec.fields[2].clone() {
545 let roundtrip: Vec<_> = prim_vec.into_i32().into_iter().collect();
546 assert_eq!(roundtrip, original_int);
547 }
548 }
549
550 #[test]
551 fn test_nested_struct() {
552 let inner1 = StructVectorMut::try_new(
553 Box::new([
554 NullVector::new(4).try_into_mut().unwrap().into(),
555 BoolVectorMut::from_iter([true, false, true, false]).into(),
556 ]),
557 MaskMut::new_true(4),
558 )
559 .unwrap()
560 .into();
561
562 let inner2 = StructVectorMut::try_new(
563 Box::new([PVectorMut::<u32>::from_iter([100, 200, 300, 400]).into()]),
564 MaskMut::new_true(4),
565 )
566 .unwrap()
567 .into();
568
569 let mut outer =
570 StructVectorMut::try_new(Box::new([inner1, inner2]), MaskMut::new_true(4)).unwrap();
571
572 let second = outer.split_off(2);
573 assert_eq!(outer.len(), 2);
574 assert_eq!(second.len(), 2);
575
576 outer.unsplit(second);
577 assert_eq!(outer.len(), 4);
578 assert!(matches!(outer.fields[0], VectorMut::Struct(_)));
579 }
580
581 #[test]
582 fn test_reserve() {
583 let mut struct_vec = StructVectorMut::try_new(
585 Box::new([
586 NullVectorMut::new(3).into(),
587 BoolVectorMut::from_iter([true, false, true]).into(),
588 PVectorMut::<i32>::from_iter([10, 20, 30]).into(),
589 ]),
590 MaskMut::new_true(3),
591 )
592 .unwrap();
593
594 let initial_capacity = struct_vec.capacity();
595 assert_eq!(struct_vec.len(), 3);
596
597 struct_vec.reserve(50);
599
600 assert!(struct_vec.capacity() >= 3 + 50);
602 assert!(struct_vec.capacity() >= initial_capacity + 50);
603
604 let min_cap = struct_vec.minimum_capacity();
606 for field in struct_vec.fields() {
607 assert!(field.capacity() >= min_cap);
608 }
609
610 let mut empty_struct = StructVectorMut::try_new(
612 Box::new([
613 NullVectorMut::new(0).into(),
614 BoolVectorMut::with_capacity(0).into(),
615 ]),
616 MaskMut::new_true(0),
617 )
618 .unwrap();
619
620 empty_struct.reserve(100);
621 assert!(empty_struct.capacity() >= 100);
622 }
623
624 #[test]
625 fn test_freeze_and_new_unchecked() {
626 let fields = Box::new([
628 NullVectorMut::new(4).into(),
629 BoolVectorMut::from_iter([Some(true), None, Some(false), Some(true)]).into(),
630 PVectorMut::<i32>::from_iter([Some(100), Some(200), None, Some(400)]).into(),
631 ]);
632
633 let validity = Mask::from_iter([true, false, true, true])
634 .try_into_mut()
635 .unwrap();
636
637 let struct_vec = unsafe { StructVectorMut::new_unchecked(fields, validity) };
640
641 assert_eq!(struct_vec.len(), 4);
642 assert_eq!(struct_vec.fields().len(), 3);
643
644 let frozen = struct_vec.freeze();
646
647 assert_eq!(frozen.len(), 4);
648 assert_eq!(frozen.fields().len(), 3);
649
650 assert_eq!(frozen.validity().true_count(), 3);
652
653 {
655 let cloned_vector = frozen.fields()[1].clone();
656 cloned_vector.try_into_mut().unwrap_err();
657 }
658
659 let mut fields = Arc::try_unwrap(frozen.into_parts().0).unwrap().into_vec();
661
662 if let Vector::Primitive(prim_vec) = fields.pop().unwrap() {
663 let prim_vec_mut = prim_vec.try_into_mut().unwrap();
664 let values: Vec<_> = prim_vec_mut.into_i32().into_iter().collect();
665 assert_eq!(values, vec![Some(100), Some(200), None, Some(400)]);
666 } else {
667 panic!("Expected primitive vector");
668 }
669
670 if let Vector::Bool(bool_vec) = fields.pop().unwrap() {
671 let bool_vec_mut = bool_vec.try_into_mut().unwrap();
672 let values: Vec<_> = bool_vec_mut.into_iter().collect();
673 assert_eq!(values, vec![Some(true), None, Some(false), Some(true)]);
675 } else {
676 panic!("Expected bool vector");
677 }
678 }
679
680 #[test]
681 fn test_with_capacity_struct() {
682 let struct_dtype = DType::Struct(
684 StructFields::new(
685 FieldNames::from(["null_field", "bool_field", "int_field"]),
686 vec![
687 DType::Null,
688 DType::Bool(Nullability::NonNullable),
689 DType::Primitive(PType::I32, Nullability::Nullable),
690 ],
691 ),
692 Nullability::Nullable,
693 );
694
695 let vector_mut = VectorMut::with_capacity(&struct_dtype, 100);
697
698 match vector_mut {
700 VectorMut::Struct(mut struct_vec) => {
701 assert_eq!(struct_vec.len(), 0);
703 assert_eq!(struct_vec.fields.len(), 3);
704
705 assert!(matches!(struct_vec.fields[0], VectorMut::Null(_)));
707 assert!(matches!(struct_vec.fields[1], VectorMut::Bool(_)));
708 assert!(matches!(struct_vec.fields[2], VectorMut::Primitive(_)));
709
710 assert!(struct_vec.capacity() >= 100);
712
713 for _ in 0..50 {
715 struct_vec.append_nulls(1);
716 }
717 assert_eq!(struct_vec.len(), 50);
718
719 assert!(struct_vec.capacity() >= 100);
721 }
722 _ => panic!("Expected VectorMut::Struct"),
723 }
724 }
725}