ta_lib_in_rust/util/
file_utils.rs

1use polars::prelude::*;
2use std::collections::HashMap;
3use std::fs::File;
4use std::path::Path;
5
6/// Structure to hold standardized financial data column names
7#[derive(Debug, Clone)]
8pub struct FinancialColumns {
9    pub date: Option<String>,
10    pub open: Option<String>,
11    pub high: Option<String>,
12    pub low: Option<String>,
13    pub close: Option<String>,
14    pub volume: Option<String>,
15}
16
17/// Read a CSV file into a DataFrame
18///
19/// # Arguments
20///
21/// * `file_path` - Path to the CSV file
22/// * `has_header` - Whether the CSV file has a header row
23/// * `delimiter` - The delimiter character (default: ',')
24///
25/// # Returns
26///
27/// Returns a PolarsResult<DataFrame> containing the data if successful
28///
29/// # Example
30///
31/// ```
32/// use ta_lib_in_rust::util::file_utils::read_csv;
33///
34/// let df = read_csv("data/prices.csv", true, ',').unwrap();
35/// println!("{:?}", df);
36/// ```
37pub fn read_csv<P: AsRef<Path>>(
38    file_path: P,
39    has_header: bool,
40    delimiter: char,
41) -> PolarsResult<DataFrame> {
42    let file = File::open(file_path)?;
43
44    // Create CSV reader with options
45    let csv_options = CsvReadOptions::default()
46        .with_has_header(has_header)
47        .map_parse_options(|opts| opts.with_separator(delimiter as u8));
48
49    CsvReader::new(file).with_options(csv_options).finish()
50}
51
52/// Read a CSV file into a DataFrame with default settings (has header and comma delimiter)
53///
54/// # Arguments
55///
56/// * `file_path` - Path to the CSV file
57///
58/// # Returns
59///
60/// Returns a PolarsResult<DataFrame> containing the data if successful
61pub fn read_csv_default<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
62    read_csv(file_path, true, ',')
63}
64
65/// Read a Parquet file into a DataFrame
66///
67/// # Arguments
68///
69/// * `file_path` - Path to the Parquet file
70///
71/// # Returns
72///
73/// Returns a PolarsResult<DataFrame> containing the data if successful
74///
75/// # Example
76///
77/// ```
78/// use ta_lib_in_rust::util::file_utils::read_parquet;
79///
80/// let df = read_parquet("data/prices.parquet").unwrap();
81/// println!("{:?}", df);
82/// ```
83pub fn read_parquet<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
84    let file = File::open(file_path)?;
85    ParquetReader::new(file).finish()
86}
87
88/// Read a financial data file (CSV or Parquet) and standardize column names
89///
90/// This function attempts to identify and standardize OHLCV columns whether or not
91/// the file has headers.
92///
93/// # Arguments
94///
95/// * `file_path` - Path to the file
96/// * `has_header` - Whether the file has headers
97/// * `file_type` - "csv" or "parquet"
98/// * `delimiter` - Delimiter for CSV files (default: ',')
99///
100/// # Returns
101///
102/// A tuple with (DataFrame, FinancialColumns) where FinancialColumns contains the
103/// standardized column name mapping
104///
105/// # Example
106///
107/// ```
108/// use ta_lib_in_rust::util::file_utils::read_financial_data;
109///
110/// let (df, columns) = read_financial_data("data/prices.csv", true, "csv", ',').unwrap();
111/// println!("Close column: {:?}", columns.close);
112/// ```
113pub fn read_financial_data<P: AsRef<Path>>(
114    file_path: P,
115    has_header: bool,
116    file_type: &str,
117    delimiter: char,
118) -> PolarsResult<(DataFrame, FinancialColumns)> {
119    // Read the data file
120    let mut df = match file_type.to_lowercase().as_str() {
121        "csv" => read_csv(file_path, has_header, delimiter)?,
122        "parquet" => read_parquet(file_path)?,
123        _ => return Err(PolarsError::ComputeError("Unsupported file type".into())),
124    };
125
126    // Map the columns
127    let columns = if has_header {
128        map_columns_with_headers(&df)?
129    } else {
130        // For files without headers, generate column names and then map them
131        rename_columns_without_headers(&mut df)?
132    };
133
134    Ok((df, columns))
135}
136
137/// Maps column names for files with headers
138fn map_columns_with_headers(df: &DataFrame) -> PolarsResult<FinancialColumns> {
139    let column_names: Vec<String> = df
140        .get_column_names()
141        .into_iter()
142        .map(|s| s.to_string())
143        .collect();
144
145    // Create mappings of common financial column names
146    let mut financial_columns = FinancialColumns {
147        date: None,
148        open: None,
149        high: None,
150        low: None,
151        close: None,
152        volume: None,
153    };
154
155    // Common variations of column names
156    let date_variations = ["date", "time", "datetime", "timestamp"];
157    let open_variations = ["open", "o", "opening"];
158    let high_variations = ["high", "h", "highest"];
159    let low_variations = ["low", "l", "lowest"];
160    let close_variations = ["close", "c", "closing"];
161    let volume_variations = ["volume", "vol", "v"];
162
163    for col in column_names {
164        let lower_col = col.to_lowercase();
165
166        if financial_columns.date.is_none()
167            && date_variations.iter().any(|&v| lower_col.contains(v))
168        {
169            financial_columns.date = Some(col.clone());
170        } else if financial_columns.open.is_none()
171            && open_variations.iter().any(|&v| lower_col.contains(v))
172        {
173            financial_columns.open = Some(col.clone());
174        } else if financial_columns.high.is_none()
175            && high_variations.iter().any(|&v| lower_col.contains(v))
176        {
177            financial_columns.high = Some(col.clone());
178        } else if financial_columns.low.is_none()
179            && low_variations.iter().any(|&v| lower_col.contains(v))
180        {
181            financial_columns.low = Some(col.clone());
182        } else if financial_columns.close.is_none()
183            && close_variations.iter().any(|&v| lower_col.contains(v))
184        {
185            financial_columns.close = Some(col.clone());
186        } else if financial_columns.volume.is_none()
187            && volume_variations.iter().any(|&v| lower_col.contains(v))
188        {
189            financial_columns.volume = Some(col.clone());
190        }
191    }
192
193    Ok(financial_columns)
194}
195
196/// For files without headers, rename columns and identify OHLCV columns
197fn rename_columns_without_headers(df: &mut DataFrame) -> PolarsResult<FinancialColumns> {
198    let n_cols = df.width();
199    let mut col_names = Vec::with_capacity(n_cols);
200
201    // Basic column renaming
202    for i in 0..n_cols {
203        col_names.push(format!("col_{}", i));
204    }
205
206    // Rename the columns from col_0, col_1, etc.
207    df.set_column_names(&col_names)?;
208
209    // Try to identify financial columns through data patterns
210    let mut financial_columns = FinancialColumns {
211        date: None,
212        open: None,
213        high: None,
214        low: None,
215        close: None,
216        volume: None,
217    };
218
219    // If 5+ columns, assume typical OHLCV structure (date, open, high, low, close, volume)
220    if n_cols >= 5 {
221        // Check each column to see if it might contain date information
222        for (i, name) in col_names.iter().enumerate() {
223            if let Ok(series) = df.column(name) {
224                if i == 0
225                    && (series.dtype() == &DataType::String
226                        || series.dtype() == &DataType::Date
227                        || matches!(series.dtype(), DataType::Datetime(_, _)))
228                {
229                    financial_columns.date = Some(name.clone());
230                    continue;
231                }
232
233                // Skip if we've identified this as a date column
234                if Some(name.clone()) == financial_columns.date {
235                    continue;
236                }
237
238                // Now analyze numerical columns
239                if series.dtype().is_primitive_numeric() {
240                    // Get basic stats for this column
241                    if let Some(stats) = series.clone().cast(&DataType::Float64)?.f64()?.mean() {
242                        // Volume is typically much larger than price and often close to integers
243                        if financial_columns.volume.is_none()
244                            && (stats > 1000.0
245                                || series.dtype() == &DataType::Int64
246                                || series.dtype() == &DataType::UInt64)
247                        {
248                            financial_columns.volume = Some(name.clone());
249                            continue;
250                        }
251                    }
252                }
253            }
254        }
255
256        // Now identify remaining columns if we have at least 4 price columns
257        let price_cols: Vec<String> = col_names
258            .iter()
259            .filter(|&name| {
260                Some(name.clone()) != financial_columns.date
261                    && Some(name.clone()) != financial_columns.volume
262            })
263            .cloned()
264            .collect();
265
266        // Simple heuristic mapping
267        if price_cols.len() >= 4 {
268            financial_columns.open = Some(price_cols[0].clone());
269            financial_columns.high = Some(price_cols[1].clone());
270            financial_columns.low = Some(price_cols[2].clone());
271            financial_columns.close = Some(price_cols[3].clone());
272        }
273
274        // For typical 6-column format (date, open, high, low, close, volume)
275        if n_cols == 6 && financial_columns.date.is_some() && financial_columns.volume.is_none() {
276            // Last column is likely volume if not identified
277            financial_columns.volume = Some(col_names[5].clone());
278        }
279
280        // If we still haven't identified the price columns, try to use statistics
281        if financial_columns.high.is_none() || financial_columns.low.is_none() {
282            identify_price_columns_by_statistics(df, &mut financial_columns, &price_cols)?;
283        }
284    }
285
286    Ok(financial_columns)
287}
288
289/// Use statistical properties to identify high, low, open, close columns
290fn identify_price_columns_by_statistics(
291    df: &DataFrame,
292    financial_columns: &mut FinancialColumns,
293    price_cols: &[String],
294) -> PolarsResult<()> {
295    // Track min and max values by column
296    let mut col_stats: HashMap<String, (f64, f64)> = HashMap::new(); // (min, max)
297
298    for col_name in price_cols {
299        if let Ok(series) = df.column(col_name) {
300            if series.dtype().is_primitive_numeric() {
301                let f64_series = series.clone().cast(&DataType::Float64)?;
302
303                // Use proper Series methods with f64() to get ChunkedArray<Float64Type>
304                if let Ok(f64_chunked) = f64_series.f64() {
305                    let min_val = f64_chunked.min();
306                    let max_val = f64_chunked.max();
307
308                    if let (Some(min), Some(max)) = (min_val, max_val) {
309                        col_stats.insert(col_name.clone(), (min, max));
310                    }
311                }
312            }
313        }
314    }
315
316    // Find column with highest max values (likely high)
317    let mut high_col = None;
318    let mut high_val = f64::MIN;
319    for (col, (_, max)) in &col_stats {
320        if *max > high_val {
321            high_val = *max;
322            high_col = Some(col.clone());
323        }
324    }
325
326    // Find column with lowest min values (likely low)
327    let mut low_col = None;
328    let mut low_val = f64::MAX;
329    for (col, (min, _)) in &col_stats {
330        if *min < low_val {
331            low_val = *min;
332            low_col = Some(col.clone());
333        }
334    }
335
336    // Assign remaining columns to open and close if they haven't been assigned
337    if price_cols.len() >= 4 {
338        let remaining_cols: Vec<String> = price_cols
339            .iter()
340            .filter(|&col| Some(col.clone()) != high_col && Some(col.clone()) != low_col)
341            .cloned()
342            .collect();
343
344        if remaining_cols.len() >= 2 {
345            if financial_columns.open.is_none() {
346                financial_columns.open = Some(remaining_cols[0].clone());
347            }
348            if financial_columns.close.is_none() {
349                financial_columns.close = Some(remaining_cols[1].clone());
350            }
351        }
352    }
353
354    // Set high and low if they were identified and not already set
355    if financial_columns.high.is_none() && high_col.is_some() {
356        financial_columns.high = high_col;
357    }
358    if financial_columns.low.is_none() && low_col.is_some() {
359        financial_columns.low = low_col;
360    }
361
362    Ok(())
363}