Skip to main content

shape_runtime/
query_executor.rs

1//! High-level query execution API for Shape
2//!
3//! This module provides the main interface for executing Shape queries
4//! against data and generating results with statistics.
5
6use chrono::{DateTime, Datelike, Timelike, Utc};
7use serde::{Deserialize, Serialize};
8use shape_ast::error::{Result, ResultExt, ShapeError};
9use std::collections::HashMap;
10
11use crate::data::DataFrame;
12use crate::{QueryResult as RuntimeQueryResult, Runtime};
13use shape_ast::parser;
14
15/// Main query executor that orchestrates the entire Shape pipeline
16pub struct QueryExecutor {
17    runtime: Runtime,
18}
19
20/// Result of executing a Shape query
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct QueryResult {
23    /// The original query string
24    pub query: String,
25
26    /// Type of query executed
27    pub query_type: QueryType,
28
29    /// Pattern matches found
30    pub matches: Vec<PatternMatch>,
31
32    /// Statistics about the results
33    pub statistics: QueryStatistics,
34
35    /// Execution metadata
36    pub metadata: ExecutionMetadata,
37}
38
39/// Types of Shape queries
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum QueryType {
42    Find,
43    Scan,
44    Analyze,
45    Alert,
46}
47
48/// A single pattern match result
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PatternMatch {
51    /// Pattern name that matched
52    pub pattern_name: String,
53
54    /// ID (if applicable)
55    pub id: Option<String>,
56
57    /// Time when pattern was found
58    pub timestamp: DateTime<Utc>,
59
60    /// Row index where pattern starts
61    pub row_index: usize,
62
63    /// Number of elements in the pattern
64    pub pattern_length: usize,
65
66    /// Match confidence (0.0 to 1.0)
67    pub confidence: f64,
68
69    /// Additional pattern-specific data
70    pub attributes: serde_json::Value,
71}
72
73/// Statistics about query results
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct QueryStatistics {
76    /// Total number of matches
77    pub total_matches: usize,
78
79    /// Number of unique patterns found
80    pub unique_patterns: usize,
81
82    /// Time range analyzed
83    pub time_range: TimeRange,
84
85    /// Generic performance metrics
86    pub performance: PerformanceMetrics,
87
88    /// Pattern frequency
89    pub pattern_frequency: HashMap<String, usize>,
90
91    /// Time distribution of matches
92    pub time_distribution: TimeDistribution,
93}
94
95/// Time range information
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TimeRange {
98    pub start: DateTime<Utc>,
99    pub end: DateTime<Utc>,
100    pub row_count: usize,
101}
102
103/// Generic metrics for pattern matches
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct PerformanceMetrics {
106    /// Average confidence of matches
107    pub avg_confidence: f64,
108
109    /// Success rate (confidence > threshold)
110    pub success_rate: f64,
111
112    /// Average duration in elements
113    pub avg_duration: f64,
114}
115
116/// Time distribution of pattern matches
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct TimeDistribution {
119    /// Matches by hour of day
120    pub hourly: HashMap<u32, usize>,
121
122    /// Matches by day of week
123    pub daily: HashMap<String, usize>,
124
125    /// Matches by month
126    pub monthly: HashMap<String, usize>,
127}
128
129/// Metadata about query execution
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ExecutionMetadata {
132    /// When the query was executed
133    pub executed_at: DateTime<Utc>,
134
135    /// Execution time in milliseconds
136    pub execution_time_ms: u64,
137
138    /// Number of rows processed
139    pub rows_processed: usize,
140
141    /// Any warnings during execution
142    pub warnings: Vec<String>,
143}
144
145impl QueryExecutor {
146    /// Create a new query executor
147    pub fn new() -> Self {
148        Self {
149            runtime: Runtime::new(),
150        }
151    }
152
153    /// Execute a Shape query against data
154    pub fn execute(&mut self, query: &str, data: &DataFrame) -> Result<QueryResult> {
155        let start_time = std::time::Instant::now();
156        let executed_at = Utc::now();
157
158        // Parse the query
159        let program = parser::parse_program(query).with_context("Failed to parse Shape query")?;
160
161        // Load the program first
162        self.runtime
163            .load_program(&program, data)
164            .with_context("Failed to load program")?;
165
166        // Find and execute the first query item
167        let query_item = program
168            .items
169            .iter()
170            .find(|item| matches!(item, shape_ast::ast::Item::Query(_, _)))
171            .ok_or_else(|| ShapeError::RuntimeError {
172                message: "No query found in program".to_string(),
173                location: None,
174            })?;
175
176        let runtime_result = self
177            .runtime
178            .execute_query(query_item, data)
179            .with_context("Query execution failed")?;
180
181        // Convert runtime results to our result format
182        let query_result = self.build_query_result(
183            query,
184            runtime_result,
185            data,
186            executed_at,
187            start_time.elapsed(),
188        )?;
189
190        Ok(query_result)
191    }
192
193    /// Execute a query and return results in JSON format
194    pub fn execute_json(&mut self, query: &str, data: &DataFrame) -> Result<String> {
195        let result = self.execute(query, data)?;
196        let json = serde_json::to_string_pretty(&result).map_err(|e| ShapeError::RuntimeError {
197            message: format!("Failed to serialize result to JSON: {}", e),
198            location: None,
199        })?;
200        Ok(json)
201    }
202
203    /// Build the final query result from runtime results
204    fn build_query_result(
205        &self,
206        query: &str,
207        runtime_result: RuntimeQueryResult,
208        data: &DataFrame,
209        executed_at: DateTime<Utc>,
210        elapsed: std::time::Duration,
211    ) -> Result<QueryResult> {
212        // Extract matches from runtime result
213        let matches = self.extract_matches(&runtime_result, data)?;
214
215        // Calculate statistics
216        let statistics = self.calculate_statistics(&matches, data)?;
217
218        // Determine query type
219        let query_type = self.determine_query_type(query)?;
220
221        // Build metadata
222        let metadata = ExecutionMetadata {
223            executed_at,
224            execution_time_ms: elapsed.as_millis() as u64,
225            rows_processed: data.row_count(),
226            warnings: Vec::new(),
227        };
228
229        Ok(QueryResult {
230            query: query.to_string(),
231            query_type,
232            matches,
233            statistics,
234            metadata,
235        })
236    }
237
238    /// Extract pattern matches from runtime results
239    fn extract_matches(
240        &self,
241        runtime_result: &RuntimeQueryResult,
242        _data: &DataFrame,
243    ) -> Result<Vec<PatternMatch>> {
244        let mut matches = Vec::new();
245
246        if let Some(runtime_matches) = &runtime_result.matches {
247            for pm in runtime_matches {
248                matches.push(PatternMatch {
249                    pattern_name: pm.pattern_name.clone(),
250                    id: Some(pm.id.clone()),
251                    timestamp: pm.timestamp,
252                    row_index: pm.index,
253                    pattern_length: 1, // Default
254                    confidence: pm.confidence,
255                    attributes: pm.metadata.clone(),
256                });
257            }
258        }
259
260        Ok(matches)
261    }
262
263    /// Calculate statistics from matches
264    fn calculate_statistics(
265        &self,
266        matches: &[PatternMatch],
267        data: &DataFrame,
268    ) -> Result<QueryStatistics> {
269        // Calculate time range
270        let time_range = self.calculate_time_range(data)?;
271
272        // Calculate performance metrics
273        let performance = self.calculate_performance_metrics(matches)?;
274
275        // Calculate pattern frequency
276        let pattern_frequency = self.calculate_pattern_frequency(matches);
277
278        // Calculate time distribution
279        let time_distribution = self.calculate_time_distribution(matches)?;
280
281        Ok(QueryStatistics {
282            total_matches: matches.len(),
283            unique_patterns: pattern_frequency.len(),
284            time_range,
285            performance,
286            pattern_frequency,
287            time_distribution,
288        })
289    }
290
291    /// Calculate time range of data
292    fn calculate_time_range(&self, data: &DataFrame) -> Result<TimeRange> {
293        if data.is_empty() {
294            return Err(ShapeError::DataError {
295                message: "No rows in data".to_string(),
296                symbol: None,
297                timeframe: None,
298            });
299        }
300
301        let start_ts = data.get_timestamp(0).unwrap();
302        let last_ts = data.get_timestamp(data.row_count() - 1).unwrap();
303
304        Ok(TimeRange {
305            start: DateTime::from_timestamp(start_ts, 0).unwrap_or_else(Utc::now),
306            end: DateTime::from_timestamp(last_ts, 0).unwrap_or_else(Utc::now),
307            row_count: data.row_count(),
308        })
309    }
310
311    /// Calculate metrics from matches
312    fn calculate_performance_metrics(
313        &self,
314        matches: &[PatternMatch],
315    ) -> Result<PerformanceMetrics> {
316        if matches.is_empty() {
317            return Ok(PerformanceMetrics {
318                avg_confidence: 0.0,
319                success_rate: 0.0,
320                avg_duration: 0.0,
321            });
322        }
323
324        let mut confidences = Vec::new();
325        let mut successes = 0;
326        let mut durations = Vec::new();
327
328        for pattern_match in matches {
329            confidences.push(pattern_match.confidence);
330            if pattern_match.confidence > 0.5 {
331                successes += 1;
332            }
333            durations.push(pattern_match.pattern_length as f64);
334        }
335
336        let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
337        let success_rate = successes as f64 / matches.len() as f64;
338        let avg_duration = durations.iter().sum::<f64>() / durations.len() as f64;
339
340        Ok(PerformanceMetrics {
341            avg_confidence,
342            success_rate,
343            avg_duration,
344        })
345    }
346
347    /// Calculate pattern frequency
348    fn calculate_pattern_frequency(&self, matches: &[PatternMatch]) -> HashMap<String, usize> {
349        let mut frequency = HashMap::new();
350        for m in matches {
351            *frequency.entry(m.pattern_name.clone()).or_insert(0) += 1;
352        }
353        frequency
354    }
355
356    /// Calculate time distribution of matches
357    fn calculate_time_distribution(&self, matches: &[PatternMatch]) -> Result<TimeDistribution> {
358        let mut hourly = HashMap::new();
359        let mut daily = HashMap::new();
360        let mut monthly = HashMap::new();
361
362        for m in matches {
363            *hourly.entry(m.timestamp.hour()).or_insert(0) += 1;
364            *daily.entry(m.timestamp.weekday().to_string()).or_insert(0) += 1;
365            *monthly.entry(m.timestamp.month().to_string()).or_insert(0) += 1;
366        }
367
368        Ok(TimeDistribution {
369            hourly,
370            daily,
371            monthly,
372        })
373    }
374
375    /// Determine query type from query string
376    fn determine_query_type(&self, query: &str) -> Result<QueryType> {
377        let query_lower = query.to_lowercase();
378        if query_lower.contains("find") {
379            Ok(QueryType::Find)
380        } else if query_lower.contains("scan") {
381            Ok(QueryType::Scan)
382        } else if query_lower.contains("analyze") {
383            Ok(QueryType::Analyze)
384        } else if query_lower.contains("alert") {
385            Ok(QueryType::Alert)
386        } else {
387            Ok(QueryType::Find) // Default
388        }
389    }
390}
391
392impl Default for QueryExecutor {
393    fn default() -> Self {
394        Self::new()
395    }
396}