Skip to main content

shape_runtime/data/
load_query.rs

1//! Generic load query for industry-agnostic data loading
2//!
3//! LoadQuery represents a data loading request with arbitrary parameters.
4//! It's provider-agnostic - different providers interpret params differently.
5
6use super::{DataQuery, Timeframe};
7use shape_ast::error::{Result, ShapeError};
8use shape_value::ValueWord;
9use std::collections::HashMap;
10
11/// Generic data load request (industry-agnostic)
12///
13/// Supports arbitrary key-value parameters for maximum flexibility.
14///
15/// # Examples
16///
17/// Finance:
18/// ```ignore
19/// LoadQuery {
20///     provider: Some("data"),
21///     params: { "symbol": "ES", "from": "2023-01-01", "to": "2023-12-31" },
22///     target_type: Some("Candle"),
23/// }
24/// ```
25///
26/// Weather:
27/// ```ignore
28/// LoadQuery {
29///     provider: Some("weather_api"),
30///     params: { "station": "LAX", "metric": "temperature", "interval": "hourly" },
31///     target_type: Some("WeatherReading"),
32/// }
33/// ```
34#[derive(Debug, Clone)]
35pub struct LoadQuery {
36    /// Provider name (e.g., "data", "api", "warehouse")
37    /// If None, uses default provider
38    pub provider: Option<String>,
39
40    /// Generic parameters (arbitrary key-value)
41    pub params: HashMap<String, ValueWord>,
42
43    /// Target type name for validation (e.g., "Candle", "TickData")
44    /// If specified, validates DataFrame has required columns
45    pub target_type: Option<String>,
46
47    /// Optional column mapping override
48    /// Maps: target_field → source_column
49    pub column_mapping: Option<HashMap<String, String>>,
50}
51
52impl LoadQuery {
53    /// Create a new empty load query
54    pub fn new() -> Self {
55        Self {
56            provider: None,
57            params: HashMap::new(),
58            target_type: None,
59            column_mapping: None,
60        }
61    }
62
63    /// Set provider name
64    pub fn with_provider(mut self, name: &str) -> Self {
65        self.provider = Some(name.to_string());
66        self
67    }
68
69    /// Add a parameter
70    pub fn with_param(mut self, key: &str, value: ValueWord) -> Self {
71        self.params.insert(key.to_string(), value);
72        self
73    }
74
75    /// Set target type for validation
76    pub fn with_type(mut self, type_name: &str) -> Self {
77        self.target_type = Some(type_name.to_string());
78        self
79    }
80
81    /// Set column mapping
82    pub fn with_column_mapping(mut self, mapping: HashMap<String, String>) -> Self {
83        self.column_mapping = Some(mapping);
84        self
85    }
86
87    /// Convert to provider-specific DataQuery
88    ///
89    /// Extracts common parameters and builds a DataQuery.
90    /// Provider-specific logic for parameter interpretation.
91    ///
92    /// # Errors
93    ///
94    /// Returns error if required parameters are missing.
95    pub fn to_data_query(&self) -> Result<DataQuery> {
96        // Extract symbol (required)
97        let symbol = self
98            .params
99            .get("symbol")
100            .and_then(|v| v.as_str().map(|s| s.to_string()))
101            .ok_or_else(|| ShapeError::RuntimeError {
102                message: "data query requires 'symbol' parameter".to_string(),
103                location: None,
104            })?;
105
106        // Extract timeframe (optional, defaults to 1m)
107        let timeframe = self
108            .params
109            .get("timeframe")
110            .and_then(|v| {
111                if let Some(tf) = v.as_timeframe() {
112                    return Some(*tf);
113                }
114                if let Some(duration) = v.as_duration() {
115                    // Convert duration to timeframe
116                    let value = duration.value;
117                    let unit = duration.unit.clone();
118                    use crate::data::TimeframeUnit;
119                    use shape_ast::ast::DurationUnit;
120
121                    if value <= 0.0 || value.fract() != 0.0 {
122                        return None;
123                    }
124
125                    let tf_value = value as u32;
126                    let tf_unit = match unit {
127                        DurationUnit::Seconds => TimeframeUnit::Second,
128                        DurationUnit::Minutes => TimeframeUnit::Minute,
129                        DurationUnit::Hours => TimeframeUnit::Hour,
130                        DurationUnit::Days => TimeframeUnit::Day,
131                        DurationUnit::Weeks => TimeframeUnit::Week,
132                        DurationUnit::Months => TimeframeUnit::Month,
133                        _ => return None,
134                    };
135
136                    return Some(Timeframe::new(tf_value, tf_unit));
137                }
138                None
139            })
140            .unwrap_or(Timeframe::m1());
141
142        let mut query = DataQuery::new(&symbol, timeframe);
143
144        // Extract date range (from/to or start/end)
145        let start_ts = self
146            .params
147            .get("from")
148            .or_else(|| self.params.get("start"))
149            .and_then(|v| self.value_to_timestamp(v));
150
151        let end_ts = self
152            .params
153            .get("to")
154            .or_else(|| self.params.get("end"))
155            .and_then(|v| self.value_to_timestamp(v));
156
157        if let (Some(start), Some(end)) = (start_ts, end_ts) {
158            query = query.range(start, end);
159        }
160
161        // Extract limit
162        if let Some(limit) = self
163            .params
164            .get("limit")
165            .and_then(|v| v.as_f64().filter(|n| *n > 0.0).map(|n| n as usize))
166        {
167            query = query.limit(limit);
168        }
169
170        Ok(query)
171    }
172
173    /// Helper to convert ValueWord value to Unix timestamp
174    fn value_to_timestamp(&self, value: &ValueWord) -> Option<i64> {
175        if let Some(dt) = value.as_time() {
176            return Some(dt.timestamp());
177        }
178        if let Some(s) = value.as_str() {
179            // Parse date string "YYYY-MM-DD"
180            use chrono::{DateTime, NaiveDate, Utc};
181
182            if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
183                if let Some(dt) = date.and_hms_opt(0, 0, 0) {
184                    let utc_dt = DateTime::<Utc>::from_naive_utc_and_offset(dt, Utc);
185                    return Some(utc_dt.timestamp());
186                }
187            }
188            return None;
189        }
190        if let Some(n) = value.as_f64() {
191            return Some(n as i64); // Assume Unix timestamp
192        }
193        None
194    }
195}
196
197impl Default for LoadQuery {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::data::TimeframeUnit;
207    use std::sync::Arc;
208
209    #[test]
210    fn test_basic_query() {
211        let query = LoadQuery::new()
212            .with_provider("data")
213            .with_param("symbol", ValueWord::from_string(Arc::new("ES".to_string())))
214            .with_type("Candle");
215
216        assert_eq!(query.provider, Some("data".to_string()));
217        assert_eq!(query.target_type, Some("Candle".to_string()));
218        assert!(query.params.contains_key("symbol"));
219    }
220
221    #[test]
222    fn test_to_data_query() {
223        let query = LoadQuery::new()
224            .with_param(
225                "symbol",
226                ValueWord::from_string(Arc::new("AAPL".to_string())),
227            )
228            .with_param(
229                "timeframe",
230                ValueWord::from_timeframe(Timeframe::new(5, TimeframeUnit::Minute)),
231            );
232
233        let data_query = query.to_data_query().unwrap();
234        assert_eq!(data_query.id, "AAPL");
235        assert_eq!(
236            data_query.timeframe,
237            Timeframe::new(5, TimeframeUnit::Minute)
238        );
239    }
240
241    #[test]
242    fn test_to_data_query_with_range() {
243        use chrono::{DateTime, Utc};
244
245        let start_ts = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z")
246            .unwrap()
247            .with_timezone(&Utc);
248        let end_ts = DateTime::parse_from_rfc3339("2023-12-31T23:59:59Z")
249            .unwrap()
250            .with_timezone(&Utc);
251
252        let query = LoadQuery::new()
253            .with_param("symbol", ValueWord::from_string(Arc::new("ES".to_string())))
254            .with_param("from", ValueWord::from_time_utc(start_ts))
255            .with_param("to", ValueWord::from_time_utc(end_ts));
256
257        let data_query = query.to_data_query().unwrap();
258        assert_eq!(data_query.id, "ES");
259        assert!(data_query.start.is_some());
260        assert!(data_query.end.is_some());
261    }
262
263    #[test]
264    fn test_to_data_query_date_strings() {
265        let query = LoadQuery::new()
266            .with_param("symbol", ValueWord::from_string(Arc::new("ES".to_string())))
267            .with_param(
268                "from",
269                ValueWord::from_string(Arc::new("2023-01-01".to_string())),
270            )
271            .with_param(
272                "to",
273                ValueWord::from_string(Arc::new("2023-12-31".to_string())),
274            );
275
276        let data_query = query.to_data_query().unwrap();
277        assert_eq!(data_query.id, "ES");
278        assert!(data_query.start.is_some());
279        assert!(data_query.end.is_some());
280    }
281
282    #[test]
283    fn test_to_data_query_missing_symbol() {
284        let query = LoadQuery::new();
285        assert!(query.to_data_query().is_err());
286    }
287
288    #[test]
289    fn test_to_data_query_with_limit() {
290        let query = LoadQuery::new()
291            .with_param(
292                "symbol",
293                ValueWord::from_string(Arc::new("AAPL".to_string())),
294            )
295            .with_param("limit", ValueWord::from_f64(100.0));
296
297        let data_query = query.to_data_query().unwrap();
298        assert_eq!(data_query.limit, Some(100));
299    }
300}