sklears_preprocessing/temporal/
temporal_features.rs1use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Fit, Trained, Transform, Untrained},
13 types::Float,
14};
15use std::marker::PhantomData;
16
17use super::datetime_utils::{DateComponents, DateTime};
18
19#[derive(Debug, Clone)]
21pub struct TemporalFeatureExtractorConfig {
22 pub extract_year: bool,
24 pub extract_month: bool,
26 pub extract_day: bool,
28 pub extract_day_of_week: bool,
30 pub extract_hour: bool,
32 pub extract_minute: bool,
34 pub extract_second: bool,
36 pub extract_quarter: bool,
38 pub extract_day_of_year: bool,
40 pub extract_week_of_year: bool,
42 pub cyclical_encoding: bool,
44 pub include_holidays: bool,
46 pub include_business_days: bool,
48 pub timezone_offset: Option<Float>,
50}
51
52impl Default for TemporalFeatureExtractorConfig {
53 fn default() -> Self {
54 Self {
55 extract_year: true,
56 extract_month: true,
57 extract_day: true,
58 extract_day_of_week: true,
59 extract_hour: false,
60 extract_minute: false,
61 extract_second: false,
62 extract_quarter: true,
63 extract_day_of_year: false,
64 extract_week_of_year: false,
65 cyclical_encoding: true,
66 include_holidays: false,
67 include_business_days: false,
68 timezone_offset: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct TemporalFeatureExtractor<S> {
76 config: TemporalFeatureExtractorConfig,
77 feature_names_: Option<Vec<String>>,
78 n_features_out_: Option<usize>,
79 _phantom: PhantomData<S>,
80}
81
82impl TemporalFeatureExtractor<Untrained> {
83 pub fn new() -> Self {
85 Self {
86 config: TemporalFeatureExtractorConfig::default(),
87 feature_names_: None,
88 n_features_out_: None,
89 _phantom: PhantomData,
90 }
91 }
92
93 pub fn extract_year(mut self, extract_year: bool) -> Self {
95 self.config.extract_year = extract_year;
96 self
97 }
98
99 pub fn extract_month(mut self, extract_month: bool) -> Self {
101 self.config.extract_month = extract_month;
102 self
103 }
104
105 pub fn extract_day(mut self, extract_day: bool) -> Self {
107 self.config.extract_day = extract_day;
108 self
109 }
110
111 pub fn extract_day_of_week(mut self, extract_day_of_week: bool) -> Self {
113 self.config.extract_day_of_week = extract_day_of_week;
114 self
115 }
116
117 pub fn extract_hour(mut self, extract_hour: bool) -> Self {
119 self.config.extract_hour = extract_hour;
120 self
121 }
122
123 pub fn extract_minute(mut self, extract_minute: bool) -> Self {
125 self.config.extract_minute = extract_minute;
126 self
127 }
128
129 pub fn extract_second(mut self, extract_second: bool) -> Self {
131 self.config.extract_second = extract_second;
132 self
133 }
134
135 pub fn extract_quarter(mut self, extract_quarter: bool) -> Self {
137 self.config.extract_quarter = extract_quarter;
138 self
139 }
140
141 pub fn extract_day_of_year(mut self, extract_day_of_year: bool) -> Self {
143 self.config.extract_day_of_year = extract_day_of_year;
144 self
145 }
146
147 pub fn extract_week_of_year(mut self, extract_week_of_year: bool) -> Self {
149 self.config.extract_week_of_year = extract_week_of_year;
150 self
151 }
152
153 pub fn cyclical_encoding(mut self, cyclical_encoding: bool) -> Self {
155 self.config.cyclical_encoding = cyclical_encoding;
156 self
157 }
158
159 pub fn include_holidays(mut self, include_holidays: bool) -> Self {
161 self.config.include_holidays = include_holidays;
162 self
163 }
164
165 pub fn include_business_days(mut self, include_business_days: bool) -> Self {
167 self.config.include_business_days = include_business_days;
168 self
169 }
170
171 pub fn timezone_offset(mut self, timezone_offset: Float) -> Self {
173 self.config.timezone_offset = Some(timezone_offset);
174 self
175 }
176
177 fn calculate_n_features_out(&self) -> usize {
179 let mut count = 0;
180
181 if self.config.extract_year {
182 count += 1;
183 }
184
185 if self.config.extract_month {
186 count += if self.config.cyclical_encoding { 2 } else { 1 };
187 }
188
189 if self.config.extract_day {
190 count += if self.config.cyclical_encoding { 2 } else { 1 };
191 }
192
193 if self.config.extract_day_of_week {
194 count += if self.config.cyclical_encoding { 2 } else { 1 };
195 }
196
197 if self.config.extract_hour {
198 count += if self.config.cyclical_encoding { 2 } else { 1 };
199 }
200
201 if self.config.extract_minute {
202 count += if self.config.cyclical_encoding { 2 } else { 1 };
203 }
204
205 if self.config.extract_second {
206 count += if self.config.cyclical_encoding { 2 } else { 1 };
207 }
208
209 if self.config.extract_quarter {
210 count += if self.config.cyclical_encoding { 2 } else { 1 };
211 }
212
213 if self.config.extract_day_of_year {
214 count += if self.config.cyclical_encoding { 2 } else { 1 };
215 }
216
217 if self.config.extract_week_of_year {
218 count += if self.config.cyclical_encoding { 2 } else { 1 };
219 }
220
221 if self.config.include_holidays {
222 count += 1;
223 }
224
225 if self.config.include_business_days {
226 count += 1;
227 }
228
229 count
230 }
231
232 fn generate_feature_names(&self) -> Vec<String> {
234 let mut names = Vec::new();
235
236 if self.config.extract_year {
237 names.push("year".to_string());
238 }
239
240 if self.config.extract_month {
241 if self.config.cyclical_encoding {
242 names.push("month_sin".to_string());
243 names.push("month_cos".to_string());
244 } else {
245 names.push("month".to_string());
246 }
247 }
248
249 if self.config.extract_day {
250 if self.config.cyclical_encoding {
251 names.push("day_sin".to_string());
252 names.push("day_cos".to_string());
253 } else {
254 names.push("day".to_string());
255 }
256 }
257
258 if self.config.extract_day_of_week {
259 if self.config.cyclical_encoding {
260 names.push("day_of_week_sin".to_string());
261 names.push("day_of_week_cos".to_string());
262 } else {
263 names.push("day_of_week".to_string());
264 }
265 }
266
267 if self.config.extract_hour {
268 if self.config.cyclical_encoding {
269 names.push("hour_sin".to_string());
270 names.push("hour_cos".to_string());
271 } else {
272 names.push("hour".to_string());
273 }
274 }
275
276 if self.config.extract_minute {
277 if self.config.cyclical_encoding {
278 names.push("minute_sin".to_string());
279 names.push("minute_cos".to_string());
280 } else {
281 names.push("minute".to_string());
282 }
283 }
284
285 if self.config.extract_second {
286 if self.config.cyclical_encoding {
287 names.push("second_sin".to_string());
288 names.push("second_cos".to_string());
289 } else {
290 names.push("second".to_string());
291 }
292 }
293
294 if self.config.extract_quarter {
295 if self.config.cyclical_encoding {
296 names.push("quarter_sin".to_string());
297 names.push("quarter_cos".to_string());
298 } else {
299 names.push("quarter".to_string());
300 }
301 }
302
303 if self.config.extract_day_of_year {
304 if self.config.cyclical_encoding {
305 names.push("day_of_year_sin".to_string());
306 names.push("day_of_year_cos".to_string());
307 } else {
308 names.push("day_of_year".to_string());
309 }
310 }
311
312 if self.config.extract_week_of_year {
313 if self.config.cyclical_encoding {
314 names.push("week_of_year_sin".to_string());
315 names.push("week_of_year_cos".to_string());
316 } else {
317 names.push("week_of_year".to_string());
318 }
319 }
320
321 if self.config.include_holidays {
322 names.push("is_holiday".to_string());
323 }
324
325 if self.config.include_business_days {
326 names.push("is_business_day".to_string());
327 }
328
329 names
330 }
331}
332
333impl TemporalFeatureExtractor<Trained> {
334 pub fn feature_names(&self) -> &[String] {
336 self.feature_names_
337 .as_ref()
338 .expect("Extractor should be fitted")
339 }
340
341 pub fn n_features_out(&self) -> usize {
343 self.n_features_out_.expect("Extractor should be fitted")
344 }
345
346 fn is_holiday(&self, components: &DateComponents) -> bool {
348 match (components.month, components.day) {
350 (1, 1) => true, (7, 4) => true, (12, 25) => true, _ => false,
354 }
355 }
356
357 fn is_business_day(&self, components: &DateComponents) -> bool {
359 let is_weekday = components.day_of_week < 5; let is_not_holiday = if self.config.include_holidays {
361 !self.is_holiday(components)
362 } else {
363 true
364 };
365 is_weekday && is_not_holiday
366 }
367
368 fn to_cyclical(&self, value: Float, period: Float) -> (Float, Float) {
370 let angle = 2.0 * std::f64::consts::PI * (value / period);
371 (angle.sin(), angle.cos())
372 }
373
374 fn extract_features_from_timestamp(&self, timestamp: Float) -> Array1<Float> {
376 let datetime = DateTime::from_timestamp(timestamp as i64);
377 let components = datetime.to_components(self.config.timezone_offset);
378
379 let mut features = Vec::new();
380
381 if self.config.extract_year {
382 features.push(components.year as Float);
383 }
384
385 if self.config.extract_month {
386 if self.config.cyclical_encoding {
387 let (sin, cos) = self.to_cyclical(components.month as Float, 12.0);
388 features.push(sin);
389 features.push(cos);
390 } else {
391 features.push(components.month as Float);
392 }
393 }
394
395 if self.config.extract_day {
396 if self.config.cyclical_encoding {
397 let (sin, cos) = self.to_cyclical(components.day as Float, 31.0);
398 features.push(sin);
399 features.push(cos);
400 } else {
401 features.push(components.day as Float);
402 }
403 }
404
405 if self.config.extract_day_of_week {
406 if self.config.cyclical_encoding {
407 let (sin, cos) = self.to_cyclical(components.day_of_week as Float, 7.0);
408 features.push(sin);
409 features.push(cos);
410 } else {
411 features.push(components.day_of_week as Float);
412 }
413 }
414
415 if self.config.extract_hour {
416 if self.config.cyclical_encoding {
417 let (sin, cos) = self.to_cyclical(components.hour as Float, 24.0);
418 features.push(sin);
419 features.push(cos);
420 } else {
421 features.push(components.hour as Float);
422 }
423 }
424
425 if self.config.extract_minute {
426 if self.config.cyclical_encoding {
427 let (sin, cos) = self.to_cyclical(components.minute as Float, 60.0);
428 features.push(sin);
429 features.push(cos);
430 } else {
431 features.push(components.minute as Float);
432 }
433 }
434
435 if self.config.extract_second {
436 if self.config.cyclical_encoding {
437 let (sin, cos) = self.to_cyclical(components.second as Float, 60.0);
438 features.push(sin);
439 features.push(cos);
440 } else {
441 features.push(components.second as Float);
442 }
443 }
444
445 if self.config.extract_quarter {
446 if self.config.cyclical_encoding {
447 let (sin, cos) = self.to_cyclical(components.quarter as Float, 4.0);
448 features.push(sin);
449 features.push(cos);
450 } else {
451 features.push(components.quarter as Float);
452 }
453 }
454
455 if self.config.extract_day_of_year {
456 if self.config.cyclical_encoding {
457 let (sin, cos) = self.to_cyclical(components.day_of_year as Float, 366.0);
458 features.push(sin);
459 features.push(cos);
460 } else {
461 features.push(components.day_of_year as Float);
462 }
463 }
464
465 if self.config.extract_week_of_year {
466 if self.config.cyclical_encoding {
467 let (sin, cos) = self.to_cyclical(components.week_of_year as Float, 53.0);
468 features.push(sin);
469 features.push(cos);
470 } else {
471 features.push(components.week_of_year as Float);
472 }
473 }
474
475 if self.config.include_holidays {
476 features.push(if self.is_holiday(&components) {
477 1.0
478 } else {
479 0.0
480 });
481 }
482
483 if self.config.include_business_days {
484 features.push(if self.is_business_day(&components) {
485 1.0
486 } else {
487 0.0
488 });
489 }
490
491 Array1::from_vec(features)
492 }
493}
494
495impl Default for TemporalFeatureExtractor<Untrained> {
496 fn default() -> Self {
497 Self::new()
498 }
499}
500
501impl Fit<Array1<Float>, ()> for TemporalFeatureExtractor<Untrained> {
502 type Fitted = TemporalFeatureExtractor<Trained>;
503
504 fn fit(self, _x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
505 let n_features_out = self.calculate_n_features_out();
506 let feature_names = self.generate_feature_names();
507
508 if n_features_out == 0 {
509 return Err(SklearsError::InvalidParameter {
510 name: "feature_extraction".to_string(),
511 reason: "No features selected for extraction".to_string(),
512 });
513 }
514
515 Ok(TemporalFeatureExtractor {
516 config: self.config,
517 feature_names_: Some(feature_names),
518 n_features_out_: Some(n_features_out),
519 _phantom: PhantomData,
520 })
521 }
522}
523
524impl Transform<Array1<Float>, Array2<Float>> for TemporalFeatureExtractor<Trained> {
525 fn transform(&self, x: &Array1<Float>) -> Result<Array2<Float>> {
526 let n_samples = x.len();
527 let n_features_out = self.n_features_out();
528
529 let mut result = Array2::<Float>::zeros((n_samples, n_features_out));
530
531 for (i, ×tamp) in x.iter().enumerate() {
532 let features = self.extract_features_from_timestamp(timestamp);
533 for (j, &feature_value) in features.iter().enumerate() {
534 result[[i, j]] = feature_value;
535 }
536 }
537
538 Ok(result)
539 }
540}