sklears_preprocessing/temporal/
temporal_features.rs

1//! Temporal feature extraction from datetime data
2//!
3//! This module provides comprehensive date/time feature extraction capabilities including:
4//! - Date component extraction (year, month, day, hour, minute, second)
5//! - Cyclical feature encoding (sin/cos transformations for periodic features)
6//! - Holiday and business day indicators
7//! - Time-based feature generation with timezone support
8
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::{Fit, Trained, Transform, Untrained},
13    types::Float,
14};
15use std::marker::PhantomData;
16
17use super::datetime_utils::{DateComponents, DateTime};
18
19/// Configuration for TemporalFeatureExtractor
20#[derive(Debug, Clone)]
21pub struct TemporalFeatureExtractorConfig {
22    /// Whether to extract year
23    pub extract_year: bool,
24    /// Whether to extract month
25    pub extract_month: bool,
26    /// Whether to extract day of month
27    pub extract_day: bool,
28    /// Whether to extract day of week (0=Monday, 6=Sunday)
29    pub extract_day_of_week: bool,
30    /// Whether to extract hour
31    pub extract_hour: bool,
32    /// Whether to extract minute
33    pub extract_minute: bool,
34    /// Whether to extract second
35    pub extract_second: bool,
36    /// Whether to extract quarter
37    pub extract_quarter: bool,
38    /// Whether to extract day of year
39    pub extract_day_of_year: bool,
40    /// Whether to extract week of year
41    pub extract_week_of_year: bool,
42    /// Whether to use cyclical encoding for periodic features (sin/cos)
43    pub cyclical_encoding: bool,
44    /// Whether to include holiday indicators
45    pub include_holidays: bool,
46    /// Whether to include business day indicators
47    pub include_business_days: bool,
48    /// Time zone offset in hours from UTC (for timezone-aware processing)
49    pub timezone_offset: Option<Float>,
50}
51
52impl Default for TemporalFeatureExtractorConfig {
53    fn default() -> Self {
54        Self {
55            extract_year: true,
56            extract_month: true,
57            extract_day: true,
58            extract_day_of_week: true,
59            extract_hour: false,
60            extract_minute: false,
61            extract_second: false,
62            extract_quarter: true,
63            extract_day_of_year: false,
64            extract_week_of_year: false,
65            cyclical_encoding: true,
66            include_holidays: false,
67            include_business_days: false,
68            timezone_offset: None,
69        }
70    }
71}
72
73/// TemporalFeatureExtractor for extracting features from datetime data
74#[derive(Debug, Clone)]
75pub struct TemporalFeatureExtractor<S> {
76    config: TemporalFeatureExtractorConfig,
77    feature_names_: Option<Vec<String>>,
78    n_features_out_: Option<usize>,
79    _phantom: PhantomData<S>,
80}
81
82impl TemporalFeatureExtractor<Untrained> {
83    /// Create a new TemporalFeatureExtractor
84    pub fn new() -> Self {
85        Self {
86            config: TemporalFeatureExtractorConfig::default(),
87            feature_names_: None,
88            n_features_out_: None,
89            _phantom: PhantomData,
90        }
91    }
92
93    /// Set whether to extract year
94    pub fn extract_year(mut self, extract_year: bool) -> Self {
95        self.config.extract_year = extract_year;
96        self
97    }
98
99    /// Set whether to extract month
100    pub fn extract_month(mut self, extract_month: bool) -> Self {
101        self.config.extract_month = extract_month;
102        self
103    }
104
105    /// Set whether to extract day
106    pub fn extract_day(mut self, extract_day: bool) -> Self {
107        self.config.extract_day = extract_day;
108        self
109    }
110
111    /// Set whether to extract day of week
112    pub fn extract_day_of_week(mut self, extract_day_of_week: bool) -> Self {
113        self.config.extract_day_of_week = extract_day_of_week;
114        self
115    }
116
117    /// Set whether to extract hour
118    pub fn extract_hour(mut self, extract_hour: bool) -> Self {
119        self.config.extract_hour = extract_hour;
120        self
121    }
122
123    /// Set whether to extract minute
124    pub fn extract_minute(mut self, extract_minute: bool) -> Self {
125        self.config.extract_minute = extract_minute;
126        self
127    }
128
129    /// Set whether to extract second
130    pub fn extract_second(mut self, extract_second: bool) -> Self {
131        self.config.extract_second = extract_second;
132        self
133    }
134
135    /// Set whether to extract quarter
136    pub fn extract_quarter(mut self, extract_quarter: bool) -> Self {
137        self.config.extract_quarter = extract_quarter;
138        self
139    }
140
141    /// Set whether to extract day of year
142    pub fn extract_day_of_year(mut self, extract_day_of_year: bool) -> Self {
143        self.config.extract_day_of_year = extract_day_of_year;
144        self
145    }
146
147    /// Set whether to extract week of year
148    pub fn extract_week_of_year(mut self, extract_week_of_year: bool) -> Self {
149        self.config.extract_week_of_year = extract_week_of_year;
150        self
151    }
152
153    /// Set whether to use cyclical encoding
154    pub fn cyclical_encoding(mut self, cyclical_encoding: bool) -> Self {
155        self.config.cyclical_encoding = cyclical_encoding;
156        self
157    }
158
159    /// Set whether to include holiday indicators
160    pub fn include_holidays(mut self, include_holidays: bool) -> Self {
161        self.config.include_holidays = include_holidays;
162        self
163    }
164
165    /// Set whether to include business day indicators
166    pub fn include_business_days(mut self, include_business_days: bool) -> Self {
167        self.config.include_business_days = include_business_days;
168        self
169    }
170
171    /// Set timezone offset in hours
172    pub fn timezone_offset(mut self, timezone_offset: Float) -> Self {
173        self.config.timezone_offset = Some(timezone_offset);
174        self
175    }
176
177    /// Calculate the number of output features based on configuration
178    fn calculate_n_features_out(&self) -> usize {
179        let mut count = 0;
180
181        if self.config.extract_year {
182            count += 1;
183        }
184
185        if self.config.extract_month {
186            count += if self.config.cyclical_encoding { 2 } else { 1 };
187        }
188
189        if self.config.extract_day {
190            count += if self.config.cyclical_encoding { 2 } else { 1 };
191        }
192
193        if self.config.extract_day_of_week {
194            count += if self.config.cyclical_encoding { 2 } else { 1 };
195        }
196
197        if self.config.extract_hour {
198            count += if self.config.cyclical_encoding { 2 } else { 1 };
199        }
200
201        if self.config.extract_minute {
202            count += if self.config.cyclical_encoding { 2 } else { 1 };
203        }
204
205        if self.config.extract_second {
206            count += if self.config.cyclical_encoding { 2 } else { 1 };
207        }
208
209        if self.config.extract_quarter {
210            count += if self.config.cyclical_encoding { 2 } else { 1 };
211        }
212
213        if self.config.extract_day_of_year {
214            count += if self.config.cyclical_encoding { 2 } else { 1 };
215        }
216
217        if self.config.extract_week_of_year {
218            count += if self.config.cyclical_encoding { 2 } else { 1 };
219        }
220
221        if self.config.include_holidays {
222            count += 1;
223        }
224
225        if self.config.include_business_days {
226            count += 1;
227        }
228
229        count
230    }
231
232    /// Generate feature names based on configuration
233    fn generate_feature_names(&self) -> Vec<String> {
234        let mut names = Vec::new();
235
236        if self.config.extract_year {
237            names.push("year".to_string());
238        }
239
240        if self.config.extract_month {
241            if self.config.cyclical_encoding {
242                names.push("month_sin".to_string());
243                names.push("month_cos".to_string());
244            } else {
245                names.push("month".to_string());
246            }
247        }
248
249        if self.config.extract_day {
250            if self.config.cyclical_encoding {
251                names.push("day_sin".to_string());
252                names.push("day_cos".to_string());
253            } else {
254                names.push("day".to_string());
255            }
256        }
257
258        if self.config.extract_day_of_week {
259            if self.config.cyclical_encoding {
260                names.push("day_of_week_sin".to_string());
261                names.push("day_of_week_cos".to_string());
262            } else {
263                names.push("day_of_week".to_string());
264            }
265        }
266
267        if self.config.extract_hour {
268            if self.config.cyclical_encoding {
269                names.push("hour_sin".to_string());
270                names.push("hour_cos".to_string());
271            } else {
272                names.push("hour".to_string());
273            }
274        }
275
276        if self.config.extract_minute {
277            if self.config.cyclical_encoding {
278                names.push("minute_sin".to_string());
279                names.push("minute_cos".to_string());
280            } else {
281                names.push("minute".to_string());
282            }
283        }
284
285        if self.config.extract_second {
286            if self.config.cyclical_encoding {
287                names.push("second_sin".to_string());
288                names.push("second_cos".to_string());
289            } else {
290                names.push("second".to_string());
291            }
292        }
293
294        if self.config.extract_quarter {
295            if self.config.cyclical_encoding {
296                names.push("quarter_sin".to_string());
297                names.push("quarter_cos".to_string());
298            } else {
299                names.push("quarter".to_string());
300            }
301        }
302
303        if self.config.extract_day_of_year {
304            if self.config.cyclical_encoding {
305                names.push("day_of_year_sin".to_string());
306                names.push("day_of_year_cos".to_string());
307            } else {
308                names.push("day_of_year".to_string());
309            }
310        }
311
312        if self.config.extract_week_of_year {
313            if self.config.cyclical_encoding {
314                names.push("week_of_year_sin".to_string());
315                names.push("week_of_year_cos".to_string());
316            } else {
317                names.push("week_of_year".to_string());
318            }
319        }
320
321        if self.config.include_holidays {
322            names.push("is_holiday".to_string());
323        }
324
325        if self.config.include_business_days {
326            names.push("is_business_day".to_string());
327        }
328
329        names
330    }
331}
332
333impl TemporalFeatureExtractor<Trained> {
334    /// Get the feature names
335    pub fn feature_names(&self) -> &[String] {
336        self.feature_names_
337            .as_ref()
338            .expect("Extractor should be fitted")
339    }
340
341    /// Get the number of output features
342    pub fn n_features_out(&self) -> usize {
343        self.n_features_out_.expect("Extractor should be fitted")
344    }
345
346    /// Check if a date is a holiday (simplified implementation)
347    fn is_holiday(&self, components: &DateComponents) -> bool {
348        // Simplified holiday detection - only major US holidays
349        match (components.month, components.day) {
350            (1, 1) => true,   // New Year's Day
351            (7, 4) => true,   // Independence Day
352            (12, 25) => true, // Christmas
353            _ => false,
354        }
355    }
356
357    /// Check if a date is a business day (Monday-Friday, not holiday)
358    fn is_business_day(&self, components: &DateComponents) -> bool {
359        let is_weekday = components.day_of_week < 5; // Monday (0) to Friday (4)
360        let is_not_holiday = if self.config.include_holidays {
361            !self.is_holiday(components)
362        } else {
363            true
364        };
365        is_weekday && is_not_holiday
366    }
367
368    /// Convert periodic value to cyclical encoding (sin/cos)
369    fn to_cyclical(&self, value: Float, period: Float) -> (Float, Float) {
370        let angle = 2.0 * std::f64::consts::PI * (value / period);
371        (angle.sin(), angle.cos())
372    }
373
374    /// Extract features from a single timestamp
375    fn extract_features_from_timestamp(&self, timestamp: Float) -> Array1<Float> {
376        let datetime = DateTime::from_timestamp(timestamp as i64);
377        let components = datetime.to_components(self.config.timezone_offset);
378
379        let mut features = Vec::new();
380
381        if self.config.extract_year {
382            features.push(components.year as Float);
383        }
384
385        if self.config.extract_month {
386            if self.config.cyclical_encoding {
387                let (sin, cos) = self.to_cyclical(components.month as Float, 12.0);
388                features.push(sin);
389                features.push(cos);
390            } else {
391                features.push(components.month as Float);
392            }
393        }
394
395        if self.config.extract_day {
396            if self.config.cyclical_encoding {
397                let (sin, cos) = self.to_cyclical(components.day as Float, 31.0);
398                features.push(sin);
399                features.push(cos);
400            } else {
401                features.push(components.day as Float);
402            }
403        }
404
405        if self.config.extract_day_of_week {
406            if self.config.cyclical_encoding {
407                let (sin, cos) = self.to_cyclical(components.day_of_week as Float, 7.0);
408                features.push(sin);
409                features.push(cos);
410            } else {
411                features.push(components.day_of_week as Float);
412            }
413        }
414
415        if self.config.extract_hour {
416            if self.config.cyclical_encoding {
417                let (sin, cos) = self.to_cyclical(components.hour as Float, 24.0);
418                features.push(sin);
419                features.push(cos);
420            } else {
421                features.push(components.hour as Float);
422            }
423        }
424
425        if self.config.extract_minute {
426            if self.config.cyclical_encoding {
427                let (sin, cos) = self.to_cyclical(components.minute as Float, 60.0);
428                features.push(sin);
429                features.push(cos);
430            } else {
431                features.push(components.minute as Float);
432            }
433        }
434
435        if self.config.extract_second {
436            if self.config.cyclical_encoding {
437                let (sin, cos) = self.to_cyclical(components.second as Float, 60.0);
438                features.push(sin);
439                features.push(cos);
440            } else {
441                features.push(components.second as Float);
442            }
443        }
444
445        if self.config.extract_quarter {
446            if self.config.cyclical_encoding {
447                let (sin, cos) = self.to_cyclical(components.quarter as Float, 4.0);
448                features.push(sin);
449                features.push(cos);
450            } else {
451                features.push(components.quarter as Float);
452            }
453        }
454
455        if self.config.extract_day_of_year {
456            if self.config.cyclical_encoding {
457                let (sin, cos) = self.to_cyclical(components.day_of_year as Float, 366.0);
458                features.push(sin);
459                features.push(cos);
460            } else {
461                features.push(components.day_of_year as Float);
462            }
463        }
464
465        if self.config.extract_week_of_year {
466            if self.config.cyclical_encoding {
467                let (sin, cos) = self.to_cyclical(components.week_of_year as Float, 53.0);
468                features.push(sin);
469                features.push(cos);
470            } else {
471                features.push(components.week_of_year as Float);
472            }
473        }
474
475        if self.config.include_holidays {
476            features.push(if self.is_holiday(&components) {
477                1.0
478            } else {
479                0.0
480            });
481        }
482
483        if self.config.include_business_days {
484            features.push(if self.is_business_day(&components) {
485                1.0
486            } else {
487                0.0
488            });
489        }
490
491        Array1::from_vec(features)
492    }
493}
494
495impl Default for TemporalFeatureExtractor<Untrained> {
496    fn default() -> Self {
497        Self::new()
498    }
499}
500
501impl Fit<Array1<Float>, ()> for TemporalFeatureExtractor<Untrained> {
502    type Fitted = TemporalFeatureExtractor<Trained>;
503
504    fn fit(self, _x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
505        let n_features_out = self.calculate_n_features_out();
506        let feature_names = self.generate_feature_names();
507
508        if n_features_out == 0 {
509            return Err(SklearsError::InvalidParameter {
510                name: "feature_extraction".to_string(),
511                reason: "No features selected for extraction".to_string(),
512            });
513        }
514
515        Ok(TemporalFeatureExtractor {
516            config: self.config,
517            feature_names_: Some(feature_names),
518            n_features_out_: Some(n_features_out),
519            _phantom: PhantomData,
520        })
521    }
522}
523
524impl Transform<Array1<Float>, Array2<Float>> for TemporalFeatureExtractor<Trained> {
525    fn transform(&self, x: &Array1<Float>) -> Result<Array2<Float>> {
526        let n_samples = x.len();
527        let n_features_out = self.n_features_out();
528
529        let mut result = Array2::<Float>::zeros((n_samples, n_features_out));
530
531        for (i, &timestamp) in x.iter().enumerate() {
532            let features = self.extract_features_from_timestamp(timestamp);
533            for (j, &feature_value) in features.iter().enumerate() {
534                result[[i, j]] = feature_value;
535            }
536        }
537
538        Ok(result)
539    }
540}