rustalib/util/
file_utils.rs

1use polars::prelude::*;
2use std::fs::File;
3use std::io::BufRead;
4use std::path::Path;
5
6/// Structure to hold standardized financial data column names
7#[derive(Debug, Clone)]
8pub struct FinancialColumns {
9    pub date: Option<String>,
10    pub open: Option<String>,
11    pub high: Option<String>,
12    pub low: Option<String>,
13    pub close: Option<String>,
14    pub volume: Option<String>,
15}
16
17/// Read a CSV file into a DataFrame
18///
19/// # Arguments
20///
21/// * `file_path` - Path to the CSV file
22/// * `has_header` - Whether the CSV file has a header row
23/// * `delimiter` - The delimiter character (default: ',')
24///
25/// # Returns
26///
27/// Returns a PolarsResult<DataFrame> containing the data if successful
28///
29/// # Example
30///
31/// ```
32/// use ta_lib_in_rust::util::file_utils::read_csv;
33///
34/// let df = read_csv("data/prices.csv", true, ',').unwrap();
35/// println!("{:?}", df);
36/// ```
37pub fn read_csv<P: AsRef<Path>>(
38    file_path: P,
39    has_header: bool,
40    delimiter: char,
41) -> PolarsResult<DataFrame> {
42    let file = File::open(file_path)?;
43
44    // Create CSV reader with options
45    let csv_options = CsvReadOptions::default()
46        .with_has_header(has_header)
47        .map_parse_options(|opts| opts.with_separator(delimiter as u8));
48
49    CsvReader::new(file).with_options(csv_options).finish()
50}
51
52/// Read a CSV file into a DataFrame with default settings (has header and comma delimiter)
53///
54/// # Arguments
55///
56/// * `file_path` - Path to the CSV file
57///
58/// # Returns
59///
60/// Returns a PolarsResult<DataFrame> containing the data if successful
61pub fn read_csv_default<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
62    read_csv(file_path, true, ',')
63}
64
65/// Read a Parquet file into a DataFrame
66///
67/// # Arguments
68///
69/// * `file_path` - Path to the Parquet file
70///
71/// # Returns
72///
73/// Returns a PolarsResult<DataFrame> containing the data if successful
74///
75/// # Example
76///
77/// ```
78/// use ta_lib_in_rust::util::file_utils::read_parquet;
79///
80/// let df = read_parquet("data/prices.parquet").unwrap();
81/// println!("{:?}", df);
82/// ```
83pub fn read_parquet<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
84    let file = File::open(file_path)?;
85    ParquetReader::new(file).finish()
86}
87
88/// Read a financial data file (CSV or Parquet) and standardize column names
89///
90/// This function automatically detects and handles various aspects of financial data files:
91/// - File type (CSV or Parquet) is detected from the file extension
92/// - For CSV files:
93///   - Automatically detects if the file has headers by checking for common financial column names
94///   - Tries multiple common delimiters (comma, semicolon, tab, pipe) until successful
95/// - For Parquet files:
96///   - Directly reads the file as Parquet format is self-describing
97///
98/// The function attempts to identify and standardize OHLCV (Open, High, Low, Close, Volume) columns
99/// whether or not the file has headers. It handles various common column name variations and
100/// automatically maps them to standardized names.
101///
102/// # Column Name Handling
103///
104/// - **Case-insensitive**: Column names are handled in a case-insensitive manner (DATE, Date, date all match)
105/// - **Abbreviations**: Common abbreviations are supported:
106///   - Date: "date", "dt", "time", "datetime", "timestamp"
107///   - Open: "open", "o", "opening"
108///   - High: "high", "h", "highest"
109///   - Low: "low", "l", "lowest"
110///   - Close: "close", "c", "closing"
111///   - Volume: "volume", "vol", "v"
112///
113/// # Arguments
114///
115/// * `file_path` - Path to the file (must have .csv or .parquet extension)
116///
117/// # Returns
118///
119/// A tuple with (DataFrame, FinancialColumns) where FinancialColumns contains the
120/// standardized column name mapping
121///
122/// # Example
123///
124/// ```
125/// use ta_lib_in_rust::util::file_utils::read_financial_data;
126///
127/// // Read a CSV file - automatically detects headers and delimiter
128/// let (df, columns) = read_financial_data("data/prices.csv").unwrap();
129/// println!("Close column: {:?}", columns.close);
130///
131/// // Read a Parquet file
132/// let (df, columns) = read_financial_data("data/prices.parquet").unwrap();
133/// println!("Volume column: {:?}", columns.volume);
134/// ```
135///
136/// # Supported File Types
137///
138/// - CSV files (`.csv` extension)
139///   - Automatically detects headers
140///   - Supports multiple delimiters: comma (,), semicolon (;), tab (\t), pipe (|)
141/// - Parquet files (`.parquet` extension)
142///
143/// # Error Handling
144///
145/// The function will return an error if:
146/// - The file extension is not supported
147/// - The file cannot be read
148/// - No valid delimiter is found for CSV files
149/// - The file format is invalid
150pub fn read_financial_data<P: AsRef<Path>>(
151    file_path: P,
152) -> PolarsResult<(DataFrame, FinancialColumns)> {
153    let path = file_path.as_ref();
154
155    // Detect file type from extension
156    let file_type = path
157        .extension()
158        .and_then(|ext| ext.to_str())
159        .map(|ext| ext.to_lowercase())
160        .ok_or_else(|| {
161            PolarsError::ComputeError("Could not determine file type from extension".into())
162        })?;
163
164    // Read the data file
165    let df = match file_type.as_str() {
166        "csv" => {
167            // Try to detect if file has headers by reading first line
168            let file = File::open(path)?;
169            let mut reader = std::io::BufReader::new(file);
170            let mut first_line = String::new();
171            reader.read_line(&mut first_line)?;
172
173            // Check if first line looks like headers (contains common column names)
174            let has_header = ["date", "time", "open", "high", "low", "close", "volume"]
175                .iter()
176                .any(|&name| first_line.to_lowercase().contains(name));
177
178            // Try different delimiters
179            let delimiters = [',', ';', '\t', '|'];
180            let mut last_error = None;
181
182            for &delimiter in &delimiters {
183                match read_csv(path, has_header, delimiter) {
184                    Ok(df) => return process_dataframe(df, has_header),
185                    Err(e) => last_error = Some(e),
186                }
187            }
188
189            // If all delimiters failed, return the last error
190            Err(last_error.unwrap_or_else(|| {
191                PolarsError::ComputeError("Failed to read CSV with any common delimiter".into())
192            }))?
193        }
194        "parquet" => read_parquet(path)?,
195        _ => {
196            return Err(PolarsError::ComputeError(
197                format!("Unsupported file type: {}", file_type).into(),
198            ))
199        }
200    };
201
202    // Map the columns
203    let columns = map_columns_with_headers(&df)?;
204    Ok((df, columns))
205}
206
207/// Helper function to process the DataFrame and map columns
208fn process_dataframe(
209    mut df: DataFrame,
210    has_header: bool,
211) -> PolarsResult<(DataFrame, FinancialColumns)> {
212    let columns = if has_header {
213        map_columns_with_headers(&df)?
214    } else {
215        rename_columns_without_headers(&mut df)?
216    };
217    Ok((df, columns))
218}
219
220/// Maps column names for files with headers
221fn map_columns_with_headers(df: &DataFrame) -> PolarsResult<FinancialColumns> {
222    let column_names: Vec<String> = df
223        .get_column_names()
224        .into_iter()
225        .map(|s| s.to_string())
226        .collect();
227
228    // Create mappings of common financial column names
229    let mut financial_columns = FinancialColumns {
230        date: None,
231        open: None,
232        high: None,
233        low: None,
234        close: None,
235        volume: None,
236    };
237
238    // Common variations of column names
239    let date_variations = ["date", "time", "datetime", "timestamp", "dt"];
240    let open_variations = ["open", "o", "opening"];
241    let high_variations = ["high", "h", "highest"];
242    let low_variations = ["low", "l", "lowest"];
243    let close_variations = ["close", "c", "closing"];
244    let volume_variations = ["volume", "vol", "v"];
245
246    for col in column_names {
247        let lower_col = col.to_lowercase();
248
249        if financial_columns.date.is_none()
250            && date_variations.iter().any(|&v| lower_col.contains(v))
251        {
252            financial_columns.date = Some(col.clone());
253        } else if financial_columns.open.is_none()
254            && open_variations.iter().any(|&v| lower_col.contains(v))
255        {
256            financial_columns.open = Some(col.clone());
257        } else if financial_columns.high.is_none()
258            && high_variations.iter().any(|&v| lower_col.contains(v))
259        {
260            financial_columns.high = Some(col.clone());
261        } else if financial_columns.low.is_none()
262            && low_variations.iter().any(|&v| lower_col.contains(v))
263        {
264            financial_columns.low = Some(col.clone());
265        } else if financial_columns.close.is_none()
266            && close_variations.iter().any(|&v| lower_col.contains(v))
267        {
268            financial_columns.close = Some(col.clone());
269        } else if financial_columns.volume.is_none()
270            && volume_variations.iter().any(|&v| lower_col.contains(v))
271        {
272            financial_columns.volume = Some(col.clone());
273        }
274    }
275
276    Ok(financial_columns)
277}
278
279/// For files without headers, rename columns and identify OHLCV columns
280fn rename_columns_without_headers(df: &mut DataFrame) -> PolarsResult<FinancialColumns> {
281    let n_cols = df.width();
282    let mut col_names = vec![String::new(); n_cols];
283    let mut identified_cols = vec![false; n_cols];
284
285    // Initialize financial columns structure
286    let mut financial_columns = FinancialColumns {
287        date: None,
288        open: None,
289        high: None,
290        low: None,
291        close: None,
292        volume: None,
293    };
294
295    // First pass: identify date column (usually first column)
296    for i in 0..n_cols {
297        if let Some(series) = df.select_at_idx(i) {
298            if !identified_cols[i]
299                && (series.dtype() == &DataType::String
300                    || series.dtype() == &DataType::Date
301                    || matches!(series.dtype(), DataType::Datetime(_, _)))
302            {
303                col_names[i] = "date".to_string();
304                financial_columns.date = Some("date".to_string());
305                identified_cols[i] = true;
306                break; // Only identify one date column
307            }
308        }
309    }
310
311    // Second pass: identify volume column
312    // Look for integer columns or columns with significantly larger values
313    for (i, &identified) in identified_cols.iter().enumerate().take(n_cols) {
314        if identified {
315            continue;
316        }
317
318        if let Some(series) = df.select_at_idx(i) {
319            if series.dtype().is_primitive_numeric() {
320                if let Ok(f64_series) = series.cast(&DataType::Float64) {
321                    if let Ok(nums) = f64_series.f64() {
322                        // Check if the column contains mostly large integers
323                        let is_volume =
324                            if let (Some(mean), Some(std_dev)) = (nums.mean(), nums.std(0)) {
325                                // Volume typically has:
326                                // 1. Much larger values than prices
327                                // 2. Higher variance
328                                // 3. Often contains round numbers
329                                let other_cols_mean = get_numeric_columns_mean(df, i)?;
330                                mean > other_cols_mean * 100.0 && std_dev > mean * 0.1
331                            } else {
332                                false
333                            };
334
335                        if is_volume {
336                            col_names[i] = "volume".to_string();
337                            financial_columns.volume = Some("volume".to_string());
338                            identified_cols[i] = true;
339                            break; // Only identify one volume column
340                        }
341                    }
342                }
343            }
344        }
345    }
346
347    // Third pass: identify OHLC columns based on their statistical properties
348    let mut price_stats: Vec<(usize, f64, f64, f64)> = Vec::new(); // (index, min, max, std_dev)
349
350    for (i, &identified) in identified_cols.iter().enumerate().take(n_cols) {
351        if identified {
352            continue;
353        }
354
355        if let Some(series) = df.select_at_idx(i) {
356            if series.dtype().is_primitive_numeric() {
357                if let Ok(f64_series) = series.cast(&DataType::Float64) {
358                    if let Ok(nums) = f64_series.f64() {
359                        if let (Some(min), Some(max), Some(std)) =
360                            (nums.min(), nums.max(), nums.std(0))
361                        {
362                            price_stats.push((i, min, max, std));
363                        }
364                    }
365                }
366            }
367        }
368    }
369
370    // Sort by max values and standard deviation to identify columns
371    price_stats.sort_by(|a, b| {
372        let a_range = a.2 - a.1;
373        let b_range = b.2 - b.1;
374        b_range
375            .partial_cmp(&a_range)
376            .unwrap_or(std::cmp::Ordering::Equal)
377            .then_with(|| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal))
378    });
379
380    // Assign OHLC names based on statistical properties
381    for (idx, stat) in price_stats.iter().enumerate() {
382        let i = stat.0;
383        if !identified_cols[i] {
384            let col_name = match idx {
385                0 => {
386                    financial_columns.high = Some("high".to_string());
387                    "high"
388                }
389                1 => {
390                    financial_columns.low = Some("low".to_string());
391                    "low"
392                }
393                2 => {
394                    financial_columns.close = Some("close".to_string());
395                    "close"
396                }
397                3 => {
398                    financial_columns.open = Some("open".to_string());
399                    "open"
400                }
401                _ => "unknown",
402            };
403            col_names[i] = col_name.to_string();
404            identified_cols[i] = true;
405        }
406    }
407
408    // Fill in any remaining unidentified columns
409    for (i, name) in col_names.iter_mut().enumerate().take(n_cols) {
410        if name.is_empty() {
411            *name = format!("unknown_{}", i);
412        }
413    }
414
415    // Rename the columns
416    df.set_column_names(&col_names)?;
417
418    Ok(financial_columns)
419}
420
421/// Helper function to calculate mean of numeric columns excluding the specified column
422fn get_numeric_columns_mean(df: &DataFrame, exclude_idx: usize) -> PolarsResult<f64> {
423    let mut sum = 0.0;
424    let mut count = 0;
425
426    for i in 0..df.width() {
427        if i == exclude_idx {
428            continue;
429        }
430
431        if let Some(series) = df.select_at_idx(i) {
432            if series.dtype().is_primitive_numeric() {
433                if let Ok(f64_series) = series.cast(&DataType::Float64) {
434                    if let Ok(nums) = f64_series.f64() {
435                        if let Some(mean) = nums.mean() {
436                            sum += mean;
437                            count += 1;
438                        }
439                    }
440                }
441            }
442        }
443    }
444
445    if count > 0 {
446        Ok(sum / count as f64)
447    } else {
448        Ok(0.0)
449    }
450}