Skip to main content

vector_ta/utilities/
data_loader.rs

1extern crate csv;
2extern crate serde;
3
4use csv::ReaderBuilder;
5use std::error::Error;
6use std::fs::File;
7
8#[derive(Debug, Clone, Copy)]
9pub struct CandleFieldFlags {
10    pub open: bool,
11    pub high: bool,
12    pub low: bool,
13    pub close: bool,
14    pub volume: bool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Candles {
19    pub timestamp: Vec<i64>,
20    pub open: Vec<f64>,
21    pub high: Vec<f64>,
22    pub low: Vec<f64>,
23    pub close: Vec<f64>,
24    pub volume: Vec<f64>,
25    pub fields: CandleFieldFlags,
26    pub hl2: Vec<f64>,
27    pub hlc3: Vec<f64>,
28    pub ohlc4: Vec<f64>,
29    pub hlcc4: Vec<f64>,
30}
31
32impl Candles {
33    pub fn new(
34        timestamp: Vec<i64>,
35        open: Vec<f64>,
36        high: Vec<f64>,
37        low: Vec<f64>,
38        close: Vec<f64>,
39        volume: Vec<f64>,
40    ) -> Self {
41        let mut candles = Candles {
42            timestamp,
43            open,
44            high,
45            low,
46            close,
47            volume,
48            fields: CandleFieldFlags {
49                open: true,
50                high: true,
51                low: true,
52                close: true,
53                volume: true,
54            },
55            hl2: Vec::new(),
56            hlc3: Vec::new(),
57            ohlc4: Vec::new(),
58            hlcc4: Vec::new(),
59        };
60
61        candles.precompute_fields();
62
63        candles
64    }
65
66    pub fn new_with_fields(
67        timestamp: Vec<i64>,
68        open: Vec<f64>,
69        high: Vec<f64>,
70        low: Vec<f64>,
71        close: Vec<f64>,
72        volume: Vec<f64>,
73        fields: CandleFieldFlags,
74    ) -> Self {
75        let mut candles = Candles {
76            timestamp,
77            open,
78            high,
79            low,
80            close,
81            volume,
82            fields,
83            hl2: Vec::new(),
84            hlc3: Vec::new(),
85            ohlc4: Vec::new(),
86            hlcc4: Vec::new(),
87        };
88
89        candles.precompute_fields();
90
91        candles
92    }
93
94    pub fn get_timestamp(&self) -> Result<&[i64], Box<dyn Error>> {
95        Ok(&self.timestamp)
96    }
97
98    fn compute_hl2(&self) -> Vec<f64> {
99        self.high
100            .iter()
101            .zip(self.low.iter())
102            .map(|(h, l)| (h + l) / 2.0)
103            .collect()
104    }
105
106    fn compute_hlc3(&self) -> Vec<f64> {
107        self.high
108            .iter()
109            .zip(self.low.iter())
110            .zip(self.close.iter())
111            .map(|((&h, &l), &c)| (h + l + c) / 3.0)
112            .collect()
113    }
114
115    fn compute_ohlc4(&self) -> Vec<f64> {
116        self.open
117            .iter()
118            .zip(self.high.iter())
119            .zip(self.low.iter())
120            .zip(self.close.iter())
121            .map(|(((&o, &h), &l), &c)| (o + h + l + c) / 4.0)
122            .collect()
123    }
124
125    fn compute_hlcc4(&self) -> Vec<f64> {
126        self.high
127            .iter()
128            .zip(self.low.iter())
129            .zip(self.close.iter())
130            .map(|((&h, &l), &c)| (h + l + 2.0 * c) / 4.0)
131            .collect()
132    }
133
134    pub fn get_calculated_field(&self, field: &str) -> Result<&[f64], Box<dyn std::error::Error>> {
135        match field.to_lowercase().as_str() {
136            "hl2" => Ok(&self.hl2),
137            "hlc3" => Ok(&self.hlc3),
138            "ohlc4" => Ok(&self.ohlc4),
139            "hlcc4" => Ok(&self.hlcc4),
140            _ => Err(format!("Invalid calculated field: {}", field).into()),
141        }
142    }
143
144    pub fn select_candle_field(&self, field: &str) -> Result<&[f64], Box<dyn std::error::Error>> {
145        match field.to_lowercase().as_str() {
146            "open" => Ok(&self.open),
147            "high" => Ok(&self.high),
148            "low" => Ok(&self.low),
149            "close" => Ok(&self.close),
150            "volume" => Ok(&self.volume),
151            _ => Err(format!("Invalid field: {}", field).into()),
152        }
153    }
154
155    fn precompute_fields(&mut self) {
156        let len = self.high.len();
157        let mut hl2 = Vec::with_capacity(len);
158        let mut hlc3 = Vec::with_capacity(len);
159        let mut ohlc4 = Vec::with_capacity(len);
160        let mut hlcc4 = Vec::with_capacity(len);
161
162        for i in 0..len {
163            let o = self.open[i];
164            let h = self.high[i];
165            let l = self.low[i];
166            let c = self.close[i];
167
168            hl2.push((h + l) / 2.0);
169            hlc3.push((h + l + c) / 3.0);
170            ohlc4.push((o + h + l + c) / 4.0);
171            hlcc4.push((h + l + 2.0 * c) / 4.0);
172        }
173
174        self.hl2 = hl2;
175        self.hlc3 = hlc3;
176        self.ohlc4 = ohlc4;
177        self.hlcc4 = hlcc4;
178    }
179}
180
181pub fn read_candles_from_csv(file_path: &str) -> Result<Candles, Box<dyn Error>> {
182    use std::io;
183
184    let file = File::open(file_path)?;
185    let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
186
187    let header_len = rdr.headers().map(|h| h.len()).unwrap_or(0);
188    if header_len < 2 {
189        return Err("CSV must have at least 2 columns: timestamp, close".into());
190    }
191
192    let (fields, idx_open, idx_close, idx_high, idx_low, idx_volume) = if header_len >= 3 {
193        (
194            CandleFieldFlags {
195                open: true,
196                close: true,
197                high: header_len > 3,
198                low: header_len > 4,
199                volume: header_len > 5,
200            },
201            Some(1usize),
202            2usize,
203            if header_len > 3 { Some(3usize) } else { None },
204            if header_len > 4 { Some(4usize) } else { None },
205            if header_len > 5 { Some(5usize) } else { None },
206        )
207    } else {
208        (
209            CandleFieldFlags {
210                open: false,
211                close: true,
212                high: false,
213                low: false,
214                volume: false,
215            },
216            None,
217            1usize,
218            None,
219            None,
220            None,
221        )
222    };
223
224    let mut timestamp = Vec::new();
225    let mut open = Vec::new();
226    let mut high = Vec::new();
227    let mut low = Vec::new();
228    let mut close = Vec::new();
229    let mut volume = Vec::new();
230
231    for result in rdr.records() {
232        let record = result?;
233
234        let ts: i64 = record
235            .get(0)
236            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing timestamp column"))?
237            .parse()?;
238        let c: f64 = record
239            .get(idx_close)
240            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing close column"))?
241            .parse()?;
242        timestamp.push(ts);
243        close.push(c);
244
245        let o: f64 = match idx_open {
246            Some(i) => record
247                .get(i)
248                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing open column"))?
249                .parse()?,
250            None => f64::NAN,
251        };
252        open.push(o);
253
254        let h: f64 = match idx_high {
255            Some(i) => record
256                .get(i)
257                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing high column"))?
258                .parse()?,
259            None => f64::NAN,
260        };
261        high.push(h);
262
263        let l: f64 = match idx_low {
264            Some(i) => record
265                .get(i)
266                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing low column"))?
267                .parse()?,
268            None => f64::NAN,
269        };
270        low.push(l);
271
272        let v: f64 = match idx_volume {
273            Some(i) => record
274                .get(i)
275                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing volume column"))?
276                .parse()?,
277            None => f64::NAN,
278        };
279        volume.push(v);
280    }
281
282    Ok(Candles::new_with_fields(
283        timestamp, open, high, low, close, volume, fields,
284    ))
285}
286
287pub fn source_type<'a>(candles: &'a Candles, source: &str) -> &'a [f64] {
288    if source.eq_ignore_ascii_case("open") {
289        &candles.open
290    } else if source.eq_ignore_ascii_case("high") {
291        &candles.high
292    } else if source.eq_ignore_ascii_case("low") {
293        &candles.low
294    } else if source.eq_ignore_ascii_case("close") {
295        &candles.close
296    } else if source.eq_ignore_ascii_case("volume") {
297        &candles.volume
298    } else if source.eq_ignore_ascii_case("hl2") {
299        &candles.hl2
300    } else if source.eq_ignore_ascii_case("hlc3") {
301        &candles.hlc3
302    } else if source.eq_ignore_ascii_case("ohlc4") {
303        &candles.ohlc4
304    } else if source.eq_ignore_ascii_case("hlcc4") || source.eq_ignore_ascii_case("hlcc") {
305        &candles.hlcc4
306    } else {
307        eprintln!("Warning: Invalid price source '{source}'. Defaulting to 'close'.");
308        &candles.close
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_field_congruency() {
318        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
319        let candles = read_candles_from_csv(file_path).expect("Failed to load CSV for testing");
320
321        let len = candles.timestamp.len();
322        assert_eq!(candles.open.len(), len, "Open length mismatch");
323        assert_eq!(candles.high.len(), len, "High length mismatch");
324        assert_eq!(candles.low.len(), len, "Low length mismatch");
325        assert_eq!(candles.close.len(), len, "Close length mismatch");
326        assert_eq!(candles.volume.len(), len, "Volume length mismatch");
327    }
328
329    #[test]
330    fn test_calculated_fields_accuracy() {
331        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
332        let candles = read_candles_from_csv(file_path).expect("Failed to load CSV for testing");
333
334        let hl2 = candles
335            .get_calculated_field("hl2")
336            .expect("Failed to get HL2");
337        let hlc3 = candles
338            .get_calculated_field("hlc3")
339            .expect("Failed to get HLC3");
340        let ohlc4 = candles
341            .get_calculated_field("ohlc4")
342            .expect("Failed to get OHLC4");
343        let hlcc4 = candles
344            .get_calculated_field("hlcc4")
345            .expect("Failed to get HLCC4");
346
347        let len = candles.timestamp.len();
348        assert_eq!(hl2.len(), len, "HL2 length mismatch");
349        assert_eq!(hlc3.len(), len, "HLC3 length mismatch");
350        assert_eq!(ohlc4.len(), len, "OHLC4 length mismatch");
351        assert_eq!(hlcc4.len(), len, "HLCC4 length mismatch");
352
353        let expected_last_5_hl2 = [59166.0, 59244.5, 59118.0, 59146.5, 58767.5];
354        let expected_last_5_hlc3 = [59205.7, 59223.3, 59091.7, 59149.3, 58730.0];
355        let expected_last_5_ohlc4 = [59221.8, 59238.8, 59114.3, 59121.8, 58836.3];
356        let expected_last_5_hlcc4 = [59225.5, 59212.8, 59078.5, 59150.8, 58711.3];
357
358        fn compare_last_five(actual: &[f64], expected: &[f64], field_name: &str) {
359            let start = actual.len().saturating_sub(5);
360            let actual_slice = &actual[start..];
361            for (i, (&a, &e)) in actual_slice.iter().zip(expected.iter()).enumerate() {
362                let diff = (a - e).abs();
363                assert!(
364                    diff < 1e-1,
365                    "Mismatch in {} at last-5 index {}: expected {}, got {}",
366                    field_name,
367                    i,
368                    e,
369                    a
370                );
371            }
372        }
373        compare_last_five(hl2, &expected_last_5_hl2, "HL2");
374        compare_last_five(hlc3, &expected_last_5_hlc3, "HLC3");
375        compare_last_five(ohlc4, &expected_last_5_ohlc4, "OHLC4");
376        compare_last_five(hlcc4, &expected_last_5_hlcc4, "HLCC4");
377    }
378
379    #[test]
380    fn test_precompute_fields_direct() {
381        let timestamp = vec![1, 2, 3];
382        let open = vec![100.0, 200.0, 300.0];
383        let high = vec![110.0, 220.0, 330.0];
384        let low = vec![90.0, 180.0, 270.0];
385        let close = vec![105.0, 190.0, 310.0];
386        let volume = vec![1000.0, 2000.0, 3000.0];
387
388        let candles = Candles::new(timestamp, open, high, low, close, volume);
389
390        let hl2 = candles.get_calculated_field("hl2").unwrap();
391        assert_eq!(hl2, &[100.0, 200.0, 300.0]);
392
393        let hlc3 = candles.get_calculated_field("hlc3").unwrap();
394        let expected_hlc3 = &[101.6667, 196.6667, 303.3333];
395        for (actual, expected) in hlc3.iter().zip(expected_hlc3.iter()) {
396            assert!((actual - expected).abs() < 1e-4);
397        }
398
399        let ohlc4 = candles.get_calculated_field("ohlc4").unwrap();
400        let expected_ohlc4 = &[101.25, 197.5, 302.5];
401        for (actual, expected) in ohlc4.iter().zip(expected_ohlc4.iter()) {
402            assert!((actual - expected).abs() < 1e-4);
403        }
404
405        let hlcc4 = candles.get_calculated_field("hlcc4").unwrap();
406        let expected_hlcc4 = &[102.5, 195.0, 305.0];
407        for (actual, expected) in hlcc4.iter().zip(expected_hlcc4.iter()) {
408            assert!((actual - expected).abs() < 1e-4);
409        }
410    }
411}