yaxp_core/xsdp/
parser.rs

1use arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
2use encoding_rs::{Encoding, UTF_8};
3use encoding_rs_io::DecodeReaderBytesBuilder;
4use indexmap::IndexMap;
5use polars::datatypes::TimeUnit as PolarsTimeUnit;
6use polars::datatypes::{DataType as PolarsDataType, PlSmallStr};
7use polars::prelude::Schema as PolarsSchema;
8#[cfg(feature = "python")]
9use pyo3::exceptions::PyValueError;
10#[cfg(feature = "python")]
11use pyo3::prelude::{PyAnyMethods, PyDictMethods};
12#[cfg(feature = "python")]
13use pyo3::types::{PyDict, PyString};
14#[cfg(feature = "python")]
15use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python};
16use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
17use roxmltree::Document;
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::collections::HashMap;
21#[cfg(feature = "python")]
22use std::convert::Infallible;
23use std::fs::File;
24use std::io::Read;
25use std::str::FromStr;
26use std::sync::{Arc, Mutex};
27use std::{fmt, fs};
28use std::path::PathBuf;
29
30/// Converting the `TimestampUnit` enum to a string representation for Polars breaks
31/// on "μs" when passing from rust to python. We handle that here by converting it to "us".
32
33#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
34pub enum TimestampUnit {
35    Ms,
36    Us,
37    Ns,
38}
39
40impl FromStr for TimestampUnit {
41    type Err = String;
42    fn from_str(s: &str) -> Result<Self, Self::Err> {
43        match s {
44            "ns" => Ok(TimestampUnit::Ns),
45            "ms" => Ok(TimestampUnit::Ms),
46            "us" => Ok(TimestampUnit::Us),
47            "μs" => Ok(TimestampUnit::Us),
48            _ => Err(format!("Invalid precision: {}. Available: ms, us, ns", s)),
49        }
50    }
51}
52
53#[cfg(feature = "python")]
54impl<'py> IntoPyObject<'py> for TimestampUnit {
55    type Target = <&'py str as IntoPyObject<'py>>::Target;
56    type Output = <&'py str as IntoPyObject<'py>>::Output;
57    type Error = Infallible;
58
59    fn into_pyobject(self, py: Python<'py>) -> Result<pyo3::Bound<'py, PyString>, Infallible> {
60        let s = match self {
61            TimestampUnit::Ms => "ms",
62            TimestampUnit::Us => "us",
63            TimestampUnit::Ns => "ns",
64        };
65        s.into_pyobject(py)
66    }
67}
68
69impl fmt::Display for TimestampUnit {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        write!(f, "{:?}", self)
72    }
73}
74
75#[derive(Serialize, Deserialize, Debug, Clone)]
76#[cfg_attr(feature = "python", derive(IntoPyObject))]
77pub struct TimestampOptions {
78    pub time_unit: Option<TimestampUnit>,
79    pub time_zone: Option<String>,
80}
81
82#[cfg(feature = "python")]
83impl<'source> FromPyObject<'source> for TimestampUnit {
84    fn extract_bound(bound: &pyo3::Bound<'source, PyAny>) -> PyResult<Self> {
85        let s: String = <String as FromPyObject>::extract_bound(bound)?;
86
87        // clippy advises to replace the closure with the function itself: `PyValueError::new_err`
88        // TimestampUnit::from_str(&s).map_err(|e| PyValueError::new_err(e))
89        TimestampUnit::from_str(&s).map_err(PyValueError::new_err)
90    }
91}
92
93#[cfg(feature = "python")]
94fn _get_extracted_string(dict: &Bound<PyDict>, key: &str) -> PyResult<Option<String>> {
95    if let Some(item) = dict.get_item(key)? {
96        Ok(Some(item.extract()?))
97    } else {
98        Ok(None)
99    }
100}
101
102#[cfg(feature = "python")]
103impl<'source> FromPyObject<'source> for TimestampOptions {
104    fn extract_bound(bound: &pyo3::Bound<'source, PyAny>) -> PyResult<Self> {
105        let obj = bound;
106        let dict = obj.downcast::<PyDict>()?;
107
108        let time_unit: Option<String> = _get_extracted_string(dict, "time_unit")?;
109        let time_zone: Option<String> = _get_extracted_string(dict, "time_zone")?;
110
111        let time_unit = match time_unit {
112            Some(s) => Some(s.parse().map_err(|e: String| PyValueError::new_err(e))?),
113            None => None,
114        };
115
116        Ok(TimestampOptions {
117            time_unit,
118            time_zone,
119        })
120    }
121}
122
123fn map_avro_data_type(dt: &str) -> AvroType {
124    match dt.to_lowercase().as_str() {
125        // primitive types
126        "null" | "xs:null" => AvroType::Simple("null".to_string()),
127        "boolean" | "xs:boolean" => AvroType::Simple("boolean".to_string()),
128        "int" | "xs:int" => AvroType::Simple("int".to_string()),
129        "long" | "xs:long" => AvroType::Simple("long".to_string()),
130        "float" | "xs:float" => AvroType::Simple("float".to_string()),
131        "double" | "xs:double" => AvroType::Simple("double".to_string()),
132        "bytes" | "xs:bytes" | "xs:base64binary" => AvroType::Simple("bytes".to_string()),
133        "string" | "xs:string" => AvroType::Simple("string".to_string()),
134
135        // logical types (these are typically built on top of primitive types)
136        "date" | "xs:date" => AvroType::Logical { base: "int".to_string(), logical: "date".to_string() },
137        "time-millis" | "xs:time" => AvroType::Logical { base: "int".to_string(), logical: "time-millis".to_string() },
138        "timestamp-millis" | "xs:datetime" => AvroType::Logical { base: "long".to_string(), logical: "timestamp-millis".to_string() },
139        "timestamp-micros" => AvroType::Logical { base: "long".to_string(), logical: "timestamp-micros".to_string() },
140
141        // complex types (placeholders – constructing full schemas for these requires extra info)
142        "array" => AvroType::Simple("array".to_string()),
143        "map" => AvroType::Simple("map".to_string()),
144        "record" => AvroType::Simple("record".to_string()),
145        "enum" => AvroType::Simple("enum".to_string()),
146        "fixed" => AvroType::Simple("fixed".to_string()),
147
148        // default to "string"
149        _ => AvroType::Simple("string".to_string()),
150    }
151}
152
153
154
155
156/// Parsing the XSD file and converting it to various formats begins with a Schema struct
157/// and a SchemaElement struct.  Currently supporting Arrow, Spark, JSON, JSON Schema,
158/// DuckDB, and Polars schemas.
159///
160
161#[derive(Serialize, Deserialize, Debug)]
162#[cfg_attr(feature = "python", derive(IntoPyObject))]
163pub struct Schema {
164    pub namespace: Option<String>,
165    #[serde(rename = "schemaElement")]
166    pub schema_element: SchemaElement,
167    pub timestamp_options: Option<TimestampOptions>,
168    pub doc: Option<String>,
169    pub custom_types: Option<IndexMap<String, SimpleType>>,
170}
171
172impl Schema {
173    pub fn new(
174        namespace: Option<String>,
175        schema_element: SchemaElement,
176        timestamp_options: Option<TimestampOptions>,
177        doc: Option<String>,
178        custom_types: Option<IndexMap<String, SimpleType>>,
179    ) -> Self {
180        Schema {
181            namespace,
182            schema_element,
183            timestamp_options,
184            doc,
185            custom_types,
186        }
187    }
188
189    pub fn to_avro(&self) -> Result<AvroSchema, Box<dyn std::error::Error>> {
190        let schema = AvroSchema {
191            schema_type: "record".to_string(),
192            name: self.schema_element.name.clone(),
193            namespace: self.namespace.clone(),
194            aliases: None, // Or derive aliases if available.
195            fields: self.schema_element.to_avro_fields(),
196            doc: None,
197        };
198        Ok(schema)
199    }
200
201    pub fn to_arrow(&self) -> Result<ArrowSchema, Box<dyn std::error::Error>> {
202        let fields = self
203            .schema_element
204            .elements
205            .par_iter()
206            .map(|element| {
207                Field::new(
208                    &element.name,
209                    element.to_arrow().unwrap(),
210                    element.nullable.unwrap_or(true),
211                )
212                .with_metadata(element.to_metadata())
213            })
214            .collect::<Vec<Field>>();
215
216        Ok(ArrowSchema::new(fields))
217    }
218
219    pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
220        let json_output = serde_json::to_string(&self).expect("Failed to serialize JSON");
221        Ok(json_output)
222    }
223
224    pub fn write_to_json_file(&self, output_file: &str) -> Result<(), Box<dyn std::error::Error>> {
225        let json_output = serde_json::to_string_pretty(&self).expect("Failed to serialize JSON");
226        fs::write(output_file, json_output).expect("Failed to write JSON");
227        Ok(())
228    }
229
230    pub fn to_spark(&self) -> Result<SparkSchema, Box<dyn std::error::Error>> {
231        let mut fields = vec![];
232
233        for element in &self.schema_element.elements {
234            fields.push(element.to_spark()?);
235        }
236
237        let schema = SparkSchema::new("struct".to_string(), fields);
238
239        Ok(schema)
240    }
241
242    pub fn to_json_schema(&self) -> serde_json::Value {
243        let mut fields = vec![];
244        let mut required = vec![];
245
246        for element in &self.schema_element.elements {
247            let (field, nullable) = element.to_json_schema();
248            fields.push(field);
249            if !nullable {
250                required.push(element.name.clone());
251            }
252        }
253
254        json!({
255            "$schema": "http://json-schema.org/draft-07/schema#",
256            "type": "object",
257            "properties": {
258                format!("{}", &self.schema_element.name): {
259                    "type": "object",
260                    "properties": fields.iter().map(|field| {
261                        let obj = field.as_object().unwrap();
262                        let (key, value) = obj.iter().next().unwrap(); // Assumes each field has one key
263                        json!({ "key": key, "value": value })
264                    }).collect::<Vec<_>>(),
265                }
266            },
267            "required": required
268        })
269    }
270
271    pub fn to_duckdb_schema(&self) -> IndexMap<String, String> {
272        // self.schema_element.to_duckdb_schema()
273        let mut columns = IndexMap::new();
274        for element in &self.schema_element.elements {
275            let mut element_columns = element.to_duckdb_schema();
276            columns.append(&mut element_columns);
277        }
278        columns
279    }
280
281    pub fn to_polars(&self) -> PolarsSchema {
282        let mut schema: PolarsSchema = Default::default();
283        let to = self.timestamp_options.clone();
284
285        for element in &self.schema_element.elements {
286            //let field = polars::datatypes::Field::new(PlSmallStr::from(&element.name), element.to_polars());
287            schema.insert(PlSmallStr::from(&element.name), element.to_polars(&to));
288        }
289        schema
290    }
291}
292
293#[derive(Serialize, Deserialize, Debug, Clone)]
294#[cfg_attr(feature = "python", derive(IntoPyObject))]
295pub struct SchemaElement {
296    pub id: String,
297    pub name: String,
298    pub documentation: Option<String>,
299    #[serde(rename = "dataType")]
300    pub data_type: Option<String>,
301    #[serde(rename = "minOccurs")]
302    pub min_occurs: Option<String>,
303    #[serde(rename = "maxOccurs")]
304    pub max_occurs: Option<String>,
305    #[serde(rename = "minLength")]
306    pub min_length: Option<String>,
307    #[serde(rename = "maxLength")]
308    pub max_length: Option<String>,
309    #[serde(rename = "minExclusive")]
310    pub min_exclusive: Option<String>,
311    #[serde(rename = "maxExclusive")]
312    pub max_exclusive: Option<String>,
313    #[serde(rename = "minInclusive")]
314    pub min_inclusive: Option<String>,
315    #[serde(rename = "maxInclusive")]
316    pub max_inclusive: Option<String>,
317    pub pattern: Option<String>,
318    #[serde(rename = "fractionDigits")]
319    pub fraction_digits: Option<String>,
320    #[serde(rename = "totalDigits")]
321    pub total_digits: Option<String>,
322    pub values: Option<Vec<String>>,
323    #[serde(rename = "isCurrency")]
324    pub is_currency: bool,
325    pub xpath: String,
326    pub nullable: Option<bool>,
327    pub elements: Vec<SchemaElement>,
328}
329
330impl SchemaElement {
331    pub fn to_metadata(&self) -> HashMap<String, String> {
332        let mut metadata = HashMap::new();
333
334        if let Some(ref max_occurs) = self.max_occurs {
335            metadata.insert("maxOccurs".to_string(), max_occurs.clone());
336        }
337        if let Some(ref min_length) = self.min_length {
338            metadata.insert("minLength".to_string(), min_length.clone());
339        }
340        if let Some(ref max_length) = self.max_length {
341            metadata.insert("maxLength".to_string(), max_length.clone());
342        }
343        if let Some(ref min_exclusive) = self.min_exclusive {
344            metadata.insert("minExclusive".to_string(), min_exclusive.clone());
345        }
346        if let Some(ref max_exclusive) = self.max_exclusive {
347            metadata.insert("maxExclusive".to_string(), max_exclusive.clone());
348        }
349        if let Some(ref min_inclusive) = self.min_inclusive {
350            metadata.insert("minInclusive".to_string(), min_inclusive.clone());
351        }
352        if let Some(ref max_inclusive) = self.max_inclusive {
353            metadata.insert("maxInclusive".to_string(), max_inclusive.clone());
354        }
355        if let Some(ref pattern) = self.pattern {
356            metadata.insert("pattern".to_string(), pattern.clone());
357        }
358        if let Some(ref values) = self.values {
359            // have to join the vector of values into a single comma-separated string
360            metadata.insert("values".to_string(), values.join(","));
361        }
362        // may want to add explicitly, check with xsd specification
363        if self.is_currency {
364            metadata.insert("isCurrency".to_string(), self.is_currency.to_string());
365        }
366
367        metadata
368    }
369
370    pub fn to_arrow(&self) -> Result<DataType, Box<dyn std::error::Error>> {
371        if let Some(ref data_type) = self.data_type {
372            match data_type.as_str() {
373                "string" => Ok(DataType::Utf8),
374                "integer" => Ok(DataType::Int32),
375                "decimal" => match (&self.total_digits, &self.fraction_digits) {
376                    (Some(precision), Some(scale)) => Ok(DataType::Decimal128(
377                        precision.parse::<u8>().unwrap(),
378                        scale.parse::<i8>().unwrap(),
379                    )),
380                    _ => Ok(DataType::Float64),
381                },
382                "boolean" => Ok(DataType::Boolean),
383                "date" => Ok(DataType::Date32),
384                "dateTime" => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
385
386                _ => Ok(DataType::Utf8),
387            }
388        } else {
389            Ok(DataType::Utf8)
390        }
391    }
392
393    pub fn to_spark(&self) -> Result<SparkField, Box<dyn std::error::Error>> {
394        let field_type = match &self.data_type.as_deref() {
395            Some("decimal") => {
396                if let (Some(total_digits), Some(fraction_digits)) = (
397                    &self.total_digits.as_deref(),
398                    &self.fraction_digits.as_deref(),
399                ) {
400                    let precision = total_digits.parse::<u32>().unwrap_or(0);
401                    let scale = fraction_digits.parse::<u32>().unwrap_or(0);
402                    format!("decimal({}, {})", precision, scale)
403                } else {
404                    "decimal".to_string()
405                }
406            }
407            Some("int") | Some("integer") => "integer".to_string(),
408            Some("long") => "long".to_string(),
409            Some("float") => "float".to_string(),
410            Some("double") => "double".to_string(),
411            Some("boolean") => "boolean".to_string(),
412            Some("dateTime") => "timestamp".to_string(),
413            Some("date") => "date".to_string(),
414            Some("string") => "string".to_string(),
415            Some(other) => other.to_string(), // todo: should we really just pass through the provided type?
416            None => "string".to_string(),
417        };
418
419        let field = SparkField {
420            field_name: self.name.clone(),
421            field_type,
422            nullable: self.nullable.unwrap_or(true),
423            metadata: Some(self.to_metadata()),
424        };
425
426        Ok(field)
427    }
428
429    fn to_json_schema(&self) -> (serde_json::Value, bool) {
430        let mut field_type = serde_json::Map::new();
431        let base_type = match self.data_type.as_deref() {
432            Some("string") => json!("string"),
433            Some("integer") => json!("integer"),
434            Some("decimal") => json!("number"),
435            Some("date") => json!("string"),
436            Some("dateTime") => json!("string"),
437            _ => json!("string"),
438        };
439
440        let final_type = if self.nullable == Some(true) {
441            json!([base_type, "null"])
442        } else {
443            base_type
444        };
445
446        field_type.insert("type".to_string(), final_type);
447
448        if let Some(max_length) = &self.max_length {
449            field_type.insert(
450                "maxLength".to_string(),
451                json!(max_length.parse::<u64>().unwrap_or(255)),
452            );
453        }
454        if let Some(min_length) = &self.min_length {
455            field_type.insert(
456                "minLength".to_string(),
457                json!(min_length.parse::<u64>().unwrap_or(0)),
458            );
459        }
460        if let Some(pattern) = &self.pattern {
461            field_type.insert("pattern".to_string(), json!(pattern));
462        }
463        if let Some(values) = &self.values {
464            field_type.insert("enum".to_string(), json!(values));
465        }
466        if self.data_type.as_deref() == Some("decimal") {
467            if let (Some(fraction_digits), Some(total_digits)) = (
468                self.fraction_digits.as_deref(),
469                self.total_digits.as_deref(),
470            ) {
471                let fraction = fraction_digits.parse::<u64>().unwrap_or(0);
472                let total = total_digits.parse::<u64>().unwrap_or(0);
473                let multiple_of = 10f64.powi(-(fraction as i32));
474                let max_value = 10f64.powi(total as i32) - multiple_of;
475
476                field_type.insert("multipleOf".to_string(), json!(multiple_of));
477                field_type.insert("minimum".to_string(), json!(0));
478                field_type.insert("maximum".to_string(), json!(max_value));
479            }
480        }
481
482        (
483            json!({
484                &self.name: field_type
485
486            }),
487            self.nullable.unwrap_or(true),
488        )
489    }
490
491    pub fn to_avro_field(&self) -> AvroField {
492        let base_type = self.to_avro_type();
493        let field_type = if self.nullable.unwrap_or(false) {
494            AvroType::Union(vec![AvroType::Simple("null".to_string()), base_type])
495        } else {
496            base_type
497        };
498
499        AvroField {
500            name: self.name.clone(),
501            field_type,
502            doc: self.documentation.clone(),
503        }
504    }
505
506    pub fn to_avro_type(&self) -> AvroType {
507        if !self.elements.is_empty() {
508            let fields = self
509                .elements
510                .iter()
511                .map(|child| child.to_avro_field())
512                .collect();
513            let record = AvroSchema {
514                schema_type: "record".to_string(),
515                name: self.name.clone(),
516                namespace: None,
517                aliases: None,
518                doc: self.documentation.clone(),
519                fields,
520            };
521            AvroType::Record(record)
522        } else if let Some(symbols) = &self.values {
523            let avro_enum = AvroEnum {
524                schema_type: "enum".to_string(),
525                name: self.name.clone(),
526                symbols: symbols.clone(),
527                doc: self.documentation.clone(),
528                namespace: None,
529            };
530            AvroType::Enum(avro_enum)
531        } else if let Some(dt) = &self.data_type {
532            map_avro_data_type(dt)
533        } else {
534            // default type.
535            AvroType::Simple("string".to_string())
536        }
537    }
538
539    pub fn to_avro_fields(&self) -> Vec<AvroField> {
540        self.elements
541            .iter()
542            .map(|child| child.to_avro_field())
543            .collect()
544    }
545
546    fn to_duckdb_schema(&self) -> IndexMap<String, String> {
547        let mut columns = IndexMap::new();
548
549        let column_type = match self.data_type.as_deref() {
550            Some("string") => format!("VARCHAR({})", self.max_length.as_deref().unwrap_or("255")),
551            Some("integer") => "INTEGER".to_string(),
552            Some("decimal") => {
553                let precision = self.total_digits.as_deref().unwrap_or("25");
554                let scale = self.fraction_digits.as_deref().unwrap_or("7");
555                format!("DECIMAL({}, {})", precision, scale)
556            }
557            Some("date") => "DATE".to_string(),
558            Some("dateTime") => "TIMESTAMP".to_string(),
559            _ => "VARCHAR(255)".to_string(),
560        };
561
562        columns.insert(self.name.clone(), column_type);
563
564        columns
565    }
566
567    fn to_polars(&self, timestamp_options: &Option<TimestampOptions>) -> PolarsDataType {
568        match self.data_type.as_deref() {
569            None => PolarsDataType::String,
570            Some("string") => PolarsDataType::String,
571            Some("int") | Some("integer") => PolarsDataType::Int64,
572            Some("float") | Some("double") => PolarsDataType::Float64,
573            Some("boolean") | Some("bool") => PolarsDataType::Boolean,
574            Some("date") => PolarsDataType::Date,
575            Some("datetime") | Some("dateTime") => {
576                let time_unit = timestamp_options
577                    .as_ref()
578                    .and_then(|options| options.time_unit.as_ref())
579                    .map(|unit| match unit {
580                        TimestampUnit::Ms => PolarsTimeUnit::Milliseconds,
581                        TimestampUnit::Us => PolarsTimeUnit::Microseconds,
582                        TimestampUnit::Ns => PolarsTimeUnit::Nanoseconds,
583                    })
584                    .unwrap_or(PolarsTimeUnit::Nanoseconds);
585                let timezone = timestamp_options
586                    .as_ref()
587                    .and_then(|options| options.time_zone.as_ref())
588                    .map(|s| s.into());
589                PolarsDataType::Datetime(time_unit, timezone)
590            }
591            Some("time") => PolarsDataType::Time,
592            Some("decimal") => {
593                // parsing the total_digits as precision and fraction_digits as scale
594                // fallback to defaults if parsing fails: 38|10
595                let precision = self
596                    .total_digits
597                    .as_ref()
598                    .and_then(|s| s.parse::<usize>().ok())
599                    .unwrap_or(38);
600                let scale = self
601                    .fraction_digits
602                    .as_ref()
603                    .and_then(|s| s.parse::<usize>().ok())
604                    .unwrap_or(10);
605                PolarsDataType::Decimal(Some(precision), Some(scale))
606            }
607            Some(other) => {
608                eprintln!(
609                    "Warning: Unrecognized data type '{}', defaulting to String.",
610                    other
611                );
612                PolarsDataType::String
613            }
614        }
615    }
616}
617
618#[derive(Serialize, Deserialize, Debug)]
619#[cfg_attr(feature = "python", derive(IntoPyObject))]
620pub struct AvroSchema {
621    #[serde(rename = "type")]
622    #[cfg_attr(feature = "python", pyo3(item("type")))]
623    pub schema_type: String,
624    pub name: String,
625    #[serde(skip_serializing_if = "Option::is_none")]
626    pub doc: Option<String>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    pub aliases: Option<Vec<String>>,
629    pub fields: Vec<AvroField>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub namespace: Option<String>,
632}
633
634#[derive(Serialize, Deserialize, Debug)]
635#[cfg_attr(feature = "python", derive(IntoPyObject))]
636pub struct AvroField {
637    pub name: String,
638    #[serde(rename = "type")]
639    #[cfg_attr(feature = "python", pyo3(item("type")))]
640    pub field_type: AvroType,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub doc: Option<String>,
643}
644
645// #[derive(Serialize, Deserialize, Debug, IntoPyObject)]
646// pub struct AvroLogical {
647//
648// }
649#[derive(Serialize, Deserialize, Debug)]
650#[cfg_attr(feature = "python", derive(IntoPyObject))]
651#[serde(untagged)]
652pub enum AvroType {
653    /// simple type (like "string", "int", etc.)
654    Simple(String),
655    /// union of types (ie. a nullable field).
656    Union(Vec<AvroType>),
657    /// inline record.
658    Record(AvroSchema),
659    Enum(AvroEnum),
660    Logical {
661        #[serde(rename = "type")]
662        #[cfg_attr(feature = "python", pyo3(item("type")))]
663        base: String,
664        #[serde(rename = "logicalType")]
665        #[cfg_attr(feature = "python", pyo3(item("logicalType")))]
666        logical: String },
667}
668
669#[derive(Serialize, Deserialize, Debug)]
670#[cfg_attr(feature = "python", derive(IntoPyObject))]
671pub struct AvroEnum {
672    #[serde(rename = "type")]
673    #[cfg_attr(feature = "python", pyo3(item("type")))]
674    pub schema_type: String, // must be "enum"
675    #[serde(skip_serializing_if = "Option::is_none")]
676    pub doc: Option<String>,
677    pub name: String,
678    pub symbols: Vec<String>,
679    #[serde(skip_serializing_if = "Option::is_none")]
680    pub namespace: Option<String>,
681}
682
683#[derive(Serialize, Deserialize, Debug)]
684#[cfg_attr(feature = "python", derive(IntoPyObject))]
685pub struct SparkSchema {
686    #[serde(rename = "type")]
687    pub schema_type: String,
688    pub fields: Vec<SparkField>,
689}
690
691impl SparkSchema {
692    pub fn new(schema_type: String, fields: Vec<SparkField>) -> Self {
693        SparkSchema {
694            schema_type,
695            fields,
696        }
697    }
698
699    pub fn to_json(&self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
700        let json_output = serde_json::to_value(self).expect("Failed to serialize JSON");
701        Ok(json_output)
702    }
703}
704
705#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
706#[cfg_attr(feature = "python", derive(IntoPyObject))]
707pub struct SparkField {
708    #[serde(rename = "name")]
709    #[cfg_attr(feature = "python", pyo3(item("name")))]
710    pub field_name: String,
711    #[serde(rename = "type")]
712    #[cfg_attr(feature = "python", pyo3(item("type")))]
713    pub field_type: String,
714    pub nullable: bool,
715    pub metadata: Option<HashMap<String, String>>,
716}
717
718impl SparkField {
719    pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
720        let json_output = serde_json::to_string(&self).expect("Failed to serialize JSON");
721        Ok(json_output)
722    }
723}
724
725#[derive(Debug, Deserialize, Serialize, Clone)]
726#[cfg_attr(feature = "python", derive(IntoPyObject))]
727pub struct SimpleType {
728    name: Option<String>,
729    data_type: Option<String>,
730    min_length: Option<String>,
731    max_length: Option<String>,
732    min_inclusive: Option<String>,
733    max_inclusive: Option<String>,
734    min_exclusive: Option<String>,
735    max_exclusive: Option<String>,
736    fraction_digits: Option<String>,
737    total_digits: Option<String>,
738    pattern: Option<String>,
739    values: Option<Vec<String>>,
740    nullable: Option<bool>,
741    documentation: Option<String>,
742}
743
744
745// xs:enumeration
746fn extract_enum_values(node: roxmltree::Node) -> Option<Vec<String>> {
747    let mut values = Vec::new();
748    for child in node.children() {
749        if child.tag_name().name() == "enumeration" {
750            if let Some(value) = child.attribute("value") {
751                values.push(value.to_string());
752            }
753        }
754    }
755    if values.is_empty() {
756        None
757    } else {
758        Some(values)
759    }
760}
761
762// xs:annotation
763fn extract_documentation(node: roxmltree::Node) -> Option<String> {
764    for child in node.children() {
765        if child.tag_name().name() == "documentation" {
766            //dbg!(&child.text().map(String::from));
767            return child.text().map(String::from);
768        }
769    }
770
771    None
772}
773
774// from xs:restriction
775fn extract_constraints(node: roxmltree::Node) -> SimpleType {
776    let mut simple_type = SimpleType {
777        name: None,
778        data_type: node.attribute("base").map(|s| s.replace("xs:", "")),
779        min_length: None,
780        max_length: None,
781        min_inclusive: None,
782        max_inclusive: None,
783        min_exclusive: None,
784        max_exclusive: None,
785        fraction_digits: None,
786        total_digits: None,
787        pattern: None,
788        values: extract_enum_values(node),
789        nullable: None,
790        documentation: extract_documentation(node),
791    };
792
793    for child in node.children() {
794        match child.tag_name().name() {
795            "minLength" => simple_type.min_length = child.attribute("value").map(String::from),
796            "maxLength" => simple_type.max_length = child.attribute("value").map(String::from),
797            "minInclusive" => {
798                simple_type.min_inclusive = child.attribute("value").map(String::from)
799            }
800            "maxInclusive" => {
801                simple_type.max_inclusive = child.attribute("value").map(String::from)
802            }
803            "minExclusive" => {
804                simple_type.min_exclusive = child.attribute("value").map(String::from)
805            }
806            "maxExclusive" => {
807                simple_type.max_exclusive = child.attribute("value").map(String::from)
808            }
809            "fractionDigits" => {
810                simple_type.fraction_digits = child.attribute("value").map(String::from)
811            }
812            "totalDigits" => simple_type.total_digits = child.attribute("value").map(String::from),
813            "pattern" => simple_type.pattern = child.attribute("value").map(String::from),
814            "nullable" => simple_type.nullable = Some(true),
815            _ => {}
816        }
817    }
818    simple_type
819}
820
821fn parse_element(
822    node: roxmltree::Node,
823    parent_xpath: &str,
824    global_types: &IndexMap<String, SimpleType>,
825    lowercase: Option<bool>,
826) -> Option<SchemaElement> {
827    if node.tag_name().name() != "element" {
828        return None;
829    }
830
831    let mut name = node.attribute("name")?.to_string();
832    if lowercase.is_some() && lowercase.unwrap() {
833        name = name.to_lowercase();
834    }
835
836    let nullable = node.attribute("nillable").map(|s| s == "true");
837    let xpath = format!("{}/{}", parent_xpath, name);
838    let mut data_type = node.attribute("type").map(|s| s.replace("xs:", ""));
839    let min_occurs = match node.attribute("minOccurs") {
840        None => Some("1".to_string()),
841        Some(m) => Some(m.to_string()),
842    };
843
844    let max_occurs = match node.attribute("maxOccurs") {
845        Some(m) => Some(m.to_string()),
846        None => Some("1".to_string()),
847    };
848
849    let mut documentation = None;
850
851    let mut min_length = None;
852    let mut max_length = None;
853    let mut min_inclusive = None;
854    let mut max_inclusive = None;
855    let mut min_exclusive = None;
856    let mut max_exclusive = None;
857    let mut fraction_digits = None;
858    let mut total_digits = None;
859    let mut pattern = None;
860    let mut values = None;
861    let mut elements = Vec::new();
862
863    if let Some(ref type_name) = data_type {
864        if let Some(global_type) = global_types.get(type_name) {
865            min_length = global_type.min_length.clone();
866            max_length = global_type.max_length.clone();
867            min_inclusive = global_type.min_inclusive.clone();
868            max_inclusive = global_type.max_inclusive.clone();
869            min_exclusive = global_type.min_exclusive.clone();
870            max_exclusive = global_type.max_exclusive.clone();
871            fraction_digits = global_type.fraction_digits.clone();
872            total_digits = global_type.total_digits.clone();
873            pattern = global_type.pattern.clone();
874            values = global_type.values.clone();
875            data_type = global_type.data_type.clone();
876            documentation = global_type.documentation.clone();
877        }
878    }
879
880    for child in node.children() {
881        match child.tag_name().name() {
882            "simpleType" => {
883                for subchild in child.children() {
884                    if subchild.tag_name().name() == "restriction" {
885                        let simple_type = extract_constraints(subchild);
886                        if simple_type.data_type.is_some() {
887                            data_type = simple_type.data_type;
888                        }
889                        min_length = simple_type.min_length;
890                        max_length = simple_type.max_length;
891                        min_inclusive = simple_type.min_inclusive;
892                        max_inclusive = simple_type.max_inclusive;
893                        min_exclusive = simple_type.min_exclusive;
894                        max_exclusive = simple_type.max_exclusive;
895                        fraction_digits = simple_type.fraction_digits;
896                        total_digits = simple_type.total_digits;
897                        pattern = simple_type.pattern;
898                        values = simple_type.values;
899                    }
900                }
901            }
902            "complexType" => {
903                //let q = extract_documentation(child);
904
905                for subchild in child.descendants() {
906                    if let Some(sub_element) = parse_element(subchild, &xpath, global_types, lowercase) {
907                        elements.push(sub_element);
908                    }
909                }
910            }
911            _ => {}
912        }
913    }
914
915    let is_currency = name == "Currency";
916
917    Some(SchemaElement {
918        id: name.clone(),
919        name,
920        data_type,
921        min_occurs,
922        max_occurs,
923        min_length,
924        max_length,
925        min_inclusive,
926        max_inclusive,
927        min_exclusive,
928        max_exclusive,
929        pattern,
930        fraction_digits,
931        total_digits,
932        values,
933        is_currency,
934        xpath,
935        nullable,
936        elements,
937        documentation,
938    })
939}
940
941pub fn read_xsd_file(xsd_file: PathBuf, encoding: Option<&'static Encoding>) -> Result<String, Box<dyn std::error::Error>> {
942    let parsed_file = File::open(xsd_file);
943    //.expect("Failed to read XSD file");
944
945    if let Err(e) = parsed_file {
946        return Err(format!("Failed to read XSD file: {}", e).into());
947    }
948
949    let file = parsed_file.unwrap();
950
951    let use_encoding = encoding.unwrap_or(UTF_8);
952
953    let mut transcode_reader = DecodeReaderBytesBuilder::new()
954        .encoding(Some(use_encoding))
955        .build(file);
956
957    let mut xml_content = String::new();
958    transcode_reader.read_to_string(&mut xml_content)?;
959
960    Ok(xml_content)
961}
962
963pub fn parse_xsd_string(xsd_string: &str, timestamp_options: Option<TimestampOptions>, lowercase: Option<bool>) -> Result<Schema, Box<dyn std::error::Error>> {
964    let parse_doc = Document::parse(xsd_string);
965
966    if let Err(e) = parse_doc {
967        return Err(format!("Failed to parse XML: {}. Maybe try a different encoding (utf-16 ?).", e).into());
968    }
969
970    let doc = parse_doc.unwrap();
971    let mut schema_doc: Option<String> = None;
972
973
974    let global_types = Arc::new(Mutex::new(IndexMap::new()));
975
976    doc.root().descendants().for_each(|node| {
977        if node.tag_name().name() == "simpleType" {
978            if let Some(name) = node.attribute("name") {
979
980                let mut doc = None;
981                for child in node.children() {
982                    if child.tag_name().name() == "annotation" {
983                        //println!("----------> {:?}", &child);
984                        let d = extract_documentation(child);
985                        if let Some(o) = d {
986                            doc = Some(o.clone());
987                        }
988                    }
989                    if child.tag_name().name() == "restriction" {
990                        let mut map = global_types.lock().unwrap();
991                        let mut st = extract_constraints(child);
992                        st.name = Some(name.to_string());
993                        st.documentation = doc.clone();
994                        map.insert(name.to_string(), st);
995                    }
996                }
997            }
998        } else if node.tag_name().name() == "annotation" {
999            if schema_doc.is_none() {
1000                schema_doc = extract_documentation(node);
1001            }
1002        }
1003    });
1004
1005    let final_map = Arc::try_unwrap(global_types)
1006        .expect("Arc should have no other refs")
1007        .into_inner()
1008        .expect("Mutex should be unlocked");
1009
1010    let mut schema_element = None;
1011
1012    for node in doc.root().descendants() {
1013        if node.tag_name().name() == "element" {
1014            let mut element_name = "".to_string();
1015            if let Some(name) = node.attribute("name"){
1016                if lowercase.is_some() && lowercase.unwrap() {
1017                    element_name = name.to_lowercase();
1018                } else {
1019                    element_name = name.to_string();
1020                }
1021            }
1022            schema_element = parse_element(node, &element_name, &final_map, lowercase);
1023            break;
1024        }
1025    }
1026
1027    let mut custom_types_vec: Vec<_> = final_map.into_iter().collect();
1028    custom_types_vec.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));
1029    let final_map: IndexMap<_, _> = custom_types_vec.into_iter().collect();
1030
1031    if let Some(schema_element) = schema_element {
1032        let schema = Schema {
1033            namespace: None,
1034            schema_element,
1035            timestamp_options,
1036            doc: schema_doc,
1037            custom_types: Some(final_map),
1038        };
1039
1040        Ok(schema)
1041    } else {
1042        Err("Failed to find the main schema element in the XSD.".into())
1043    }
1044}
1045
1046pub fn parse_file(xsd_file: PathBuf, timestamp_options: Option<TimestampOptions>,
1047                  encoding: Option<&'static Encoding>,
1048                  lowercase: Option<bool>) -> Result<Schema, Box<dyn std::error::Error>> {
1049    let xml_content = read_xsd_file(xsd_file, encoding)?;
1050
1051    parse_xsd_string(&xml_content, timestamp_options, lowercase)
1052}
1053
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058    use std::fs::File;
1059    use std::io::Write;
1060    use tempfile::tempdir;
1061
1062    fn create_test_schema() -> Schema {
1063        let element1 = SchemaElement {
1064            id: "id".to_string(),
1065            name: "field1".to_string(),
1066            data_type: Some("string".to_string()),
1067            min_occurs: Some("1".to_string()),
1068            max_occurs: Some("1".to_string()),
1069            min_length: None,
1070            max_length: None,
1071            min_inclusive: None,
1072            max_inclusive: None,
1073            min_exclusive: None,
1074            max_exclusive: None,
1075            pattern: None,
1076            fraction_digits: None,
1077            total_digits: None,
1078            values: None,
1079            is_currency: false,
1080            xpath: "/name".to_string(),
1081            nullable: Some(false),
1082            elements: vec![],
1083            documentation: Some("This is the first test field".to_string()),
1084        };
1085
1086        let element2 = SchemaElement {
1087            id: "id".to_string(),
1088            name: "field2".to_string(),
1089            data_type: Some("string".to_string()),
1090            min_occurs: Some("1".to_string()),
1091            max_occurs: Some("1".to_string()),
1092            min_length: None,
1093            max_length: None,
1094            min_inclusive: None,
1095            max_inclusive: None,
1096            min_exclusive: None,
1097            max_exclusive: None,
1098            pattern: None,
1099            fraction_digits: None,
1100            total_digits: None,
1101            values: None,
1102            is_currency: false,
1103            xpath: "/name".to_string(),
1104            nullable: Some(true),
1105            elements: vec![],
1106            documentation: Some("This is the second test field".to_string()),
1107        };
1108        let element3 = SchemaElement {
1109            id: "id".to_string(),
1110            name: "field3".to_string(),
1111            data_type: Some("string".to_string()),
1112            min_occurs: Some("1".to_string()),
1113            max_occurs: Some("1".to_string()),
1114            min_length: None,
1115            max_length: None,
1116            min_inclusive: None,
1117            max_inclusive: None,
1118            min_exclusive: None,
1119            max_exclusive: None,
1120            pattern: None,
1121            fraction_digits: None,
1122            total_digits: None,
1123            values: None,
1124            is_currency: false,
1125            xpath: "/name".to_string(),
1126            nullable: Some(true),
1127            elements: vec![],
1128            documentation: Some("This is the third and last test field".to_string()),
1129        };
1130
1131        let schema = Schema {
1132            namespace: None,
1133            schema_element: SchemaElement {
1134                id: "id".to_string(),
1135                name: "main_schema".to_string(),
1136                data_type: Some("string".to_string()),
1137                min_occurs: Some("1".to_string()),
1138                max_occurs: Some("1".to_string()),
1139                min_length: None,
1140                max_length: None,
1141                min_inclusive: None,
1142                max_inclusive: None,
1143                min_exclusive: None,
1144                max_exclusive: None,
1145                pattern: None,
1146                fraction_digits: None,
1147                total_digits: None,
1148                values: None,
1149                is_currency: false,
1150                xpath: "/name".to_string(),
1151                nullable: Some(true),
1152                elements: vec![element1, element2, element3],
1153                documentation: Some("This is the main schema".to_string()),
1154            },
1155            timestamp_options: None,
1156            doc: Some("TestSchema".to_string()),
1157            custom_types: None,
1158        };
1159
1160        schema
1161    }
1162
1163    #[test]
1164    fn test_timestamp_unit_from_str() {
1165        assert_eq!(TimestampUnit::from_str("ns").unwrap(), TimestampUnit::Ns);
1166        assert_eq!(TimestampUnit::from_str("ms").unwrap(), TimestampUnit::Ms);
1167        assert_eq!(TimestampUnit::from_str("us").unwrap(), TimestampUnit::Us);
1168        assert!(TimestampUnit::from_str("invalid").is_err());
1169    }
1170
1171    #[test]
1172    fn test_schema_to_arrow() {
1173        let schema = create_test_schema();
1174        let arrow_schema = schema.to_arrow().unwrap();
1175        assert_eq!(arrow_schema.fields().len(), 3);
1176        assert_eq!(arrow_schema.field(0).name(), "field1");
1177        assert_eq!(schema.doc, Some("TestSchema".to_string()));
1178    }
1179
1180    #[test]
1181    fn test_parse_file() {
1182        let dir = tempdir().unwrap();
1183        let file_path = dir.path().join("test.xsd");
1184        let mut file = File::create(&file_path).unwrap();
1185        writeln!(
1186            file,
1187            r#"
1188            <schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1189                <xs:element name="testElement" type="xs:string"/>
1190            </schema>
1191            "#
1192        )
1193        .unwrap();
1194
1195        let schema = parse_file(file_path, None, None, Some(false)).unwrap();
1196        assert_eq!(schema.schema_element.name, "testElement");
1197    }
1198
1199    #[test]
1200    fn test_parse_file_lowercase() {
1201        let dir = tempdir().unwrap();
1202        let file_path = dir.path().join("test.xsd");
1203        let mut file = File::create(&file_path).unwrap();
1204        writeln!(
1205            file,
1206            r#"
1207            <schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1208                <xs:element name="testElement" type="xs:string"/>
1209            </schema>
1210            "#
1211        )
1212            .unwrap();
1213
1214        let schema = parse_file(file_path, None, None, Some(true)).unwrap();
1215        dbg!(&schema);
1216        assert_eq!(schema.schema_element.name, "testelement");
1217    }
1218    #[test]
1219    fn test_schema_element_to_arrow() {
1220        let schema = create_test_schema();
1221        let element = &schema.schema_element.elements[0];
1222
1223        let data_type = element.to_arrow().unwrap();
1224        assert_eq!(data_type, DataType::Utf8);
1225    }
1226
1227    #[test]
1228    fn test_schema_element_to_spark() {
1229        let schema = create_test_schema();
1230        let element = &schema.schema_element.elements[0];
1231
1232        let spark_field = element.to_spark().unwrap();
1233        assert_eq!(spark_field.field_name, "field1");
1234        assert_eq!(spark_field.field_type, "string");
1235    }
1236
1237    #[test]
1238    fn test_schema_element_to_json_schema() {
1239        let schema = create_test_schema();
1240        let element = &schema.schema_element.elements[0];
1241
1242        let (json_element, nullable) = element.to_json_schema();
1243        assert_eq!(
1244            json_element
1245                .get("field1")
1246                .and_then(|v| v.get("type"))
1247                .and_then(|v| v.as_str()),
1248            Some("string")
1249        );
1250        assert!(!nullable);
1251    }
1252
1253    #[test]
1254    fn test_schema_element_to_duckdb_schema() {
1255        let schema = create_test_schema();
1256        let element = &schema.schema_element.elements[0];
1257        let duckdb_schema = element.to_duckdb_schema();
1258        assert_eq!(
1259            duckdb_schema.get("field1").unwrap().to_string(),
1260            "VARCHAR(255)"
1261        );
1262    }
1263
1264    #[test]
1265    fn test_duckdb_schema_ordered() {
1266        let schema = create_test_schema();
1267
1268        let duckdb_schema = schema.to_duckdb_schema();
1269        dbg!(&duckdb_schema);
1270        let names = duckdb_schema
1271            .iter()
1272            .map(|x| x.0.clone())
1273            .collect::<Vec<_>>();
1274        assert_eq!(
1275            names,
1276            &[
1277                "field1".to_string(),
1278                "field2".to_string(),
1279                "field3".to_string()
1280            ]
1281        );
1282    }
1283
1284    // #[test]
1285    // fn test_schema_to_avro(){
1286    //     let schema = create_test_schema();
1287    //     let avro_schema = schema.to_avro().unwrap();
1288    //     assert_eq!(avro_schema.)
1289    // }
1290
1291    #[test]
1292    fn test_extract_enum_values() {
1293        let xml = r#"
1294            <restriction base="xs:string">
1295                <enumeration value="A"/>
1296                <enumeration value="B"/>
1297            </restriction>
1298        "#;
1299        let doc = Document::parse(xml).unwrap();
1300        let node = doc.root().first_child().unwrap();
1301        let values = extract_enum_values(node).unwrap();
1302        assert_eq!(values, vec!["A", "B"]);
1303    }
1304
1305    #[test]
1306    fn test_extract_constraints() {
1307        let xml = r#"
1308            <restriction base="xs:string">
1309                <minLength value="1"/>
1310                <maxLength value="255"/>
1311            </restriction>
1312        "#;
1313        let doc = Document::parse(xml).unwrap();
1314        let node = doc.root().first_child().unwrap();
1315        let constraints = extract_constraints(node);
1316        assert_eq!(constraints.min_length, Some("1".to_string()));
1317        assert_eq!(constraints.max_length, Some("255".to_string()));
1318    }
1319
1320    #[test]
1321    fn test_parse_element() {
1322        let xml = r#"
1323            <element name="testElement" type="xs:string"/>
1324        "#;
1325        let doc = Document::parse(xml).unwrap();
1326        let node = doc.root().first_child().unwrap();
1327        let element = parse_element(node, "", &IndexMap::new(), Some(false)).unwrap();
1328        assert_eq!(element.name, "testElement");
1329        assert_eq!(element.data_type, Some("string".to_string()));
1330    }
1331
1332    #[test]
1333    fn test_extract_documentation() {
1334        let xml = r#"
1335            <annotation>
1336                <documentation>This is a test element</documentation>
1337            </annotation>
1338        "#;
1339        let doc = Document::parse(xml).unwrap();
1340        let node = doc.root().first_child().unwrap();
1341        let documentation = extract_documentation(node);
1342        assert_eq!(documentation, Some("This is a test element".to_string()));
1343    }
1344
1345    #[test]
1346    fn test_avro_schema_serialization() {
1347        let schema = AvroSchema {
1348            schema_type: "record".to_string(),
1349            namespace: Some("example.avro".to_string()),
1350            name: "LongList".to_string(),
1351            doc: Some("Linked list of 64-bit longs.".to_string()),
1352            aliases: Some(vec!["LinkedLongs".to_string()]),
1353            fields: vec![
1354                AvroField {
1355                    name: "value".to_string(),
1356                    field_type: AvroType::Simple("long".to_string()),
1357                    doc: Some("The value of the node.".to_string()),
1358                },
1359                AvroField {
1360                    name: "next".to_string(),
1361                    field_type: AvroType::Union(vec![
1362                        AvroType::Simple("null".to_string()),
1363                        AvroType::Simple("LongList".to_string()),
1364                    ]),
1365                    doc: Some("The next node in the list.".to_string()),
1366                },
1367            ],
1368        };
1369        assert_eq!(schema.fields.len(), 2);
1370        assert_eq!(schema.doc, Some("Linked list of 64-bit longs.".to_string()));
1371    }
1372}