Skip to main content

tupa_codegen/
execution_plan.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use tupa_parser::{Comparator, Expr, ExprKind, Item, PipelineDecl, Program, Stmt, Type};
4use tupa_typecheck::analyze_effects;
5
6#[derive(Serialize, Deserialize)]
7pub struct ExecutionPlan {
8    pub name: String,
9    pub version: String,
10    pub seed: Option<u64>,
11    pub input_schema: TypeSchema,
12    pub output_schema: Option<TypeSchema>,
13    pub steps: Vec<StepPlan>,
14    pub constraints: Vec<ConstraintPlan>,
15    pub metrics: HashMap<String, f64>,
16    pub metric_plans: Vec<MetricPlan>,
17}
18
19#[derive(Serialize, Deserialize)]
20pub struct StepPlan {
21    pub name: String,
22    pub function_ref: String,
23    pub effects: Vec<String>,
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct ConstraintPlan {
28    pub metric: String,
29    pub comparator: String,
30    pub threshold: f64,
31}
32
33#[derive(Serialize, Deserialize)]
34pub struct TypeSchema {
35    pub kind: String,
36    pub elem: Option<Box<TypeSchema>>,
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub fields: Option<HashMap<String, TypeSchema>>,
39    pub len: Option<i64>,
40    pub name: Option<String>,
41    pub tensor_shape: Option<Vec<Option<usize>>>,
42    pub tensor_dtype: Option<String>,
43}
44
45#[derive(Serialize, Deserialize)]
46pub struct MetricPlan {
47    pub name: String,
48    pub function_ref: String,
49    pub args: serde_json::Value,
50}
51
52pub fn type_to_schema(ty: &Type) -> TypeSchema {
53    match ty {
54        Type::Tensor(t) => TypeSchema {
55            kind: "tensor".into(),
56            elem: None,
57            fields: None,
58            len: None,
59            name: None,
60            tensor_shape: Some(t.shape.iter().map(|&x| x.map(|n| n as usize)).collect()),
61            tensor_dtype: Some(format!("{:?}", t.dtype)),
62        },
63        Type::Array { elem, len } => TypeSchema {
64            kind: "array".into(),
65            elem: Some(Box::new(type_to_schema(elem))),
66            fields: None,
67            len: Some(*len),
68            name: None,
69            tensor_shape: None,
70            tensor_dtype: None,
71        },
72        Type::Slice { elem } => TypeSchema {
73            kind: "slice".into(),
74            elem: Some(Box::new(type_to_schema(elem))),
75            fields: None,
76            len: None,
77            name: None,
78            tensor_shape: None,
79            tensor_dtype: None,
80        },
81        Type::Record(fields) => TypeSchema {
82            kind: "object".into(),
83            elem: None,
84            fields: Some(
85                fields
86                    .iter()
87                    .map(|(name, ty)| (name.clone(), type_to_schema(ty)))
88                    .collect(),
89            ),
90            len: None,
91            name: None,
92            tensor_shape: None,
93            tensor_dtype: None,
94        },
95        Type::Safe { base, .. } => type_to_schema(base),
96        Type::Ident(name) => match name.as_str() {
97            "i64" => TypeSchema {
98                kind: "i64".into(),
99                elem: None,
100                fields: None,
101                len: None,
102                name: None,
103                tensor_shape: None,
104                tensor_dtype: None,
105            },
106            "f64" => TypeSchema {
107                kind: "f64".into(),
108                elem: None,
109                fields: None,
110                len: None,
111                name: None,
112                tensor_shape: None,
113                tensor_dtype: None,
114            },
115            "bool" => TypeSchema {
116                kind: "bool".into(),
117                elem: None,
118                fields: None,
119                len: None,
120                name: None,
121                tensor_shape: None,
122                tensor_dtype: None,
123            },
124            "string" => TypeSchema {
125                kind: "string".into(),
126                elem: None,
127                fields: None,
128                len: None,
129                name: None,
130                tensor_shape: None,
131                tensor_dtype: None,
132            },
133            _ => TypeSchema {
134                kind: "ident".into(),
135                elem: None,
136                fields: None,
137                len: None,
138                name: Some(name.clone()),
139                tensor_shape: None,
140                tensor_dtype: None,
141            },
142        },
143        _ => TypeSchema {
144            kind: "unknown".into(),
145            elem: None,
146            fields: None,
147            len: None,
148            name: None,
149            tensor_shape: None,
150            tensor_dtype: None,
151        },
152    }
153}
154
155fn constraint_to_plan(c: &tupa_parser::Constraint) -> ConstraintPlan {
156    ConstraintPlan {
157        metric: c.metric.clone(),
158        comparator: match c.comparator {
159            Comparator::Lt => "lt".into(),
160            Comparator::Le => "le".into(),
161            Comparator::Eq => "eq".into(),
162            Comparator::Ge => "ge".into(),
163            Comparator::Gt => "gt".into(),
164        },
165        threshold: c.threshold,
166    }
167}
168
169fn extract_metrics(pipeline: &PipelineDecl) -> HashMap<String, f64> {
170    let mut map = HashMap::new();
171    if let Some(block) = &pipeline.validation {
172        for stmt in block {
173            if let Stmt::Let { name, expr, .. } = stmt {
174                match &expr.kind {
175                    ExprKind::Int(n) => {
176                        map.insert(name.clone(), *n as f64);
177                    }
178                    ExprKind::Float(f) => {
179                        map.insert(name.clone(), *f);
180                    }
181                    _ => {}
182                }
183            }
184        }
185    }
186    map
187}
188
189fn expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
190    match &expr.kind {
191        ExprKind::Int(n) => Some(serde_json::json!(*n)),
192        ExprKind::Float(f) => Some(serde_json::json!(*f)),
193        ExprKind::Bool(b) => Some(serde_json::json!(*b)),
194        ExprKind::ArrayLiteral(items) => {
195            let mut arr = Vec::new();
196            for it in items {
197                if let Some(v) = expr_to_json(it) {
198                    arr.push(v);
199                } else {
200                    return None;
201                }
202            }
203            Some(serde_json::Value::Array(arr))
204        }
205        _ => None,
206    }
207}
208
209fn extract_metric_plans(module_name: &str, pipeline: &PipelineDecl) -> Vec<MetricPlan> {
210    let mut list = Vec::new();
211    if let Some(block) = &pipeline.validation {
212        for stmt in block {
213            if let Stmt::Let { name, expr, .. } = stmt {
214                if let ExprKind::Call { callee, args } = &expr.kind {
215                    if let ExprKind::Ident(func) = &callee.kind {
216                        // build args JSON as array of supported literals
217                        let json_args = if args.len() == 1 {
218                            expr_to_json(&args[0]).unwrap_or(serde_json::Value::Null)
219                        } else {
220                            let mut arr = Vec::new();
221                            for a in args {
222                                if let Some(v) = expr_to_json(a) {
223                                    arr.push(v);
224                                }
225                            }
226                            serde_json::Value::Array(arr)
227                        };
228                        list.push(MetricPlan {
229                            name: name.clone(),
230                            function_ref: format!("{module_name}::{func}"),
231                            args: json_args,
232                        });
233                    }
234                }
235            }
236        }
237    }
238    list
239}
240
241pub fn codegen_pipeline(
242    module_name: &str,
243    pipeline: &PipelineDecl,
244    program: &Program,
245) -> serde_json::Result<String> {
246    let steps: Vec<StepPlan> = pipeline
247        .steps
248        .iter()
249        .map(|step| {
250            let effects = analyze_effects(&step.body, &HashMap::new()).to_names();
251            let mut function_ref = format!("{module_name}::step_{}", step.name);
252
253            // Check if body is a direct call to an external function
254            if let ExprKind::Call { callee, args } = &step.body.kind {
255                if let ExprKind::Ident(func_name) = &callee.kind {
256                    // We only optimize direct calls with 'input' argument for now
257                    let is_simple_call = args.len() == 1
258                        && matches!(&args[0].kind, ExprKind::Ident(n) if n == "input");
259
260                    if is_simple_call {
261                        // Find function definition
262                        for item in &program.items {
263                            if let Item::Function(f) = item {
264                                if &f.name == func_name {
265                                    if let Some(spec) = &f.external_spec {
266                                        if let Some(py_target) = &spec.python {
267                                            function_ref = format!("py:{}", py_target);
268                                        }
269                                    }
270                                    break;
271                                }
272                            }
273                        }
274                    }
275                }
276            }
277
278            StepPlan {
279                name: step.name.clone(),
280                function_ref,
281                effects,
282            }
283        })
284        .collect();
285    let plan = ExecutionPlan {
286        name: pipeline.name.clone(),
287        version: env!("CARGO_PKG_VERSION").to_string(),
288        seed: pipeline.seed,
289        input_schema: type_to_schema(&pipeline.input_ty),
290        output_schema: pipeline.output_ty.as_ref().map(type_to_schema),
291        steps,
292        constraints: pipeline
293            .constraints
294            .iter()
295            .map(constraint_to_plan)
296            .collect(),
297        metrics: extract_metrics(pipeline),
298        metric_plans: extract_metric_plans(module_name, pipeline),
299    };
300    serde_json::to_string_pretty(&plan)
301}