Skip to main content

wp_arrow/
schema.rs

1use std::sync::Arc;
2
3use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField, Schema, TimeUnit};
4
5use crate::error::WpArrowError;
6
7/// WPL data types that can be mapped to Apache Arrow types.
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub enum WpDataType {
10    Chars,
11    Digit,
12    Float,
13    Bool,
14    Time,
15    Ip,
16    Hex,
17    Array(Box<WpDataType>),
18}
19
20/// A named, typed field definition for building Arrow schemas.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct FieldDef {
23    pub name: String,
24    pub data_type: WpDataType,
25    pub nullable: bool,
26}
27
28impl FieldDef {
29    pub fn new(name: impl Into<String>, data_type: WpDataType) -> Self {
30        Self {
31            name: name.into(),
32            data_type,
33            nullable: true,
34        }
35    }
36
37    pub fn with_nullable(mut self, nullable: bool) -> Self {
38        self.nullable = nullable;
39        self
40    }
41}
42
43/// Maps a [`WpDataType`] to the corresponding [`ArrowDataType`].
44pub fn to_arrow_type(wp_type: &WpDataType) -> ArrowDataType {
45    match wp_type {
46        WpDataType::Chars => ArrowDataType::Utf8,
47        WpDataType::Digit => ArrowDataType::Int64,
48        WpDataType::Float => ArrowDataType::Float64,
49        WpDataType::Bool => ArrowDataType::Boolean,
50        WpDataType::Time => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
51        WpDataType::Ip => ArrowDataType::Utf8,
52        WpDataType::Hex => ArrowDataType::Utf8,
53        WpDataType::Array(inner) => {
54            let inner_arrow = to_arrow_type(inner);
55            ArrowDataType::List(Arc::new(ArrowField::new("item", inner_arrow, true)))
56        }
57    }
58}
59
60/// Converts a [`FieldDef`] into an Arrow [`ArrowField`].
61///
62/// Returns an error if the field name is empty.
63pub fn to_arrow_field(field: &FieldDef) -> Result<ArrowField, WpArrowError> {
64    if field.name.is_empty() {
65        return Err(WpArrowError::EmptyFieldName);
66    }
67    let arrow_type = to_arrow_type(&field.data_type);
68    Ok(ArrowField::new(&field.name, arrow_type, field.nullable))
69}
70
71/// Converts a slice of [`FieldDef`] into an Arrow [`Schema`].
72pub fn to_arrow_schema(fields: &[FieldDef]) -> Result<Schema, WpArrowError> {
73    let arrow_fields: Vec<ArrowField> = fields
74        .iter()
75        .map(to_arrow_field)
76        .collect::<Result<_, _>>()?;
77    Ok(Schema::new(arrow_fields))
78}
79
80/// Parses a WPL type string into a [`WpDataType`].
81///
82/// Supported formats:
83/// - Basic types: `"chars"`, `"digit"`, `"float"`, `"bool"`, `"time"`, `"ip"`, `"hex"`
84/// - Array types: `"array<chars>"`, `"array<array<digit>>"`
85///
86/// Type names are case-insensitive.
87pub fn parse_wp_type(s: &str) -> Result<WpDataType, WpArrowError> {
88    let s = s.trim();
89    let lower = s.to_ascii_lowercase();
90
91    match lower.as_str() {
92        "chars" => Ok(WpDataType::Chars),
93        "digit" => Ok(WpDataType::Digit),
94        "float" => Ok(WpDataType::Float),
95        "bool" => Ok(WpDataType::Bool),
96        "time" => Ok(WpDataType::Time),
97        "ip" => Ok(WpDataType::Ip),
98        "hex" => Ok(WpDataType::Hex),
99        _ if lower.starts_with("array<") && lower.ends_with('>') => {
100            let inner_str = &s[6..s.len() - 1];
101            let inner_trimmed = inner_str.trim();
102            if inner_trimmed.is_empty() {
103                return Err(WpArrowError::InvalidArrayInnerType(String::new()));
104            }
105            let inner = parse_wp_type(inner_trimmed)?;
106            Ok(WpDataType::Array(Box::new(inner)))
107        }
108        _ => Err(WpArrowError::UnsupportedDataType(s.to_string())),
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    // ---------------------------------------------------------------
117    // to_arrow_type: basic types
118    // ---------------------------------------------------------------
119
120    #[test]
121    fn arrow_type_chars() {
122        assert_eq!(to_arrow_type(&WpDataType::Chars), ArrowDataType::Utf8);
123    }
124
125    #[test]
126    fn arrow_type_digit() {
127        assert_eq!(to_arrow_type(&WpDataType::Digit), ArrowDataType::Int64);
128    }
129
130    #[test]
131    fn arrow_type_float() {
132        assert_eq!(to_arrow_type(&WpDataType::Float), ArrowDataType::Float64);
133    }
134
135    #[test]
136    fn arrow_type_bool() {
137        assert_eq!(to_arrow_type(&WpDataType::Bool), ArrowDataType::Boolean);
138    }
139
140    #[test]
141    fn arrow_type_time() {
142        assert_eq!(
143            to_arrow_type(&WpDataType::Time),
144            ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
145        );
146    }
147
148    #[test]
149    fn arrow_type_ip() {
150        assert_eq!(to_arrow_type(&WpDataType::Ip), ArrowDataType::Utf8);
151    }
152
153    #[test]
154    fn arrow_type_hex() {
155        assert_eq!(to_arrow_type(&WpDataType::Hex), ArrowDataType::Utf8);
156    }
157
158    // ---------------------------------------------------------------
159    // to_arrow_type: array types
160    // ---------------------------------------------------------------
161
162    #[test]
163    fn arrow_type_array_digit() {
164        let wp = WpDataType::Array(Box::new(WpDataType::Digit));
165        let arrow = to_arrow_type(&wp);
166        assert_eq!(
167            arrow,
168            ArrowDataType::List(Arc::new(ArrowField::new(
169                "item",
170                ArrowDataType::Int64,
171                true
172            )))
173        );
174    }
175
176    #[test]
177    fn arrow_type_array_chars() {
178        let wp = WpDataType::Array(Box::new(WpDataType::Chars));
179        let arrow = to_arrow_type(&wp);
180        assert_eq!(
181            arrow,
182            ArrowDataType::List(Arc::new(ArrowField::new("item", ArrowDataType::Utf8, true)))
183        );
184    }
185
186    #[test]
187    fn arrow_type_nested_array() {
188        let wp = WpDataType::Array(Box::new(WpDataType::Array(Box::new(WpDataType::Float))));
189        let inner_list = ArrowDataType::List(Arc::new(ArrowField::new(
190            "item",
191            ArrowDataType::Float64,
192            true,
193        )));
194        let expected = ArrowDataType::List(Arc::new(ArrowField::new("item", inner_list, true)));
195        assert_eq!(to_arrow_type(&wp), expected);
196    }
197
198    // ---------------------------------------------------------------
199    // to_arrow_field
200    // ---------------------------------------------------------------
201
202    #[test]
203    fn arrow_field_basic() {
204        let fd = FieldDef::new("src_ip", WpDataType::Ip);
205        let field = to_arrow_field(&fd).unwrap();
206        assert_eq!(field.name(), "src_ip");
207        assert_eq!(field.data_type(), &ArrowDataType::Utf8);
208        assert!(field.is_nullable());
209    }
210
211    #[test]
212    fn arrow_field_non_nullable() {
213        let fd = FieldDef::new("count", WpDataType::Digit).with_nullable(false);
214        let field = to_arrow_field(&fd).unwrap();
215        assert!(!field.is_nullable());
216    }
217
218    #[test]
219    fn arrow_field_empty_name_errors() {
220        let fd = FieldDef::new("", WpDataType::Chars);
221        assert_eq!(to_arrow_field(&fd), Err(WpArrowError::EmptyFieldName));
222    }
223
224    // ---------------------------------------------------------------
225    // to_arrow_schema
226    // ---------------------------------------------------------------
227
228    #[test]
229    fn arrow_schema_firewall_log() {
230        let fields = vec![
231            FieldDef::new("src_ip", WpDataType::Ip),
232            FieldDef::new("dst_ip", WpDataType::Ip),
233            FieldDef::new("port", WpDataType::Digit),
234            FieldDef::new("protocol", WpDataType::Chars),
235            FieldDef::new("timestamp", WpDataType::Time),
236            FieldDef::new("allowed", WpDataType::Bool),
237        ];
238        let schema = to_arrow_schema(&fields).unwrap();
239        assert_eq!(schema.fields().len(), 6);
240        assert_eq!(schema.field(0).name(), "src_ip");
241        assert_eq!(schema.field(2).data_type(), &ArrowDataType::Int64);
242        assert_eq!(
243            schema.field(4).data_type(),
244            &ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
245        );
246    }
247
248    #[test]
249    fn arrow_schema_with_array_field() {
250        let fields = vec![
251            FieldDef::new("name", WpDataType::Chars),
252            FieldDef::new("tags", WpDataType::Array(Box::new(WpDataType::Chars))),
253        ];
254        let schema = to_arrow_schema(&fields).unwrap();
255        assert_eq!(schema.fields().len(), 2);
256        assert!(matches!(
257            schema.field(1).data_type(),
258            ArrowDataType::List(_)
259        ));
260    }
261
262    #[test]
263    fn arrow_schema_empty_fields() {
264        let schema = to_arrow_schema(&[]).unwrap();
265        assert_eq!(schema.fields().len(), 0);
266    }
267
268    #[test]
269    fn arrow_schema_error_propagation() {
270        let fields = vec![
271            FieldDef::new("ok", WpDataType::Chars),
272            FieldDef::new("", WpDataType::Digit),
273        ];
274        assert_eq!(to_arrow_schema(&fields), Err(WpArrowError::EmptyFieldName));
275    }
276
277    // ---------------------------------------------------------------
278    // parse_wp_type: basic types
279    // ---------------------------------------------------------------
280
281    #[test]
282    fn parse_chars() {
283        assert_eq!(parse_wp_type("chars"), Ok(WpDataType::Chars));
284    }
285
286    #[test]
287    fn parse_digit() {
288        assert_eq!(parse_wp_type("digit"), Ok(WpDataType::Digit));
289    }
290
291    #[test]
292    fn parse_float() {
293        assert_eq!(parse_wp_type("float"), Ok(WpDataType::Float));
294    }
295
296    #[test]
297    fn parse_bool() {
298        assert_eq!(parse_wp_type("bool"), Ok(WpDataType::Bool));
299    }
300
301    #[test]
302    fn parse_time() {
303        assert_eq!(parse_wp_type("time"), Ok(WpDataType::Time));
304    }
305
306    #[test]
307    fn parse_ip() {
308        assert_eq!(parse_wp_type("ip"), Ok(WpDataType::Ip));
309    }
310
311    #[test]
312    fn parse_hex() {
313        assert_eq!(parse_wp_type("hex"), Ok(WpDataType::Hex));
314    }
315
316    // ---------------------------------------------------------------
317    // parse_wp_type: case insensitivity
318    // ---------------------------------------------------------------
319
320    #[test]
321    fn parse_case_insensitive() {
322        assert_eq!(parse_wp_type("CHARS"), Ok(WpDataType::Chars));
323        assert_eq!(parse_wp_type("Digit"), Ok(WpDataType::Digit));
324        assert_eq!(parse_wp_type("BOOL"), Ok(WpDataType::Bool));
325    }
326
327    // ---------------------------------------------------------------
328    // parse_wp_type: array types
329    // ---------------------------------------------------------------
330
331    #[test]
332    fn parse_array_chars() {
333        assert_eq!(
334            parse_wp_type("array<chars>"),
335            Ok(WpDataType::Array(Box::new(WpDataType::Chars)))
336        );
337    }
338
339    #[test]
340    fn parse_array_digit() {
341        assert_eq!(
342            parse_wp_type("array<digit>"),
343            Ok(WpDataType::Array(Box::new(WpDataType::Digit)))
344        );
345    }
346
347    #[test]
348    fn parse_nested_array() {
349        assert_eq!(
350            parse_wp_type("array<array<float>>"),
351            Ok(WpDataType::Array(Box::new(WpDataType::Array(Box::new(
352                WpDataType::Float
353            )))))
354        );
355    }
356
357    #[test]
358    fn parse_array_with_whitespace() {
359        assert_eq!(
360            parse_wp_type("  array< chars >  "),
361            Ok(WpDataType::Array(Box::new(WpDataType::Chars)))
362        );
363    }
364
365    // ---------------------------------------------------------------
366    // parse_wp_type: error cases
367    // ---------------------------------------------------------------
368
369    #[test]
370    fn parse_unsupported_type() {
371        let err = parse_wp_type("unknown").unwrap_err();
372        assert_eq!(
373            err,
374            WpArrowError::UnsupportedDataType("unknown".to_string())
375        );
376    }
377
378    #[test]
379    fn parse_array_empty_inner() {
380        let err = parse_wp_type("array<>").unwrap_err();
381        assert_eq!(err, WpArrowError::InvalidArrayInnerType(String::new()));
382    }
383
384    #[test]
385    fn parse_array_invalid_inner() {
386        let err = parse_wp_type("array<invalid>").unwrap_err();
387        assert_eq!(
388            err,
389            WpArrowError::UnsupportedDataType("invalid".to_string())
390        );
391    }
392
393    // ---------------------------------------------------------------
394    // Property tests: Clone, Eq, Hash, FieldDef defaults
395    // ---------------------------------------------------------------
396
397    #[test]
398    fn wf_data_type_clone_eq() {
399        let a = WpDataType::Array(Box::new(WpDataType::Chars));
400        let b = a.clone();
401        assert_eq!(a, b);
402    }
403
404    #[test]
405    fn wf_data_type_hash_consistent() {
406        use std::collections::HashSet;
407        let mut set = HashSet::new();
408        set.insert(WpDataType::Digit);
409        set.insert(WpDataType::Digit);
410        assert_eq!(set.len(), 1);
411    }
412
413    #[test]
414    fn field_def_default_nullable() {
415        let fd = FieldDef::new("test", WpDataType::Bool);
416        assert!(fd.nullable);
417    }
418}