Skip to main content

xml2arrow/
config.rs

1use std::{
2    fs::File,
3    io::{BufReader, BufWriter},
4    path::Path,
5};
6
7use crate::errors::{Error, Result};
8use arrow::datatypes::DataType;
9use serde::{Deserialize, Serialize};
10
11/// Configuration for the XML parser.
12#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
13pub struct ParserOptions {
14    /// Whether to trim whitespace from text nodes. Defaults to false.
15    #[serde(default)]
16    pub trim_text: bool,
17    /// Optional XML paths where parsing should stop after the closing tag.
18    #[serde(default)]
19    pub stop_at_paths: Vec<String>,
20}
21
22impl Default for ParserOptions {
23    fn default() -> Self {
24        Self {
25            trim_text: false,
26            stop_at_paths: Vec::new(),
27        }
28    }
29}
30
31/// Top-level configuration for XML to Arrow conversion.
32///
33/// This struct holds a collection of `TableConfig` structs, each defining how a specific
34/// part of the XML document should be parsed into an Arrow table.
35#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
36pub struct Config {
37    /// A vector of `TableConfig` structs, each defining a table to be extracted from the XML.
38    pub tables: Vec<TableConfig>,
39    /// Parser options.
40    #[serde(default)]
41    pub parser_options: ParserOptions,
42}
43
44impl Config {
45    /// Validates the configuration by checking all field configurations.
46    ///
47    /// Returns an error if any field uses an unsupported combination (e.g., scale/offset on non-float types).
48    pub fn validate(&self) -> Result<()> {
49        for table in &self.tables {
50            for field in &table.fields {
51                field.validate()?;
52            }
53        }
54        Ok(())
55    }
56
57    /// Creates a `Config` struct from a YAML configuration file.
58    ///
59    /// This function reads a YAML file at the given path and deserializes it into a `Config` struct.
60    ///
61    /// # Arguments
62    ///
63    /// *   `path`: The path to the YAML configuration file.
64    ///
65    /// # Returns
66    ///
67    /// A `Result` containing:
68    ///
69    /// *   `Ok(Config)`: The deserialized `Config` struct.
70    /// *   `Err(Error)`: An `Error` value if the file cannot be opened, read, or parsed as YAML.
71    ///
72    /// # Errors
73    ///
74    /// This function may return the following errors:
75    ///
76    /// *   `Error::Io`: If an I/O error occurs while opening or reading the file.
77    /// *   `Error::Yaml`: If there is an error parsing the YAML data.
78    pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
79        let file = File::open(path)?;
80        let reader = BufReader::new(file);
81        let config: Config = serde_yaml::from_reader(reader).map_err(Error::Yaml)?;
82        config.validate()?;
83        Ok(config)
84    }
85
86    /// Writes the `Config` struct to a YAML file.
87    ///
88    /// This function serializes the `Config` struct to YAML format and writes it to a file at the given path.
89    ///
90    /// # Arguments
91    ///
92    /// *   `path`: The path to the output YAML file.
93    ///
94    /// # Returns
95    ///
96    /// A `Result` containing:
97    ///
98    /// *   `Ok(())`: If the `Config` was successfully written to the file.
99    /// *   `Err(Error)`: An `Error` value if the file cannot be created or the `Config` cannot be serialized to YAML.
100    ///
101    /// # Errors
102    ///
103    /// This function may return the following errors:
104    ///
105    /// *   `Error::Io`: If an I/O error occurs while creating or writing to the file.
106    /// *   `Error::Yaml`: If there is an error serializing the `Config` to YAML.
107    pub fn to_yaml_file(&self, path: impl AsRef<Path>) -> Result<()> {
108        let file = File::create(path)?;
109        let writer = BufWriter::new(file);
110        serde_yaml::to_writer(writer, self).map_err(Error::Yaml)
111    }
112
113    /// Checks if the configuration contains any fields that require attribute parsing.
114    ///
115    /// This method iterates through all tables and their fields in the configuration and returns
116    /// `true` if any field's XML path contains the "@" symbol, indicating that it targets an attribute.
117    ///
118    /// # Returns
119    ///
120    /// `true` if the configuration contains at least one attribute to parse, `false` otherwise.
121    pub fn requires_attribute_parsing(&self) -> bool {
122        for table in &self.tables {
123            for field in &table.fields {
124                if field.xml_path.contains('@') {
125                    return true;
126                }
127            }
128        }
129        false
130    }
131}
132
133/// Configuration for an XML table to be parsed into an Arrow record batch.
134///
135/// This struct defines how an XML structure should be interpreted as a table, including
136/// the path to the table elements, the element representing a row, and the configuration
137/// of the fields (columns) within the table.
138#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
139pub struct TableConfig {
140    /// The name of the table.
141    pub name: String,
142    /// The XML path to the table elements. For example `/data/dataset/table`.
143    pub xml_path: String,
144    /// The levels of nesting for this table. This is used to create the indices for nested tables.
145    /// For example if the xml_path is `/data/dataset/table/item/properties` the levels should
146    /// be `["table", "properties"]`.
147    pub levels: Vec<String>,
148    /// A vector of `FieldConfig` structs, each defining a field (column) in the table.
149    pub fields: Vec<FieldConfig>,
150}
151
152impl TableConfig {
153    pub fn new(name: &str, xml_path: &str, levels: Vec<String>, fields: Vec<FieldConfig>) -> Self {
154        Self {
155            name: name.to_string(),
156            xml_path: xml_path.to_string(),
157            levels,
158            fields,
159        }
160    }
161}
162
163/// Configuration for a single field within an XML table.
164///
165/// This struct defines how a specific XML element or attribute should be extracted and
166/// converted into an Arrow column.
167#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
168pub struct FieldConfig {
169    /// The name of the field (and the name of the resulting Arrow column).
170    pub name: String,
171    /// The XML path to the element or attribute.
172    pub xml_path: String,
173    /// The data type of the field. This determines the Arrow data type of the resulting column.
174    pub data_type: DType,
175    /// Whether the field is nullable (can contain null values). Defaults to false.
176    #[serde(default)]
177    pub nullable: bool,
178    /// Scale for decimal types.
179    pub scale: Option<f64>,
180    /// Offset for decimal types.
181    pub offset: Option<f64>,
182}
183
184impl FieldConfig {
185    /// Validates that scale/offset are only used with floating point data types.
186    pub fn validate(&self) -> Result<()> {
187        match self.data_type {
188            DType::Float32 | DType::Float64 => Ok(()),
189            _ => {
190                if self.scale.is_some() {
191                    return Err(Error::UnsupportedConversion(format!(
192                        "Scaling is only supported for Float32 and Float64, not {:?}",
193                        self.data_type
194                    )));
195                }
196                if self.offset.is_some() {
197                    return Err(Error::UnsupportedConversion(format!(
198                        "Offset is only supported for Float32 and Float64, not {:?}",
199                        self.data_type
200                    )));
201                }
202                Ok(())
203            }
204        }
205    }
206}
207/// A builder for configuring a `FieldConfig` struct.
208///
209/// This builder allows you to set the various properties of a field
210/// definition within a table configuration for parsing XML data.
211#[derive(Default)]
212pub struct FieldConfigBuilder {
213    name: String,
214    xml_path: String,
215    data_type: DType,
216    nullable: bool,
217    scale: Option<f64>,
218    offset: Option<f64>,
219}
220
221impl FieldConfigBuilder {
222    /// Creates a new `FieldConfigBuilder` with the provided name, XML path, and data type.
223    ///
224    /// This is the starting point for building a `FieldConfig`.
225    ///
226    /// # Arguments
227    ///
228    /// * `name` - The name of the field.
229    /// * `xml_path` - The XML path that points to the location of the field data in the XML document.
230    /// * `data_type` - The data type of the field.
231    ///
232    /// # Returns
233    ///
234    /// A new `FieldConfigBuilder` instance with the provided properties.
235    pub fn new(name: &str, xml_path: &str, data_type: DType) -> Self {
236        Self {
237            name: name.to_string(),
238            xml_path: xml_path.to_string(),
239            data_type,
240            ..Default::default()
241        }
242    }
243
244    /// Sets the `nullable` flag for the field configuration being built.
245    ///
246    /// This method allows you to specify whether the field can be null (missing data) in the XML document.
247    ///
248    /// # Arguments
249    ///
250    /// * `nullable` - A boolean value indicating whether the field is nullable.
251    ///
252    /// # Returns
253    ///
254    /// The builder instance itself, allowing for method chaining.
255    pub fn nullable(mut self, nullable: bool) -> Self {
256        self.nullable = nullable;
257        self
258    }
259
260    /// Sets the `scale` factor for the field configuration being built.
261    ///
262    /// This method is typically used with float data types to specify the scale factor.
263    ///
264    /// # Arguments
265    ///
266    /// * `scale` - The scale factor as an f64 value.
267    ///
268    /// # Returns
269    ///
270    /// The builder instance itself, allowing for method chaining.
271    pub fn scale(mut self, scale: f64) -> Self {
272        self.scale = Some(scale);
273        self
274    }
275
276    /// Sets the `offset` value for the field configuration being built.
277    ///
278    /// This method can be used with float data types to specify an offset value.
279    ///
280    /// # Arguments
281    ///
282    /// * `offset` - The offset value as an f64 value.
283    ///
284    /// # Returns
285    ///
286    /// The builder instance itself, allowing for method chaining.
287    pub fn offset(mut self, offset: f64) -> Self {
288        self.offset = Some(offset);
289        self
290    }
291
292    /// Consumes the builder and builds the final `FieldConfig` struct.
293    ///
294    /// This method takes the configuration set on the builder and returns a new `FieldConfig` instance.
295    ///
296    /// # Returns
297    ///
298    /// A `FieldConfig` struct with the configured properties
299    pub fn build(self) -> Result<FieldConfig> {
300        let cfg = FieldConfig {
301            name: self.name,
302            xml_path: self.xml_path,
303            data_type: self.data_type,
304            nullable: self.nullable,
305            scale: self.scale,
306            offset: self.offset,
307        };
308        cfg.validate()?;
309        Ok(cfg)
310    }
311}
312
313/// Represents the data type of a field.
314#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
315pub enum DType {
316    Boolean,
317    Float32,
318    Float64,
319    Int8,
320    UInt8,
321    Int16,
322    UInt16,
323    Int32,
324    UInt32,
325    Int64,
326    UInt64,
327    #[default]
328    Utf8,
329}
330
331impl DType {
332    pub(crate) fn as_arrow_type(&self) -> DataType {
333        match self {
334            DType::Boolean => DataType::Boolean,
335            DType::Float32 => DataType::Float32,
336            DType::Float64 => DataType::Float64,
337            DType::Utf8 => DataType::Utf8,
338            DType::Int8 => DataType::Int8,
339            DType::UInt8 => DataType::UInt8,
340            DType::Int16 => DataType::Int16,
341            DType::UInt16 => DataType::UInt16,
342            DType::Int32 => DataType::Int32,
343            &DType::UInt32 => DataType::UInt32,
344            DType::Int64 => DataType::Int64,
345            DType::UInt64 => DataType::UInt64,
346        }
347    }
348}
349
350/// Creates a `Config` struct from a YAML string at compile time.
351///
352/// This macro takes a YAML string literal as input and parses it into a `Config` struct at compile time.
353/// It panics if the YAML is invalid.
354#[macro_export]
355macro_rules! config_from_yaml {
356    ($yaml:expr) => {{
357        match serde_yaml::from_str::<$crate::config::Config>($yaml) {
358            Ok(config) => {
359                if let Err(e) = config.validate() {
360                    panic!("Invalid configuration: {:?}", e);
361                }
362                config
363            }
364            Err(e) => panic!("Invalid YAML configuration: {}", e),
365        }
366    }};
367}
368
369#[cfg(test)]
370mod tests {
371    use std::path::PathBuf;
372
373    use super::*;
374    use rstest::rstest;
375
376    #[rstest]
377    fn test_yaml_config_roundtrip(
378        #[values(
379            Config {
380                parser_options: Default::default(),
381                tables: vec![
382                    TableConfig::new("table1", "/path/to", vec![], vec![
383                        match FieldConfigBuilder::new("string_field", "/path/to/string_field", DType::Utf8)
384                            .nullable(true)
385                            .build()
386                        {
387                            Ok(f) => f,
388                            Err(e) => panic!("Failed to build field config: {:?}", e),
389                        },
390                        match FieldConfigBuilder::new("int32_field", "/path/to/int32_field", DType::Int32)
391                            .build()
392                        {
393                            Ok(f) => f,
394                            Err(e) => panic!("Failed to build field config: {:?}", e),
395                        },
396                        match FieldConfigBuilder::new("float64_field", "/path/to/float64_field", DType::Float64)
397                            .nullable(true)
398                            .scale(1.0e-9)
399                            .offset(1.0e-3)
400                            .build()
401                        {
402                            Ok(f) => f,
403                            Err(e) => panic!("Failed to build field config: {:?}", e),
404                        },
405                        ]
406                    ),
407                ],
408            },
409            Config {
410                parser_options: Default::default(),
411                tables: vec![]
412            }
413        )]
414        config: Config,
415    ) {
416        // Write to a temporary file
417        let temp_file = tempfile::NamedTempFile::new().unwrap();
418        let path = temp_file.path().to_path_buf();
419        config.to_yaml_file(&path).unwrap();
420
421        // Read from the same file
422        let read_config = Config::from_yaml_file(&path).unwrap();
423
424        // Check if the read config is the same as the original
425        assert_eq!(config, read_config);
426    }
427
428    #[test]
429    fn test_yaml_from_file_invalid_content() {
430        let invalid_yaml = "tables:\n  - name: table1\n    row_element: /path\n    fields:\n      - name: field1\n        xml_path: path\n        type: InvalidType\n        nullable: true";
431        let temp_file = tempfile::NamedTempFile::new().unwrap();
432        let path = temp_file.path().to_path_buf();
433        std::fs::write(&path, invalid_yaml).unwrap();
434        let result = Config::from_yaml_file(&path);
435        assert!(result.is_err());
436        assert!(matches!(result.unwrap_err(), Error::Yaml(_)));
437    }
438
439    #[test]
440    fn test_yaml_from_file_not_found() {
441        let result = Config::from_yaml_file(PathBuf::from("not_existing.yaml"));
442        assert!(result.is_err());
443        assert!(matches!(result.unwrap_err(), Error::Io(_)));
444    }
445
446    #[test]
447    fn test_yaml_to_file_invalid_path() {
448        let config = Config {
449            tables: vec![],
450            parser_options: Default::default(),
451        };
452        let result = config.to_yaml_file(PathBuf::from("/not/existing/path/config.yaml"));
453        assert!(result.is_err());
454        assert!(matches!(result.unwrap_err(), Error::Io(_)));
455    }
456
457    #[test]
458    fn test_yaml_field_nullable_default() {
459        let yaml_string = r#"
460            name: test_field
461            xml_path: /path/to/field
462            data_type: Utf8
463            "#;
464
465        let field_config: FieldConfig = serde_yaml::from_str(yaml_string).unwrap();
466        assert!(!field_config.nullable);
467    }
468
469    #[test]
470    fn test_yaml_parser_options_trim_text_default() {
471        let yaml_string = r#"
472            tables:
473              - name: test_table
474                xml_path: /root
475                levels: []
476                fields:
477                  - name: bool_field
478                    xml_path: /root/value
479                    data_type: Boolean
480                    nullable: true
481            "#;
482
483        let config: Config = serde_yaml::from_str(yaml_string).unwrap();
484        assert!(
485            !config.parser_options.trim_text,
486            "trim_text should default to false"
487        );
488    }
489
490    #[test]
491    fn test_yaml_parser_options_trim_text_explicit() {
492        let yaml_string = r#"
493            parser_options:
494              trim_text: true
495            tables: []
496            "#;
497
498        let config: Config = serde_yaml::from_str(yaml_string).unwrap();
499        assert!(
500            config.parser_options.trim_text,
501            "trim_text should be true when explicitly set"
502        );
503    }
504
505    #[test]
506    fn test_yaml_parser_options_empty_section() {
507        let yaml_string = r#"
508            parser_options: {}
509            tables: []
510            "#;
511
512        let config: Config = serde_yaml::from_str(yaml_string).unwrap();
513        assert!(
514            !config.parser_options.trim_text,
515            "trim_text should default to false when parser_options is empty"
516        );
517    }
518
519    #[test]
520    fn test_config_requires_attr_parsing_with_attributes() {
521        let config = Config {
522            tables: vec![TableConfig::new(
523                "test",
524                "/root",
525                vec![],
526                vec![
527                    match FieldConfigBuilder::new("id", "/root/item/@id", DType::Int32).build() {
528                        Ok(f) => f,
529                        Err(e) => panic!("Failed to build field config: {:?}", e),
530                    },
531                ],
532            )],
533            parser_options: Default::default(),
534        };
535
536        assert!(config.requires_attribute_parsing());
537    }
538
539    #[test]
540    fn test_config_requires_attr_parsing_without_attributes() {
541        let config = Config {
542            tables: vec![TableConfig::new(
543                "test",
544                "/root",
545                vec![],
546                vec![
547                    match FieldConfigBuilder::new("id", "/root/item/id", DType::Int32).build() {
548                        Ok(f) => f,
549                        Err(e) => panic!("Failed to build field config: {:?}", e),
550                    },
551                ],
552            )],
553            parser_options: Default::default(),
554        };
555
556        assert!(!config.requires_attribute_parsing());
557    }
558
559    #[test]
560    fn test_config_requires_attr_parsing_mixed() {
561        let config = Config {
562            tables: vec![TableConfig::new(
563                "test",
564                "/root",
565                vec![],
566                vec![
567                    match FieldConfigBuilder::new("id", "/root/item/id", DType::Int32).build() {
568                        Ok(f) => f,
569                        Err(e) => panic!("Failed to build field config: {:?}", e),
570                    },
571                    match FieldConfigBuilder::new("type", "/root/item/@type", DType::Utf8).build() {
572                        Ok(f) => f,
573                        Err(e) => panic!("Failed to build field config: {:?}", e),
574                    },
575                ],
576            )],
577            parser_options: Default::default(),
578        };
579
580        assert!(config.requires_attribute_parsing());
581    }
582
583    #[test]
584    fn test_dtype_as_arrow_type_all_variants() {
585        use arrow::datatypes::DataType as ArrowDataType;
586
587        assert_eq!(DType::Boolean.as_arrow_type(), ArrowDataType::Boolean);
588        assert_eq!(DType::Float32.as_arrow_type(), ArrowDataType::Float32);
589        assert_eq!(DType::Float64.as_arrow_type(), ArrowDataType::Float64);
590        assert_eq!(DType::Utf8.as_arrow_type(), ArrowDataType::Utf8);
591        assert_eq!(DType::Int8.as_arrow_type(), ArrowDataType::Int8);
592        assert_eq!(DType::UInt8.as_arrow_type(), ArrowDataType::UInt8);
593        assert_eq!(DType::Int16.as_arrow_type(), ArrowDataType::Int16);
594        assert_eq!(DType::UInt16.as_arrow_type(), ArrowDataType::UInt16);
595        assert_eq!(DType::Int32.as_arrow_type(), ArrowDataType::Int32);
596        assert_eq!(DType::UInt32.as_arrow_type(), ArrowDataType::UInt32);
597        assert_eq!(DType::Int64.as_arrow_type(), ArrowDataType::Int64);
598        assert_eq!(DType::UInt64.as_arrow_type(), ArrowDataType::UInt64);
599    }
600
601    #[test]
602    fn test_builder_field_config_chaining() {
603        let field = FieldConfigBuilder::new("test_field", "/path/to/field", DType::Float64)
604            .nullable(true)
605            .scale(0.001)
606            .offset(100.0)
607            .build()
608            .unwrap();
609
610        assert_eq!(field.name, "test_field");
611        assert_eq!(field.xml_path, "/path/to/field");
612        assert_eq!(field.data_type, DType::Float64);
613        assert!(field.nullable);
614        assert_eq!(field.scale, Some(0.001));
615        assert_eq!(field.offset, Some(100.0));
616    }
617
618    #[test]
619    fn test_builder_field_config_scale_only() {
620        let field = FieldConfigBuilder::new("test", "/path", DType::Float32)
621            .scale(0.5)
622            .build()
623            .unwrap();
624
625        assert_eq!(field.scale, Some(0.5));
626        assert_eq!(field.offset, None);
627    }
628
629    #[test]
630    fn test_builder_field_config_offset_only() {
631        let field = FieldConfigBuilder::new("test", "/path", DType::Float64)
632            .offset(5.0)
633            .build()
634            .unwrap();
635
636        assert_eq!(field.scale, None);
637        assert_eq!(field.offset, Some(5.0));
638    }
639}