1pub mod sparse;
24
25use scirs2_core::ndarray::Array2;
26use std::fs::File;
27use std::io::{BufRead, BufReader, BufWriter, Write};
28use std::path::Path;
29
30use crate::error::{IoError, Result};
31
32pub use sparse::{read_sparse_arff, write_sparse_arff, SparseArffData, SparseInstance};
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum AttributeType {
37 Numeric,
39 String,
41 Date(String),
43 Nominal(Vec<String>),
45}
46
47#[derive(Debug, Clone)]
49pub struct ArffData {
50 pub relation: String,
52 pub attributes: Vec<(String, AttributeType)>,
54 pub data: Array2<ArffValue>,
56}
57
58#[derive(Debug, Clone, PartialEq)]
60pub enum ArffValue {
61 Numeric(f64),
63 String(String),
65 Date(String),
67 Nominal(String),
69 Missing,
71}
72
73impl ArffValue {
74 pub fn to_f64(&self) -> Option<f64> {
76 match self {
77 ArffValue::Numeric(val) => Some(*val),
78 _ => None,
79 }
80 }
81
82 pub fn as_string(&self) -> String {
84 match self {
85 ArffValue::Numeric(val) => val.to_string(),
86 ArffValue::String(val) => val.clone(),
87 ArffValue::Date(val) => val.clone(),
88 ArffValue::Nominal(val) => val.clone(),
89 ArffValue::Missing => "?".to_string(),
90 }
91 }
92
93 pub fn is_missing(&self) -> bool {
95 matches!(self, ArffValue::Missing)
96 }
97
98 pub fn is_numeric_zero(&self) -> bool {
100 matches!(self, ArffValue::Numeric(v) if *v == 0.0)
101 }
102}
103
104fn parse_attribute(line: &str) -> Result<(String, AttributeType)> {
106 let trimmed = line.trim();
107 if !trimmed.to_lowercase().starts_with("@attribute") {
108 return Err(IoError::FormatError("Invalid attribute format".to_string()));
109 }
110
111 let rest = trimmed["@attribute".len()..].trim_start();
113
114 let (name, type_part) = if rest.starts_with('\'') || rest.starts_with('"') {
116 let quote = rest.as_bytes()[0];
117 let end = rest[1..]
118 .find(|c: char| c as u8 == quote)
119 .ok_or_else(|| IoError::FormatError("Unterminated attribute name quote".to_string()))?;
120 let name = rest[1..end + 1].to_string();
121 let remaining = rest[end + 2..].trim_start();
122 (name, remaining)
123 } else {
124 let parts: Vec<&str> = rest.splitn(2, ' ').collect();
125 if parts.len() < 2 {
126 return Err(IoError::FormatError("Invalid attribute format".to_string()));
127 }
128 (parts[0].trim().to_string(), parts[1].trim())
129 };
130
131 let attr_type = if type_part.eq_ignore_ascii_case("numeric")
133 || type_part.eq_ignore_ascii_case("real")
134 || type_part.eq_ignore_ascii_case("integer")
135 {
136 AttributeType::Numeric
137 } else if type_part.eq_ignore_ascii_case("string") {
138 AttributeType::String
139 } else if type_part.to_lowercase().starts_with("date") {
140 let format = if type_part.len() > 4 && type_part.contains(' ') {
141 let format_str = type_part.split_once(' ').map(|x| x.1).unwrap_or("").trim();
142 if (format_str.starts_with('"') && format_str.ends_with('"'))
143 || (format_str.starts_with('\'') && format_str.ends_with('\''))
144 {
145 format_str[1..format_str.len() - 1].to_string()
146 } else {
147 format_str.to_string()
148 }
149 } else {
150 "yyyy-MM-dd'T'HH:mm:ss".to_string()
151 };
152 AttributeType::Date(format)
153 } else if type_part.starts_with('{') && type_part.ends_with('}') {
154 let values_str = &type_part[1..type_part.len() - 1];
155 let values: Vec<String> = values_str
156 .split(',')
157 .map(|s| {
158 let s = s.trim();
159 if (s.starts_with('"') && s.ends_with('"'))
160 || (s.starts_with('\'') && s.ends_with('\''))
161 {
162 s[1..s.len() - 1].to_string()
163 } else {
164 s.to_string()
165 }
166 })
167 .collect();
168 AttributeType::Nominal(values)
169 } else {
170 return Err(IoError::FormatError(format!(
171 "Unknown attribute type: {type_part}"
172 )));
173 };
174
175 Ok((name, attr_type))
176}
177
178fn parse_data_line(line: &str, attributes: &[(String, AttributeType)]) -> Result<Vec<ArffValue>> {
180 let trimmed = line.trim();
181 if trimmed.is_empty() {
182 return Err(IoError::FormatError("Empty data line".to_string()));
183 }
184
185 if trimmed.starts_with('{') {
187 return parse_sparse_data_line(trimmed, attributes);
188 }
189
190 let mut values = Vec::new();
191 let parts: Vec<&str> = trimmed.split(',').collect();
192
193 if parts.len() != attributes.len() {
194 return Err(IoError::FormatError(format!(
195 "Data line has {} values but expected {}",
196 parts.len(),
197 attributes.len()
198 )));
199 }
200
201 for (i, part) in parts.iter().enumerate() {
202 let part = part.trim();
203 if part == "?" {
204 values.push(ArffValue::Missing);
205 continue;
206 }
207
208 let attr_type = &attributes[i].1;
209 let value = parse_value(part, attr_type)?;
210 values.push(value);
211 }
212
213 Ok(values)
214}
215
216fn parse_sparse_data_line(
218 line: &str,
219 attributes: &[(String, AttributeType)],
220) -> Result<Vec<ArffValue>> {
221 let mut values: Vec<ArffValue> = Vec::new();
222 for (_, attr_type) in attributes {
224 let default = match attr_type {
225 AttributeType::Numeric => ArffValue::Numeric(0.0),
226 AttributeType::String => ArffValue::String(String::new()),
227 AttributeType::Date(_) => ArffValue::Missing,
228 AttributeType::Nominal(_) => ArffValue::Missing,
229 };
230 values.push(default);
231 }
232
233 let inner = line
235 .trim()
236 .trim_start_matches('{')
237 .trim_end_matches('}')
238 .trim();
239
240 if inner.is_empty() {
241 return Ok(values);
242 }
243
244 for pair in inner.split(',') {
246 let pair = pair.trim();
247 if pair.is_empty() {
248 continue;
249 }
250
251 let space_pos = pair
252 .find(' ')
253 .ok_or_else(|| IoError::FormatError(format!("Invalid sparse pair: '{}'", pair)))?;
254
255 let idx_str = &pair[..space_pos];
256 let val_str = pair[space_pos + 1..].trim();
257
258 let idx: usize = idx_str
259 .parse()
260 .map_err(|_| IoError::FormatError(format!("Invalid sparse index: '{}'", idx_str)))?;
261
262 if idx >= attributes.len() {
263 return Err(IoError::FormatError(format!(
264 "Sparse index {} out of range (max {})",
265 idx,
266 attributes.len() - 1
267 )));
268 }
269
270 if val_str == "?" {
271 values[idx] = ArffValue::Missing;
272 } else {
273 values[idx] = parse_value(val_str, &attributes[idx].1)?;
274 }
275 }
276
277 Ok(values)
278}
279
280fn parse_value(part: &str, attr_type: &AttributeType) -> Result<ArffValue> {
282 match attr_type {
283 AttributeType::Numeric => {
284 let num = part
285 .parse::<f64>()
286 .map_err(|_| IoError::FormatError(format!("Invalid numeric value: {part}")))?;
287 Ok(ArffValue::Numeric(num))
288 }
289 AttributeType::String => {
290 let s = strip_quotes(part);
291 Ok(ArffValue::String(s))
292 }
293 AttributeType::Date(_) => {
294 let s = strip_quotes(part);
295 Ok(ArffValue::Date(s))
296 }
297 AttributeType::Nominal(allowed_values) => {
298 let s = strip_quotes(part);
299 if !allowed_values.contains(&s) {
300 return Err(IoError::FormatError(format!(
301 "Invalid nominal value: {s}, expected one of {allowed_values:?}"
302 )));
303 }
304 Ok(ArffValue::Nominal(s))
305 }
306 }
307}
308
309fn strip_quotes(s: &str) -> String {
311 let s = s.trim();
312 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
313 s[1..s.len() - 1].to_string()
314 } else {
315 s.to_string()
316 }
317}
318
319pub fn read_arff<P: AsRef<Path>>(path: P) -> Result<ArffData> {
321 let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
322 let reader = BufReader::new(file);
323
324 let mut relation = String::new();
325 let mut attributes = Vec::new();
326 let mut data_lines = Vec::new();
327 let mut in_data_section = false;
328
329 for (line_num, line_result) in reader.lines().enumerate() {
330 let line = line_result
331 .map_err(|e| IoError::FileError(format!("Error reading line {}: {e}", line_num + 1)))?;
332
333 let trimmed = line.trim();
334 if trimmed.is_empty() || trimmed.starts_with('%') {
335 continue;
336 }
337
338 if in_data_section {
339 data_lines.push(trimmed.to_string());
340 } else {
341 let lower = trimmed.to_lowercase();
342 if lower.starts_with("@relation") {
343 let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
344 if parts.len() < 2 {
345 return Err(IoError::FormatError("Invalid relation format".to_string()));
346 }
347 relation = strip_quotes(parts[1].trim());
348 } else if lower.starts_with("@attribute") {
349 let (name, attr_type) = parse_attribute(trimmed)?;
350 attributes.push((name, attr_type));
351 } else if lower.starts_with("@data") {
352 in_data_section = true;
353 } else {
354 return Err(IoError::FormatError(format!(
355 "Unexpected line in header section: {trimmed}"
356 )));
357 }
358 }
359 }
360
361 if !in_data_section {
362 return Err(IoError::FormatError("No @data section found".to_string()));
363 }
364
365 if attributes.is_empty() {
366 return Err(IoError::FormatError("No attributes defined".to_string()));
367 }
368
369 let mut data_values = Vec::new();
371 for (i, line) in data_lines.iter().enumerate() {
372 let values = parse_data_line(line, &attributes)
373 .map_err(|e| IoError::FormatError(format!("Error parsing data line {}: {e}", i + 1)))?;
374 data_values.push(values);
375 }
376
377 let num_instances = data_values.len();
379 let num_attributes = attributes.len();
380 let mut data = Array2::from_elem((num_instances, num_attributes), ArffValue::Missing);
381
382 for (i, row) in data_values.iter().enumerate() {
383 for (j, value) in row.iter().enumerate() {
384 data[[i, j]] = value.clone();
385 }
386 }
387
388 Ok(ArffData {
389 relation,
390 attributes,
391 data,
392 })
393}
394
395pub fn get_numeric_matrix(
397 arff_data: &ArffData,
398 numeric_attributes: &[String],
399) -> Result<Array2<f64>> {
400 let mut indices = Vec::new();
401 let mut attr_names = Vec::new();
402
403 for attr_name in numeric_attributes {
404 let mut found = false;
405 for (i, (name, attr_type)) in arff_data.attributes.iter().enumerate() {
406 if name == attr_name {
407 match attr_type {
408 AttributeType::Numeric => {
409 indices.push(i);
410 attr_names.push(name.clone());
411 found = true;
412 break;
413 }
414 _ => {
415 return Err(IoError::FormatError(format!(
416 "Attribute '{name}' is not numeric"
417 )));
418 }
419 }
420 }
421 }
422
423 if !found {
424 return Err(IoError::FormatError(format!(
425 "Attribute '{attr_name}' not found"
426 )));
427 }
428 }
429
430 let num_instances = arff_data.data.shape()[0];
431 let num_selected = indices.len();
432 let mut output = Array2::from_elem((num_instances, num_selected), f64::NAN);
433
434 for (out_col, &in_col) in indices.iter().enumerate() {
435 for row in 0..num_instances {
436 match &arff_data.data[[row, in_col]] {
437 ArffValue::Numeric(val) => {
438 output[[row, out_col]] = *val;
439 }
440 ArffValue::Missing => {} _ => {
442 return Err(IoError::FormatError(format!(
443 "Non-numeric value found in numeric attribute '{}' at row {}",
444 attr_names[out_col], row
445 )));
446 }
447 }
448 }
449 }
450
451 Ok(output)
452}
453
454pub fn write_arff<P: AsRef<Path>>(path: P, arff_data: &ArffData) -> Result<()> {
456 let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
457 let mut writer = BufWriter::new(file);
458
459 writeln!(
460 writer,
461 "@relation {}",
462 format_arff_string(&arff_data.relation)
463 )
464 .map_err(|e| IoError::FileError(format!("Failed to write relation: {e}")))?;
465
466 writeln!(writer).map_err(|e| IoError::FileError(format!("Failed to write newline: {e}")))?;
467
468 for (name, attr_type) in &arff_data.attributes {
469 let type_str = match attr_type {
470 AttributeType::Numeric => "numeric".to_string(),
471 AttributeType::String => "string".to_string(),
472 AttributeType::Date(format) => {
473 if format.is_empty() {
474 "date".to_string()
475 } else {
476 format!("date {}", format_arff_string(format))
477 }
478 }
479 AttributeType::Nominal(values) => {
480 let values_str: Vec<String> =
481 values.iter().map(|v| format_arff_string(v)).collect();
482 format!("{{{}}}", values_str.join(", "))
483 }
484 };
485
486 writeln!(
487 writer,
488 "@attribute {} {}",
489 format_arff_string(name),
490 type_str
491 )
492 .map_err(|e| IoError::FileError(format!("Failed to write attribute: {e}")))?;
493 }
494
495 writeln!(writer, "\n@data")
496 .map_err(|e| IoError::FileError(format!("Failed to write data header: {e}")))?;
497
498 let shape = arff_data.data.shape();
499 let num_instances = shape[0];
500 let num_attributes = shape[1];
501
502 for i in 0..num_instances {
503 let mut line = String::new();
504 for j in 0..num_attributes {
505 let value = &arff_data.data[[i, j]];
506 let value_str = match value {
507 ArffValue::Missing => "?".to_string(),
508 ArffValue::Numeric(val) => val.to_string(),
509 ArffValue::String(val) => format_arff_string(val),
510 ArffValue::Date(val) => format_arff_string(val),
511 ArffValue::Nominal(val) => format_arff_string(val),
512 };
513 if j > 0 {
514 line.push(',');
515 }
516 line.push_str(&value_str);
517 }
518 writeln!(writer, "{line}")
519 .map_err(|e| IoError::FileError(format!("Failed to write data line: {e}")))?;
520 }
521
522 Ok(())
523}
524
525pub fn numeric_matrix_to_arff(
527 relation: &str,
528 attribute_names: &[String],
529 data: &Array2<f64>,
530) -> ArffData {
531 let shape = data.shape();
532 let num_instances = shape[0];
533 let num_attributes = shape[1];
534
535 let mut attributes = Vec::with_capacity(num_attributes);
536 for name in attribute_names {
537 attributes.push((name.clone(), AttributeType::Numeric));
538 }
539
540 let mut arff_data = Array2::from_elem((num_instances, num_attributes), ArffValue::Missing);
541
542 for i in 0..num_instances {
543 for j in 0..num_attributes {
544 let val = data[[i, j]];
545 arff_data[[i, j]] = if val.is_nan() {
546 ArffValue::Missing
547 } else {
548 ArffValue::Numeric(val)
549 };
550 }
551 }
552
553 ArffData {
554 relation: relation.to_string(),
555 attributes,
556 data: arff_data,
557 }
558}
559
560fn format_arff_string(s: &str) -> String {
562 if s.contains(' ')
563 || s.contains(',')
564 || s.contains('\'')
565 || s.contains('"')
566 || s.contains('{')
567 || s.contains('}')
568 {
569 format!("\"{}\"", s.replace('"', "\\\""))
570 } else {
571 s.to_string()
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_arff_roundtrip_dense() {
581 let dir = std::env::temp_dir().join("scirs2_arff_test_dense");
582 let _ = std::fs::create_dir_all(&dir);
583 let path = dir.join("test.arff");
584
585 let arff_data = ArffData {
586 relation: "test_relation".to_string(),
587 attributes: vec![
588 ("temp".to_string(), AttributeType::Numeric),
589 (
590 "outlook".to_string(),
591 AttributeType::Nominal(vec![
592 "sunny".to_string(),
593 "overcast".to_string(),
594 "rainy".to_string(),
595 ]),
596 ),
597 ],
598 data: Array2::from_shape_vec(
599 (2, 2),
600 vec![
601 ArffValue::Numeric(85.0),
602 ArffValue::Nominal("sunny".to_string()),
603 ArffValue::Numeric(72.0),
604 ArffValue::Nominal("overcast".to_string()),
605 ],
606 )
607 .expect("Array creation failed"),
608 };
609
610 write_arff(&path, &arff_data).expect("Write failed");
611 let loaded = read_arff(&path).expect("Read failed");
612
613 assert_eq!(loaded.relation, "test_relation");
614 assert_eq!(loaded.attributes.len(), 2);
615 assert_eq!(loaded.data.shape(), &[2, 2]);
616 assert_eq!(loaded.data[[0, 0]], ArffValue::Numeric(85.0));
617 assert_eq!(loaded.data[[0, 1]], ArffValue::Nominal("sunny".to_string()));
618
619 let _ = std::fs::remove_dir_all(&dir);
620 }
621
622 #[test]
623 fn test_arff_missing_values() {
624 let dir = std::env::temp_dir().join("scirs2_arff_test_missing");
625 let _ = std::fs::create_dir_all(&dir);
626 let path = dir.join("missing.arff");
627
628 let arff_data = ArffData {
629 relation: "test".to_string(),
630 attributes: vec![
631 ("x".to_string(), AttributeType::Numeric),
632 ("y".to_string(), AttributeType::Numeric),
633 ],
634 data: Array2::from_shape_vec(
635 (2, 2),
636 vec![
637 ArffValue::Numeric(1.0),
638 ArffValue::Missing,
639 ArffValue::Missing,
640 ArffValue::Numeric(2.0),
641 ],
642 )
643 .expect("Array creation failed"),
644 };
645
646 write_arff(&path, &arff_data).expect("Write failed");
647 let loaded = read_arff(&path).expect("Read failed");
648
649 assert!(loaded.data[[0, 1]].is_missing());
650 assert!(loaded.data[[1, 0]].is_missing());
651 assert_eq!(loaded.data[[0, 0]], ArffValue::Numeric(1.0));
652
653 let _ = std::fs::remove_dir_all(&dir);
654 }
655
656 #[test]
657 fn test_arff_with_date_and_string() {
658 let dir = std::env::temp_dir().join("scirs2_arff_test_mixed");
659 let _ = std::fs::create_dir_all(&dir);
660 let path = dir.join("mixed.arff");
661
662 let arff_data = ArffData {
663 relation: "mixed_types".to_string(),
664 attributes: vec![
665 ("name".to_string(), AttributeType::String),
666 (
667 "timestamp".to_string(),
668 AttributeType::Date("yyyy-MM-dd".to_string()),
669 ),
670 ("value".to_string(), AttributeType::Numeric),
671 ],
672 data: Array2::from_shape_vec(
673 (1, 3),
674 vec![
675 ArffValue::String("sensor_1".to_string()),
676 ArffValue::Date("2025-01-15".to_string()),
677 ArffValue::Numeric(42.5),
678 ],
679 )
680 .expect("Array creation failed"),
681 };
682
683 write_arff(&path, &arff_data).expect("Write failed");
684 let loaded = read_arff(&path).expect("Read failed");
685
686 assert_eq!(loaded.attributes.len(), 3);
687 assert_eq!(
688 loaded.data[[0, 0]],
689 ArffValue::String("sensor_1".to_string())
690 );
691
692 let _ = std::fs::remove_dir_all(&dir);
693 }
694
695 #[test]
696 fn test_arff_numeric_matrix_conversion() {
697 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, f64::NAN, 5.0, 6.0])
698 .expect("Array creation failed");
699
700 let names = vec!["a".to_string(), "b".to_string()];
701 let arff = numeric_matrix_to_arff("test", &names, &data);
702
703 assert_eq!(arff.data[[0, 0]], ArffValue::Numeric(1.0));
704 assert!(arff.data[[1, 1]].is_missing()); let matrix = get_numeric_matrix(&arff, &names).expect("Conversion failed");
707 assert!((matrix[[0, 0]] - 1.0).abs() < 1e-10);
708 assert!(matrix[[1, 1]].is_nan()); }
710
711 #[test]
712 fn test_arff_sparse_read() {
713 let dir = std::env::temp_dir().join("scirs2_arff_test_sparse_read");
714 let _ = std::fs::create_dir_all(&dir);
715 let path = dir.join("sparse.arff");
716
717 let content = "\
719@relation sparse_test
720
721@attribute x numeric
722@attribute y numeric
723@attribute z numeric
724
725@data
726{0 1.0, 2 3.0}
727{1 2.5}
728{}
729";
730 std::fs::write(&path, content).expect("Write failed");
731
732 let loaded = read_arff(&path).expect("Read failed");
733 assert_eq!(loaded.data.shape(), &[3, 3]);
734
735 assert_eq!(loaded.data[[0, 0]], ArffValue::Numeric(1.0));
737 assert_eq!(loaded.data[[0, 1]], ArffValue::Numeric(0.0));
738 assert_eq!(loaded.data[[0, 2]], ArffValue::Numeric(3.0));
739
740 assert_eq!(loaded.data[[1, 0]], ArffValue::Numeric(0.0));
742 assert_eq!(loaded.data[[1, 1]], ArffValue::Numeric(2.5));
743
744 assert_eq!(loaded.data[[2, 0]], ArffValue::Numeric(0.0));
746
747 let _ = std::fs::remove_dir_all(&dir);
748 }
749
750 #[test]
751 fn test_arff_parse_attribute_types() {
752 let (name, attr) = parse_attribute("@attribute temp numeric").expect("Parse failed");
753 assert_eq!(name, "temp");
754 assert_eq!(attr, AttributeType::Numeric);
755
756 let (name, attr) = parse_attribute("@attribute name string").expect("Parse failed");
757 assert_eq!(name, "name");
758 assert_eq!(attr, AttributeType::String);
759
760 let (name, attr) =
761 parse_attribute("@attribute class {yes, no, maybe}").expect("Parse failed");
762 assert_eq!(name, "class");
763 assert!(matches!(attr, AttributeType::Nominal(_)));
764 }
765
766 #[test]
767 fn test_arff_no_data_section() {
768 let dir = std::env::temp_dir().join("scirs2_arff_test_nodata");
769 let _ = std::fs::create_dir_all(&dir);
770 let path = dir.join("nodata.arff");
771
772 let content = "@relation test\n@attribute x numeric\n";
773 std::fs::write(&path, content).expect("Write failed");
774
775 let result = read_arff(&path);
776 assert!(result.is_err());
777
778 let _ = std::fs::remove_dir_all(&dir);
779 }
780
781 #[test]
782 fn test_arff_no_attributes() {
783 let dir = std::env::temp_dir().join("scirs2_arff_test_noattr");
784 let _ = std::fs::create_dir_all(&dir);
785 let path = dir.join("noattr.arff");
786
787 let content = "@relation test\n@data\n";
788 std::fs::write(&path, content).expect("Write failed");
789
790 let result = read_arff(&path);
791 assert!(result.is_err());
792
793 let _ = std::fs::remove_dir_all(&dir);
794 }
795}