1use crate::expression::compiler::config::CompilationConfig;
2use crate::transform::TransformTrait;
3use async_trait::async_trait;
4use datafusion::prelude::DataFrame;
5use datafusion_common::DFSchema;
6use datafusion_functions::expr_fn::{date_part, date_trunc};
7use std::collections::HashSet;
8use std::ops::{Add, Mul, Rem, Sub};
9use vegafusion_common::arrow::datatypes::{DataType, TimeUnit as ArrowTimeUnit};
10use vegafusion_core::error::{Result, ResultWithContext, VegaFusionError};
11use vegafusion_core::proto::gen::transforms::{TimeUnit, TimeUnitTimeZone, TimeUnitUnit};
12use vegafusion_core::task_graph::task_value::TaskValue;
13
14use crate::datafusion::udfs::datetime::make_timestamptz::make_timestamptz;
15use crate::datafusion::udfs::datetime::timeunit::TIMEUNIT_START_UDF;
16use crate::expression::compiler::utils::ExprHelpers;
17use crate::transform::utils::{from_epoch_millis, str_to_timestamp};
18use datafusion_expr::{interval_datetime_lit, interval_year_month_lit, lit, Expr, ExprSchemable};
19use itertools::Itertools;
20use vegafusion_common::column::{flat_col, unescaped_col};
21use vegafusion_common::datatypes::{cast_to, is_numeric_datatype};
22
23fn timeunit_date_trunc(
25 field: &str,
26 smallest_unit: TimeUnitUnit,
27 schema: &DFSchema,
28 default_input_tz: &str,
29 tz: &str,
30) -> Result<(Expr, Expr)> {
31 let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
33 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
34 schema,
35 )?;
36
37 if let TimeUnitUnit::Week = smallest_unit {
39 let day_interval = interval_datetime_lit("1 day");
40 let trunc_expr =
41 date_trunc(lit("week"), field_col.add(day_interval.clone())).sub(day_interval);
42 let interval = interval_datetime_lit("7 day");
43 return Ok((trunc_expr, interval));
44 }
45
46 let (part_str, interval_expr) = match smallest_unit {
48 TimeUnitUnit::Year => ("year", interval_year_month_lit("1 year")),
49 TimeUnitUnit::Quarter => ("quarter", interval_year_month_lit("3 month")),
50 TimeUnitUnit::Month => ("month", interval_year_month_lit("1 month")),
51 TimeUnitUnit::Date => ("day", interval_datetime_lit("1 day")),
52 TimeUnitUnit::Hours => ("hour", interval_datetime_lit("1 hour")),
53 TimeUnitUnit::Minutes => ("minute", interval_datetime_lit("1 minute")),
54 TimeUnitUnit::Seconds => ("second", interval_datetime_lit("1 second")),
55 TimeUnitUnit::Milliseconds => ("millisecond", interval_datetime_lit("1 millisecond")),
56 _ => {
57 return Err(VegaFusionError::internal(format!(
58 "Unsupported date trunc unit: {smallest_unit:?}"
59 )))
60 }
61 };
62
63 let trunc_expr = date_trunc(lit(part_str), field_col);
65
66 Ok((trunc_expr, interval_expr))
67}
68
69fn timeunit_date_part_tz(
71 field: &str,
72 units_set: &HashSet<TimeUnitUnit>,
73 schema: &DFSchema,
74 default_input_tz: &str,
75 tz: &str,
76) -> Result<(Expr, Expr)> {
77 let mut year_arg = lit(2012);
78 let mut month_arg = lit(1);
79 let mut date_arg = lit(1);
80 let mut hour_arg = lit(0);
81 let mut minute_arg = lit(0);
82 let mut second_arg = lit(0);
83 let mut millisecond_arg = lit(0);
84
85 let mut interval = interval_year_month_lit("1 year");
87
88 let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
90 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
91 schema,
92 )?;
93
94 if units_set.contains(&TimeUnitUnit::Year) {
96 year_arg = date_part(lit("year"), field_col.clone());
97 interval = interval_year_month_lit("1 year");
98 }
99
100 if units_set.contains(&TimeUnitUnit::Quarter) {
102 let month_from_quarter = date_part(lit("quarter"), field_col.clone())
104 .sub(lit(1))
105 .mul(lit(3))
106 .add(lit(1));
107
108 month_arg = month_from_quarter;
109 interval = interval_year_month_lit("3 month");
110 }
111
112 if units_set.contains(&TimeUnitUnit::Month) {
114 month_arg = date_part(lit("month"), field_col.clone());
115 interval = interval_year_month_lit("1 month");
116 }
117
118 if units_set.contains(&TimeUnitUnit::Date) {
120 date_arg = date_part(lit("day"), field_col.clone());
121 interval = interval_datetime_lit("1 day");
122 }
123
124 if units_set.contains(&TimeUnitUnit::Hours) {
126 hour_arg = date_part(lit("hour"), field_col.clone());
127 interval = interval_datetime_lit("1 hour");
128 }
129
130 if units_set.contains(&TimeUnitUnit::Minutes) {
132 minute_arg = date_part(lit("minute"), field_col.clone());
133 interval = interval_datetime_lit("1 minute");
134 }
135
136 if units_set.contains(&TimeUnitUnit::Seconds) {
138 second_arg = date_part(lit("second"), field_col.clone());
139 interval = interval_datetime_lit("1 second");
140 }
141
142 if units_set.contains(&TimeUnitUnit::Seconds) {
144 millisecond_arg = date_part(lit("millisecond"), field_col.clone()).rem(lit(1000));
145 interval = interval_datetime_lit("1 millisecond");
146 }
147
148 let start_expr = make_timestamptz(
150 year_arg,
151 month_arg,
152 date_arg,
153 hour_arg,
154 minute_arg,
155 second_arg,
156 millisecond_arg,
157 tz,
158 );
159
160 Ok((start_expr, interval))
161}
162
163fn timeunit_weekday(
165 field: &str,
166 schema: &DFSchema,
167 default_input_tz: &str,
168 tz: &str,
169) -> Result<(Expr, Expr)> {
170 let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
171 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz.into())),
172 schema,
173 )?;
174
175 let weekday0 = date_part(lit("dow"), field_col);
178
179 let weekday1 = weekday0.add(lit(1));
182
183 let start_expr = make_timestamptz(
184 lit(2012),
185 lit(1),
186 weekday1,
187 lit(0),
188 lit(0),
189 lit(0),
190 lit(0),
191 tz,
192 );
193
194 Ok((start_expr, interval_datetime_lit("1 day")))
195}
196
197fn timeunit_custom_udf(
199 field: &str,
200 units_set: &HashSet<TimeUnitUnit>,
201 schema: &DFSchema,
202 default_input_tz: &str,
203 tz: &str,
204) -> Result<(Expr, Expr)> {
205 let units_mask = [
206 units_set.contains(&TimeUnitUnit::Year), units_set.contains(&TimeUnitUnit::Quarter), units_set.contains(&TimeUnitUnit::Month), units_set.contains(&TimeUnitUnit::Date), units_set.contains(&TimeUnitUnit::Week), units_set.contains(&TimeUnitUnit::Day), units_set.contains(&TimeUnitUnit::DayOfYear), units_set.contains(&TimeUnitUnit::Hours), units_set.contains(&TimeUnitUnit::Minutes), units_set.contains(&TimeUnitUnit::Seconds), units_set.contains(&TimeUnitUnit::Milliseconds),
217 ];
218
219 let timeunit_start_udf = &TIMEUNIT_START_UDF;
220
221 let field_col = to_timestamp_col(unescaped_col(field), schema, default_input_tz)?.try_cast_to(
222 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some("UTC".into())),
223 schema,
224 )?;
225
226 let timeunit_start_value = timeunit_start_udf.call(vec![
227 field_col,
228 lit(tz),
229 lit(units_mask[0]),
230 lit(units_mask[1]),
231 lit(units_mask[2]),
232 lit(units_mask[3]),
233 lit(units_mask[4]),
234 lit(units_mask[5]),
235 lit(units_mask[6]),
236 lit(units_mask[7]),
237 lit(units_mask[8]),
238 lit(units_mask[9]),
239 lit(units_mask[10]),
240 ]);
241
242 let mut interval = interval_year_month_lit("1 year");
244
245 if units_set.contains(&TimeUnitUnit::Year) {
247 interval = interval_year_month_lit("1 year");
248 }
249
250 if units_set.contains(&TimeUnitUnit::Quarter) {
252 interval = interval_year_month_lit("3 month");
253 }
254
255 if units_set.contains(&TimeUnitUnit::Month) {
257 interval = interval_year_month_lit("1 month");
258 }
259
260 if units_set.contains(&TimeUnitUnit::Week) {
262 interval = interval_datetime_lit("7 day");
263 }
264
265 if units_set.contains(&TimeUnitUnit::Date)
267 || units_set.contains(&TimeUnitUnit::DayOfYear)
268 || units_set.contains(&TimeUnitUnit::Day)
269 {
270 interval = interval_datetime_lit("1 day");
271 }
272
273 if units_set.contains(&TimeUnitUnit::Hours) {
275 interval = interval_datetime_lit("1 hour");
276 }
277
278 if units_set.contains(&TimeUnitUnit::Minutes) {
280 interval = interval_datetime_lit("1 minute");
281 }
282
283 if units_set.contains(&TimeUnitUnit::Seconds) {
285 interval = interval_datetime_lit("1 second");
286 }
287
288 Ok((timeunit_start_value, interval))
289}
290
291pub fn to_timestamp_col(expr: Expr, schema: &DFSchema, default_input_tz: &str) -> Result<Expr> {
293 Ok(match expr.get_type(schema)? {
294 DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(_)) => expr,
295 DataType::Timestamp(_, Some(tz)) => expr.try_cast_to(
296 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(tz)),
297 schema,
298 )?,
299 DataType::Timestamp(_, None) => expr.try_cast_to(
300 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(default_input_tz.into())),
301 schema,
302 )?,
303 DataType::Date32 | DataType::Date64 => cast_to(
304 expr,
305 &DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
306 schema,
307 )?
308 .try_cast_to(
309 &DataType::Timestamp(ArrowTimeUnit::Millisecond, Some(default_input_tz.into())),
310 schema,
311 )?,
312 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
313 str_to_timestamp(expr, default_input_tz, schema, None)?
314 }
315 dtype if is_numeric_datatype(&dtype) => from_epoch_millis(expr, schema)?,
316 dtype => {
317 return Err(VegaFusionError::compilation(format!(
318 "Invalid data type for timeunit transform: {dtype:?}"
319 )))
320 }
321 })
322}
323
324#[async_trait]
325impl TransformTrait for TimeUnit {
326 async fn eval(
327 &self,
328 dataframe: DataFrame,
329 config: &CompilationConfig,
330 ) -> Result<(DataFrame, Vec<TaskValue>)> {
331 let tz_config = config
332 .tz_config
333 .with_context(|| "No local timezone info provided".to_string())?;
334
335 let tz = if self.timezone != Some(TimeUnitTimeZone::Utc as i32) {
336 tz_config.local_tz.to_string()
337 } else {
338 "UTC".to_string()
339 };
340
341 let schema = dataframe.schema();
342 let default_input_tz = tz_config.default_input_tz.to_string();
343
344 let timeunit_start_alias = if let Some(alias_0) = &self.alias_0 {
346 alias_0.clone()
347 } else {
348 "unit0".to_string()
349 };
350
351 let units_vec = self
352 .units
353 .iter()
354 .sorted()
355 .map(|unit_i32| TimeUnitUnit::try_from(*unit_i32).unwrap())
356 .collect::<Vec<TimeUnitUnit>>();
357
358 let (timeunit_start_expr, interval) = match *units_vec.as_slice() {
360 [TimeUnitUnit::Year] => timeunit_date_trunc(
361 &self.field,
362 TimeUnitUnit::Year,
363 schema,
364 &default_input_tz,
365 &tz,
366 )?,
367 [TimeUnitUnit::Year, TimeUnitUnit::Quarter] => timeunit_date_trunc(
368 &self.field,
369 TimeUnitUnit::Quarter,
370 schema,
371 &default_input_tz,
372 &tz,
373 )?,
374 [TimeUnitUnit::Year, TimeUnitUnit::Month] => timeunit_date_trunc(
375 &self.field,
376 TimeUnitUnit::Month,
377 schema,
378 &default_input_tz,
379 &tz,
380 )?,
381 [TimeUnitUnit::Year, TimeUnitUnit::Week] => timeunit_date_trunc(
382 &self.field,
383 TimeUnitUnit::Week,
384 schema,
385 &default_input_tz,
386 &tz,
387 )?,
388 [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date] => timeunit_date_trunc(
389 &self.field,
390 TimeUnitUnit::Date,
391 schema,
392 &default_input_tz,
393 &tz,
394 )?,
395 [TimeUnitUnit::Year, TimeUnitUnit::DayOfYear] => timeunit_date_trunc(
396 &self.field,
397 TimeUnitUnit::Date,
398 schema,
399 &default_input_tz,
400 &tz,
401 )?,
402 [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours] => {
403 timeunit_date_trunc(
404 &self.field,
405 TimeUnitUnit::Hours,
406 schema,
407 &default_input_tz,
408 &tz,
409 )?
410 }
411 [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours, TimeUnitUnit::Minutes] => {
412 timeunit_date_trunc(
413 &self.field,
414 TimeUnitUnit::Minutes,
415 schema,
416 &default_input_tz,
417 &tz,
418 )?
419 }
420 [TimeUnitUnit::Year, TimeUnitUnit::Month, TimeUnitUnit::Date, TimeUnitUnit::Hours, TimeUnitUnit::Minutes, TimeUnitUnit::Seconds] => {
421 timeunit_date_trunc(
422 &self.field,
423 TimeUnitUnit::Seconds,
424 schema,
425 &default_input_tz,
426 &tz,
427 )?
428 }
429 [TimeUnitUnit::Day] => timeunit_weekday(&self.field, schema, &default_input_tz, &tz)?,
430 _ => {
431 let units_set = units_vec.iter().cloned().collect::<HashSet<_>>();
433 let date_part_units = vec![
434 TimeUnitUnit::Year,
435 TimeUnitUnit::Quarter,
436 TimeUnitUnit::Month,
437 TimeUnitUnit::Date,
438 TimeUnitUnit::Hours,
439 TimeUnitUnit::Minutes,
440 TimeUnitUnit::Seconds,
441 ]
442 .into_iter()
443 .collect::<HashSet<_>>();
444 if units_set.is_subset(&date_part_units) {
445 timeunit_date_part_tz(&self.field, &units_set, schema, &default_input_tz, &tz)?
446 } else {
447 timeunit_custom_udf(&self.field, &units_set, schema, &default_input_tz, &tz)?
449 }
450 }
451 };
452
453 let timeunit_start_expr = timeunit_start_expr.alias(&timeunit_start_alias);
454
455 let mut select_exprs: Vec<_> = dataframe
457 .schema()
458 .fields()
459 .iter()
460 .filter_map(|field| {
461 if field.name() != &timeunit_start_alias {
462 Some(flat_col(field.name()))
463 } else {
464 None
465 }
466 })
467 .collect();
468 select_exprs.push(timeunit_start_expr);
469
470 let dataframe = dataframe.select(select_exprs)?;
471
472 let timeunit_end_alias = if let Some(alias_1) = &self.alias_1 {
474 alias_1.clone()
475 } else {
476 "unit1".to_string()
477 };
478
479 let timeunit_end_expr = flat_col(&timeunit_start_alias)
480 .add(interval)
481 .alias(&timeunit_end_alias);
482
483 let mut select_exprs: Vec<_> = dataframe
484 .schema()
485 .fields()
486 .iter()
487 .filter_map(|field| {
488 if field.name() != &timeunit_end_alias {
489 Some(flat_col(field.name()))
490 } else {
491 None
492 }
493 })
494 .collect();
495 select_exprs.push(timeunit_end_expr);
496 let dataframe = dataframe.select(select_exprs)?;
497
498 Ok((dataframe, Vec::new()))
499 }
500}