1extern crate csv;
2extern crate serde;
3
4use csv::ReaderBuilder;
5use std::error::Error;
6use std::fs::File;
7
8#[derive(Debug, Clone, Copy)]
9pub struct CandleFieldFlags {
10 pub open: bool,
11 pub high: bool,
12 pub low: bool,
13 pub close: bool,
14 pub volume: bool,
15}
16
17#[derive(Debug, Clone)]
18pub struct Candles {
19 pub timestamp: Vec<i64>,
20 pub open: Vec<f64>,
21 pub high: Vec<f64>,
22 pub low: Vec<f64>,
23 pub close: Vec<f64>,
24 pub volume: Vec<f64>,
25 pub fields: CandleFieldFlags,
26 pub hl2: Vec<f64>,
27 pub hlc3: Vec<f64>,
28 pub ohlc4: Vec<f64>,
29 pub hlcc4: Vec<f64>,
30}
31
32impl Candles {
33 pub fn new(
34 timestamp: Vec<i64>,
35 open: Vec<f64>,
36 high: Vec<f64>,
37 low: Vec<f64>,
38 close: Vec<f64>,
39 volume: Vec<f64>,
40 ) -> Self {
41 let mut candles = Candles {
42 timestamp,
43 open,
44 high,
45 low,
46 close,
47 volume,
48 fields: CandleFieldFlags {
49 open: true,
50 high: true,
51 low: true,
52 close: true,
53 volume: true,
54 },
55 hl2: Vec::new(),
56 hlc3: Vec::new(),
57 ohlc4: Vec::new(),
58 hlcc4: Vec::new(),
59 };
60
61 candles.precompute_fields();
62
63 candles
64 }
65
66 pub fn new_with_fields(
67 timestamp: Vec<i64>,
68 open: Vec<f64>,
69 high: Vec<f64>,
70 low: Vec<f64>,
71 close: Vec<f64>,
72 volume: Vec<f64>,
73 fields: CandleFieldFlags,
74 ) -> Self {
75 let mut candles = Candles {
76 timestamp,
77 open,
78 high,
79 low,
80 close,
81 volume,
82 fields,
83 hl2: Vec::new(),
84 hlc3: Vec::new(),
85 ohlc4: Vec::new(),
86 hlcc4: Vec::new(),
87 };
88
89 candles.precompute_fields();
90
91 candles
92 }
93
94 pub fn get_timestamp(&self) -> Result<&[i64], Box<dyn Error>> {
95 Ok(&self.timestamp)
96 }
97
98 fn compute_hl2(&self) -> Vec<f64> {
99 self.high
100 .iter()
101 .zip(self.low.iter())
102 .map(|(h, l)| (h + l) / 2.0)
103 .collect()
104 }
105
106 fn compute_hlc3(&self) -> Vec<f64> {
107 self.high
108 .iter()
109 .zip(self.low.iter())
110 .zip(self.close.iter())
111 .map(|((&h, &l), &c)| (h + l + c) / 3.0)
112 .collect()
113 }
114
115 fn compute_ohlc4(&self) -> Vec<f64> {
116 self.open
117 .iter()
118 .zip(self.high.iter())
119 .zip(self.low.iter())
120 .zip(self.close.iter())
121 .map(|(((&o, &h), &l), &c)| (o + h + l + c) / 4.0)
122 .collect()
123 }
124
125 fn compute_hlcc4(&self) -> Vec<f64> {
126 self.high
127 .iter()
128 .zip(self.low.iter())
129 .zip(self.close.iter())
130 .map(|((&h, &l), &c)| (h + l + 2.0 * c) / 4.0)
131 .collect()
132 }
133
134 pub fn get_calculated_field(&self, field: &str) -> Result<&[f64], Box<dyn std::error::Error>> {
135 match field.to_lowercase().as_str() {
136 "hl2" => Ok(&self.hl2),
137 "hlc3" => Ok(&self.hlc3),
138 "ohlc4" => Ok(&self.ohlc4),
139 "hlcc4" => Ok(&self.hlcc4),
140 _ => Err(format!("Invalid calculated field: {}", field).into()),
141 }
142 }
143
144 pub fn select_candle_field(&self, field: &str) -> Result<&[f64], Box<dyn std::error::Error>> {
145 match field.to_lowercase().as_str() {
146 "open" => Ok(&self.open),
147 "high" => Ok(&self.high),
148 "low" => Ok(&self.low),
149 "close" => Ok(&self.close),
150 "volume" => Ok(&self.volume),
151 _ => Err(format!("Invalid field: {}", field).into()),
152 }
153 }
154
155 fn precompute_fields(&mut self) {
156 let len = self.high.len();
157 let mut hl2 = Vec::with_capacity(len);
158 let mut hlc3 = Vec::with_capacity(len);
159 let mut ohlc4 = Vec::with_capacity(len);
160 let mut hlcc4 = Vec::with_capacity(len);
161
162 for i in 0..len {
163 let o = self.open[i];
164 let h = self.high[i];
165 let l = self.low[i];
166 let c = self.close[i];
167
168 hl2.push((h + l) / 2.0);
169 hlc3.push((h + l + c) / 3.0);
170 ohlc4.push((o + h + l + c) / 4.0);
171 hlcc4.push((h + l + 2.0 * c) / 4.0);
172 }
173
174 self.hl2 = hl2;
175 self.hlc3 = hlc3;
176 self.ohlc4 = ohlc4;
177 self.hlcc4 = hlcc4;
178 }
179}
180
181pub fn read_candles_from_csv(file_path: &str) -> Result<Candles, Box<dyn Error>> {
182 use std::io;
183
184 let file = File::open(file_path)?;
185 let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
186
187 let header_len = rdr.headers().map(|h| h.len()).unwrap_or(0);
188 if header_len < 2 {
189 return Err("CSV must have at least 2 columns: timestamp, close".into());
190 }
191
192 let (fields, idx_open, idx_close, idx_high, idx_low, idx_volume) = if header_len >= 3 {
193 (
194 CandleFieldFlags {
195 open: true,
196 close: true,
197 high: header_len > 3,
198 low: header_len > 4,
199 volume: header_len > 5,
200 },
201 Some(1usize),
202 2usize,
203 if header_len > 3 { Some(3usize) } else { None },
204 if header_len > 4 { Some(4usize) } else { None },
205 if header_len > 5 { Some(5usize) } else { None },
206 )
207 } else {
208 (
209 CandleFieldFlags {
210 open: false,
211 close: true,
212 high: false,
213 low: false,
214 volume: false,
215 },
216 None,
217 1usize,
218 None,
219 None,
220 None,
221 )
222 };
223
224 let mut timestamp = Vec::new();
225 let mut open = Vec::new();
226 let mut high = Vec::new();
227 let mut low = Vec::new();
228 let mut close = Vec::new();
229 let mut volume = Vec::new();
230
231 for result in rdr.records() {
232 let record = result?;
233
234 let ts: i64 = record
235 .get(0)
236 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing timestamp column"))?
237 .parse()?;
238 let c: f64 = record
239 .get(idx_close)
240 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing close column"))?
241 .parse()?;
242 timestamp.push(ts);
243 close.push(c);
244
245 let o: f64 = match idx_open {
246 Some(i) => record
247 .get(i)
248 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing open column"))?
249 .parse()?,
250 None => f64::NAN,
251 };
252 open.push(o);
253
254 let h: f64 = match idx_high {
255 Some(i) => record
256 .get(i)
257 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing high column"))?
258 .parse()?,
259 None => f64::NAN,
260 };
261 high.push(h);
262
263 let l: f64 = match idx_low {
264 Some(i) => record
265 .get(i)
266 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing low column"))?
267 .parse()?,
268 None => f64::NAN,
269 };
270 low.push(l);
271
272 let v: f64 = match idx_volume {
273 Some(i) => record
274 .get(i)
275 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing volume column"))?
276 .parse()?,
277 None => f64::NAN,
278 };
279 volume.push(v);
280 }
281
282 Ok(Candles::new_with_fields(
283 timestamp,
284 open,
285 high,
286 low,
287 close,
288 volume,
289 fields,
290 ))
291}
292
293pub fn source_type<'a>(candles: &'a Candles, source: &str) -> &'a [f64] {
294 if source.eq_ignore_ascii_case("open") {
295 &candles.open
296 } else if source.eq_ignore_ascii_case("high") {
297 &candles.high
298 } else if source.eq_ignore_ascii_case("low") {
299 &candles.low
300 } else if source.eq_ignore_ascii_case("close") {
301 &candles.close
302 } else if source.eq_ignore_ascii_case("volume") {
303 &candles.volume
304 } else if source.eq_ignore_ascii_case("hl2") {
305 &candles.hl2
306 } else if source.eq_ignore_ascii_case("hlc3") {
307 &candles.hlc3
308 } else if source.eq_ignore_ascii_case("ohlc4") {
309 &candles.ohlc4
310 } else if source.eq_ignore_ascii_case("hlcc4") || source.eq_ignore_ascii_case("hlcc") {
311 &candles.hlcc4
312 } else {
313 eprintln!("Warning: Invalid price source '{source}'. Defaulting to 'close'.");
314 &candles.close
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_field_congruency() {
324 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
325 let candles = read_candles_from_csv(file_path).expect("Failed to load CSV for testing");
326
327 let len = candles.timestamp.len();
328 assert_eq!(candles.open.len(), len, "Open length mismatch");
329 assert_eq!(candles.high.len(), len, "High length mismatch");
330 assert_eq!(candles.low.len(), len, "Low length mismatch");
331 assert_eq!(candles.close.len(), len, "Close length mismatch");
332 assert_eq!(candles.volume.len(), len, "Volume length mismatch");
333 }
334
335 #[test]
336 fn test_calculated_fields_accuracy() {
337 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
338 let candles = read_candles_from_csv(file_path).expect("Failed to load CSV for testing");
339
340 let hl2 = candles
341 .get_calculated_field("hl2")
342 .expect("Failed to get HL2");
343 let hlc3 = candles
344 .get_calculated_field("hlc3")
345 .expect("Failed to get HLC3");
346 let ohlc4 = candles
347 .get_calculated_field("ohlc4")
348 .expect("Failed to get OHLC4");
349 let hlcc4 = candles
350 .get_calculated_field("hlcc4")
351 .expect("Failed to get HLCC4");
352
353 let len = candles.timestamp.len();
354 assert_eq!(hl2.len(), len, "HL2 length mismatch");
355 assert_eq!(hlc3.len(), len, "HLC3 length mismatch");
356 assert_eq!(ohlc4.len(), len, "OHLC4 length mismatch");
357 assert_eq!(hlcc4.len(), len, "HLCC4 length mismatch");
358
359 let expected_last_5_hl2 = [59166.0, 59244.5, 59118.0, 59146.5, 58767.5];
360 let expected_last_5_hlc3 = [59205.7, 59223.3, 59091.7, 59149.3, 58730.0];
361 let expected_last_5_ohlc4 = [59221.8, 59238.8, 59114.3, 59121.8, 58836.3];
362 let expected_last_5_hlcc4 = [59225.5, 59212.8, 59078.5, 59150.8, 58711.3];
363
364 fn compare_last_five(actual: &[f64], expected: &[f64], field_name: &str) {
365 let start = actual.len().saturating_sub(5);
366 let actual_slice = &actual[start..];
367 for (i, (&a, &e)) in actual_slice.iter().zip(expected.iter()).enumerate() {
368 let diff = (a - e).abs();
369 assert!(
370 diff < 1e-1,
371 "Mismatch in {} at last-5 index {}: expected {}, got {}",
372 field_name,
373 i,
374 e,
375 a
376 );
377 }
378 }
379 compare_last_five(hl2, &expected_last_5_hl2, "HL2");
380 compare_last_five(hlc3, &expected_last_5_hlc3, "HLC3");
381 compare_last_five(ohlc4, &expected_last_5_ohlc4, "OHLC4");
382 compare_last_five(hlcc4, &expected_last_5_hlcc4, "HLCC4");
383 }
384
385 #[test]
386 fn test_precompute_fields_direct() {
387 let timestamp = vec![1, 2, 3];
388 let open = vec![100.0, 200.0, 300.0];
389 let high = vec![110.0, 220.0, 330.0];
390 let low = vec![90.0, 180.0, 270.0];
391 let close = vec![105.0, 190.0, 310.0];
392 let volume = vec![1000.0, 2000.0, 3000.0];
393
394 let candles = Candles::new(timestamp, open, high, low, close, volume);
395
396 let hl2 = candles.get_calculated_field("hl2").unwrap();
397 assert_eq!(hl2, &[100.0, 200.0, 300.0]);
398
399 let hlc3 = candles.get_calculated_field("hlc3").unwrap();
400 let expected_hlc3 = &[101.6667, 196.6667, 303.3333];
401 for (actual, expected) in hlc3.iter().zip(expected_hlc3.iter()) {
402 assert!((actual - expected).abs() < 1e-4);
403 }
404
405 let ohlc4 = candles.get_calculated_field("ohlc4").unwrap();
406 let expected_ohlc4 = &[101.25, 197.5, 302.5];
407 for (actual, expected) in ohlc4.iter().zip(expected_ohlc4.iter()) {
408 assert!((actual - expected).abs() < 1e-4);
409 }
410
411 let hlcc4 = candles.get_calculated_field("hlcc4").unwrap();
412 let expected_hlcc4 = &[102.5, 195.0, 305.0];
413 for (actual, expected) in hlcc4.iter().zip(expected_hlcc4.iter()) {
414 assert!((actual - expected).abs() < 1e-4);
415 }
416 }
417}