Skip to main content

yscv_model/dataset/
csv.rs

1use std::path::Path;
2
3use crate::ModelError;
4
5use super::helpers::{
6    adapter_sample_len, build_supervised_dataset_from_flat_values, load_dataset_text_file,
7    validate_adapter_sample_shape, validate_csv_delimiter, validate_finite_values,
8};
9
10/// Configuration for parsing/loading supervised CSV datasets.
11#[derive(Debug, Clone, PartialEq)]
12pub struct SupervisedCsvConfig {
13    input_shape: Vec<usize>,
14    target_shape: Vec<usize>,
15    delimiter: char,
16    has_header: bool,
17}
18
19impl SupervisedCsvConfig {
20    pub fn new(input_shape: Vec<usize>, target_shape: Vec<usize>) -> Result<Self, ModelError> {
21        validate_adapter_sample_shape("input_shape", &input_shape)?;
22        validate_adapter_sample_shape("target_shape", &target_shape)?;
23        Ok(Self {
24            input_shape,
25            target_shape,
26            delimiter: ',',
27            has_header: false,
28        })
29    }
30
31    pub fn with_delimiter(mut self, delimiter: char) -> Result<Self, ModelError> {
32        validate_csv_delimiter(delimiter)?;
33        self.delimiter = delimiter;
34        Ok(self)
35    }
36
37    pub fn with_header(mut self, has_header: bool) -> Self {
38        self.has_header = has_header;
39        self
40    }
41
42    pub fn input_shape(&self) -> &[usize] {
43        &self.input_shape
44    }
45
46    pub fn target_shape(&self) -> &[usize] {
47        &self.target_shape
48    }
49
50    pub fn delimiter(&self) -> char {
51        self.delimiter
52    }
53
54    pub fn has_header(&self) -> bool {
55        self.has_header
56    }
57}
58
59/// Parses supervised training samples from CSV text into a `SupervisedDataset`.
60///
61/// Each non-empty line is treated as one sample row with
62/// `input_len + target_len` numeric columns.
63/// Optional header row skipping is controlled by `config.has_header`.
64pub fn parse_supervised_dataset_csv(
65    content: &str,
66    config: &SupervisedCsvConfig,
67) -> Result<super::types::SupervisedDataset, ModelError> {
68    let input_row_len = adapter_sample_len("input_shape", config.input_shape())?;
69    let target_row_len = adapter_sample_len("target_shape", config.target_shape())?;
70    let expected_columns = input_row_len.checked_add(target_row_len).ok_or_else(|| {
71        ModelError::InvalidDatasetAdapterShape {
72            field: "row_columns",
73            shape: vec![input_row_len, target_row_len],
74            message: "column count overflow".to_string(),
75        }
76    })?;
77
78    let mut input_values = Vec::new();
79    let mut target_values = Vec::new();
80    let mut sample_count = 0usize;
81    let mut header_skipped = false;
82    let mut row_values = Vec::with_capacity(expected_columns);
83
84    for (line_idx, raw_line) in content.lines().enumerate() {
85        let line_number = line_idx + 1;
86        let line = raw_line.trim();
87        if line.is_empty() || line.starts_with('#') {
88            continue;
89        }
90        if config.has_header() && !header_skipped {
91            header_skipped = true;
92            continue;
93        }
94
95        let columns = line
96            .split(config.delimiter())
97            .map(str::trim)
98            .collect::<Vec<_>>();
99        if columns.len() != expected_columns {
100            return Err(ModelError::InvalidDatasetRecordColumns {
101                line: line_number,
102                expected: expected_columns,
103                got: columns.len(),
104            });
105        }
106
107        row_values.clear();
108        for (column_idx, value_str) in columns.iter().enumerate() {
109            let value = value_str
110                .parse::<f32>()
111                .map_err(|error| ModelError::DatasetCsvParse {
112                    line: line_number,
113                    column: column_idx + 1,
114                    message: error.to_string(),
115                })?;
116            row_values.push(value);
117        }
118
119        let (input_row, target_row) = row_values.split_at(input_row_len);
120        validate_finite_values(line_number, "input", input_row)?;
121        validate_finite_values(line_number, "target", target_row)?;
122        input_values.extend_from_slice(input_row);
123        target_values.extend_from_slice(target_row);
124
125        sample_count =
126            sample_count
127                .checked_add(1)
128                .ok_or_else(|| ModelError::InvalidDatasetAdapterShape {
129                    field: "sample_count",
130                    shape: vec![sample_count],
131                    message: "sample count overflow".to_string(),
132                })?;
133    }
134
135    if sample_count == 0 {
136        return Err(ModelError::EmptyDataset);
137    }
138
139    build_supervised_dataset_from_flat_values(
140        config.input_shape(),
141        config.target_shape(),
142        sample_count,
143        input_values,
144        target_values,
145    )
146}
147
148/// Loads supervised training samples from a CSV file.
149pub fn load_supervised_dataset_csv_file<P: AsRef<Path>>(
150    path: P,
151    config: &SupervisedCsvConfig,
152) -> Result<super::types::SupervisedDataset, ModelError> {
153    let content = load_dataset_text_file(path)?;
154    parse_supervised_dataset_csv(&content, config)
155}