1use std::path::Path;
13
14use serde::Deserialize;
15
16use crate::error::{Error, Result};
17use wickra_core::Candle;
18
19const REQUIRED_COLUMNS: [&str; 6] = ["timestamp", "open", "high", "low", "close", "volume"];
24
25#[derive(Debug, Clone, Deserialize)]
30pub struct DefaultRow {
31 pub timestamp: i64,
32 pub open: f64,
33 pub high: f64,
34 pub low: f64,
35 pub close: f64,
36 pub volume: f64,
37}
38
39impl DefaultRow {
40 fn into_candle(self) -> Result<Candle> {
41 Candle::new(
42 self.open,
43 self.high,
44 self.low,
45 self.close,
46 self.volume,
47 self.timestamp,
48 )
49 .map_err(Error::from)
50 }
51}
52
53#[derive(Debug)]
62pub struct BomStripReader<R> {
63 inner: R,
64 checked: bool,
66 leftover: Vec<u8>,
69 leftover_pos: usize,
70}
71
72impl<R: std::io::Read> BomStripReader<R> {
73 pub fn new(inner: R) -> Self {
75 Self {
76 inner,
77 checked: false,
78 leftover: Vec::new(),
79 leftover_pos: 0,
80 }
81 }
82
83 fn check_bom(&mut self) -> std::io::Result<()> {
86 if self.checked {
87 return Ok(());
88 }
89 self.checked = true;
90
91 let mut probe = [0u8; 3];
92 let mut filled = 0;
93 while filled < probe.len() {
94 let n = self.inner.read(&mut probe[filled..])?;
95 if n == 0 {
96 break; }
98 filled += n;
99 }
100
101 if probe[..filled] != [0xEF, 0xBB, 0xBF] {
102 self.leftover.extend_from_slice(&probe[..filled]);
104 }
105 Ok(())
106 }
107}
108
109impl<R: std::io::Read> std::io::Read for BomStripReader<R> {
110 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
111 self.check_bom()?;
112 if self.leftover_pos < self.leftover.len() {
113 let n = (self.leftover.len() - self.leftover_pos).min(buf.len());
114 buf[..n].copy_from_slice(&self.leftover[self.leftover_pos..self.leftover_pos + n]);
115 self.leftover_pos += n;
116 return Ok(n);
117 }
118 self.inner.read(buf)
119 }
120}
121
122fn validate_headers<R: std::io::Read>(reader: &mut csv::Reader<R>) -> Result<()> {
124 let headers = reader.headers()?;
125 let present: Vec<String> = headers.iter().map(|h| h.trim().to_string()).collect();
126 let missing: Vec<&str> = REQUIRED_COLUMNS
127 .iter()
128 .copied()
129 .filter(|col| !present.iter().any(|h| h == col))
130 .collect();
131 if !missing.is_empty() {
132 return Err(Error::Malformed(format!(
133 "CSV header is missing required column(s) [{}]; found [{}] — \
134 the first line must be a header naming {}",
135 missing.join(", "),
136 present.join(", "),
137 REQUIRED_COLUMNS.join(",")
138 )));
139 }
140 Ok(())
141}
142
143#[derive(Debug)]
145pub struct CandleReader<R: std::io::Read> {
146 reader: csv::Reader<R>,
147}
148
149impl<R: std::io::Read> CandleReader<R> {
150 fn build(inner: R) -> Result<Self> {
152 let mut reader = csv::ReaderBuilder::new()
153 .has_headers(true)
154 .trim(csv::Trim::All)
155 .from_reader(inner);
156 validate_headers(&mut reader)?;
157 Ok(Self { reader })
158 }
159}
160
161impl CandleReader<BomStripReader<std::fs::File>> {
162 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
172 let file = std::fs::File::open(path)?;
173 Self::from_reader(file)
174 }
175}
176
177impl<R: std::io::Read> CandleReader<BomStripReader<R>> {
178 pub fn from_reader(inner: R) -> Result<Self> {
186 Self::build(BomStripReader::new(inner))
187 }
188}
189
190impl<R: std::io::Read> CandleReader<R> {
191 pub fn from_csv_reader(mut reader: csv::Reader<R>) -> Result<Self> {
202 validate_headers(&mut reader)?;
203 Ok(Self { reader })
204 }
205
206 pub fn candles(&mut self) -> impl Iterator<Item = Result<Candle>> + '_ {
208 self.reader.deserialize::<DefaultRow>().map(|row_res| {
209 let row = row_res?;
210 row.into_candle()
211 })
212 }
213
214 pub fn read_all(&mut self) -> Result<Vec<Candle>> {
216 self.candles().collect()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use std::io::Write;
224
225 #[test]
226 fn reads_well_formed_csv() {
227 let mut tmp = tempfile::NamedTempFile::new().unwrap();
228 writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
229 writeln!(tmp, "1,10.0,11.0,9.0,10.5,100").unwrap();
230 writeln!(tmp, "2,10.5,11.5,10.0,11.0,150").unwrap();
231 writeln!(tmp, "3,11.0,12.0,10.5,11.5,200").unwrap();
232 tmp.flush().unwrap();
233
234 let mut r = CandleReader::open(tmp.path()).unwrap();
235 let candles = r.read_all().unwrap();
236 assert_eq!(candles.len(), 3);
237 assert_eq!(candles[0].open, 10.0);
238 assert_eq!(candles[2].close, 11.5);
239 assert_eq!(candles[1].timestamp, 2);
240 }
241
242 #[test]
243 fn rejects_invalid_ohlc() {
244 let mut tmp = tempfile::NamedTempFile::new().unwrap();
245 writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
246 writeln!(tmp, "1,10.0,8.0,9.0,9.5,100").unwrap();
248 tmp.flush().unwrap();
249
250 let mut r = CandleReader::open(tmp.path()).unwrap();
251 let candles: Result<Vec<Candle>> = r.candles().collect();
252 assert!(candles.is_err());
253 }
254
255 #[test]
256 fn from_reader_works_on_in_memory_data() {
257 let data = "timestamp,open,high,low,close,volume\n1,1,2,0,1,10\n2,1,2,0,1,10\n";
258 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
259 let v = r.read_all().unwrap();
260 assert_eq!(v.len(), 2);
261 }
262
263 #[test]
264 fn rejects_file_without_header() {
265 let data = "1,10.0,11.0,9.0,10.5,100\n2,10.5,11.5,10.0,11.0,150\n";
268 let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
269 assert!(matches!(err, Error::Malformed(_)));
270 }
271
272 #[test]
273 fn rejects_header_missing_a_column() {
274 let data = "timestamp,open,high,low,close\n1,10.0,11.0,9.0,10.5\n";
276 let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
277 assert!(
282 matches!(&err, Error::Malformed(msg) if msg.contains("volume")),
283 "expected Malformed mentioning 'volume', got {err:?}"
284 );
285 }
286
287 #[test]
292 fn from_csv_reader_accepts_a_prebuilt_reader() {
293 let data = "timestamp;open;high;low;close;volume\n1;10.0;11.0;9.0;10.5;100\n";
294 let inner = csv::ReaderBuilder::new()
295 .delimiter(b';')
296 .from_reader(data.as_bytes());
297 let mut r = CandleReader::from_csv_reader(inner).unwrap();
298 let candles = r.read_all().unwrap();
299 assert_eq!(candles.len(), 1);
300 assert_eq!(candles[0].close, 10.5);
301 }
302
303 #[test]
304 fn strips_leading_utf8_bom() {
305 let data = "\u{feff}timestamp,open,high,low,close,volume\n1,10.0,11.0,9.0,10.5,100\n";
307 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
308 let v = r.read_all().unwrap();
309 assert_eq!(v.len(), 1);
310 assert_eq!(v[0].timestamp, 1);
311 assert_eq!(v[0].open, 10.0);
312 }
313
314 #[test]
315 fn tolerates_whitespace_around_fields() {
316 let data = " timestamp , open , high , low , close , volume \n\
317 1 , 10.0 , 11.0 , 9.0 , 10.5 , 100 \n";
318 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
319 let v = r.read_all().unwrap();
320 assert_eq!(v.len(), 1);
321 assert_eq!(v[0].close, 10.5);
322 assert_eq!(v[0].volume, 100.0);
323 }
324
325 #[test]
326 fn bom_stripper_passes_through_non_bom_input() {
327 use std::io::Read;
328 let mut out = String::new();
329 BomStripReader::new("hello".as_bytes())
330 .read_to_string(&mut out)
331 .unwrap();
332 assert_eq!(out, "hello");
333 }
334
335 #[test]
336 fn bom_stripper_handles_short_input() {
337 use std::io::Read;
338 let mut out = Vec::new();
339 BomStripReader::new([0x41u8, 0x42u8].as_slice())
341 .read_to_end(&mut out)
342 .unwrap();
343 assert_eq!(out, vec![0x41, 0x42]);
344 }
345}