Skip to main content

shape_runtime/
time_window.rs

1//! Time window support for Shape
2//!
3//! This module handles conversion between time-based windows and row indices,
4//! supporting queries like "last(5 days)" or "between(@yesterday, @today)".
5
6use chrono::{DateTime, Duration, Timelike, Utc};
7use shape_ast::error::{Result, ShapeError};
8
9use super::context::ExecutionContext;
10use shape_ast::ast::{NamedTime, RelativeTime, TimeDirection, TimeReference, TimeUnit, TimeWindow};
11
12/// Time window resolver
13pub struct TimeWindowResolver;
14
15impl TimeWindowResolver {
16    /// Convert a time window to row index range
17    pub fn resolve_window(
18        window: &TimeWindow,
19        ctx: &ExecutionContext,
20    ) -> Result<std::ops::Range<usize>> {
21        match window {
22            TimeWindow::Last { amount, unit } => {
23                Self::resolve_last_window(*amount as u32, unit, ctx)
24            }
25            TimeWindow::Between { start, end } => Self::resolve_between_window(start, end, ctx),
26            TimeWindow::Window { start, end } => Self::resolve_window_indices(*start, *end, ctx),
27            TimeWindow::Session { start, end } => Self::resolve_session_window(start, end, ctx),
28        }
29    }
30
31    /// Resolve "last(N units)" window
32    fn resolve_last_window(
33        amount: u32,
34        unit: &TimeUnit,
35        ctx: &ExecutionContext,
36    ) -> Result<std::ops::Range<usize>> {
37        let row_count = ctx.row_count();
38        if row_count == 0 {
39            return Ok(0..0);
40        }
41
42        // For sample-based units, it's straightforward
43        if matches!(unit, TimeUnit::Samples) {
44            let start = row_count.saturating_sub(amount as usize);
45            return Ok(start..row_count);
46        }
47
48        // For time-based units, we need to calculate based on timestamps
49        let current_ts = ctx.get_row_timestamp(row_count - 1)?;
50        let current_time = DateTime::from_timestamp(current_ts, 0).unwrap_or_else(Utc::now);
51
52        let duration = Self::time_unit_to_duration(amount, unit)?;
53        let start_time = current_time - duration;
54
55        // Find the row index for start_time
56        let start_idx = Self::find_row_at_or_after(start_time, ctx)?;
57
58        Ok(start_idx..row_count)
59    }
60
61    /// Resolve "between(start, end)" window
62    fn resolve_between_window(
63        start_ref: &TimeReference,
64        end_ref: &TimeReference,
65        ctx: &ExecutionContext,
66    ) -> Result<std::ops::Range<usize>> {
67        let start_time = Self::resolve_time_reference(start_ref, ctx)?;
68        let end_time = Self::resolve_time_reference(end_ref, ctx)?;
69
70        if start_time > end_time {
71            return Err(ShapeError::RuntimeError {
72                message: "Invalid time window: start time is after end time".into(),
73                location: None,
74            });
75        }
76
77        let start_idx = Self::find_row_at_or_after(start_time, ctx)?;
78        let end_idx = Self::find_row_at_or_before(end_time, ctx)? + 1;
79
80        Ok(start_idx..end_idx)
81    }
82
83    /// Resolve window with explicit indices
84    fn resolve_window_indices(
85        start: i32,
86        end: Option<i32>,
87        ctx: &ExecutionContext,
88    ) -> Result<std::ops::Range<usize>> {
89        let row_count = ctx.row_count();
90
91        // Convert negative indices to positive
92        let start_idx = if start < 0 {
93            (row_count as i32 + start) as usize
94        } else {
95            start as usize
96        };
97
98        let end_idx = match end {
99            Some(e) => {
100                if e < 0 {
101                    (row_count as i32 + e) as usize
102                } else {
103                    e as usize
104                }
105            }
106            None => start_idx + 1,
107        };
108
109        // Validate range
110        if start_idx >= row_count || end_idx > row_count {
111            return Err(ShapeError::RuntimeError {
112                message: "Window indices out of range".into(),
113                location: None,
114            });
115        }
116
117        Ok(start_idx..end_idx)
118    }
119
120    /// Resolve session window with start and end times
121    fn resolve_session_window(
122        start_time: &str,
123        end_time: &str,
124        ctx: &ExecutionContext,
125    ) -> Result<std::ops::Range<usize>> {
126        // First try to parse as time strings (HH:MM or HH:MM:SS format)
127        if let (Some(start_hour), Some(end_hour)) = (
128            Self::parse_time_of_day(start_time),
129            Self::parse_time_of_day(end_time),
130        ) {
131            return Self::find_session_rows(start_hour, end_hour, ctx);
132        }
133
134        // If parsing fails, treat start_time as a session name
135        Self::resolve_named_session(start_time, ctx)
136    }
137
138    /// Parse a time of day string like "09:30" or "16:00" to hour (with minute fraction)
139    fn parse_time_of_day(time_str: &str) -> Option<u32> {
140        let parts: Vec<&str> = time_str.split(':').collect();
141        if parts.len() >= 2 {
142            let hour: u32 = parts[0].parse().ok()?;
143            // We only use hour for session matching
144            Some(hour)
145        } else if let Ok(hour) = time_str.parse::<u32>() {
146            // Allow just hour number
147            Some(hour)
148        } else {
149            None
150        }
151    }
152
153    /// Resolve session window by name (e.g., "london", "newyork", "tokyo")
154    fn resolve_named_session(
155        session_name: &str,
156        ctx: &ExecutionContext,
157    ) -> Result<std::ops::Range<usize>> {
158        match session_name.to_lowercase().as_str() {
159            "london" => {
160                // London session: 08:00 - 16:00 UTC
161                Self::find_session_rows(8, 16, ctx)
162            }
163            "newyork" | "ny" => {
164                // New York session: 13:00 - 21:00 UTC
165                Self::find_session_rows(13, 21, ctx)
166            }
167            "tokyo" => {
168                // Tokyo session: 00:00 - 08:00 UTC
169                Self::find_session_rows(0, 8, ctx)
170            }
171            "sydney" => {
172                // Sydney session: 22:00 - 06:00 UTC (next day)
173                Self::find_session_rows(22, 6, ctx)
174            }
175            _ => Err(ShapeError::RuntimeError {
176                message: format!("Unknown session: {}", session_name),
177                location: None,
178            }),
179        }
180    }
181
182    /// Find rows within a specific hour range
183    fn find_session_rows(
184        start_hour: u32,
185        end_hour: u32,
186        ctx: &ExecutionContext,
187    ) -> Result<std::ops::Range<usize>> {
188        let row_count = ctx.row_count();
189        if row_count == 0 {
190            return Ok(0..0);
191        }
192
193        // Find the most recent session
194        let mut session_indices = Vec::new();
195
196        for i in (0..row_count).rev() {
197            let ts = ctx.get_row_timestamp(i)?;
198            let dt = DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now);
199            let hour = dt.hour();
200
201            let in_session = if end_hour > start_hour {
202                hour >= start_hour && hour < end_hour
203            } else {
204                // Handle sessions that cross midnight
205                hour >= start_hour || hour < end_hour
206            };
207
208            if in_session {
209                session_indices.push(i);
210            } else if !session_indices.is_empty() {
211                // We've found a complete session
212                break;
213            }
214        }
215
216        if session_indices.is_empty() {
217            return Ok(0..0);
218        }
219
220        session_indices.reverse();
221        let start = *session_indices.first().unwrap();
222        let end = *session_indices.last().unwrap() + 1;
223
224        Ok(start..end)
225    }
226
227    /// Resolve a time reference to an absolute timestamp
228    fn resolve_time_reference(
229        reference: &TimeReference,
230        ctx: &ExecutionContext,
231    ) -> Result<DateTime<Utc>> {
232        match reference {
233            TimeReference::Absolute(time_str) => {
234                // Parse various time formats
235                Self::parse_time_string(time_str)
236            }
237            TimeReference::Named(named) => Self::resolve_named_time(named, ctx),
238            TimeReference::Relative(relative) => Self::resolve_relative_time(relative, ctx),
239        }
240    }
241
242    /// Resolve named time references
243    fn resolve_named_time(named: &NamedTime, ctx: &ExecutionContext) -> Result<DateTime<Utc>> {
244        let now = if ctx.row_count() > 0 {
245            let ts = ctx.get_row_timestamp(ctx.row_count() - 1)?;
246            DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now)
247        } else {
248            Utc::now()
249        };
250
251        match named {
252            NamedTime::Today => Ok(now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()),
253            NamedTime::Yesterday => {
254                let yesterday = now - Duration::days(1);
255                Ok(yesterday
256                    .date_naive()
257                    .and_hms_opt(0, 0, 0)
258                    .unwrap()
259                    .and_utc())
260            }
261            NamedTime::Now => Ok(now),
262        }
263    }
264
265    /// Resolve relative time references
266    fn resolve_relative_time(
267        relative: &RelativeTime,
268        ctx: &ExecutionContext,
269    ) -> Result<DateTime<Utc>> {
270        let now = if ctx.row_count() > 0 {
271            let ts = ctx.get_row_timestamp(ctx.row_count() - 1)?;
272            DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now)
273        } else {
274            Utc::now()
275        };
276
277        let duration = Self::time_unit_to_duration(relative.amount as u32, &relative.unit)?;
278
279        match relative.direction {
280            TimeDirection::Ago => Ok(now - duration),
281            TimeDirection::Future => Ok(now + duration),
282        }
283    }
284
285    /// Convert time unit to chrono duration
286    fn time_unit_to_duration(amount: u32, unit: &TimeUnit) -> Result<Duration> {
287        let amount = amount as i64;
288
289        match unit {
290            TimeUnit::Minutes => Ok(Duration::minutes(amount)),
291            TimeUnit::Hours => Ok(Duration::hours(amount)),
292            TimeUnit::Days => Ok(Duration::days(amount)),
293            TimeUnit::Weeks => Ok(Duration::weeks(amount)),
294            TimeUnit::Months => Ok(Duration::days(amount * 30)), // Approximate
295            TimeUnit::Samples => Err(ShapeError::RuntimeError {
296                message: "Cannot convert samples to duration".into(),
297                location: None,
298            }),
299        }
300    }
301
302    /// Find the row at or after the given timestamp
303    fn find_row_at_or_after(target_time: DateTime<Utc>, ctx: &ExecutionContext) -> Result<usize> {
304        let row_count = ctx.row_count();
305        let target_ts = target_time.timestamp();
306
307        // Binary search for efficiency
308        let mut left = 0;
309        let mut right = row_count;
310
311        while left < right {
312            let mid = left + (right - left) / 2;
313            let mid_time = ctx.get_row_timestamp(mid)?;
314
315            if mid_time < target_ts {
316                left = mid + 1;
317            } else {
318                right = mid;
319            }
320        }
321
322        Ok(left)
323    }
324
325    /// Find the row at or before the given timestamp
326    fn find_row_at_or_before(target_time: DateTime<Utc>, ctx: &ExecutionContext) -> Result<usize> {
327        let row_count = ctx.row_count();
328        if row_count == 0 {
329            return Err(ShapeError::DataError {
330                message: "No rows available".into(),
331                symbol: None,
332                timeframe: None,
333            });
334        }
335
336        let target_ts = target_time.timestamp();
337
338        // Binary search
339        let mut left = 0;
340        let mut right = row_count;
341
342        while left < right {
343            let mid = left + (right - left).div_ceil(2);
344            let mid_time = ctx.get_row_timestamp(mid - 1)?;
345
346            if mid_time <= target_ts {
347                left = mid;
348            } else {
349                right = mid - 1;
350            }
351        }
352
353        if left > 0 { Ok(left - 1) } else { Ok(0) }
354    }
355
356    /// Parse a time string in various formats
357    fn parse_time_string(time_str: &str) -> Result<DateTime<Utc>> {
358        // Try different formats
359        // ISO 8601
360        if let Ok(dt) = DateTime::parse_from_rfc3339(time_str) {
361            return Ok(dt.with_timezone(&Utc));
362        }
363
364        // Common date formats
365        let formats = [
366            "%Y-%m-%d %H:%M:%S",
367            "%Y-%m-%d %H:%M",
368            "%Y-%m-%d",
369            "%Y/%m/%d %H:%M:%S",
370            "%Y/%m/%d %H:%M",
371            "%Y/%m/%d",
372            "%d-%m-%Y %H:%M:%S",
373            "%d-%m-%Y %H:%M",
374            "%d-%m-%Y",
375        ];
376
377        for format in &formats {
378            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(time_str, format) {
379                return Ok(dt.and_utc());
380            }
381            if let Ok(date) = chrono::NaiveDate::parse_from_str(time_str, format) {
382                return Ok(date.and_hms_opt(0, 0, 0).unwrap().and_utc());
383            }
384        }
385
386        Err(ShapeError::RuntimeError {
387            message: format!("Unable to parse time string: {}", time_str),
388            location: None,
389        })
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::context::ExecutionContext;
397    use crate::data::OwnedDataRow as RowValue;
398    use crate::data::Timeframe;
399    use chrono::TimeZone;
400
401    fn create_test_context() -> ExecutionContext {
402        let mut ctx = ExecutionContext::new_empty();
403
404        // Create dummy rows: 100 days starting from 2024-01-01
405        let base_time = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
406        let tf = Timeframe::d1();
407        let mut rows = Vec::new();
408
409        for i in 0..100 {
410            let mut fields = std::collections::HashMap::new();
411            fields.insert("open".to_string(), 100.0);
412            fields.insert("high".to_string(), 110.0);
413            fields.insert("low".to_string(), 90.0);
414            fields.insert("close".to_string(), 105.0);
415            fields.insert("volume".to_string(), 1000.0);
416            rows.push(RowValue::from_hashmap(
417                (base_time + Duration::days(i as i64)).timestamp(),
418                fields,
419            ));
420        }
421
422        ctx.set_reference_datetime(base_time);
423
424        // Build a DataFrame from the rows and inject it into the DataCache
425        let df = crate::data::DataFrame::from_rows("TEST", tf, rows);
426        ctx.update_data(&df);
427
428        let mut cache_data = std::collections::HashMap::new();
429        cache_data.insert(
430            crate::data::cache::CacheKey::new("TEST".to_string(), tf),
431            df,
432        );
433        ctx.data_cache = Some(crate::data::DataCache::from_test_data(cache_data));
434
435        ctx
436    }
437
438    #[test]
439    fn test_resolve_last_samples() {
440        let ctx = create_test_context();
441        let window = TimeWindow::Last {
442            amount: 10,
443            unit: TimeUnit::Samples,
444        };
445
446        let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
447        assert_eq!(range, 90..100);
448    }
449
450    #[test]
451    fn test_resolve_last_days() {
452        let ctx = create_test_context();
453        let window = TimeWindow::Last {
454            amount: 5,
455            unit: TimeUnit::Days,
456        };
457
458        let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
459        assert!(range.len() >= 5);
460        assert_eq!(range.end, 100);
461    }
462
463    #[test]
464    fn test_resolve_between() {
465        let ctx = create_test_context();
466        let start_str = "2024-01-02"; // Index 1
467        let end_str = "2024-01-05"; // Index 4
468
469        let window = TimeWindow::Between {
470            start: TimeReference::Absolute(start_str.to_string()),
471            end: TimeReference::Absolute(end_str.to_string()),
472        };
473
474        let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
475        // Should correspond to indices 1..5 (inclusive of 4)
476        assert_eq!(range, 1..5);
477    }
478
479    #[test]
480    fn test_resolve_between_invalid() {
481        let ctx = create_test_context();
482        let window = TimeWindow::Between {
483            start: TimeReference::Absolute("2024-02-01".to_string()),
484            end: TimeReference::Absolute("2024-01-01".to_string()),
485        };
486
487        assert!(TimeWindowResolver::resolve_window(&window, &ctx).is_err());
488    }
489}