ta_lib_in_rust/util/
file_utils.rs1use polars::prelude::*;
2use std::collections::HashMap;
3use std::fs::File;
4use std::path::Path;
5
6#[derive(Debug, Clone)]
8pub struct FinancialColumns {
9 pub date: Option<String>,
10 pub open: Option<String>,
11 pub high: Option<String>,
12 pub low: Option<String>,
13 pub close: Option<String>,
14 pub volume: Option<String>,
15}
16
17pub fn read_csv<P: AsRef<Path>>(
38 file_path: P,
39 has_header: bool,
40 delimiter: char,
41) -> PolarsResult<DataFrame> {
42 let file = File::open(file_path)?;
43
44 let csv_options = CsvReadOptions::default()
46 .with_has_header(has_header)
47 .map_parse_options(|opts| opts.with_separator(delimiter as u8));
48
49 CsvReader::new(file).with_options(csv_options).finish()
50}
51
52pub fn read_csv_default<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
62 read_csv(file_path, true, ',')
63}
64
65pub fn read_parquet<P: AsRef<Path>>(file_path: P) -> PolarsResult<DataFrame> {
84 let file = File::open(file_path)?;
85 ParquetReader::new(file).finish()
86}
87
88pub fn read_financial_data<P: AsRef<Path>>(
114 file_path: P,
115 has_header: bool,
116 file_type: &str,
117 delimiter: char,
118) -> PolarsResult<(DataFrame, FinancialColumns)> {
119 let mut df = match file_type.to_lowercase().as_str() {
121 "csv" => read_csv(file_path, has_header, delimiter)?,
122 "parquet" => read_parquet(file_path)?,
123 _ => return Err(PolarsError::ComputeError("Unsupported file type".into())),
124 };
125
126 let columns = if has_header {
128 map_columns_with_headers(&df)?
129 } else {
130 rename_columns_without_headers(&mut df)?
132 };
133
134 Ok((df, columns))
135}
136
137fn map_columns_with_headers(df: &DataFrame) -> PolarsResult<FinancialColumns> {
139 let column_names: Vec<String> = df
140 .get_column_names()
141 .into_iter()
142 .map(|s| s.to_string())
143 .collect();
144
145 let mut financial_columns = FinancialColumns {
147 date: None,
148 open: None,
149 high: None,
150 low: None,
151 close: None,
152 volume: None,
153 };
154
155 let date_variations = ["date", "time", "datetime", "timestamp"];
157 let open_variations = ["open", "o", "opening"];
158 let high_variations = ["high", "h", "highest"];
159 let low_variations = ["low", "l", "lowest"];
160 let close_variations = ["close", "c", "closing"];
161 let volume_variations = ["volume", "vol", "v"];
162
163 for col in column_names {
164 let lower_col = col.to_lowercase();
165
166 if financial_columns.date.is_none()
167 && date_variations.iter().any(|&v| lower_col.contains(v))
168 {
169 financial_columns.date = Some(col.clone());
170 } else if financial_columns.open.is_none()
171 && open_variations.iter().any(|&v| lower_col.contains(v))
172 {
173 financial_columns.open = Some(col.clone());
174 } else if financial_columns.high.is_none()
175 && high_variations.iter().any(|&v| lower_col.contains(v))
176 {
177 financial_columns.high = Some(col.clone());
178 } else if financial_columns.low.is_none()
179 && low_variations.iter().any(|&v| lower_col.contains(v))
180 {
181 financial_columns.low = Some(col.clone());
182 } else if financial_columns.close.is_none()
183 && close_variations.iter().any(|&v| lower_col.contains(v))
184 {
185 financial_columns.close = Some(col.clone());
186 } else if financial_columns.volume.is_none()
187 && volume_variations.iter().any(|&v| lower_col.contains(v))
188 {
189 financial_columns.volume = Some(col.clone());
190 }
191 }
192
193 Ok(financial_columns)
194}
195
196fn rename_columns_without_headers(df: &mut DataFrame) -> PolarsResult<FinancialColumns> {
198 let n_cols = df.width();
199 let mut col_names = Vec::with_capacity(n_cols);
200
201 for i in 0..n_cols {
203 col_names.push(format!("col_{}", i));
204 }
205
206 df.set_column_names(&col_names)?;
208
209 let mut financial_columns = FinancialColumns {
211 date: None,
212 open: None,
213 high: None,
214 low: None,
215 close: None,
216 volume: None,
217 };
218
219 if n_cols >= 5 {
221 for (i, name) in col_names.iter().enumerate() {
223 if let Ok(series) = df.column(name) {
224 if i == 0
225 && (series.dtype() == &DataType::String
226 || series.dtype() == &DataType::Date
227 || matches!(series.dtype(), DataType::Datetime(_, _)))
228 {
229 financial_columns.date = Some(name.clone());
230 continue;
231 }
232
233 if Some(name.clone()) == financial_columns.date {
235 continue;
236 }
237
238 if series.dtype().is_primitive_numeric() {
240 if let Some(stats) = series.clone().cast(&DataType::Float64)?.f64()?.mean() {
242 if financial_columns.volume.is_none()
244 && (stats > 1000.0
245 || series.dtype() == &DataType::Int64
246 || series.dtype() == &DataType::UInt64)
247 {
248 financial_columns.volume = Some(name.clone());
249 continue;
250 }
251 }
252 }
253 }
254 }
255
256 let price_cols: Vec<String> = col_names
258 .iter()
259 .filter(|&name| {
260 Some(name.clone()) != financial_columns.date
261 && Some(name.clone()) != financial_columns.volume
262 })
263 .cloned()
264 .collect();
265
266 if price_cols.len() >= 4 {
268 financial_columns.open = Some(price_cols[0].clone());
269 financial_columns.high = Some(price_cols[1].clone());
270 financial_columns.low = Some(price_cols[2].clone());
271 financial_columns.close = Some(price_cols[3].clone());
272 }
273
274 if n_cols == 6 && financial_columns.date.is_some() && financial_columns.volume.is_none() {
276 financial_columns.volume = Some(col_names[5].clone());
278 }
279
280 if financial_columns.high.is_none() || financial_columns.low.is_none() {
282 identify_price_columns_by_statistics(df, &mut financial_columns, &price_cols)?;
283 }
284 }
285
286 Ok(financial_columns)
287}
288
289fn identify_price_columns_by_statistics(
291 df: &DataFrame,
292 financial_columns: &mut FinancialColumns,
293 price_cols: &[String],
294) -> PolarsResult<()> {
295 let mut col_stats: HashMap<String, (f64, f64)> = HashMap::new(); for col_name in price_cols {
299 if let Ok(series) = df.column(col_name) {
300 if series.dtype().is_primitive_numeric() {
301 let f64_series = series.clone().cast(&DataType::Float64)?;
302
303 if let Ok(f64_chunked) = f64_series.f64() {
305 let min_val = f64_chunked.min();
306 let max_val = f64_chunked.max();
307
308 if let (Some(min), Some(max)) = (min_val, max_val) {
309 col_stats.insert(col_name.clone(), (min, max));
310 }
311 }
312 }
313 }
314 }
315
316 let mut high_col = None;
318 let mut high_val = f64::MIN;
319 for (col, (_, max)) in &col_stats {
320 if *max > high_val {
321 high_val = *max;
322 high_col = Some(col.clone());
323 }
324 }
325
326 let mut low_col = None;
328 let mut low_val = f64::MAX;
329 for (col, (min, _)) in &col_stats {
330 if *min < low_val {
331 low_val = *min;
332 low_col = Some(col.clone());
333 }
334 }
335
336 if price_cols.len() >= 4 {
338 let remaining_cols: Vec<String> = price_cols
339 .iter()
340 .filter(|&col| Some(col.clone()) != high_col && Some(col.clone()) != low_col)
341 .cloned()
342 .collect();
343
344 if remaining_cols.len() >= 2 {
345 if financial_columns.open.is_none() {
346 financial_columns.open = Some(remaining_cols[0].clone());
347 }
348 if financial_columns.close.is_none() {
349 financial_columns.close = Some(remaining_cols[1].clone());
350 }
351 }
352 }
353
354 if financial_columns.high.is_none() && high_col.is_some() {
356 financial_columns.high = high_col;
357 }
358 if financial_columns.low.is_none() && low_col.is_some() {
359 financial_columns.low = low_col;
360 }
361
362 Ok(())
363}