1use chrono::{DateTime, Duration, Timelike, Utc};
7use shape_ast::error::{Result, ShapeError};
8
9use super::context::ExecutionContext;
10use shape_ast::ast::{NamedTime, RelativeTime, TimeDirection, TimeReference, TimeUnit, TimeWindow};
11
12pub struct TimeWindowResolver;
14
15impl TimeWindowResolver {
16 pub fn resolve_window(
18 window: &TimeWindow,
19 ctx: &ExecutionContext,
20 ) -> Result<std::ops::Range<usize>> {
21 match window {
22 TimeWindow::Last { amount, unit } => {
23 Self::resolve_last_window(*amount as u32, unit, ctx)
24 }
25 TimeWindow::Between { start, end } => Self::resolve_between_window(start, end, ctx),
26 TimeWindow::Window { start, end } => Self::resolve_window_indices(*start, *end, ctx),
27 TimeWindow::Session { start, end } => Self::resolve_session_window(start, end, ctx),
28 }
29 }
30
31 fn resolve_last_window(
33 amount: u32,
34 unit: &TimeUnit,
35 ctx: &ExecutionContext,
36 ) -> Result<std::ops::Range<usize>> {
37 let row_count = ctx.row_count();
38 if row_count == 0 {
39 return Ok(0..0);
40 }
41
42 if matches!(unit, TimeUnit::Samples) {
44 let start = row_count.saturating_sub(amount as usize);
45 return Ok(start..row_count);
46 }
47
48 let current_ts = ctx.get_row_timestamp(row_count - 1)?;
50 let current_time = DateTime::from_timestamp(current_ts, 0).unwrap_or_else(Utc::now);
51
52 let duration = Self::time_unit_to_duration(amount, unit)?;
53 let start_time = current_time - duration;
54
55 let start_idx = Self::find_row_at_or_after(start_time, ctx)?;
57
58 Ok(start_idx..row_count)
59 }
60
61 fn resolve_between_window(
63 start_ref: &TimeReference,
64 end_ref: &TimeReference,
65 ctx: &ExecutionContext,
66 ) -> Result<std::ops::Range<usize>> {
67 let start_time = Self::resolve_time_reference(start_ref, ctx)?;
68 let end_time = Self::resolve_time_reference(end_ref, ctx)?;
69
70 if start_time > end_time {
71 return Err(ShapeError::RuntimeError {
72 message: "Invalid time window: start time is after end time".into(),
73 location: None,
74 });
75 }
76
77 let start_idx = Self::find_row_at_or_after(start_time, ctx)?;
78 let end_idx = Self::find_row_at_or_before(end_time, ctx)? + 1;
79
80 Ok(start_idx..end_idx)
81 }
82
83 fn resolve_window_indices(
85 start: i32,
86 end: Option<i32>,
87 ctx: &ExecutionContext,
88 ) -> Result<std::ops::Range<usize>> {
89 let row_count = ctx.row_count();
90
91 let start_idx = if start < 0 {
93 (row_count as i32 + start) as usize
94 } else {
95 start as usize
96 };
97
98 let end_idx = match end {
99 Some(e) => {
100 if e < 0 {
101 (row_count as i32 + e) as usize
102 } else {
103 e as usize
104 }
105 }
106 None => start_idx + 1,
107 };
108
109 if start_idx >= row_count || end_idx > row_count {
111 return Err(ShapeError::RuntimeError {
112 message: "Window indices out of range".into(),
113 location: None,
114 });
115 }
116
117 Ok(start_idx..end_idx)
118 }
119
120 fn resolve_session_window(
122 start_time: &str,
123 end_time: &str,
124 ctx: &ExecutionContext,
125 ) -> Result<std::ops::Range<usize>> {
126 if let (Some(start_hour), Some(end_hour)) = (
128 Self::parse_time_of_day(start_time),
129 Self::parse_time_of_day(end_time),
130 ) {
131 return Self::find_session_rows(start_hour, end_hour, ctx);
132 }
133
134 Self::resolve_named_session(start_time, ctx)
136 }
137
138 fn parse_time_of_day(time_str: &str) -> Option<u32> {
140 let parts: Vec<&str> = time_str.split(':').collect();
141 if parts.len() >= 2 {
142 let hour: u32 = parts[0].parse().ok()?;
143 Some(hour)
145 } else if let Ok(hour) = time_str.parse::<u32>() {
146 Some(hour)
148 } else {
149 None
150 }
151 }
152
153 fn resolve_named_session(
155 session_name: &str,
156 ctx: &ExecutionContext,
157 ) -> Result<std::ops::Range<usize>> {
158 match session_name.to_lowercase().as_str() {
159 "london" => {
160 Self::find_session_rows(8, 16, ctx)
162 }
163 "newyork" | "ny" => {
164 Self::find_session_rows(13, 21, ctx)
166 }
167 "tokyo" => {
168 Self::find_session_rows(0, 8, ctx)
170 }
171 "sydney" => {
172 Self::find_session_rows(22, 6, ctx)
174 }
175 _ => Err(ShapeError::RuntimeError {
176 message: format!("Unknown session: {}", session_name),
177 location: None,
178 }),
179 }
180 }
181
182 fn find_session_rows(
184 start_hour: u32,
185 end_hour: u32,
186 ctx: &ExecutionContext,
187 ) -> Result<std::ops::Range<usize>> {
188 let row_count = ctx.row_count();
189 if row_count == 0 {
190 return Ok(0..0);
191 }
192
193 let mut session_indices = Vec::new();
195
196 for i in (0..row_count).rev() {
197 let ts = ctx.get_row_timestamp(i)?;
198 let dt = DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now);
199 let hour = dt.hour();
200
201 let in_session = if end_hour > start_hour {
202 hour >= start_hour && hour < end_hour
203 } else {
204 hour >= start_hour || hour < end_hour
206 };
207
208 if in_session {
209 session_indices.push(i);
210 } else if !session_indices.is_empty() {
211 break;
213 }
214 }
215
216 if session_indices.is_empty() {
217 return Ok(0..0);
218 }
219
220 session_indices.reverse();
221 let start = *session_indices.first().unwrap();
222 let end = *session_indices.last().unwrap() + 1;
223
224 Ok(start..end)
225 }
226
227 fn resolve_time_reference(
229 reference: &TimeReference,
230 ctx: &ExecutionContext,
231 ) -> Result<DateTime<Utc>> {
232 match reference {
233 TimeReference::Absolute(time_str) => {
234 Self::parse_time_string(time_str)
236 }
237 TimeReference::Named(named) => Self::resolve_named_time(named, ctx),
238 TimeReference::Relative(relative) => Self::resolve_relative_time(relative, ctx),
239 }
240 }
241
242 fn resolve_named_time(named: &NamedTime, ctx: &ExecutionContext) -> Result<DateTime<Utc>> {
244 let now = if ctx.row_count() > 0 {
245 let ts = ctx.get_row_timestamp(ctx.row_count() - 1)?;
246 DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now)
247 } else {
248 Utc::now()
249 };
250
251 match named {
252 NamedTime::Today => Ok(now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()),
253 NamedTime::Yesterday => {
254 let yesterday = now - Duration::days(1);
255 Ok(yesterday
256 .date_naive()
257 .and_hms_opt(0, 0, 0)
258 .unwrap()
259 .and_utc())
260 }
261 NamedTime::Now => Ok(now),
262 }
263 }
264
265 fn resolve_relative_time(
267 relative: &RelativeTime,
268 ctx: &ExecutionContext,
269 ) -> Result<DateTime<Utc>> {
270 let now = if ctx.row_count() > 0 {
271 let ts = ctx.get_row_timestamp(ctx.row_count() - 1)?;
272 DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now)
273 } else {
274 Utc::now()
275 };
276
277 let duration = Self::time_unit_to_duration(relative.amount as u32, &relative.unit)?;
278
279 match relative.direction {
280 TimeDirection::Ago => Ok(now - duration),
281 TimeDirection::Future => Ok(now + duration),
282 }
283 }
284
285 fn time_unit_to_duration(amount: u32, unit: &TimeUnit) -> Result<Duration> {
287 let amount = amount as i64;
288
289 match unit {
290 TimeUnit::Minutes => Ok(Duration::minutes(amount)),
291 TimeUnit::Hours => Ok(Duration::hours(amount)),
292 TimeUnit::Days => Ok(Duration::days(amount)),
293 TimeUnit::Weeks => Ok(Duration::weeks(amount)),
294 TimeUnit::Months => Ok(Duration::days(amount * 30)), TimeUnit::Samples => Err(ShapeError::RuntimeError {
296 message: "Cannot convert samples to duration".into(),
297 location: None,
298 }),
299 }
300 }
301
302 fn find_row_at_or_after(target_time: DateTime<Utc>, ctx: &ExecutionContext) -> Result<usize> {
304 let row_count = ctx.row_count();
305 let target_ts = target_time.timestamp();
306
307 let mut left = 0;
309 let mut right = row_count;
310
311 while left < right {
312 let mid = left + (right - left) / 2;
313 let mid_time = ctx.get_row_timestamp(mid)?;
314
315 if mid_time < target_ts {
316 left = mid + 1;
317 } else {
318 right = mid;
319 }
320 }
321
322 Ok(left)
323 }
324
325 fn find_row_at_or_before(target_time: DateTime<Utc>, ctx: &ExecutionContext) -> Result<usize> {
327 let row_count = ctx.row_count();
328 if row_count == 0 {
329 return Err(ShapeError::DataError {
330 message: "No rows available".into(),
331 symbol: None,
332 timeframe: None,
333 });
334 }
335
336 let target_ts = target_time.timestamp();
337
338 let mut left = 0;
340 let mut right = row_count;
341
342 while left < right {
343 let mid = left + (right - left).div_ceil(2);
344 let mid_time = ctx.get_row_timestamp(mid - 1)?;
345
346 if mid_time <= target_ts {
347 left = mid;
348 } else {
349 right = mid - 1;
350 }
351 }
352
353 if left > 0 { Ok(left - 1) } else { Ok(0) }
354 }
355
356 fn parse_time_string(time_str: &str) -> Result<DateTime<Utc>> {
358 if let Ok(dt) = DateTime::parse_from_rfc3339(time_str) {
361 return Ok(dt.with_timezone(&Utc));
362 }
363
364 let formats = [
366 "%Y-%m-%d %H:%M:%S",
367 "%Y-%m-%d %H:%M",
368 "%Y-%m-%d",
369 "%Y/%m/%d %H:%M:%S",
370 "%Y/%m/%d %H:%M",
371 "%Y/%m/%d",
372 "%d-%m-%Y %H:%M:%S",
373 "%d-%m-%Y %H:%M",
374 "%d-%m-%Y",
375 ];
376
377 for format in &formats {
378 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(time_str, format) {
379 return Ok(dt.and_utc());
380 }
381 if let Ok(date) = chrono::NaiveDate::parse_from_str(time_str, format) {
382 return Ok(date.and_hms_opt(0, 0, 0).unwrap().and_utc());
383 }
384 }
385
386 Err(ShapeError::RuntimeError {
387 message: format!("Unable to parse time string: {}", time_str),
388 location: None,
389 })
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::context::ExecutionContext;
397 use crate::data::OwnedDataRow as RowValue;
398 use crate::data::Timeframe;
399 use chrono::TimeZone;
400
401 fn create_test_context() -> ExecutionContext {
402 let mut ctx = ExecutionContext::new_empty();
403
404 let base_time = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
406 let tf = Timeframe::d1();
407 let mut rows = Vec::new();
408
409 for i in 0..100 {
410 let mut fields = std::collections::HashMap::new();
411 fields.insert("open".to_string(), 100.0);
412 fields.insert("high".to_string(), 110.0);
413 fields.insert("low".to_string(), 90.0);
414 fields.insert("close".to_string(), 105.0);
415 fields.insert("volume".to_string(), 1000.0);
416 rows.push(RowValue::from_hashmap(
417 (base_time + Duration::days(i as i64)).timestamp(),
418 fields,
419 ));
420 }
421
422 ctx.set_reference_datetime(base_time);
423
424 let df = crate::data::DataFrame::from_rows("TEST", tf, rows);
426 ctx.update_data(&df);
427
428 let mut cache_data = std::collections::HashMap::new();
429 cache_data.insert(
430 crate::data::cache::CacheKey::new("TEST".to_string(), tf),
431 df,
432 );
433 ctx.data_cache = Some(crate::data::DataCache::from_test_data(cache_data));
434
435 ctx
436 }
437
438 #[test]
439 fn test_resolve_last_samples() {
440 let ctx = create_test_context();
441 let window = TimeWindow::Last {
442 amount: 10,
443 unit: TimeUnit::Samples,
444 };
445
446 let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
447 assert_eq!(range, 90..100);
448 }
449
450 #[test]
451 fn test_resolve_last_days() {
452 let ctx = create_test_context();
453 let window = TimeWindow::Last {
454 amount: 5,
455 unit: TimeUnit::Days,
456 };
457
458 let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
459 assert!(range.len() >= 5);
460 assert_eq!(range.end, 100);
461 }
462
463 #[test]
464 fn test_resolve_between() {
465 let ctx = create_test_context();
466 let start_str = "2024-01-02"; let end_str = "2024-01-05"; let window = TimeWindow::Between {
470 start: TimeReference::Absolute(start_str.to_string()),
471 end: TimeReference::Absolute(end_str.to_string()),
472 };
473
474 let range = TimeWindowResolver::resolve_window(&window, &ctx).unwrap();
475 assert_eq!(range, 1..5);
477 }
478
479 #[test]
480 fn test_resolve_between_invalid() {
481 let ctx = create_test_context();
482 let window = TimeWindow::Between {
483 start: TimeReference::Absolute("2024-02-01".to_string()),
484 end: TimeReference::Absolute("2024-01-01".to_string()),
485 };
486
487 assert!(TimeWindowResolver::resolve_window(&window, &ctx).is_err());
488 }
489}