vegafusion_runtime/transform/
timeunit.rs

1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4use datafusion::prelude::DataFrame;
5use datafusion_common::DFSchema;
6use datafusion_functions::expr_fn::{date_part, date_trunc};
7use std::collections::HashSet;
8use std::ops::{Add, Mul, Rem, Sub};
9use vegafusion_common::arrow::datatypes::{DataType, TimeUnit as ArrowTimeUnit};
10use vegafusion_core::error::{Result, ResultWithContext, VegaFusionError};
11use vegafusion_core::proto::gen::transforms::{TimeUnit, TimeUnitTimeZone, TimeUnitUnit};
12use vegafusion_core::task_graph::task_value::TaskValue;
13
14use crate::datafusion::udfs::datetime::make_timestamptz::make_timestamptz;
15use crate::datafusion::udfs::datetime::timeunit::TIMEUNIT_START_UDF;
16use crate::expression::compiler::utils::ExprHelpers;
17use crate::transform::utils::{from_epoch_millis, str_to_timestamp};
18use datafusion_expr::{interval_datetime_lit, interval_year_month_lit, lit, Expr, ExprSchemable};
19use itertools::Itertools;
20use vegafusion_common::column::{flat_col, unescaped_col};
21use vegafusion_common::datatypes::{cast_to, is_numeric_datatype};
22
23/// Implementation of timeunit start using the SQL date_trunc function
24fn timeunit_date_trunc(
25    field: &str,
26    smallest_unit: TimeUnitUnit,
27    schema: &DFSchema,
28    default_input_tz: &str,
29    tz: &str,
30) -> Result<(Expr, Expr)> {
31    // Convert field to timestamp in target timezone
32    let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
33        &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
34        schema,
35    )?;
36
37    // Handle Sunday-based weeks as special case
38    if let TimeUnitUnit::Week = smallest_unit {
39        let day_interval = interval_datetime_lit("1 day");
40        let trunc_expr =
41            date_trunc(lit("week"), field_col.add(day_interval.clone())).sub(day_interval);
42        let interval = interval_datetime_lit("7 day");
43        return Ok((trunc_expr, interval));
44    }
45
46    // Handle uniform case
47    let (part_str, interval_expr) = match smallest_unit {
48        TimeUnitUnit::Year => ("year", interval_year_month_lit("1 year")),
49        TimeUnitUnit::Quarter => ("quarter", interval_year_month_lit("3 month")),
50        TimeUnitUnit::Month => ("month", interval_year_month_lit("1 month")),
51        TimeUnitUnit::Date => ("day", interval_datetime_lit("1 day")),
52        TimeUnitUnit::Hours => ("hour", interval_datetime_lit("1 hour")),
53        TimeUnitUnit::Minutes => ("minute", interval_datetime_lit("1 minute")),
54        TimeUnitUnit::Seconds => ("second", interval_datetime_lit("1 second")),
55        TimeUnitUnit::Milliseconds => ("millisecond", interval_datetime_lit("1 millisecond")),
56        _ => {
57            return Err(VegaFusionError::internal(format!(
58                "Unsupported date trunc unit: {smallest_unit:?}"
59            )))
60        }
61    };
62
63    // date_trunc after converting to the required timezone (will be the local_tz or UTC)
64    let trunc_expr = date_trunc(lit(part_str), field_col);
65
66    Ok((trunc_expr, interval_expr))
67}
68
69/// Implementation of timeunit start using make_timestamptz and the date_part functions
70fn timeunit_date_part_tz(
71    field: &str,
72    units_set: &HashSet<TimeUnitUnit>,
73    schema: &DFSchema,
74    default_input_tz: &str,
75    tz: &str,
76) -> Result<(Expr, Expr)> {
77    let mut year_arg = lit(2012);
78    let mut month_arg = lit(1);
79    let mut date_arg = lit(1);
80    let mut hour_arg = lit(0);
81    let mut minute_arg = lit(0);
82    let mut second_arg = lit(0);
83    let mut millisecond_arg = lit(0);
84
85    // Initialize interval string, this will be overwritten with the smallest specified unit
86    let mut interval = interval_year_month_lit("1 year");
87
88    // Convert field column to timestamp
89    let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
90        &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
91        schema,
92    )?;
93
94    // Year
95    if units_set.contains(&TimeUnitUnit::Year) {
96        year_arg = date_part(lit("year"), field_col.clone());
97        interval = interval_year_month_lit("1 year");
98    }
99
100    // Quarter
101    if units_set.contains(&TimeUnitUnit::Quarter) {
102        // Compute month (1-based) from the extracted quarter (1-based)
103        let month_from_quarter = date_part(lit("quarter"), field_col.clone())
104            .sub(lit(1))
105            .mul(lit(3))
106            .add(lit(1));
107
108        month_arg = month_from_quarter;
109        interval = interval_year_month_lit("3 month");
110    }
111
112    // Month
113    if units_set.contains(&TimeUnitUnit::Month) {
114        month_arg = date_part(lit("month"), field_col.clone());
115        interval = interval_year_month_lit("1 month");
116    }
117
118    // Date
119    if units_set.contains(&TimeUnitUnit::Date) {
120        date_arg = date_part(lit("day"), field_col.clone());
121        interval = interval_datetime_lit("1 day");
122    }
123
124    // Hour
125    if units_set.contains(&TimeUnitUnit::Hours) {
126        hour_arg = date_part(lit("hour"), field_col.clone());
127        interval = interval_datetime_lit("1 hour");
128    }
129
130    // Minute
131    if units_set.contains(&TimeUnitUnit::Minutes) {
132        minute_arg = date_part(lit("minute"), field_col.clone());
133        interval = interval_datetime_lit("1 minute");
134    }
135
136    // Second
137    if units_set.contains(&TimeUnitUnit::Seconds) {
138        second_arg = date_part(lit("second"), field_col.clone());
139        interval = interval_datetime_lit("1 second");
140    }
141
142    // Millisecond
143    if units_set.contains(&TimeUnitUnit::Seconds) {
144        millisecond_arg = date_part(lit("millisecond"), field_col.clone()).rem(lit(1000));
145        interval = interval_datetime_lit("1 millisecond");
146    }
147
148    // Construct expression to make timestamp from components
149    let start_expr = make_timestamptz(
150        year_arg,
151        month_arg,
152        date_arg,
153        hour_arg,
154        minute_arg,
155        second_arg,
156        millisecond_arg,
157        tz,
158    );
159
160    Ok((start_expr, interval))
161}
162
163/// timeunit transform for 'day' unit (day of the week)
164fn timeunit_weekday(
165    field: &str,
166    schema: &DFSchema,
167    default_input_tz: &str,
168    tz: &str,
169) -> Result<(Expr, Expr)> {
170    let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
171        &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
172        schema,
173    )?;
174
175    // Use DATE_PART_TZ to extract the weekday
176    // where Sunday is 0, Saturday is 6
177    let weekday0 = date_part(lit("dow"), field_col);
178
179    // Add one to line up with the signature of make_timestamptz
180    // where Sunday is 1 and Saturday is 7
181    let weekday1 = weekday0.add(lit(1));
182
183    let start_expr = make_timestamptz(
184        lit(2012),
185        lit(1),
186        weekday1,
187        lit(0),
188        lit(0),
189        lit(0),
190        lit(0),
191        tz,
192    );
193
194    Ok((start_expr, interval_datetime_lit("1 day")))
195}
196
197// Fallback implementation of timeunit that uses a custom DataFusion UDF
198fn timeunit_custom_udf(
199    field: &str,
200    units_set: &HashSet<TimeUnitUnit>,
201    schema: &DFSchema,
202    default_input_tz: &str,
203    tz: &str,
204) -> Result<(Expr, Expr)> {
205    let units_mask = [
206        units_set.contains(&TimeUnitUnit::Year),      // 0
207        units_set.contains(&TimeUnitUnit::Quarter),   // 1
208        units_set.contains(&TimeUnitUnit::Month),     // 2
209        units_set.contains(&TimeUnitUnit::Date),      // 3
210        units_set.contains(&TimeUnitUnit::Week),      // 4
211        units_set.contains(&TimeUnitUnit::Day),       // 5
212        units_set.contains(&TimeUnitUnit::DayOfYear), // 6
213        units_set.contains(&TimeUnitUnit::Hours),     // 7
214        units_set.contains(&TimeUnitUnit::Minutes),   // 8
215        units_set.contains(&TimeUnitUnit::Seconds),   // 9
216        units_set.contains(&TimeUnitUnit::Milliseconds),
217    ];
218
219    let timeunit_start_udf = &TIMEUNIT_START_UDF;
220
221    let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
222        &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some("UTC".into())),
223        schema,
224    )?;
225
226    let timeunit_start_value = timeunit_start_udf.call(vec![
227        field_col,
228        lit(tz),
229        lit(units_mask[0]),
230        lit(units_mask[1]),
231        lit(units_mask[2]),
232        lit(units_mask[3]),
233        lit(units_mask[4]),
234        lit(units_mask[5]),
235        lit(units_mask[6]),
236        lit(units_mask[7]),
237        lit(units_mask[8]),
238        lit(units_mask[9]),
239        lit(units_mask[10]),
240    ]);
241
242    // Initialize interval string, this will be overwritten with the smallest specified unit
243    let mut interval = interval_year_month_lit("1 year");
244
245    // Year
246    if units_set.contains(&TimeUnitUnit::Year) {
247        interval = interval_year_month_lit("1 year");
248    }
249
250    // Quarter
251    if units_set.contains(&TimeUnitUnit::Quarter) {
252        interval = interval_year_month_lit("3 month");
253    }
254
255    // Month
256    if units_set.contains(&TimeUnitUnit::Month) {
257        interval = interval_year_month_lit("1 month");
258    }
259
260    // Week
261    if units_set.contains(&TimeUnitUnit::Week) {
262        interval = interval_datetime_lit("7 day");
263    }
264
265    // Day
266    if units_set.contains(&TimeUnitUnit::Date)
267        || units_set.contains(&TimeUnitUnit::DayOfYear)
268        || units_set.contains(&TimeUnitUnit::Day)
269    {
270        interval = interval_datetime_lit("1 day");
271    }
272
273    // Hour
274    if units_set.contains(&TimeUnitUnit::Hours) {
275        interval = interval_datetime_lit("1 hour");
276    }
277
278    // Minute
279    if units_set.contains(&TimeUnitUnit::Minutes) {
280        interval = interval_datetime_lit("1 minute");
281    }
282
283    // Second
284    if units_set.contains(&TimeUnitUnit::Seconds) {
285        interval = interval_datetime_lit("1 second");
286    }
287
288    Ok((timeunit_start_value, interval))
289}
290
291/// Convert a column to a timezone aware timestamp with Millisecond precision, in UTC
292pub fn to_timestamp_col(expr: Expr, schema: &DFSchema, default_input_tz: &str) -> Result<Expr> {
293    Ok(match expr.get_type(schema)? {
294        DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(_)) => expr,
295        DataType::Timestamp(_, Some(tz)) => expr.try_cast_to(
296            &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz)),
297            schema,
298        )?,
299        DataType::Timestamp(_, None) => expr.try_cast_to(
300            &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(default_input_tz.into())),
301            schema,
302        )?,
303        DataType::Date32 | DataType::Date64 => cast_to(
304            expr,
305            &DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
306            schema,
307        )?
308        .try_cast_to(
309            &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(default_input_tz.into())),
310            schema,
311        )?,
312        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
313            str_to_timestamp(expr, default_input_tz, schema, None)?
314        }
315        dtype if is_numeric_datatype(&dtype) => from_epoch_millis(expr, schema)?,
316        dtype => {
317            return Err(VegaFusionError::compilation(format!(
318                "Invalid data type for timeunit transform: {dtype:?}"
319            )))
320        }
321    })
322}
323
324#[async_trait]
325impl TransformTrait for TimeUnit {
326    async fn eval(
327        &self,
328        dataframe: DataFrame,
329        config: &CompilationConfig,
330    ) -> Result<(DataFrame, Vec<TaskValue>)> {
331        let tz_config = config
332            .tz_config
333            .with_context(|| "No local timezone info provided".to_string())?;
334
335        let tz = if self.timezone != Some(TimeUnitTimeZone::Utc as i32) {
336            tz_config.local_tz.to_string()
337        } else {
338            "UTC".to_string()
339        };
340
341        let schema = dataframe.schema();
342        let default_input_tz = tz_config.default_input_tz.to_string();
343
344        // Compute Apply alias
345        let timeunit_start_alias = if let Some(alias_0) = &self.alias_0 {
346            alias_0.clone()
347        } else {
348            "unit0".to_string()
349        };
350
351        let units_vec = self
352            .units
353            .iter()
354            .sorted()
355            .map(|unit_i32| TimeUnitUnit::try_from(*unit_i32).unwrap())
356            .collect::<Vec<TimeUnitUnit>>();
357
358        // Add timeunit start
359        let (timeunit_start_expr, interval) = match *units_vec.as_slice() {
360            [TimeUnitUnit::Year] => timeunit_date_trunc(
361                &self.field,
362                TimeUnitUnit::Year,
363                schema,
364                &default_input_tz,
365                &tz,
366            )?,
367            [TimeUnitUnit::Year, TimeUnitUnit::Quarter] => timeunit_date_trunc(
368                &self.field,
369                TimeUnitUnit::Quarter,
370                schema,
371                &default_input_tz,
372                &tz,
373            )?,
374            [TimeUnitUnit::Year, TimeUnitUnit::Month] => timeunit_date_trunc(
375                &self.field,
376                TimeUnitUnit::Month,
377                schema,
378                &default_input_tz,
379                &tz,
380            )?,
381            [TimeUnitUnit::Year, TimeUnitUnit::Week] => timeunit_date_trunc(
382                &self.field,
383                TimeUnitUnit::Week,
384                schema,
385                &default_input_tz,
386                &tz,
387            )?,
388            [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date] => timeunit_date_trunc(
389                &self.field,
390                TimeUnitUnit::Date,
391                schema,
392                &default_input_tz,
393                &tz,
394            )?,
395            [TimeUnitUnit::Year, TimeUnitUnit::DayOfYear] => timeunit_date_trunc(
396                &self.field,
397                TimeUnitUnit::Date,
398                schema,
399                &default_input_tz,
400                &tz,
401            )?,
402            [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours] => {
403                timeunit_date_trunc(
404                    &self.field,
405                    TimeUnitUnit::Hours,
406                    schema,
407                    &default_input_tz,
408                    &tz,
409                )?
410            }
411            [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours, TimeUnitUnit::Minutes] => {
412                timeunit_date_trunc(
413                    &self.field,
414                    TimeUnitUnit::Minutes,
415                    schema,
416                    &default_input_tz,
417                    &tz,
418                )?
419            }
420            [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours, TimeUnitUnit::Minutes, TimeUnitUnit::Seconds] => {
421                timeunit_date_trunc(
422                    &self.field,
423                    TimeUnitUnit::Seconds,
424                    schema,
425                    &default_input_tz,
426                    &tz,
427                )?
428            }
429            [TimeUnitUnit::Day] => timeunit_weekday(&self.field, schema, &default_input_tz, &tz)?,
430            _ => {
431                // Check if timeunit can be handled by make_utc_timestamp
432                let units_set = units_vec.iter().cloned().collect::<HashSet<_>>();
433                let date_part_units = vec![
434                    TimeUnitUnit::Year,
435                    TimeUnitUnit::Quarter,
436                    TimeUnitUnit::Month,
437                    TimeUnitUnit::Date,
438                    TimeUnitUnit::Hours,
439                    TimeUnitUnit::Minutes,
440                    TimeUnitUnit::Seconds,
441                ]
442                .into_iter()
443                .collect::<HashSet<_>>();
444                if units_set.is_subset(&date_part_units) {
445                    timeunit_date_part_tz(&self.field, &units_set, schema, &default_input_tz, &tz)?
446                } else {
447                    // Fallback to custom UDF
448                    timeunit_custom_udf(&self.field, &units_set, schema, &default_input_tz, &tz)?
449                }
450            }
451        };
452
453        let timeunit_start_expr = timeunit_start_expr.alias(&timeunit_start_alias);
454
455        // Add timeunit start value to the dataframe
456        let mut select_exprs: Vec<_> = dataframe
457            .schema()
458            .fields()
459            .iter()
460            .filter_map(|field| {
461                if field.name() != &timeunit_start_alias {
462                    Some(flat_col(field.name()))
463                } else {
464                    None
465                }
466            })
467            .collect();
468        select_exprs.push(timeunit_start_expr);
469
470        let dataframe = dataframe.select(select_exprs)?;
471
472        // Add timeunit end value to the dataframe
473        let timeunit_end_alias = if let Some(alias_1) = &self.alias_1 {
474            alias_1.clone()
475        } else {
476            "unit1".to_string()
477        };
478
479        let timeunit_end_expr = flat_col(&timeunit_start_alias)
480            .add(interval)
481            .alias(&timeunit_end_alias);
482
483        let mut select_exprs: Vec<_> = dataframe
484            .schema()
485            .fields()
486            .iter()
487            .filter_map(|field| {
488                if field.name() != &timeunit_end_alias {
489                    Some(flat_col(field.name()))
490                } else {
491                    None
492                }
493            })
494            .collect();
495        select_exprs.push(timeunit_end_expr);
496        let dataframe = dataframe.select(select_exprs)?;
497
498        Ok((dataframe, Vec::new()))
499    }
500}