Skip to main content

shape_runtime/
lookahead_guard.rs

1//! Data access validation for time-series processing
2//!
3//! This module provides runtime guards to prevent accessing future data
4//! during restricted execution modes (useful for any time-series domain).
5
6use chrono::{DateTime, Utc};
7use shape_ast::error::{Result, ShapeError};
8use std::sync::RwLock;
9
10/// Data access mode - controls which data can be accessed during evaluation
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum DataAccessMode {
13    /// Unrestricted - can access all data (historical analysis)
14    Unrestricted,
15    /// Restricted - no future data access (simulation, backtesting, validation)
16    Restricted,
17    /// Forward-only - only current and future data (real-time streaming)
18    ForwardOnly,
19}
20
21/// Guards against accessing future data
22#[derive(Debug)]
23pub struct LookAheadGuard {
24    mode: DataAccessMode,
25    current_time: RwLock<Option<DateTime<Utc>>>,
26    strict_mode: bool,
27    access_log: RwLock<Vec<DataAccess>>,
28}
29
30/// Record of data access for auditing
31#[derive(Debug, Clone)]
32pub struct DataAccess {
33    pub timestamp: DateTime<Utc>,
34    pub accessed_time: DateTime<Utc>,
35    pub access_type: String,
36    pub allowed: bool,
37}
38
39impl LookAheadGuard {
40    pub fn new(mode: DataAccessMode, strict_mode: bool) -> Self {
41        Self {
42            mode,
43            current_time: RwLock::new(None),
44            strict_mode,
45            access_log: RwLock::new(Vec::new()),
46        }
47    }
48
49    /// Set the current processing time
50    pub fn set_current_time(&self, time: DateTime<Utc>) {
51        *self.current_time.write().unwrap() = Some(time);
52    }
53
54    /// Check if accessing data at a specific time is allowed
55    pub fn check_access(&self, access_time: DateTime<Utc>, access_type: &str) -> Result<()> {
56        let current =
57            self.current_time
58                .read()
59                .unwrap()
60                .ok_or_else(|| ShapeError::RuntimeError {
61                    message: "Current time not set in LookAheadGuard".to_string(),
62                    location: None,
63                })?;
64
65        let allowed = match self.mode {
66            DataAccessMode::Unrestricted => true, // Can access any data
67            DataAccessMode::Restricted | DataAccessMode::ForwardOnly => access_time <= current,
68        };
69
70        // Log the access
71        self.access_log.write().unwrap().push(DataAccess {
72            timestamp: current,
73            accessed_time: access_time,
74            access_type: access_type.to_string(),
75            allowed,
76        });
77
78        if !allowed {
79            if self.strict_mode {
80                return Err(ShapeError::RuntimeError {
81                    message: format!(
82                        "Future data access violation: Attempted to access data at {} while current time is {}",
83                        access_time, current
84                    ),
85                    location: None,
86                });
87            } else {
88                // In non-strict mode, log warning but continue
89                eprintln!(
90                    "WARNING: Future data access - accessing {} at current time {}",
91                    access_time, current
92                );
93            }
94        }
95
96        Ok(())
97    }
98
99    /// Check if accessing a row index is allowed
100    pub fn check_row_index(&self, index: i32, _access_type: &str) -> Result<()> {
101        match self.mode {
102            DataAccessMode::Unrestricted => Ok(()), // Can access any index
103            DataAccessMode::Restricted | DataAccessMode::ForwardOnly => {
104                if index > 0 {
105                    let msg = format!(
106                        "Future data access violation: Attempted to access data[{}] in restricted mode",
107                        index
108                    );
109
110                    if self.strict_mode {
111                        return Err(ShapeError::RuntimeError {
112                            message: msg,
113                            location: None,
114                        });
115                    } else {
116                        eprintln!("WARNING: {}", msg);
117                    }
118                }
119                Ok(())
120            }
121        }
122    }
123
124    /// Get access log for auditing
125    pub fn get_access_log(&self) -> Vec<DataAccess> {
126        self.access_log.read().unwrap().clone()
127    }
128
129    /// Clear access log
130    pub fn clear_log(&self) {
131        self.access_log.write().unwrap().clear();
132    }
133
134    /// Get summary of violations
135    pub fn get_violation_summary(&self) -> LookAheadSummary {
136        let log = self.access_log.read().unwrap();
137        let violations: Vec<_> = log
138            .iter()
139            .filter(|access| !access.allowed)
140            .cloned()
141            .collect();
142
143        LookAheadSummary {
144            total_accesses: log.len(),
145            violations: violations.len(),
146            violation_details: violations,
147        }
148    }
149}
150
151impl Clone for LookAheadGuard {
152    fn clone(&self) -> Self {
153        Self {
154            mode: self.mode,
155            current_time: RwLock::new(*self.current_time.read().unwrap()),
156            strict_mode: self.strict_mode,
157            access_log: RwLock::new(self.access_log.read().unwrap().clone()),
158        }
159    }
160}
161
162/// Summary of look-ahead violations
163#[derive(Debug, Clone)]
164pub struct LookAheadSummary {
165    pub total_accesses: usize,
166    pub violations: usize,
167    pub violation_details: Vec<DataAccess>,
168}
169
170impl LookAheadSummary {
171    pub fn print_report(&self) {
172        println!("=== Data Access Validation Report ===");
173        println!("Total data accesses: {}", self.total_accesses);
174        println!("Violations found: {}", self.violations);
175
176        if self.violations > 0 {
177            println!("\nViolation Details:");
178            for (i, violation) in self.violation_details.iter().enumerate() {
179                println!(
180                    "  {}. At {}: Tried to access {} (type: {})",
181                    i + 1,
182                    violation.timestamp.format("%Y-%m-%d %H:%M:%S"),
183                    violation.accessed_time.format("%Y-%m-%d %H:%M:%S"),
184                    violation.access_type
185                );
186            }
187        }
188    }
189}