shape_runtime/
lookahead_guard.rs1use chrono::{DateTime, Utc};
7use shape_ast::error::{Result, ShapeError};
8use std::sync::RwLock;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum DataAccessMode {
13 Unrestricted,
15 Restricted,
17 ForwardOnly,
19}
20
21#[derive(Debug)]
23pub struct LookAheadGuard {
24 mode: DataAccessMode,
25 current_time: RwLock<Option<DateTime<Utc>>>,
26 strict_mode: bool,
27 access_log: RwLock<Vec<DataAccess>>,
28}
29
30#[derive(Debug, Clone)]
32pub struct DataAccess {
33 pub timestamp: DateTime<Utc>,
34 pub accessed_time: DateTime<Utc>,
35 pub access_type: String,
36 pub allowed: bool,
37}
38
39impl LookAheadGuard {
40 pub fn new(mode: DataAccessMode, strict_mode: bool) -> Self {
41 Self {
42 mode,
43 current_time: RwLock::new(None),
44 strict_mode,
45 access_log: RwLock::new(Vec::new()),
46 }
47 }
48
49 pub fn set_current_time(&self, time: DateTime<Utc>) {
51 *self.current_time.write().unwrap() = Some(time);
52 }
53
54 pub fn check_access(&self, access_time: DateTime<Utc>, access_type: &str) -> Result<()> {
56 let current =
57 self.current_time
58 .read()
59 .unwrap()
60 .ok_or_else(|| ShapeError::RuntimeError {
61 message: "Current time not set in LookAheadGuard".to_string(),
62 location: None,
63 })?;
64
65 let allowed = match self.mode {
66 DataAccessMode::Unrestricted => true, DataAccessMode::Restricted | DataAccessMode::ForwardOnly => access_time <= current,
68 };
69
70 self.access_log.write().unwrap().push(DataAccess {
72 timestamp: current,
73 accessed_time: access_time,
74 access_type: access_type.to_string(),
75 allowed,
76 });
77
78 if !allowed {
79 if self.strict_mode {
80 return Err(ShapeError::RuntimeError {
81 message: format!(
82 "Future data access violation: Attempted to access data at {} while current time is {}",
83 access_time, current
84 ),
85 location: None,
86 });
87 } else {
88 eprintln!(
90 "WARNING: Future data access - accessing {} at current time {}",
91 access_time, current
92 );
93 }
94 }
95
96 Ok(())
97 }
98
99 pub fn check_row_index(&self, index: i32, _access_type: &str) -> Result<()> {
101 match self.mode {
102 DataAccessMode::Unrestricted => Ok(()), DataAccessMode::Restricted | DataAccessMode::ForwardOnly => {
104 if index > 0 {
105 let msg = format!(
106 "Future data access violation: Attempted to access data[{}] in restricted mode",
107 index
108 );
109
110 if self.strict_mode {
111 return Err(ShapeError::RuntimeError {
112 message: msg,
113 location: None,
114 });
115 } else {
116 eprintln!("WARNING: {}", msg);
117 }
118 }
119 Ok(())
120 }
121 }
122 }
123
124 pub fn get_access_log(&self) -> Vec<DataAccess> {
126 self.access_log.read().unwrap().clone()
127 }
128
129 pub fn clear_log(&self) {
131 self.access_log.write().unwrap().clear();
132 }
133
134 pub fn get_violation_summary(&self) -> LookAheadSummary {
136 let log = self.access_log.read().unwrap();
137 let violations: Vec<_> = log
138 .iter()
139 .filter(|access| !access.allowed)
140 .cloned()
141 .collect();
142
143 LookAheadSummary {
144 total_accesses: log.len(),
145 violations: violations.len(),
146 violation_details: violations,
147 }
148 }
149}
150
151impl Clone for LookAheadGuard {
152 fn clone(&self) -> Self {
153 Self {
154 mode: self.mode,
155 current_time: RwLock::new(*self.current_time.read().unwrap()),
156 strict_mode: self.strict_mode,
157 access_log: RwLock::new(self.access_log.read().unwrap().clone()),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct LookAheadSummary {
165 pub total_accesses: usize,
166 pub violations: usize,
167 pub violation_details: Vec<DataAccess>,
168}
169
170impl LookAheadSummary {
171 pub fn print_report(&self) {
172 println!("=== Data Access Validation Report ===");
173 println!("Total data accesses: {}", self.total_accesses);
174 println!("Violations found: {}", self.violations);
175
176 if self.violations > 0 {
177 println!("\nViolation Details:");
178 for (i, violation) in self.violation_details.iter().enumerate() {
179 println!(
180 " {}. At {}: Tried to access {} (type: {})",
181 i + 1,
182 violation.timestamp.format("%Y-%m-%d %H:%M:%S"),
183 violation.accessed_time.format("%Y-%m-%d %H:%M:%S"),
184 violation.access_type
185 );
186 }
187 }
188 }
189}