1use arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
2use encoding_rs::{Encoding, UTF_8};
3use encoding_rs_io::DecodeReaderBytesBuilder;
4use indexmap::IndexMap;
5use polars::datatypes::TimeUnit as PolarsTimeUnit;
6use polars::datatypes::{DataType as PolarsDataType, PlSmallStr};
7use polars::prelude::Schema as PolarsSchema;
8#[cfg(feature = "python")]
9use pyo3::exceptions::PyValueError;
10#[cfg(feature = "python")]
11use pyo3::prelude::{PyAnyMethods, PyDictMethods};
12#[cfg(feature = "python")]
13use pyo3::types::{PyDict, PyString};
14#[cfg(feature = "python")]
15use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python};
16use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
17use roxmltree::Document;
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::collections::HashMap;
21#[cfg(feature = "python")]
22use std::convert::Infallible;
23use std::fs::File;
24use std::io::Read;
25use std::str::FromStr;
26use std::sync::{Arc, Mutex};
27use std::{fmt, fs};
28use std::path::PathBuf;
29
30#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
34pub enum TimestampUnit {
35 Ms,
36 Us,
37 Ns,
38}
39
40impl FromStr for TimestampUnit {
41 type Err = String;
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 match s {
44 "ns" => Ok(TimestampUnit::Ns),
45 "ms" => Ok(TimestampUnit::Ms),
46 "us" => Ok(TimestampUnit::Us),
47 "μs" => Ok(TimestampUnit::Us),
48 _ => Err(format!("Invalid precision: {}. Available: ms, us, ns", s)),
49 }
50 }
51}
52
53#[cfg(feature = "python")]
54impl<'py> IntoPyObject<'py> for TimestampUnit {
55 type Target = <&'py str as IntoPyObject<'py>>::Target;
56 type Output = <&'py str as IntoPyObject<'py>>::Output;
57 type Error = Infallible;
58
59 fn into_pyobject(self, py: Python<'py>) -> Result<pyo3::Bound<'py, PyString>, Infallible> {
60 let s = match self {
61 TimestampUnit::Ms => "ms",
62 TimestampUnit::Us => "us",
63 TimestampUnit::Ns => "ns",
64 };
65 s.into_pyobject(py)
66 }
67}
68
69impl fmt::Display for TimestampUnit {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 write!(f, "{:?}", self)
72 }
73}
74
75#[derive(Serialize, Deserialize, Debug, Clone)]
76#[cfg_attr(feature = "python", derive(IntoPyObject))]
77pub struct TimestampOptions {
78 pub time_unit: Option<TimestampUnit>,
79 pub time_zone: Option<String>,
80}
81
82#[cfg(feature = "python")]
83impl<'source> FromPyObject<'source> for TimestampUnit {
84 fn extract_bound(bound: &pyo3::Bound<'source, PyAny>) -> PyResult<Self> {
85 let s: String = <String as FromPyObject>::extract_bound(bound)?;
86
87 TimestampUnit::from_str(&s).map_err(PyValueError::new_err)
90 }
91}
92
93#[cfg(feature = "python")]
94fn _get_extracted_string(dict: &Bound<PyDict>, key: &str) -> PyResult<Option<String>> {
95 if let Some(item) = dict.get_item(key)? {
96 Ok(Some(item.extract()?))
97 } else {
98 Ok(None)
99 }
100}
101
102#[cfg(feature = "python")]
103impl<'source> FromPyObject<'source> for TimestampOptions {
104 fn extract_bound(bound: &pyo3::Bound<'source, PyAny>) -> PyResult<Self> {
105 let obj = bound;
106 let dict = obj.downcast::<PyDict>()?;
107
108 let time_unit: Option<String> = _get_extracted_string(dict, "time_unit")?;
109 let time_zone: Option<String> = _get_extracted_string(dict, "time_zone")?;
110
111 let time_unit = match time_unit {
112 Some(s) => Some(s.parse().map_err(|e: String| PyValueError::new_err(e))?),
113 None => None,
114 };
115
116 Ok(TimestampOptions {
117 time_unit,
118 time_zone,
119 })
120 }
121}
122
123fn map_avro_data_type(dt: &str) -> AvroType {
124 match dt.to_lowercase().as_str() {
125 "null" | "xs:null" => AvroType::Simple("null".to_string()),
127 "boolean" | "xs:boolean" => AvroType::Simple("boolean".to_string()),
128 "int" | "xs:int" => AvroType::Simple("int".to_string()),
129 "long" | "xs:long" => AvroType::Simple("long".to_string()),
130 "float" | "xs:float" => AvroType::Simple("float".to_string()),
131 "double" | "xs:double" => AvroType::Simple("double".to_string()),
132 "bytes" | "xs:bytes" | "xs:base64binary" => AvroType::Simple("bytes".to_string()),
133 "string" | "xs:string" => AvroType::Simple("string".to_string()),
134
135 "date" | "xs:date" => AvroType::Logical { base: "int".to_string(), logical: "date".to_string() },
137 "time-millis" | "xs:time" => AvroType::Logical { base: "int".to_string(), logical: "time-millis".to_string() },
138 "timestamp-millis" | "xs:datetime" => AvroType::Logical { base: "long".to_string(), logical: "timestamp-millis".to_string() },
139 "timestamp-micros" => AvroType::Logical { base: "long".to_string(), logical: "timestamp-micros".to_string() },
140
141 "array" => AvroType::Simple("array".to_string()),
143 "map" => AvroType::Simple("map".to_string()),
144 "record" => AvroType::Simple("record".to_string()),
145 "enum" => AvroType::Simple("enum".to_string()),
146 "fixed" => AvroType::Simple("fixed".to_string()),
147
148 _ => AvroType::Simple("string".to_string()),
150 }
151}
152
153
154
155
156#[derive(Serialize, Deserialize, Debug)]
162#[cfg_attr(feature = "python", derive(IntoPyObject))]
163pub struct Schema {
164 pub namespace: Option<String>,
165 #[serde(rename = "schemaElement")]
166 pub schema_element: SchemaElement,
167 pub timestamp_options: Option<TimestampOptions>,
168 pub doc: Option<String>,
169 pub custom_types: Option<IndexMap<String, SimpleType>>,
170}
171
172impl Schema {
173 pub fn new(
174 namespace: Option<String>,
175 schema_element: SchemaElement,
176 timestamp_options: Option<TimestampOptions>,
177 doc: Option<String>,
178 custom_types: Option<IndexMap<String, SimpleType>>,
179 ) -> Self {
180 Schema {
181 namespace,
182 schema_element,
183 timestamp_options,
184 doc,
185 custom_types,
186 }
187 }
188
189 pub fn to_avro(&self) -> Result<AvroSchema, Box<dyn std::error::Error>> {
190 let schema = AvroSchema {
191 schema_type: "record".to_string(),
192 name: self.schema_element.name.clone(),
193 namespace: self.namespace.clone(),
194 aliases: None, fields: self.schema_element.to_avro_fields(),
196 doc: None,
197 };
198 Ok(schema)
199 }
200
201 pub fn to_arrow(&self) -> Result<ArrowSchema, Box<dyn std::error::Error>> {
202 let fields = self
203 .schema_element
204 .elements
205 .par_iter()
206 .map(|element| {
207 Field::new(
208 &element.name,
209 element.to_arrow().unwrap(),
210 element.nullable.unwrap_or(true),
211 )
212 .with_metadata(element.to_metadata())
213 })
214 .collect::<Vec<Field>>();
215
216 Ok(ArrowSchema::new(fields))
217 }
218
219 pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
220 let json_output = serde_json::to_string(&self).expect("Failed to serialize JSON");
221 Ok(json_output)
222 }
223
224 pub fn write_to_json_file(&self, output_file: &str) -> Result<(), Box<dyn std::error::Error>> {
225 let json_output = serde_json::to_string_pretty(&self).expect("Failed to serialize JSON");
226 fs::write(output_file, json_output).expect("Failed to write JSON");
227 Ok(())
228 }
229
230 pub fn to_spark(&self) -> Result<SparkSchema, Box<dyn std::error::Error>> {
231 let mut fields = vec![];
232
233 for element in &self.schema_element.elements {
234 fields.push(element.to_spark()?);
235 }
236
237 let schema = SparkSchema::new("struct".to_string(), fields);
238
239 Ok(schema)
240 }
241
242 pub fn to_json_schema(&self) -> serde_json::Value {
243 let mut fields = vec![];
244 let mut required = vec![];
245
246 for element in &self.schema_element.elements {
247 let (field, nullable) = element.to_json_schema();
248 fields.push(field);
249 if !nullable {
250 required.push(element.name.clone());
251 }
252 }
253
254 json!({
255 "$schema": "http://json-schema.org/draft-07/schema#",
256 "type": "object",
257 "properties": {
258 format!("{}", &self.schema_element.name): {
259 "type": "object",
260 "properties": fields.iter().map(|field| {
261 let obj = field.as_object().unwrap();
262 let (key, value) = obj.iter().next().unwrap(); json!({ "key": key, "value": value })
264 }).collect::<Vec<_>>(),
265 }
266 },
267 "required": required
268 })
269 }
270
271 pub fn to_duckdb_schema(&self) -> IndexMap<String, String> {
272 let mut columns = IndexMap::new();
274 for element in &self.schema_element.elements {
275 let mut element_columns = element.to_duckdb_schema();
276 columns.append(&mut element_columns);
277 }
278 columns
279 }
280
281 pub fn to_polars(&self) -> PolarsSchema {
282 let mut schema: PolarsSchema = Default::default();
283 let to = self.timestamp_options.clone();
284
285 for element in &self.schema_element.elements {
286 schema.insert(PlSmallStr::from(&element.name), element.to_polars(&to));
288 }
289 schema
290 }
291}
292
293#[derive(Serialize, Deserialize, Debug, Clone)]
294#[cfg_attr(feature = "python", derive(IntoPyObject))]
295pub struct SchemaElement {
296 pub id: String,
297 pub name: String,
298 pub documentation: Option<String>,
299 #[serde(rename = "dataType")]
300 pub data_type: Option<String>,
301 #[serde(rename = "minOccurs")]
302 pub min_occurs: Option<String>,
303 #[serde(rename = "maxOccurs")]
304 pub max_occurs: Option<String>,
305 #[serde(rename = "minLength")]
306 pub min_length: Option<String>,
307 #[serde(rename = "maxLength")]
308 pub max_length: Option<String>,
309 #[serde(rename = "minExclusive")]
310 pub min_exclusive: Option<String>,
311 #[serde(rename = "maxExclusive")]
312 pub max_exclusive: Option<String>,
313 #[serde(rename = "minInclusive")]
314 pub min_inclusive: Option<String>,
315 #[serde(rename = "maxInclusive")]
316 pub max_inclusive: Option<String>,
317 pub pattern: Option<String>,
318 #[serde(rename = "fractionDigits")]
319 pub fraction_digits: Option<String>,
320 #[serde(rename = "totalDigits")]
321 pub total_digits: Option<String>,
322 pub values: Option<Vec<String>>,
323 #[serde(rename = "isCurrency")]
324 pub is_currency: bool,
325 pub xpath: String,
326 pub nullable: Option<bool>,
327 pub elements: Vec<SchemaElement>,
328}
329
330impl SchemaElement {
331 pub fn to_metadata(&self) -> HashMap<String, String> {
332 let mut metadata = HashMap::new();
333
334 if let Some(ref max_occurs) = self.max_occurs {
335 metadata.insert("maxOccurs".to_string(), max_occurs.clone());
336 }
337 if let Some(ref min_length) = self.min_length {
338 metadata.insert("minLength".to_string(), min_length.clone());
339 }
340 if let Some(ref max_length) = self.max_length {
341 metadata.insert("maxLength".to_string(), max_length.clone());
342 }
343 if let Some(ref min_exclusive) = self.min_exclusive {
344 metadata.insert("minExclusive".to_string(), min_exclusive.clone());
345 }
346 if let Some(ref max_exclusive) = self.max_exclusive {
347 metadata.insert("maxExclusive".to_string(), max_exclusive.clone());
348 }
349 if let Some(ref min_inclusive) = self.min_inclusive {
350 metadata.insert("minInclusive".to_string(), min_inclusive.clone());
351 }
352 if let Some(ref max_inclusive) = self.max_inclusive {
353 metadata.insert("maxInclusive".to_string(), max_inclusive.clone());
354 }
355 if let Some(ref pattern) = self.pattern {
356 metadata.insert("pattern".to_string(), pattern.clone());
357 }
358 if let Some(ref values) = self.values {
359 metadata.insert("values".to_string(), values.join(","));
361 }
362 if self.is_currency {
364 metadata.insert("isCurrency".to_string(), self.is_currency.to_string());
365 }
366
367 metadata
368 }
369
370 pub fn to_arrow(&self) -> Result<DataType, Box<dyn std::error::Error>> {
371 if let Some(ref data_type) = self.data_type {
372 match data_type.as_str() {
373 "string" => Ok(DataType::Utf8),
374 "integer" => Ok(DataType::Int32),
375 "decimal" => match (&self.total_digits, &self.fraction_digits) {
376 (Some(precision), Some(scale)) => Ok(DataType::Decimal128(
377 precision.parse::<u8>().unwrap(),
378 scale.parse::<i8>().unwrap(),
379 )),
380 _ => Ok(DataType::Float64),
381 },
382 "boolean" => Ok(DataType::Boolean),
383 "date" => Ok(DataType::Date32),
384 "dateTime" => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
385
386 _ => Ok(DataType::Utf8),
387 }
388 } else {
389 Ok(DataType::Utf8)
390 }
391 }
392
393 pub fn to_spark(&self) -> Result<SparkField, Box<dyn std::error::Error>> {
394 let field_type = match &self.data_type.as_deref() {
395 Some("decimal") => {
396 if let (Some(total_digits), Some(fraction_digits)) = (
397 &self.total_digits.as_deref(),
398 &self.fraction_digits.as_deref(),
399 ) {
400 let precision = total_digits.parse::<u32>().unwrap_or(0);
401 let scale = fraction_digits.parse::<u32>().unwrap_or(0);
402 format!("decimal({}, {})", precision, scale)
403 } else {
404 "decimal".to_string()
405 }
406 }
407 Some("int") | Some("integer") => "integer".to_string(),
408 Some("long") => "long".to_string(),
409 Some("float") => "float".to_string(),
410 Some("double") => "double".to_string(),
411 Some("boolean") => "boolean".to_string(),
412 Some("dateTime") => "timestamp".to_string(),
413 Some("date") => "date".to_string(),
414 Some("string") => "string".to_string(),
415 Some(other) => other.to_string(), None => "string".to_string(),
417 };
418
419 let field = SparkField {
420 field_name: self.name.clone(),
421 field_type,
422 nullable: self.nullable.unwrap_or(true),
423 metadata: Some(self.to_metadata()),
424 };
425
426 Ok(field)
427 }
428
429 fn to_json_schema(&self) -> (serde_json::Value, bool) {
430 let mut field_type = serde_json::Map::new();
431 let base_type = match self.data_type.as_deref() {
432 Some("string") => json!("string"),
433 Some("integer") => json!("integer"),
434 Some("decimal") => json!("number"),
435 Some("date") => json!("string"),
436 Some("dateTime") => json!("string"),
437 _ => json!("string"),
438 };
439
440 let final_type = if self.nullable == Some(true) {
441 json!([base_type, "null"])
442 } else {
443 base_type
444 };
445
446 field_type.insert("type".to_string(), final_type);
447
448 if let Some(max_length) = &self.max_length {
449 field_type.insert(
450 "maxLength".to_string(),
451 json!(max_length.parse::<u64>().unwrap_or(255)),
452 );
453 }
454 if let Some(min_length) = &self.min_length {
455 field_type.insert(
456 "minLength".to_string(),
457 json!(min_length.parse::<u64>().unwrap_or(0)),
458 );
459 }
460 if let Some(pattern) = &self.pattern {
461 field_type.insert("pattern".to_string(), json!(pattern));
462 }
463 if let Some(values) = &self.values {
464 field_type.insert("enum".to_string(), json!(values));
465 }
466 if self.data_type.as_deref() == Some("decimal") {
467 if let (Some(fraction_digits), Some(total_digits)) = (
468 self.fraction_digits.as_deref(),
469 self.total_digits.as_deref(),
470 ) {
471 let fraction = fraction_digits.parse::<u64>().unwrap_or(0);
472 let total = total_digits.parse::<u64>().unwrap_or(0);
473 let multiple_of = 10f64.powi(-(fraction as i32));
474 let max_value = 10f64.powi(total as i32) - multiple_of;
475
476 field_type.insert("multipleOf".to_string(), json!(multiple_of));
477 field_type.insert("minimum".to_string(), json!(0));
478 field_type.insert("maximum".to_string(), json!(max_value));
479 }
480 }
481
482 (
483 json!({
484 &self.name: field_type
485
486 }),
487 self.nullable.unwrap_or(true),
488 )
489 }
490
491 pub fn to_avro_field(&self) -> AvroField {
492 let base_type = self.to_avro_type();
493 let field_type = if self.nullable.unwrap_or(false) {
494 AvroType::Union(vec![AvroType::Simple("null".to_string()), base_type])
495 } else {
496 base_type
497 };
498
499 AvroField {
500 name: self.name.clone(),
501 field_type,
502 doc: self.documentation.clone(),
503 }
504 }
505
506 pub fn to_avro_type(&self) -> AvroType {
507 if !self.elements.is_empty() {
508 let fields = self
509 .elements
510 .iter()
511 .map(|child| child.to_avro_field())
512 .collect();
513 let record = AvroSchema {
514 schema_type: "record".to_string(),
515 name: self.name.clone(),
516 namespace: None,
517 aliases: None,
518 doc: self.documentation.clone(),
519 fields,
520 };
521 AvroType::Record(record)
522 } else if let Some(symbols) = &self.values {
523 let avro_enum = AvroEnum {
524 schema_type: "enum".to_string(),
525 name: self.name.clone(),
526 symbols: symbols.clone(),
527 doc: self.documentation.clone(),
528 namespace: None,
529 };
530 AvroType::Enum(avro_enum)
531 } else if let Some(dt) = &self.data_type {
532 map_avro_data_type(dt)
533 } else {
534 AvroType::Simple("string".to_string())
536 }
537 }
538
539 pub fn to_avro_fields(&self) -> Vec<AvroField> {
540 self.elements
541 .iter()
542 .map(|child| child.to_avro_field())
543 .collect()
544 }
545
546 fn to_duckdb_schema(&self) -> IndexMap<String, String> {
547 let mut columns = IndexMap::new();
548
549 let column_type = match self.data_type.as_deref() {
550 Some("string") => format!("VARCHAR({})", self.max_length.as_deref().unwrap_or("255")),
551 Some("integer") => "INTEGER".to_string(),
552 Some("decimal") => {
553 let precision = self.total_digits.as_deref().unwrap_or("25");
554 let scale = self.fraction_digits.as_deref().unwrap_or("7");
555 format!("DECIMAL({}, {})", precision, scale)
556 }
557 Some("date") => "DATE".to_string(),
558 Some("dateTime") => "TIMESTAMP".to_string(),
559 _ => "VARCHAR(255)".to_string(),
560 };
561
562 columns.insert(self.name.clone(), column_type);
563
564 columns
565 }
566
567 fn to_polars(&self, timestamp_options: &Option<TimestampOptions>) -> PolarsDataType {
568 match self.data_type.as_deref() {
569 None => PolarsDataType::String,
570 Some("string") => PolarsDataType::String,
571 Some("int") | Some("integer") => PolarsDataType::Int64,
572 Some("float") | Some("double") => PolarsDataType::Float64,
573 Some("boolean") | Some("bool") => PolarsDataType::Boolean,
574 Some("date") => PolarsDataType::Date,
575 Some("datetime") | Some("dateTime") => {
576 let time_unit = timestamp_options
577 .as_ref()
578 .and_then(|options| options.time_unit.as_ref())
579 .map(|unit| match unit {
580 TimestampUnit::Ms => PolarsTimeUnit::Milliseconds,
581 TimestampUnit::Us => PolarsTimeUnit::Microseconds,
582 TimestampUnit::Ns => PolarsTimeUnit::Nanoseconds,
583 })
584 .unwrap_or(PolarsTimeUnit::Nanoseconds);
585 let timezone = timestamp_options
586 .as_ref()
587 .and_then(|options| options.time_zone.as_ref())
588 .map(|s| s.into());
589 PolarsDataType::Datetime(time_unit, timezone)
590 }
591 Some("time") => PolarsDataType::Time,
592 Some("decimal") => {
593 let precision = self
596 .total_digits
597 .as_ref()
598 .and_then(|s| s.parse::<usize>().ok())
599 .unwrap_or(38);
600 let scale = self
601 .fraction_digits
602 .as_ref()
603 .and_then(|s| s.parse::<usize>().ok())
604 .unwrap_or(10);
605 PolarsDataType::Decimal(Some(precision), Some(scale))
606 }
607 Some(other) => {
608 eprintln!(
609 "Warning: Unrecognized data type '{}', defaulting to String.",
610 other
611 );
612 PolarsDataType::String
613 }
614 }
615 }
616}
617
618#[derive(Serialize, Deserialize, Debug)]
619#[cfg_attr(feature = "python", derive(IntoPyObject))]
620pub struct AvroSchema {
621 #[serde(rename = "type")]
622 #[cfg_attr(feature = "python", pyo3(item("type")))]
623 pub schema_type: String,
624 pub name: String,
625 #[serde(skip_serializing_if = "Option::is_none")]
626 pub doc: Option<String>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 pub aliases: Option<Vec<String>>,
629 pub fields: Vec<AvroField>,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 pub namespace: Option<String>,
632}
633
634#[derive(Serialize, Deserialize, Debug)]
635#[cfg_attr(feature = "python", derive(IntoPyObject))]
636pub struct AvroField {
637 pub name: String,
638 #[serde(rename = "type")]
639 #[cfg_attr(feature = "python", pyo3(item("type")))]
640 pub field_type: AvroType,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub doc: Option<String>,
643}
644
645#[derive(Serialize, Deserialize, Debug)]
650#[cfg_attr(feature = "python", derive(IntoPyObject))]
651#[serde(untagged)]
652pub enum AvroType {
653 Simple(String),
655 Union(Vec<AvroType>),
657 Record(AvroSchema),
659 Enum(AvroEnum),
660 Logical {
661 #[serde(rename = "type")]
662 #[cfg_attr(feature = "python", pyo3(item("type")))]
663 base: String,
664 #[serde(rename = "logicalType")]
665 #[cfg_attr(feature = "python", pyo3(item("logicalType")))]
666 logical: String },
667}
668
669#[derive(Serialize, Deserialize, Debug)]
670#[cfg_attr(feature = "python", derive(IntoPyObject))]
671pub struct AvroEnum {
672 #[serde(rename = "type")]
673 #[cfg_attr(feature = "python", pyo3(item("type")))]
674 pub schema_type: String, #[serde(skip_serializing_if = "Option::is_none")]
676 pub doc: Option<String>,
677 pub name: String,
678 pub symbols: Vec<String>,
679 #[serde(skip_serializing_if = "Option::is_none")]
680 pub namespace: Option<String>,
681}
682
683#[derive(Serialize, Deserialize, Debug)]
684#[cfg_attr(feature = "python", derive(IntoPyObject))]
685pub struct SparkSchema {
686 #[serde(rename = "type")]
687 pub schema_type: String,
688 pub fields: Vec<SparkField>,
689}
690
691impl SparkSchema {
692 pub fn new(schema_type: String, fields: Vec<SparkField>) -> Self {
693 SparkSchema {
694 schema_type,
695 fields,
696 }
697 }
698
699 pub fn to_json(&self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
700 let json_output = serde_json::to_value(self).expect("Failed to serialize JSON");
701 Ok(json_output)
702 }
703}
704
705#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
706#[cfg_attr(feature = "python", derive(IntoPyObject))]
707pub struct SparkField {
708 #[serde(rename = "name")]
709 #[cfg_attr(feature = "python", pyo3(item("name")))]
710 pub field_name: String,
711 #[serde(rename = "type")]
712 #[cfg_attr(feature = "python", pyo3(item("type")))]
713 pub field_type: String,
714 pub nullable: bool,
715 pub metadata: Option<HashMap<String, String>>,
716}
717
718impl SparkField {
719 pub fn to_json(&self) -> Result<String, Box<dyn std::error::Error>> {
720 let json_output = serde_json::to_string(&self).expect("Failed to serialize JSON");
721 Ok(json_output)
722 }
723}
724
725#[derive(Debug, Deserialize, Serialize, Clone)]
726#[cfg_attr(feature = "python", derive(IntoPyObject))]
727pub struct SimpleType {
728 name: Option<String>,
729 data_type: Option<String>,
730 min_length: Option<String>,
731 max_length: Option<String>,
732 min_inclusive: Option<String>,
733 max_inclusive: Option<String>,
734 min_exclusive: Option<String>,
735 max_exclusive: Option<String>,
736 fraction_digits: Option<String>,
737 total_digits: Option<String>,
738 pattern: Option<String>,
739 values: Option<Vec<String>>,
740 nullable: Option<bool>,
741 documentation: Option<String>,
742}
743
744
745fn extract_enum_values(node: roxmltree::Node) -> Option<Vec<String>> {
747 let mut values = Vec::new();
748 for child in node.children() {
749 if child.tag_name().name() == "enumeration" {
750 if let Some(value) = child.attribute("value") {
751 values.push(value.to_string());
752 }
753 }
754 }
755 if values.is_empty() {
756 None
757 } else {
758 Some(values)
759 }
760}
761
762fn extract_documentation(node: roxmltree::Node) -> Option<String> {
764 for child in node.children() {
765 if child.tag_name().name() == "documentation" {
766 return child.text().map(String::from);
768 }
769 }
770
771 None
772}
773
774fn extract_constraints(node: roxmltree::Node) -> SimpleType {
776 let mut simple_type = SimpleType {
777 name: None,
778 data_type: node.attribute("base").map(|s| s.replace("xs:", "")),
779 min_length: None,
780 max_length: None,
781 min_inclusive: None,
782 max_inclusive: None,
783 min_exclusive: None,
784 max_exclusive: None,
785 fraction_digits: None,
786 total_digits: None,
787 pattern: None,
788 values: extract_enum_values(node),
789 nullable: None,
790 documentation: extract_documentation(node),
791 };
792
793 for child in node.children() {
794 match child.tag_name().name() {
795 "minLength" => simple_type.min_length = child.attribute("value").map(String::from),
796 "maxLength" => simple_type.max_length = child.attribute("value").map(String::from),
797 "minInclusive" => {
798 simple_type.min_inclusive = child.attribute("value").map(String::from)
799 }
800 "maxInclusive" => {
801 simple_type.max_inclusive = child.attribute("value").map(String::from)
802 }
803 "minExclusive" => {
804 simple_type.min_exclusive = child.attribute("value").map(String::from)
805 }
806 "maxExclusive" => {
807 simple_type.max_exclusive = child.attribute("value").map(String::from)
808 }
809 "fractionDigits" => {
810 simple_type.fraction_digits = child.attribute("value").map(String::from)
811 }
812 "totalDigits" => simple_type.total_digits = child.attribute("value").map(String::from),
813 "pattern" => simple_type.pattern = child.attribute("value").map(String::from),
814 "nullable" => simple_type.nullable = Some(true),
815 _ => {}
816 }
817 }
818 simple_type
819}
820
821fn parse_element(
822 node: roxmltree::Node,
823 parent_xpath: &str,
824 global_types: &IndexMap<String, SimpleType>,
825 lowercase: Option<bool>,
826) -> Option<SchemaElement> {
827 if node.tag_name().name() != "element" {
828 return None;
829 }
830
831 let mut name = node.attribute("name")?.to_string();
832 if lowercase.is_some() && lowercase.unwrap() {
833 name = name.to_lowercase();
834 }
835
836 let nullable = node.attribute("nillable").map(|s| s == "true");
837 let xpath = format!("{}/{}", parent_xpath, name);
838 let mut data_type = node.attribute("type").map(|s| s.replace("xs:", ""));
839 let min_occurs = match node.attribute("minOccurs") {
840 None => Some("1".to_string()),
841 Some(m) => Some(m.to_string()),
842 };
843
844 let max_occurs = match node.attribute("maxOccurs") {
845 Some(m) => Some(m.to_string()),
846 None => Some("1".to_string()),
847 };
848
849 let mut documentation = None;
850
851 let mut min_length = None;
852 let mut max_length = None;
853 let mut min_inclusive = None;
854 let mut max_inclusive = None;
855 let mut min_exclusive = None;
856 let mut max_exclusive = None;
857 let mut fraction_digits = None;
858 let mut total_digits = None;
859 let mut pattern = None;
860 let mut values = None;
861 let mut elements = Vec::new();
862
863 if let Some(ref type_name) = data_type {
864 if let Some(global_type) = global_types.get(type_name) {
865 min_length = global_type.min_length.clone();
866 max_length = global_type.max_length.clone();
867 min_inclusive = global_type.min_inclusive.clone();
868 max_inclusive = global_type.max_inclusive.clone();
869 min_exclusive = global_type.min_exclusive.clone();
870 max_exclusive = global_type.max_exclusive.clone();
871 fraction_digits = global_type.fraction_digits.clone();
872 total_digits = global_type.total_digits.clone();
873 pattern = global_type.pattern.clone();
874 values = global_type.values.clone();
875 data_type = global_type.data_type.clone();
876 documentation = global_type.documentation.clone();
877 }
878 }
879
880 for child in node.children() {
881 match child.tag_name().name() {
882 "simpleType" => {
883 for subchild in child.children() {
884 if subchild.tag_name().name() == "restriction" {
885 let simple_type = extract_constraints(subchild);
886 if simple_type.data_type.is_some() {
887 data_type = simple_type.data_type;
888 }
889 min_length = simple_type.min_length;
890 max_length = simple_type.max_length;
891 min_inclusive = simple_type.min_inclusive;
892 max_inclusive = simple_type.max_inclusive;
893 min_exclusive = simple_type.min_exclusive;
894 max_exclusive = simple_type.max_exclusive;
895 fraction_digits = simple_type.fraction_digits;
896 total_digits = simple_type.total_digits;
897 pattern = simple_type.pattern;
898 values = simple_type.values;
899 }
900 }
901 }
902 "complexType" => {
903 for subchild in child.descendants() {
906 if let Some(sub_element) = parse_element(subchild, &xpath, global_types, lowercase) {
907 elements.push(sub_element);
908 }
909 }
910 }
911 _ => {}
912 }
913 }
914
915 let is_currency = name == "Currency";
916
917 Some(SchemaElement {
918 id: name.clone(),
919 name,
920 data_type,
921 min_occurs,
922 max_occurs,
923 min_length,
924 max_length,
925 min_inclusive,
926 max_inclusive,
927 min_exclusive,
928 max_exclusive,
929 pattern,
930 fraction_digits,
931 total_digits,
932 values,
933 is_currency,
934 xpath,
935 nullable,
936 elements,
937 documentation,
938 })
939}
940
941pub fn read_xsd_file(xsd_file: PathBuf, encoding: Option<&'static Encoding>) -> Result<String, Box<dyn std::error::Error>> {
942 let parsed_file = File::open(xsd_file);
943 if let Err(e) = parsed_file {
946 return Err(format!("Failed to read XSD file: {}", e).into());
947 }
948
949 let file = parsed_file.unwrap();
950
951 let use_encoding = encoding.unwrap_or(UTF_8);
952
953 let mut transcode_reader = DecodeReaderBytesBuilder::new()
954 .encoding(Some(use_encoding))
955 .build(file);
956
957 let mut xml_content = String::new();
958 transcode_reader.read_to_string(&mut xml_content)?;
959
960 Ok(xml_content)
961}
962
963pub fn parse_xsd_string(xsd_string: &str, timestamp_options: Option<TimestampOptions>, lowercase: Option<bool>) -> Result<Schema, Box<dyn std::error::Error>> {
964 let parse_doc = Document::parse(xsd_string);
965
966 if let Err(e) = parse_doc {
967 return Err(format!("Failed to parse XML: {}. Maybe try a different encoding (utf-16 ?).", e).into());
968 }
969
970 let doc = parse_doc.unwrap();
971 let mut schema_doc: Option<String> = None;
972
973
974 let global_types = Arc::new(Mutex::new(IndexMap::new()));
975
976 doc.root().descendants().for_each(|node| {
977 if node.tag_name().name() == "simpleType" {
978 if let Some(name) = node.attribute("name") {
979
980 let mut doc = None;
981 for child in node.children() {
982 if child.tag_name().name() == "annotation" {
983 let d = extract_documentation(child);
985 if let Some(o) = d {
986 doc = Some(o.clone());
987 }
988 }
989 if child.tag_name().name() == "restriction" {
990 let mut map = global_types.lock().unwrap();
991 let mut st = extract_constraints(child);
992 st.name = Some(name.to_string());
993 st.documentation = doc.clone();
994 map.insert(name.to_string(), st);
995 }
996 }
997 }
998 } else if node.tag_name().name() == "annotation" {
999 if schema_doc.is_none() {
1000 schema_doc = extract_documentation(node);
1001 }
1002 }
1003 });
1004
1005 let final_map = Arc::try_unwrap(global_types)
1006 .expect("Arc should have no other refs")
1007 .into_inner()
1008 .expect("Mutex should be unlocked");
1009
1010 let mut schema_element = None;
1011
1012 for node in doc.root().descendants() {
1013 if node.tag_name().name() == "element" {
1014 let mut element_name = "".to_string();
1015 if let Some(name) = node.attribute("name"){
1016 if lowercase.is_some() && lowercase.unwrap() {
1017 element_name = name.to_lowercase();
1018 } else {
1019 element_name = name.to_string();
1020 }
1021 }
1022 schema_element = parse_element(node, &element_name, &final_map, lowercase);
1023 break;
1024 }
1025 }
1026
1027 let mut custom_types_vec: Vec<_> = final_map.into_iter().collect();
1028 custom_types_vec.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));
1029 let final_map: IndexMap<_, _> = custom_types_vec.into_iter().collect();
1030
1031 if let Some(schema_element) = schema_element {
1032 let schema = Schema {
1033 namespace: None,
1034 schema_element,
1035 timestamp_options,
1036 doc: schema_doc,
1037 custom_types: Some(final_map),
1038 };
1039
1040 Ok(schema)
1041 } else {
1042 Err("Failed to find the main schema element in the XSD.".into())
1043 }
1044}
1045
1046pub fn parse_file(xsd_file: PathBuf, timestamp_options: Option<TimestampOptions>,
1047 encoding: Option<&'static Encoding>,
1048 lowercase: Option<bool>) -> Result<Schema, Box<dyn std::error::Error>> {
1049 let xml_content = read_xsd_file(xsd_file, encoding)?;
1050
1051 parse_xsd_string(&xml_content, timestamp_options, lowercase)
1052}
1053
1054
1055#[cfg(test)]
1056mod tests {
1057 use super::*;
1058 use std::fs::File;
1059 use std::io::Write;
1060 use tempfile::tempdir;
1061
1062 fn create_test_schema() -> Schema {
1063 let element1 = SchemaElement {
1064 id: "id".to_string(),
1065 name: "field1".to_string(),
1066 data_type: Some("string".to_string()),
1067 min_occurs: Some("1".to_string()),
1068 max_occurs: Some("1".to_string()),
1069 min_length: None,
1070 max_length: None,
1071 min_inclusive: None,
1072 max_inclusive: None,
1073 min_exclusive: None,
1074 max_exclusive: None,
1075 pattern: None,
1076 fraction_digits: None,
1077 total_digits: None,
1078 values: None,
1079 is_currency: false,
1080 xpath: "/name".to_string(),
1081 nullable: Some(false),
1082 elements: vec![],
1083 documentation: Some("This is the first test field".to_string()),
1084 };
1085
1086 let element2 = SchemaElement {
1087 id: "id".to_string(),
1088 name: "field2".to_string(),
1089 data_type: Some("string".to_string()),
1090 min_occurs: Some("1".to_string()),
1091 max_occurs: Some("1".to_string()),
1092 min_length: None,
1093 max_length: None,
1094 min_inclusive: None,
1095 max_inclusive: None,
1096 min_exclusive: None,
1097 max_exclusive: None,
1098 pattern: None,
1099 fraction_digits: None,
1100 total_digits: None,
1101 values: None,
1102 is_currency: false,
1103 xpath: "/name".to_string(),
1104 nullable: Some(true),
1105 elements: vec![],
1106 documentation: Some("This is the second test field".to_string()),
1107 };
1108 let element3 = SchemaElement {
1109 id: "id".to_string(),
1110 name: "field3".to_string(),
1111 data_type: Some("string".to_string()),
1112 min_occurs: Some("1".to_string()),
1113 max_occurs: Some("1".to_string()),
1114 min_length: None,
1115 max_length: None,
1116 min_inclusive: None,
1117 max_inclusive: None,
1118 min_exclusive: None,
1119 max_exclusive: None,
1120 pattern: None,
1121 fraction_digits: None,
1122 total_digits: None,
1123 values: None,
1124 is_currency: false,
1125 xpath: "/name".to_string(),
1126 nullable: Some(true),
1127 elements: vec![],
1128 documentation: Some("This is the third and last test field".to_string()),
1129 };
1130
1131 let schema = Schema {
1132 namespace: None,
1133 schema_element: SchemaElement {
1134 id: "id".to_string(),
1135 name: "main_schema".to_string(),
1136 data_type: Some("string".to_string()),
1137 min_occurs: Some("1".to_string()),
1138 max_occurs: Some("1".to_string()),
1139 min_length: None,
1140 max_length: None,
1141 min_inclusive: None,
1142 max_inclusive: None,
1143 min_exclusive: None,
1144 max_exclusive: None,
1145 pattern: None,
1146 fraction_digits: None,
1147 total_digits: None,
1148 values: None,
1149 is_currency: false,
1150 xpath: "/name".to_string(),
1151 nullable: Some(true),
1152 elements: vec![element1, element2, element3],
1153 documentation: Some("This is the main schema".to_string()),
1154 },
1155 timestamp_options: None,
1156 doc: Some("TestSchema".to_string()),
1157 custom_types: None,
1158 };
1159
1160 schema
1161 }
1162
1163 #[test]
1164 fn test_timestamp_unit_from_str() {
1165 assert_eq!(TimestampUnit::from_str("ns").unwrap(), TimestampUnit::Ns);
1166 assert_eq!(TimestampUnit::from_str("ms").unwrap(), TimestampUnit::Ms);
1167 assert_eq!(TimestampUnit::from_str("us").unwrap(), TimestampUnit::Us);
1168 assert!(TimestampUnit::from_str("invalid").is_err());
1169 }
1170
1171 #[test]
1172 fn test_schema_to_arrow() {
1173 let schema = create_test_schema();
1174 let arrow_schema = schema.to_arrow().unwrap();
1175 assert_eq!(arrow_schema.fields().len(), 3);
1176 assert_eq!(arrow_schema.field(0).name(), "field1");
1177 assert_eq!(schema.doc, Some("TestSchema".to_string()));
1178 }
1179
1180 #[test]
1181 fn test_parse_file() {
1182 let dir = tempdir().unwrap();
1183 let file_path = dir.path().join("test.xsd");
1184 let mut file = File::create(&file_path).unwrap();
1185 writeln!(
1186 file,
1187 r#"
1188 <schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1189 <xs:element name="testElement" type="xs:string"/>
1190 </schema>
1191 "#
1192 )
1193 .unwrap();
1194
1195 let schema = parse_file(file_path, None, None, Some(false)).unwrap();
1196 assert_eq!(schema.schema_element.name, "testElement");
1197 }
1198
1199 #[test]
1200 fn test_parse_file_lowercase() {
1201 let dir = tempdir().unwrap();
1202 let file_path = dir.path().join("test.xsd");
1203 let mut file = File::create(&file_path).unwrap();
1204 writeln!(
1205 file,
1206 r#"
1207 <schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
1208 <xs:element name="testElement" type="xs:string"/>
1209 </schema>
1210 "#
1211 )
1212 .unwrap();
1213
1214 let schema = parse_file(file_path, None, None, Some(true)).unwrap();
1215 dbg!(&schema);
1216 assert_eq!(schema.schema_element.name, "testelement");
1217 }
1218 #[test]
1219 fn test_schema_element_to_arrow() {
1220 let schema = create_test_schema();
1221 let element = &schema.schema_element.elements[0];
1222
1223 let data_type = element.to_arrow().unwrap();
1224 assert_eq!(data_type, DataType::Utf8);
1225 }
1226
1227 #[test]
1228 fn test_schema_element_to_spark() {
1229 let schema = create_test_schema();
1230 let element = &schema.schema_element.elements[0];
1231
1232 let spark_field = element.to_spark().unwrap();
1233 assert_eq!(spark_field.field_name, "field1");
1234 assert_eq!(spark_field.field_type, "string");
1235 }
1236
1237 #[test]
1238 fn test_schema_element_to_json_schema() {
1239 let schema = create_test_schema();
1240 let element = &schema.schema_element.elements[0];
1241
1242 let (json_element, nullable) = element.to_json_schema();
1243 assert_eq!(
1244 json_element
1245 .get("field1")
1246 .and_then(|v| v.get("type"))
1247 .and_then(|v| v.as_str()),
1248 Some("string")
1249 );
1250 assert!(!nullable);
1251 }
1252
1253 #[test]
1254 fn test_schema_element_to_duckdb_schema() {
1255 let schema = create_test_schema();
1256 let element = &schema.schema_element.elements[0];
1257 let duckdb_schema = element.to_duckdb_schema();
1258 assert_eq!(
1259 duckdb_schema.get("field1").unwrap().to_string(),
1260 "VARCHAR(255)"
1261 );
1262 }
1263
1264 #[test]
1265 fn test_duckdb_schema_ordered() {
1266 let schema = create_test_schema();
1267
1268 let duckdb_schema = schema.to_duckdb_schema();
1269 dbg!(&duckdb_schema);
1270 let names = duckdb_schema
1271 .iter()
1272 .map(|x| x.0.clone())
1273 .collect::<Vec<_>>();
1274 assert_eq!(
1275 names,
1276 &[
1277 "field1".to_string(),
1278 "field2".to_string(),
1279 "field3".to_string()
1280 ]
1281 );
1282 }
1283
1284 #[test]
1292 fn test_extract_enum_values() {
1293 let xml = r#"
1294 <restriction base="xs:string">
1295 <enumeration value="A"/>
1296 <enumeration value="B"/>
1297 </restriction>
1298 "#;
1299 let doc = Document::parse(xml).unwrap();
1300 let node = doc.root().first_child().unwrap();
1301 let values = extract_enum_values(node).unwrap();
1302 assert_eq!(values, vec!["A", "B"]);
1303 }
1304
1305 #[test]
1306 fn test_extract_constraints() {
1307 let xml = r#"
1308 <restriction base="xs:string">
1309 <minLength value="1"/>
1310 <maxLength value="255"/>
1311 </restriction>
1312 "#;
1313 let doc = Document::parse(xml).unwrap();
1314 let node = doc.root().first_child().unwrap();
1315 let constraints = extract_constraints(node);
1316 assert_eq!(constraints.min_length, Some("1".to_string()));
1317 assert_eq!(constraints.max_length, Some("255".to_string()));
1318 }
1319
1320 #[test]
1321 fn test_parse_element() {
1322 let xml = r#"
1323 <element name="testElement" type="xs:string"/>
1324 "#;
1325 let doc = Document::parse(xml).unwrap();
1326 let node = doc.root().first_child().unwrap();
1327 let element = parse_element(node, "", &IndexMap::new(), Some(false)).unwrap();
1328 assert_eq!(element.name, "testElement");
1329 assert_eq!(element.data_type, Some("string".to_string()));
1330 }
1331
1332 #[test]
1333 fn test_extract_documentation() {
1334 let xml = r#"
1335 <annotation>
1336 <documentation>This is a test element</documentation>
1337 </annotation>
1338 "#;
1339 let doc = Document::parse(xml).unwrap();
1340 let node = doc.root().first_child().unwrap();
1341 let documentation = extract_documentation(node);
1342 assert_eq!(documentation, Some("This is a test element".to_string()));
1343 }
1344
1345 #[test]
1346 fn test_avro_schema_serialization() {
1347 let schema = AvroSchema {
1348 schema_type: "record".to_string(),
1349 namespace: Some("example.avro".to_string()),
1350 name: "LongList".to_string(),
1351 doc: Some("Linked list of 64-bit longs.".to_string()),
1352 aliases: Some(vec!["LinkedLongs".to_string()]),
1353 fields: vec![
1354 AvroField {
1355 name: "value".to_string(),
1356 field_type: AvroType::Simple("long".to_string()),
1357 doc: Some("The value of the node.".to_string()),
1358 },
1359 AvroField {
1360 name: "next".to_string(),
1361 field_type: AvroType::Union(vec![
1362 AvroType::Simple("null".to_string()),
1363 AvroType::Simple("LongList".to_string()),
1364 ]),
1365 doc: Some("The next node in the list.".to_string()),
1366 },
1367 ],
1368 };
1369 assert_eq!(schema.fields.len(), 2);
1370 assert_eq!(schema.doc, Some("Linked list of 64-bit longs.".to_string()));
1371 }
1372}