1use std::collections::HashSet;
2use std::fs::File;
3use std::path::{Path, PathBuf};
4use std::str::FromStr;
5
6use anyhow::{anyhow, Context, Result};
7use arrow::array::{Array, Decimal128Array, StringArray, TimestampNanosecondArray};
8use arrow::record_batch::RecordBatch;
9use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
10use csv::{ReaderBuilder, WriterBuilder};
11use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
12use parquet::arrow::ArrowWriter;
13use parquet::basic::{Compression, ZstdLevel};
14use parquet::file::properties::WriterProperties;
15use rust_decimal::Decimal;
16
17use tesser_core::{Candle, Interval, Symbol, Tick};
18
19use crate::encoding::{candles_to_batch, ticks_to_batch};
20
21#[derive(Clone, Copy, Debug, Eq, PartialEq)]
23pub enum DatasetFormat {
24 Csv,
25 Parquet,
26}
27
28impl DatasetFormat {
29 #[must_use]
31 pub fn from_path(path: &Path) -> Self {
32 match path
33 .extension()
34 .and_then(|ext| ext.to_str())
35 .map(|ext| ext.to_ascii_lowercase())
36 .as_deref()
37 {
38 Some("parquet") => Self::Parquet,
39 _ => Self::Csv,
40 }
41 }
42}
43
44pub struct CandleDataset {
46 pub format: DatasetFormat,
47 pub candles: Vec<Candle>,
48}
49
50pub fn read_dataset(path: &Path) -> Result<CandleDataset> {
52 let format = DatasetFormat::from_path(path);
53 let candles = match format {
54 DatasetFormat::Csv => read_csv(path)
55 .with_context(|| format!("failed to load CSV dataset {}", path.display()))?,
56 DatasetFormat::Parquet => read_parquet(path)
57 .with_context(|| format!("failed to load parquet dataset {}", path.display()))?,
58 };
59 Ok(CandleDataset { format, candles })
60}
61
62pub fn write_dataset(path: &Path, format: DatasetFormat, candles: &[Candle]) -> Result<()> {
64 match format {
65 DatasetFormat::Csv => write_csv(path, candles),
66 DatasetFormat::Parquet => write_parquet(path, candles),
67 }
68}
69
70pub struct TicksWriter {
72 path: PathBuf,
73 rows: Vec<Tick>,
74 seen_ids: HashSet<String>,
75}
76
77impl TicksWriter {
78 pub fn new(path: impl Into<PathBuf>) -> Self {
80 Self {
81 path: path.into(),
82 rows: Vec::new(),
83 seen_ids: HashSet::new(),
84 }
85 }
86
87 pub fn len(&self) -> usize {
89 self.rows.len()
90 }
91
92 pub fn is_empty(&self) -> bool {
94 self.rows.is_empty()
95 }
96
97 pub fn push(&mut self, trade_id: Option<String>, tick: Tick) {
99 if let Some(id) = trade_id.filter(|value| !value.is_empty()) {
100 if !self.seen_ids.insert(id) {
101 return;
102 }
103 }
104 self.rows.push(tick);
105 }
106
107 pub fn extend<I>(&mut self, trades: I)
109 where
110 I: IntoIterator<Item = (Option<String>, Tick)>,
111 {
112 for (trade_id, tick) in trades {
113 self.push(trade_id, tick);
114 }
115 }
116
117 pub fn finish(mut self) -> Result<()> {
119 self.rows
120 .sort_by_key(|tick| tick.exchange_timestamp.timestamp_millis());
121 self.rows.dedup_by(|a, b| {
122 a.exchange_timestamp == b.exchange_timestamp
123 && a.price == b.price
124 && a.size == b.size
125 && a.side == b.side
126 });
127 write_ticks_parquet(&self.path, &self.rows)
128 }
129}
130
131fn read_csv(path: &Path) -> Result<Vec<Candle>> {
132 let mut reader = ReaderBuilder::new()
133 .flexible(true)
134 .from_path(path)
135 .with_context(|| format!("failed to open {}", path.display()))?;
136 let mut candles = Vec::new();
137 for row in reader.records() {
138 let record = row.with_context(|| format!("invalid row in {}", path.display()))?;
139 let timestamp = parse_timestamp(
140 record
141 .get(1)
142 .ok_or_else(|| anyhow!("missing timestamp column in {}", path.display()))?,
143 )?;
144 let symbol = match record.get(0) {
145 Some(value) if !value.trim().is_empty() => value.to_string(),
146 _ => infer_symbol(path).ok_or_else(|| {
147 anyhow!(
148 "missing symbol column and unable to infer from {}",
149 path.display()
150 )
151 })?,
152 };
153 let candle = Candle {
154 symbol: Symbol::from(symbol.as_str()),
155 interval: infer_interval(path).unwrap_or(Interval::OneMinute),
156 open: parse_decimal(record.get(2), "open", path)?,
157 high: parse_decimal(record.get(3), "high", path)?,
158 low: parse_decimal(record.get(4), "low", path)?,
159 close: parse_decimal(record.get(5), "close", path)?,
160 volume: parse_decimal(record.get(6), "volume", path)?,
161 timestamp,
162 };
163 candles.push(candle);
164 }
165 candles.sort_by_key(|c| c.timestamp);
166 Ok(candles)
167}
168
169fn read_parquet(path: &Path) -> Result<Vec<Candle>> {
170 let file = File::open(path)
171 .with_context(|| format!("failed to open parquet file {}", path.display()))?;
172 let reader = ParquetRecordBatchReaderBuilder::try_new(file)?
173 .with_batch_size(1024)
174 .build()?;
175 let mut columns: Option<CandleColumns> = None;
176 let mut candles = Vec::new();
177 for batch in reader {
178 let batch = batch?;
179 if columns.is_none() {
180 columns = Some(CandleColumns::from_batch(&batch)?);
181 }
182 let column_mapping = columns.as_ref().unwrap();
183 for row in 0..batch.num_rows() {
184 candles.push(column_mapping.decode(&batch, row)?);
185 }
186 }
187 candles.sort_by_key(|c| c.timestamp);
188 Ok(candles)
189}
190
191fn write_csv(path: &Path, candles: &[Candle]) -> Result<()> {
192 if let Some(parent) = path.parent() {
193 std::fs::create_dir_all(parent)
194 .with_context(|| format!("failed to create directory {}", parent.display()))?;
195 }
196 let mut writer = WriterBuilder::new()
197 .has_headers(true)
198 .from_path(path)
199 .with_context(|| format!("failed to create {}", path.display()))?;
200 writer.write_record([
201 "symbol",
202 "timestamp",
203 "open",
204 "high",
205 "low",
206 "close",
207 "volume",
208 ])?;
209 for candle in candles {
210 writer.write_record([
211 candle.symbol.code(),
212 &candle.timestamp.to_rfc3339(),
213 &candle.open.to_string(),
214 &candle.high.to_string(),
215 &candle.low.to_string(),
216 &candle.close.to_string(),
217 &candle.volume.to_string(),
218 ])?;
219 }
220 writer.flush()?;
221 Ok(())
222}
223
224fn write_parquet(path: &Path, candles: &[Candle]) -> Result<()> {
225 if let Some(parent) = path.parent() {
226 std::fs::create_dir_all(parent)
227 .with_context(|| format!("failed to create directory {}", parent.display()))?;
228 }
229 let batch = candles_to_batch(candles)?;
230 let file =
231 File::create(path).with_context(|| format!("failed to create {}", path.display()))?;
232 let props = WriterProperties::builder()
233 .set_compression(Compression::ZSTD(ZstdLevel::default()))
234 .build();
235 let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
236 writer.write(&batch)?;
237 writer.close()?;
238 Ok(())
239}
240
241fn write_ticks_parquet(path: &Path, ticks: &[Tick]) -> Result<()> {
242 if let Some(parent) = path.parent() {
243 std::fs::create_dir_all(parent)
244 .with_context(|| format!("failed to create directory {}", parent.display()))?;
245 }
246 let batch = ticks_to_batch(ticks)?;
247 let file =
248 File::create(path).with_context(|| format!("failed to create {}", path.display()))?;
249 let props = WriterProperties::builder()
250 .set_compression(Compression::ZSTD(ZstdLevel::default()))
251 .build();
252 let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
253 writer.write(&batch)?;
254 writer.close()?;
255 Ok(())
256}
257
258fn parse_decimal(value: Option<&str>, column: &str, path: &Path) -> Result<Decimal> {
259 let text = value.ok_or_else(|| anyhow!("missing {column} column in {}", path.display()))?;
260 Decimal::from_str(text)
261 .with_context(|| format!("invalid {column} value '{text}' in {}", path.display()))
262}
263
264fn parse_timestamp(value: &str) -> Result<DateTime<Utc>> {
265 if let Ok(ts) = DateTime::parse_from_rfc3339(value) {
266 return Ok(ts.with_timezone(&Utc));
267 }
268 if let Ok(dt) = NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S") {
269 return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
270 }
271 if let Ok(date) = NaiveDate::parse_from_str(value, "%Y-%m-%d") {
272 let dt = date
273 .and_hms_opt(0, 0, 0)
274 .ok_or_else(|| anyhow!("invalid date '{value}'"))?;
275 return Ok(DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc));
276 }
277 Err(anyhow!("unable to parse timestamp '{value}'"))
278}
279
280fn infer_symbol(path: &Path) -> Option<String> {
281 path.parent()
282 .and_then(|parent| parent.file_name())
283 .map(|os| os.to_string_lossy().to_string())
284}
285
286fn infer_interval(path: &Path) -> Option<Interval> {
287 path.file_stem()
288 .and_then(|stem| stem.to_str())
289 .and_then(|stem| stem.split('_').next())
290 .and_then(|token| Interval::from_str(token).ok())
291}
292
293struct CandleColumns {
294 symbol: usize,
295 interval: usize,
296 open: usize,
297 high: usize,
298 low: usize,
299 close: usize,
300 volume: usize,
301 timestamp: usize,
302}
303
304impl CandleColumns {
305 fn from_batch(batch: &RecordBatch) -> Result<Self> {
306 Ok(Self {
307 symbol: column_index(batch, "symbol")?,
308 interval: column_index(batch, "interval")?,
309 open: column_index(batch, "open")?,
310 high: column_index(batch, "high")?,
311 low: column_index(batch, "low")?,
312 close: column_index(batch, "close")?,
313 volume: column_index(batch, "volume")?,
314 timestamp: column_index(batch, "timestamp")?,
315 })
316 }
317
318 fn decode(&self, batch: &RecordBatch, row: usize) -> Result<Candle> {
319 let symbol = string_value(batch, self.symbol, row)?;
320 let interval_raw = string_value(batch, self.interval, row)?;
321 let interval =
322 Interval::from_str(&interval_raw).map_err(|err| anyhow!("{interval_raw}: {err}"))?;
323 Ok(Candle {
324 symbol: Symbol::from(symbol.as_str()),
325 interval,
326 open: decimal_value(batch, self.open, row)?,
327 high: decimal_value(batch, self.high, row)?,
328 low: decimal_value(batch, self.low, row)?,
329 close: decimal_value(batch, self.close, row)?,
330 volume: decimal_value(batch, self.volume, row)?,
331 timestamp: timestamp_value(batch, self.timestamp, row)?,
332 })
333 }
334}
335
336fn column_index(batch: &RecordBatch, name: &str) -> Result<usize> {
337 batch
338 .schema()
339 .column_with_name(name)
340 .map(|(idx, _)| idx)
341 .ok_or_else(|| anyhow!("column '{name}' missing from schema"))
342}
343
344fn string_value(batch: &RecordBatch, column: usize, row: usize) -> Result<String> {
345 let array = batch
346 .column(column)
347 .as_any()
348 .downcast_ref::<StringArray>()
349 .ok_or_else(|| anyhow!("column {column} is not Utf8"))?;
350 if array.is_null(row) {
351 return Err(anyhow!("column {column} contains null string"));
352 }
353 Ok(array.value(row).to_string())
354}
355
356fn decimal_value(batch: &RecordBatch, column: usize, row: usize) -> Result<Decimal> {
357 let array = batch
358 .column(column)
359 .as_any()
360 .downcast_ref::<Decimal128Array>()
361 .ok_or_else(|| anyhow!("column {column} is not decimal"))?;
362 if array.is_null(row) {
363 return Err(anyhow!("column {column} contains null decimal"));
364 }
365 Ok(Decimal::from_i128_with_scale(
366 array.value(row),
367 array.scale() as u32,
368 ))
369}
370
371fn timestamp_value(batch: &RecordBatch, column: usize, row: usize) -> Result<DateTime<Utc>> {
372 let array = batch
373 .column(column)
374 .as_any()
375 .downcast_ref::<TimestampNanosecondArray>()
376 .ok_or_else(|| anyhow!("column {column} is not timestamp"))?;
377 if array.is_null(row) {
378 return Err(anyhow!("column {column} contains null timestamp"));
379 }
380 let nanos = array.value(row);
381 let secs = nanos.div_euclid(1_000_000_000);
382 let sub = nanos.rem_euclid(1_000_000_000) as u32;
383 DateTime::<Utc>::from_timestamp(secs, sub)
384 .ok_or_else(|| anyhow!("timestamp overflow for value {nanos}"))
385}
386
387#[cfg(test)]
388mod tests {
389 use std::fs::File;
390
391 use chrono::{Duration, TimeZone, Utc};
392 use rust_decimal::{prelude::FromPrimitive, Decimal};
393 use tempfile::tempdir;
394 use tesser_core::{Side, Symbol};
395
396 use super::*;
397
398 fn sample_candles() -> Vec<Candle> {
399 let base = Utc::now() - Duration::minutes(10);
400 (0..4)
401 .map(|idx| Candle {
402 symbol: Symbol::from("BTCUSDT"),
403 interval: Interval::OneMinute,
404 open: Decimal::new(10 + idx as i64, 0),
405 high: Decimal::new(11 + idx as i64, 0),
406 low: Decimal::new(9 + idx as i64, 0),
407 close: Decimal::new(10 + idx as i64, 0),
408 volume: Decimal::new(1, 0),
409 timestamp: base + Duration::minutes(idx as i64),
410 })
411 .collect()
412 }
413
414 #[test]
415 fn round_trip_csv() -> Result<()> {
416 let temp = tempdir()?;
417 let path = temp.path().join("1m_BTCUSDT.csv");
418 let candles = sample_candles();
419 write_dataset(&path, DatasetFormat::Csv, &candles)?;
420 let dataset = read_dataset(&path)?;
421 assert_eq!(dataset.candles.len(), candles.len());
422 Ok(())
423 }
424
425 #[test]
426 fn round_trip_parquet() -> Result<()> {
427 let temp = tempdir()?;
428 let path = temp.path().join("1m_BTCUSDT.parquet");
429 let candles = sample_candles();
430 write_dataset(&path, DatasetFormat::Parquet, &candles)?;
431 let dataset = read_dataset(&path)?;
432 assert_eq!(dataset.candles.len(), candles.len());
433 Ok(())
434 }
435
436 #[test]
437 fn ticks_writer_dedupes_trade_ids_and_payloads() -> Result<()> {
438 let temp = tempdir()?;
439 let path = temp.path().join("ticks.parquet");
440 let mut writer = TicksWriter::new(&path);
441 writer.push(
442 Some("trade-1".to_string()),
443 sample_tick(1_000, 100.0, 1.0, Side::Buy),
444 );
445 writer.push(
447 Some("trade-1".to_string()),
448 sample_tick(1_000, 100.0, 1.0, Side::Buy),
449 );
450 writer.extend([
452 (None, sample_tick(2_000, 101.0, 2.0, Side::Sell)),
453 (None, sample_tick(2_000, 101.0, 2.0, Side::Sell)),
454 ]);
455 writer.finish()?;
456
457 let file = File::open(&path)?;
458 let reader = ParquetRecordBatchReaderBuilder::try_new(file)?
459 .with_batch_size(8)
460 .build()?;
461 let mut rows = 0;
462 for batch in reader {
463 rows += batch?.num_rows();
464 }
465 assert_eq!(rows, 2);
466 Ok(())
467 }
468
469 fn sample_tick(ts_ms: i64, price: f64, size: f64, side: Side) -> Tick {
470 let price = Decimal::from_f64(price).expect("valid price");
471 let size = Decimal::from_f64(size).expect("valid size");
472 let timestamp = Utc
473 .timestamp_millis_opt(ts_ms)
474 .single()
475 .expect("valid timestamp");
476 Tick {
477 symbol: Symbol::from("BTCUSDT"),
478 price,
479 size,
480 side,
481 exchange_timestamp: timestamp,
482 received_at: timestamp,
483 }
484 }
485}