Skip to main content

shape_runtime/multi_table/
functions.rs

1//! Built-in functions for multi-table analysis
2
3use super::alignment::{align_intersection, align_union};
4use super::config::AlignmentMode;
5use crate::context::ExecutionContext;
6use crate::data::OwnedDataRow as RowValue;
7use crate::data::Timeframe;
8use crate::timeframe_utils::parse_timeframe_string;
9use shape_ast::error::{Result, ShapeError};
10use shape_value::ValueWord;
11use std::sync::Arc;
12
13fn parse_dataset_id(id: &str, default_timeframe: Timeframe) -> Result<(String, Timeframe)> {
14    if let Some((base_id, tf_str)) = id.rsplit_once('_') {
15        if let Ok(tf) = parse_timeframe_string(tf_str) {
16            return Ok((base_id.to_string(), tf));
17        }
18    }
19    Ok((id.to_string(), default_timeframe))
20}
21
22fn load_rows(_ctx: &ExecutionContext, _id: &str, _timeframe: Timeframe) -> Result<Vec<RowValue>> {
23    Err(ShapeError::RuntimeError {
24        message: "Data access requires prefetching data first via execute_async()".to_string(),
25        location: None,
26    })
27}
28
29/// Align multiple datasets
30pub fn align_tables(ctx: &mut ExecutionContext, args: &[ValueWord]) -> Result<ValueWord> {
31    if args.is_empty() || args.len() > 2 {
32        return Err(ShapeError::RuntimeError {
33            message: "align_tables() requires 1-2 arguments: ids, [mode]".into(),
34            location: None,
35        });
36    }
37
38    let dataset_ids = match args[0].as_any_array() {
39        Some(view) => {
40            let arr = view.to_generic();
41            arr.iter()
42                .map(|v| {
43                    if let Some(s) = v.as_str() {
44                        Ok(Arc::new(s.to_string()))
45                    } else {
46                        Err(ShapeError::RuntimeError {
47                            message: "IDs must be strings".into(),
48                            location: None,
49                        })
50                    }
51                })
52                .collect::<Result<Vec<_>>>()?
53        }
54        None => {
55            return Err(ShapeError::RuntimeError {
56                message: "First argument must be an array of IDs".into(),
57                location: None,
58            });
59        }
60    };
61
62    let mode = if args.len() > 1 {
63        match args[1].as_str() {
64            Some("intersection") => AlignmentMode::Intersection,
65            Some("union") => AlignmentMode::Union,
66            Some(s) => {
67                return Err(ShapeError::RuntimeError {
68                    message: format!("Unknown alignment mode: {}", s),
69                    location: None,
70                });
71            }
72            None => AlignmentMode::Intersection,
73        }
74    } else {
75        AlignmentMode::Intersection
76    };
77
78    let default_tf = ctx.get_current_timeframe().unwrap_or_default();
79    let mut datasets = Vec::with_capacity(dataset_ids.len());
80    for id in &dataset_ids {
81        let (base_id, timeframe) = parse_dataset_id(id, default_tf)?;
82        let rows = load_rows(ctx, &base_id, timeframe)?;
83        datasets.push(rows);
84    }
85
86    let aligned = match mode {
87        AlignmentMode::Intersection => align_intersection(&datasets)?,
88        AlignmentMode::Union => align_union(&datasets)?,
89        _ => {
90            return Err(ShapeError::RuntimeError {
91                message: "align_tables supports only intersection or union modes".to_string(),
92                location: None,
93            });
94        }
95    };
96
97    let ids_val = ValueWord::from_array(Arc::new(
98        dataset_ids
99            .iter()
100            .map(|s| ValueWord::from_string(s.clone()))
101            .collect(),
102    ));
103
104    // Convert aligned data to ValueWord
105    let mut aligned_data_val: Vec<ValueWord> = Vec::new();
106    for rows in aligned {
107        let rows_val: Vec<ValueWord> = rows
108            .into_iter()
109            .map(|r| {
110                let pairs: Vec<(&str, ValueWord)> = r
111                    .fields
112                    .iter()
113                    .map(|(k, v)| (k.as_str(), ValueWord::from_f64(*v)))
114                    .collect();
115                crate::type_schema::typed_object_from_nb_pairs(&pairs)
116            })
117            .collect();
118        aligned_data_val.push(ValueWord::from_array(Arc::new(rows_val)));
119    }
120
121    Ok(crate::type_schema::typed_object_from_nb_pairs(&[
122        ("ids", ids_val),
123        ("data", ValueWord::from_array(Arc::new(aligned_data_val))),
124    ]))
125}
126
127pub fn correlation(_ctx: &mut ExecutionContext, args: &[ValueWord]) -> Result<ValueWord> {
128    if args.len() != 2 {
129        return Err(ShapeError::RuntimeError {
130            message: "correlation() requires 2 series arguments".into(),
131            location: None,
132        });
133    }
134
135    // Placeholder
136    Ok(ValueWord::from_f64(0.0))
137}
138
139pub fn find_divergences(_ctx: &mut ExecutionContext, _args: &[ValueWord]) -> Result<ValueWord> {
140    Err(ShapeError::RuntimeError {
141        message: "find_divergences() not implemented".into(),
142        location: None,
143    })
144}
145
146pub fn spread(_ctx: &mut ExecutionContext, _args: &[ValueWord]) -> Result<ValueWord> {
147    Err(ShapeError::RuntimeError {
148        message: "spread() not implemented".into(),
149        location: None,
150    })
151}
152
153pub fn temporal_join(_ctx: &mut ExecutionContext, _args: &[ValueWord]) -> Result<ValueWord> {
154    Err(ShapeError::RuntimeError {
155        message: "temporal_join() not implemented".into(),
156        location: None,
157    })
158}