vortex_fsst/compute/
cast.rs1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::VarBinArray;
9use vortex_array::arrays::varbin::VarBinArrayExt;
10use vortex_array::dtype::DType;
11use vortex_array::scalar_fn::fns::cast::CastKernel;
12use vortex_array::scalar_fn::fns::cast::CastReduce;
13use vortex_array::validity::Validity;
14use vortex_error::VortexResult;
15
16use crate::FSST;
17use crate::FSSTArrayExt;
18
19fn build_with_codes_validity(
20 array: ArrayView<'_, FSST>,
21 dtype: &DType,
22 new_codes_validity: Validity,
23) -> VortexResult<ArrayRef> {
24 let codes = array.codes();
25 let new_codes = VarBinArray::try_new(
26 codes.offsets().clone(),
27 codes.bytes().clone(),
28 codes.dtype().with_nullability(dtype.nullability()),
29 new_codes_validity,
30 )?;
31
32 Ok(unsafe {
33 FSST::new_unchecked(
34 dtype.clone(),
35 array.symbols().clone(),
36 array.symbol_lengths().clone(),
37 new_codes,
38 array.uncompressed_lengths().clone(),
39 )
40 }
41 .into_array())
42}
43
44impl CastReduce for FSST {
45 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
46 if !array.dtype().eq_ignore_nullability(dtype) {
47 return Ok(None);
48 }
49
50 let codes = array.codes();
51 let Some(new_codes_validity) = codes
52 .validity()?
53 .trivial_cast_nullability(dtype.nullability(), codes.len())?
54 else {
55 return Ok(None);
56 };
57
58 Ok(Some(build_with_codes_validity(
59 array,
60 dtype,
61 new_codes_validity,
62 )?))
63 }
64}
65
66impl CastKernel for FSST {
67 fn cast(
68 array: ArrayView<'_, Self>,
69 dtype: &DType,
70 ctx: &mut ExecutionCtx,
71 ) -> VortexResult<Option<ArrayRef>> {
72 if !array.dtype().eq_ignore_nullability(dtype) {
73 return Ok(None);
74 }
75
76 let codes = array.codes();
77 let new_codes_validity =
78 codes
79 .validity()?
80 .cast_nullability(dtype.nullability(), codes.len(), ctx)?;
81
82 Ok(Some(build_with_codes_validity(
83 array,
84 dtype,
85 new_codes_validity,
86 )?))
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use std::sync::LazyLock;
93
94 use rstest::rstest;
95 use vortex_array::IntoArray;
96 use vortex_array::VortexSessionExecute;
97 use vortex_array::arrays::VarBinArray;
98 use vortex_array::builtins::ArrayBuiltins;
99 use vortex_array::compute::conformance::cast::test_cast_conformance;
100 use vortex_array::dtype::DType;
101 use vortex_array::dtype::Nullability;
102 use vortex_array::session::ArraySession;
103 use vortex_session::VortexSession;
104
105 use crate::fsst_compress;
106 use crate::fsst_train_compressor;
107
108 static SESSION: LazyLock<VortexSession> =
109 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
110
111 #[test]
112 fn test_cast_fsst_nullability() {
113 let mut ctx = SESSION.create_execution_ctx();
114 let strings = VarBinArray::from_iter(
115 vec![Some("hello"), Some("world"), Some("hello world")],
116 DType::Utf8(Nullability::NonNullable),
117 );
118
119 let compressor = fsst_train_compressor(&strings);
120 let len = strings.len();
121 let dtype = strings.dtype().clone();
122 let fsst = fsst_compress(strings, len, &dtype, &compressor, &mut ctx);
123
124 let casted = fsst
126 .into_array()
127 .cast(DType::Utf8(Nullability::Nullable))
128 .unwrap();
129 assert_eq!(casted.dtype(), &DType::Utf8(Nullability::Nullable));
130 }
131
132 #[rstest]
133 #[case(VarBinArray::from_iter(
134 vec![Some("hello"), Some("world"), Some("hello world")],
135 DType::Utf8(Nullability::NonNullable)
136 ))]
137 #[case(VarBinArray::from_iter(
138 vec![Some("foo"), None, Some("bar"), Some("foobar")],
139 DType::Utf8(Nullability::Nullable)
140 ))]
141 #[case(VarBinArray::from_iter(
142 vec![Some("test")],
143 DType::Utf8(Nullability::NonNullable)
144 ))]
145 fn test_cast_fsst_conformance(#[case] array: VarBinArray) {
146 let mut ctx = SESSION.create_execution_ctx();
147 let compressor = fsst_train_compressor(&array);
148 let fsst = fsst_compress(&array, array.len(), array.dtype(), &compressor, &mut ctx);
149 test_cast_conformance(&fsst.into_array());
150 }
151}