1use std::cmp::Ordering;
2use std::collections::hash_map::Entry;
3use std::collections::HashMap;
4use std::mem;
5
6use chrono::{DateTime, Utc};
7use futures::StreamExt;
8use rust_decimal::Decimal;
9
10use tesser_core::{Candle, Interval, Symbol};
11
12pub struct Resampler {
24 interval: Interval,
25 interval_nanos: i64,
26 active: HashMap<Symbol, Bucket>,
27 output: Vec<Candle>,
28}
29
30impl Resampler {
31 pub fn new(interval: Interval) -> Self {
33 let interval_nanos = interval
34 .as_duration()
35 .num_nanoseconds()
36 .expect("interval nanoseconds fit into i64");
37 Self {
38 interval,
39 interval_nanos,
40 active: HashMap::new(),
41 output: Vec::new(),
42 }
43 }
44
45 pub fn resample(candles: Vec<Candle>, interval: Interval) -> Vec<Candle> {
47 let mut sorted = candles;
48 sorted.sort_by_key(|c| c.timestamp);
49 let mut resampler = Self::new(interval);
50 for candle in sorted {
51 resampler.push(candle);
52 }
53 resampler.finish()
54 }
55
56 pub async fn resample_stream<S>(interval: Interval, stream: &mut S) -> Vec<Candle>
58 where
59 S: futures::Stream<Item = Candle> + Unpin,
60 {
61 let mut resampler = Self::new(interval);
62 while let Some(candle) = stream.next().await {
63 resampler.push(candle);
64 }
65 resampler.finish()
66 }
67
68 pub fn push(&mut self, candle: Candle) {
70 let bucket_start = align_timestamp(candle.timestamp, self.interval_nanos);
71 let symbol = candle.symbol;
72 match self.active.entry(symbol) {
73 Entry::Vacant(slot) => {
74 slot.insert(Bucket::from_candle(
75 symbol,
76 bucket_start,
77 self.interval,
78 &candle,
79 ));
80 }
81 Entry::Occupied(mut slot) => {
82 let entry = slot.get_mut();
83 match bucket_start.cmp(&entry.start) {
84 Ordering::Less => {
85 let finished = mem::replace(
87 entry,
88 Bucket::from_candle(symbol, bucket_start, self.interval, &candle),
89 );
90 self.output.push(finished.into_candle());
91 }
92 Ordering::Equal => {
93 entry.update(&candle);
94 }
95 Ordering::Greater => {
96 let finished = mem::replace(
98 entry,
99 Bucket::from_candle(symbol, bucket_start, self.interval, &candle),
100 );
101 self.output.push(finished.into_candle());
102 }
103 }
104 }
105 }
106 }
107
108 pub fn finish(mut self) -> Vec<Candle> {
110 for bucket in self.active.into_values() {
111 self.output.push(bucket.into_candle());
112 }
113 self.output.sort_by(|a, b| {
114 let ts = a.timestamp.cmp(&b.timestamp);
115 if ts == Ordering::Equal {
116 (a.symbol.exchange.as_raw(), a.symbol.market_id)
117 .cmp(&(b.symbol.exchange.as_raw(), b.symbol.market_id))
118 } else {
119 ts
120 }
121 });
122 self.output
123 }
124}
125
126struct Bucket {
127 symbol: Symbol,
128 interval: Interval,
129 start: DateTime<Utc>,
130 open: Decimal,
131 high: Decimal,
132 low: Decimal,
133 close: Decimal,
134 volume: Decimal,
135}
136
137impl Bucket {
138 fn from_candle(
139 symbol: Symbol,
140 start: DateTime<Utc>,
141 interval: Interval,
142 candle: &Candle,
143 ) -> Self {
144 Self {
145 symbol,
146 interval,
147 start,
148 open: candle.open,
149 high: candle.high,
150 low: candle.low,
151 close: candle.close,
152 volume: candle.volume,
153 }
154 }
155
156 fn update(&mut self, candle: &Candle) {
157 if candle.high > self.high {
158 self.high = candle.high;
159 }
160 if candle.low < self.low {
161 self.low = candle.low;
162 }
163 self.close = candle.close;
164 self.volume += candle.volume;
165 }
166
167 fn into_candle(self) -> Candle {
168 Candle {
169 symbol: self.symbol,
170 interval: self.interval,
171 open: self.open,
172 high: self.high,
173 low: self.low,
174 close: self.close,
175 volume: self.volume,
176 timestamp: self.start,
177 }
178 }
179}
180
181fn align_timestamp(ts: DateTime<Utc>, step_nanos: i64) -> DateTime<Utc> {
182 let timestamp = ts
183 .timestamp_nanos_opt()
184 .expect("timestamp fits into i64 nanoseconds");
185 let aligned = timestamp - timestamp.rem_euclid(step_nanos);
186 let secs = aligned.div_euclid(1_000_000_000);
187 let nanos = aligned.rem_euclid(1_000_000_000) as u32;
188 DateTime::<Utc>::from_timestamp(secs, nanos)
189 .expect("aligned timestamp within chrono supported range")
190}
191
192#[cfg(test)]
193mod tests {
194 use chrono::{Duration, TimeZone, Timelike, Utc};
195 use rust_decimal::Decimal;
196 use tesser_core::Interval;
197
198 use super::*;
199
200 fn candle_at(minute: i64, close: i64) -> Candle {
201 Candle {
202 symbol: "BTCUSDT".into(),
203 interval: Interval::OneMinute,
204 open: Decimal::ONE,
205 high: Decimal::ONE,
206 low: Decimal::ONE,
207 close: Decimal::new(close, 0),
208 volume: Decimal::new(10, 0),
209 timestamp: Utc.with_ymd_and_hms(2023, 1, 1, 0, 0, 0).unwrap()
210 + Duration::minutes(minute),
211 }
212 }
213
214 #[tokio::test]
215 async fn resamples_stream() {
216 let candles: Vec<_> = (0..10).map(|idx| candle_at(idx, idx)).collect();
217 let mut stream = futures::stream::iter(candles.clone());
218 let resampled = Resampler::resample_stream(Interval::FiveMinutes, &mut stream).await;
219 assert_eq!(resampled.len(), 2);
220 assert_eq!(resampled[0].close, candles[4].close);
221 assert_eq!(resampled[1].close, candles[9].close);
222 }
223
224 #[test]
225 fn resamples_vec() {
226 let candles: Vec<_> = (0..10).map(|idx| candle_at(idx, idx)).collect();
227 let resampled = Resampler::resample(candles, Interval::FiveMinutes);
228 assert_eq!(resampled.len(), 2);
229 assert_eq!(resampled[0].timestamp.minute(), 0);
230 assert_eq!(resampled[1].timestamp.minute(), 5);
231 }
232}