1use std::collections::BTreeMap;
14use std::fs::File;
15use std::io::{BufRead, BufReader, BufWriter, Write};
16use std::path::Path;
17
18use crate::error::{IoError, Result};
19
20use super::{parse_attribute, ArffValue, AttributeType};
21
22#[derive(Debug, Clone)]
24pub struct SparseInstance {
25 pub values: BTreeMap<usize, ArffValue>,
27}
28
29impl SparseInstance {
30 pub fn new() -> Self {
32 SparseInstance {
33 values: BTreeMap::new(),
34 }
35 }
36
37 pub fn set(&mut self, index: usize, value: ArffValue) {
39 self.values.insert(index, value);
40 }
41
42 pub fn get(&self, index: usize) -> Option<&ArffValue> {
44 self.values.get(&index)
45 }
46
47 pub fn get_or_default(&self, index: usize, attr_type: &AttributeType) -> ArffValue {
49 if let Some(val) = self.values.get(&index) {
50 val.clone()
51 } else {
52 match attr_type {
53 AttributeType::Numeric => ArffValue::Numeric(0.0),
54 AttributeType::String => ArffValue::String(String::new()),
55 AttributeType::Date(_) => ArffValue::Missing,
56 AttributeType::Nominal(_) => ArffValue::Missing,
57 }
58 }
59 }
60
61 pub fn nnz(&self) -> usize {
63 self.values.len()
64 }
65}
66
67impl Default for SparseInstance {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct SparseArffData {
76 pub relation: String,
78 pub attributes: Vec<(String, AttributeType)>,
80 pub instances: Vec<SparseInstance>,
82}
83
84impl SparseArffData {
85 pub fn new(relation: impl Into<String>, attributes: Vec<(String, AttributeType)>) -> Self {
87 SparseArffData {
88 relation: relation.into(),
89 attributes,
90 instances: Vec::new(),
91 }
92 }
93
94 pub fn add_instance(&mut self, instance: SparseInstance) {
96 self.instances.push(instance);
97 }
98
99 pub fn num_instances(&self) -> usize {
101 self.instances.len()
102 }
103
104 pub fn num_attributes(&self) -> usize {
106 self.attributes.len()
107 }
108
109 pub fn total_nnz(&self) -> usize {
111 self.instances.iter().map(|inst| inst.nnz()).sum()
112 }
113
114 pub fn sparsity(&self) -> f64 {
116 if self.instances.is_empty() || self.attributes.is_empty() {
117 return 1.0;
118 }
119 let total_cells = self.instances.len() * self.attributes.len();
120 let nnz = self.total_nnz();
121 1.0 - (nnz as f64 / total_cells as f64)
122 }
123
124 pub fn to_dense(&self) -> super::ArffData {
126 use scirs2_core::ndarray::Array2;
127
128 let num_instances = self.instances.len();
129 let num_attributes = self.attributes.len();
130
131 let mut data = Array2::from_elem((num_instances, num_attributes), ArffValue::Missing);
132
133 for (i, instance) in self.instances.iter().enumerate() {
134 for j in 0..num_attributes {
135 data[[i, j]] = instance.get_or_default(j, &self.attributes[j].1);
136 }
137 }
138
139 super::ArffData {
140 relation: self.relation.clone(),
141 attributes: self.attributes.clone(),
142 data,
143 }
144 }
145
146 pub fn from_dense(dense: &super::ArffData) -> Self {
148 let num_instances = dense.data.shape()[0];
149 let mut instances = Vec::with_capacity(num_instances);
150
151 for i in 0..num_instances {
152 let mut inst = SparseInstance::new();
153 for (j, (_, attr_type)) in dense.attributes.iter().enumerate() {
154 let value = &dense.data[[i, j]];
155 let is_default = match (value, attr_type) {
156 (ArffValue::Numeric(v), AttributeType::Numeric) => *v == 0.0,
157 (ArffValue::String(s), AttributeType::String) => s.is_empty(),
158 (ArffValue::Missing, _) => true,
159 _ => false,
160 };
161 if !is_default {
162 inst.set(j, value.clone());
163 }
164 }
165 instances.push(inst);
166 }
167
168 SparseArffData {
169 relation: dense.relation.clone(),
170 attributes: dense.attributes.clone(),
171 instances,
172 }
173 }
174}
175
176pub fn read_sparse_arff<P: AsRef<Path>>(path: P) -> Result<SparseArffData> {
181 let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
182 let reader = BufReader::new(file);
183
184 let mut relation = String::new();
185 let mut attributes = Vec::new();
186 let mut instances = Vec::new();
187 let mut in_data_section = false;
188
189 for (line_num, line_result) in reader.lines().enumerate() {
190 let line = line_result
191 .map_err(|e| IoError::FileError(format!("Error reading line {}: {e}", line_num + 1)))?;
192
193 let trimmed = line.trim();
194 if trimmed.is_empty() || trimmed.starts_with('%') {
195 continue;
196 }
197
198 if in_data_section {
199 let instance = parse_sparse_line(trimmed, &attributes)?;
200 instances.push(instance);
201 } else {
202 let lower = trimmed.to_lowercase();
203 if lower.starts_with("@relation") {
204 let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
205 if parts.len() < 2 {
206 return Err(IoError::FormatError("Invalid relation format".to_string()));
207 }
208 relation = strip_quotes_local(parts[1].trim());
209 } else if lower.starts_with("@attribute") {
210 let (name, attr_type) = parse_attribute(trimmed)?;
211 attributes.push((name, attr_type));
212 } else if lower.starts_with("@data") {
213 in_data_section = true;
214 } else {
215 return Err(IoError::FormatError(format!(
216 "Unexpected header line: {trimmed}"
217 )));
218 }
219 }
220 }
221
222 if !in_data_section {
223 return Err(IoError::FormatError("No @data section found".to_string()));
224 }
225
226 Ok(SparseArffData {
227 relation,
228 attributes,
229 instances,
230 })
231}
232
233fn parse_sparse_line(line: &str, attributes: &[(String, AttributeType)]) -> Result<SparseInstance> {
235 let trimmed = line.trim();
236
237 if trimmed.starts_with('{') {
238 let inner = trimmed.trim_start_matches('{').trim_end_matches('}').trim();
240
241 let mut inst = SparseInstance::new();
242
243 if inner.is_empty() {
244 return Ok(inst);
245 }
246
247 for pair in inner.split(',') {
248 let pair = pair.trim();
249 if pair.is_empty() {
250 continue;
251 }
252
253 let space_pos = pair
254 .find(' ')
255 .ok_or_else(|| IoError::FormatError(format!("Invalid sparse pair: '{}'", pair)))?;
256
257 let idx_str = &pair[..space_pos];
258 let val_str = pair[space_pos + 1..].trim();
259
260 let idx: usize = idx_str.parse().map_err(|_| {
261 IoError::FormatError(format!("Invalid sparse index: '{}'", idx_str))
262 })?;
263
264 if idx >= attributes.len() {
265 return Err(IoError::FormatError(format!(
266 "Sparse index {} out of range (max {})",
267 idx,
268 attributes.len() - 1
269 )));
270 }
271
272 if val_str != "?" {
273 let value = super::parse_value(val_str, &attributes[idx].1)?;
274 inst.set(idx, value);
275 }
276 }
277
278 Ok(inst)
279 } else {
280 let parts: Vec<&str> = trimmed.split(',').collect();
282 if parts.len() != attributes.len() {
283 return Err(IoError::FormatError(format!(
284 "Data line has {} values, expected {}",
285 parts.len(),
286 attributes.len()
287 )));
288 }
289
290 let mut inst = SparseInstance::new();
291 for (i, part) in parts.iter().enumerate() {
292 let part = part.trim();
293 if part == "?" {
294 continue; }
296
297 let value = super::parse_value(part, &attributes[i].1)?;
298
299 let is_default = match (&value, &attributes[i].1) {
301 (ArffValue::Numeric(v), AttributeType::Numeric) => *v == 0.0,
302 (ArffValue::String(s), AttributeType::String) => s.is_empty(),
303 _ => false,
304 };
305
306 if !is_default {
307 inst.set(i, value);
308 }
309 }
310
311 Ok(inst)
312 }
313}
314
315pub fn write_sparse_arff<P: AsRef<Path>>(path: P, data: &SparseArffData) -> Result<()> {
317 let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
318 let mut writer = BufWriter::new(file);
319
320 writeln!(writer, "@relation {}", format_arff_str(&data.relation))
322 .map_err(|e| IoError::FileError(format!("Write error: {}", e)))?;
323 writeln!(writer).map_err(|e| IoError::FileError(format!("Write error: {}", e)))?;
324
325 for (name, attr_type) in &data.attributes {
326 let type_str = match attr_type {
327 AttributeType::Numeric => "numeric".to_string(),
328 AttributeType::String => "string".to_string(),
329 AttributeType::Date(fmt) => {
330 if fmt.is_empty() {
331 "date".to_string()
332 } else {
333 format!("date {}", format_arff_str(fmt))
334 }
335 }
336 AttributeType::Nominal(values) => {
337 let vals: Vec<String> = values.iter().map(|v| format_arff_str(v)).collect();
338 format!("{{{}}}", vals.join(", "))
339 }
340 };
341 writeln!(writer, "@attribute {} {}", format_arff_str(name), type_str)
342 .map_err(|e| IoError::FileError(format!("Write error: {}", e)))?;
343 }
344
345 writeln!(writer, "\n@data").map_err(|e| IoError::FileError(format!("Write error: {}", e)))?;
346
347 for instance in &data.instances {
349 let mut pairs = Vec::new();
350 for (&idx, value) in &instance.values {
351 let val_str = match value {
352 ArffValue::Missing => "?".to_string(),
353 ArffValue::Numeric(v) => v.to_string(),
354 ArffValue::String(s) => format_arff_str(s),
355 ArffValue::Date(s) => format_arff_str(s),
356 ArffValue::Nominal(s) => format_arff_str(s),
357 };
358 pairs.push(format!("{} {}", idx, val_str));
359 }
360 writeln!(writer, "{{{}}}", pairs.join(", "))
361 .map_err(|e| IoError::FileError(format!("Write error: {}", e)))?;
362 }
363
364 writer
365 .flush()
366 .map_err(|e| IoError::FileError(format!("Flush error: {}", e)))?;
367
368 Ok(())
369}
370
371fn format_arff_str(s: &str) -> String {
372 if s.contains(' ')
373 || s.contains(',')
374 || s.contains('\'')
375 || s.contains('"')
376 || s.contains('{')
377 || s.contains('}')
378 {
379 format!("\"{}\"", s.replace('"', "\\\""))
380 } else {
381 s.to_string()
382 }
383}
384
385fn strip_quotes_local(s: &str) -> String {
386 let s = s.trim();
387 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
388 s[1..s.len() - 1].to_string()
389 } else {
390 s.to_string()
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_sparse_arff_roundtrip() {
400 let dir = std::env::temp_dir().join("scirs2_arff_sparse_rt");
401 let _ = std::fs::create_dir_all(&dir);
402 let path = dir.join("sparse.arff");
403
404 let mut data = SparseArffData::new(
405 "sparse_test",
406 vec![
407 ("x".to_string(), AttributeType::Numeric),
408 ("y".to_string(), AttributeType::Numeric),
409 ("z".to_string(), AttributeType::Numeric),
410 ("w".to_string(), AttributeType::Numeric),
411 ],
412 );
413
414 let mut inst1 = SparseInstance::new();
415 inst1.set(0, ArffValue::Numeric(1.0));
416 inst1.set(3, ArffValue::Numeric(4.0));
417 data.add_instance(inst1);
418
419 let mut inst2 = SparseInstance::new();
420 inst2.set(1, ArffValue::Numeric(2.5));
421 data.add_instance(inst2);
422
423 data.add_instance(SparseInstance::new()); write_sparse_arff(&path, &data).expect("Write failed");
426 let loaded = read_sparse_arff(&path).expect("Read failed");
427
428 assert_eq!(loaded.num_instances(), 3);
429 assert_eq!(loaded.num_attributes(), 4);
430
431 let inst0 = &loaded.instances[0];
433 assert_eq!(inst0.get(0), Some(&ArffValue::Numeric(1.0)));
434 assert_eq!(inst0.get(1), None); assert_eq!(inst0.get(3), Some(&ArffValue::Numeric(4.0)));
436
437 assert_eq!(loaded.instances[2].nnz(), 0);
439
440 let _ = std::fs::remove_dir_all(&dir);
441 }
442
443 #[test]
444 fn test_sparse_to_dense_conversion() {
445 let mut data = SparseArffData::new(
446 "test",
447 vec![
448 ("a".to_string(), AttributeType::Numeric),
449 ("b".to_string(), AttributeType::Numeric),
450 ],
451 );
452
453 let mut inst = SparseInstance::new();
454 inst.set(0, ArffValue::Numeric(5.0));
455 data.add_instance(inst);
456
457 let dense = data.to_dense();
458 assert_eq!(dense.data[[0, 0]], ArffValue::Numeric(5.0));
459 assert_eq!(dense.data[[0, 1]], ArffValue::Numeric(0.0)); }
461
462 #[test]
463 fn test_dense_to_sparse_conversion() {
464 use scirs2_core::ndarray::Array2;
465
466 let dense = super::super::ArffData {
467 relation: "test".to_string(),
468 attributes: vec![
469 ("a".to_string(), AttributeType::Numeric),
470 ("b".to_string(), AttributeType::Numeric),
471 ("c".to_string(), AttributeType::Numeric),
472 ],
473 data: Array2::from_shape_vec(
474 (2, 3),
475 vec![
476 ArffValue::Numeric(1.0),
477 ArffValue::Numeric(0.0),
478 ArffValue::Numeric(3.0),
479 ArffValue::Numeric(0.0),
480 ArffValue::Numeric(0.0),
481 ArffValue::Numeric(0.0),
482 ],
483 )
484 .expect("Array creation failed"),
485 };
486
487 let sparse = SparseArffData::from_dense(&dense);
488 assert_eq!(sparse.instances[0].nnz(), 2); assert_eq!(sparse.instances[1].nnz(), 0); }
491
492 #[test]
493 fn test_sparsity_calculation() {
494 let mut data = SparseArffData::new(
495 "test",
496 vec![
497 ("a".to_string(), AttributeType::Numeric),
498 ("b".to_string(), AttributeType::Numeric),
499 ("c".to_string(), AttributeType::Numeric),
500 ("d".to_string(), AttributeType::Numeric),
501 ],
502 );
503
504 for i in 0..10 {
506 let mut inst = SparseInstance::new();
507 inst.set(i % 4, ArffValue::Numeric(1.0));
508 data.add_instance(inst);
509 }
510
511 let sparsity = data.sparsity();
512 assert!((sparsity - 0.75).abs() < 1e-10);
513 }
514
515 #[test]
516 fn test_sparse_with_nominal() {
517 let dir = std::env::temp_dir().join("scirs2_arff_sparse_nom");
518 let _ = std::fs::create_dir_all(&dir);
519 let path = dir.join("sparse_nominal.arff");
520
521 let mut data = SparseArffData::new(
522 "nominal_test",
523 vec![
524 ("x".to_string(), AttributeType::Numeric),
525 (
526 "class".to_string(),
527 AttributeType::Nominal(vec!["a".to_string(), "b".to_string()]),
528 ),
529 ],
530 );
531
532 let mut inst = SparseInstance::new();
533 inst.set(0, ArffValue::Numeric(42.0));
534 inst.set(1, ArffValue::Nominal("a".to_string()));
535 data.add_instance(inst);
536
537 write_sparse_arff(&path, &data).expect("Write failed");
538 let loaded = read_sparse_arff(&path).expect("Read failed");
539
540 let inst0 = &loaded.instances[0];
541 assert_eq!(inst0.get(0), Some(&ArffValue::Numeric(42.0)));
542 assert_eq!(inst0.get(1), Some(&ArffValue::Nominal("a".to_string())));
543
544 let _ = std::fs::remove_dir_all(&dir);
545 }
546
547 #[test]
548 fn test_sparse_high_dimensional() {
549 let dir = std::env::temp_dir().join("scirs2_arff_sparse_hd");
550 let _ = std::fs::create_dir_all(&dir);
551 let path = dir.join("high_dim.arff");
552
553 let attrs: Vec<(String, AttributeType)> = (0..100)
555 .map(|i| (format!("feat_{}", i), AttributeType::Numeric))
556 .collect();
557
558 let mut data = SparseArffData::new("high_dim", attrs);
559
560 for i in 0..50 {
561 let mut inst = SparseInstance::new();
562 inst.set(i % 100, ArffValue::Numeric(1.0));
564 inst.set((i * 7) % 100, ArffValue::Numeric(2.0));
565 inst.set((i * 13) % 100, ArffValue::Numeric(3.0));
566 data.add_instance(inst);
567 }
568
569 write_sparse_arff(&path, &data).expect("Write failed");
570 let loaded = read_sparse_arff(&path).expect("Read failed");
571
572 assert_eq!(loaded.num_instances(), 50);
573 assert_eq!(loaded.num_attributes(), 100);
574
575 assert!(loaded.sparsity() > 0.9);
577
578 let _ = std::fs::remove_dir_all(&dir);
579 }
580}