sdf_metadata/metadata/operator/
partition_operator.rs

1use std::collections::BTreeMap;
2
3use anyhow::Result;
4
5use crate::{
6    importer::{
7        function::{
8            imported_assign_key_config, imported_operator_config, imported_update_state_config,
9        },
10        states::inject_states,
11    },
12    metadata::io::topic::KVSchemaType,
13    util::{
14        sdf_types_map::SdfTypesMap, validation_error::ValidationError,
15        validation_failure::ValidationFailure,
16    },
17    wit::{
18        dataflow::{PackageDefinition, PackageImport},
19        operator::{OperatorType, PartitionOperator, StepInvocation},
20        states::State,
21    },
22};
23
24use super::transforms::validate_transforms_steps;
25
26impl PartitionOperator {
27    pub fn operators(&self) -> Vec<(StepInvocation, OperatorType)> {
28        let mut operators = vec![(self.assign_key.clone(), OperatorType::AssignKey)];
29        operators.extend(self.transforms.steps.iter().map(|step| {
30            let inner = step.inner();
31            (inner.clone(), step.clone().into())
32        }));
33        if let Some(update_state) = &self.update_state {
34            operators.push((update_state.clone(), OperatorType::UpdateState));
35        }
36        operators
37    }
38
39    pub fn add_operator(
40        &mut self,
41        operator_index: Option<usize>,
42        operator_type: OperatorType,
43        step_invocation: StepInvocation,
44    ) -> Result<()> {
45        if operator_index.is_some() {
46            self.transforms
47                .insert_operator(operator_index, operator_type, step_invocation)
48        } else {
49            todo!("cannot add assign key unless it is optional")
50        }
51    }
52
53    pub fn delete_operator(&mut self, operator_index: Option<usize>) -> Result<()> {
54        match operator_index {
55            Some(index) => self.transforms.delete_operator(index),
56            None => {
57                todo!("cannot delete assign key unless it is optional")
58            }
59        }
60    }
61
62    pub fn import_operator_configs(
63        &mut self,
64        imports: &[PackageImport],
65        packages: &[PackageDefinition],
66        service_states: &mut BTreeMap<String, State>,
67    ) -> Result<()> {
68        if self.assign_key.is_imported(imports) {
69            self.assign_key = imported_assign_key_config(&self.assign_key, imports, packages)?;
70            inject_states(service_states, &self.assign_key.states)?;
71        }
72
73        for step in &mut self.transforms.steps {
74            if step.is_imported(imports) {
75                *step = imported_operator_config(step, imports, packages)?;
76                inject_states(service_states, &step.inner().states)?;
77            }
78        }
79
80        if let Some(update_state) = &mut self.update_state {
81            if update_state.is_imported(imports) {
82                *update_state = imported_update_state_config(update_state, imports, packages)?;
83                inject_states(service_states, &update_state.states)?;
84            }
85        }
86
87        Ok(())
88    }
89
90    pub fn output_type(&self, input_type: KVSchemaType) -> Result<KVSchemaType, ValidationError> {
91        self.transforms.output_type(input_type)
92    }
93
94    pub fn validate(
95        &self,
96        types: &SdfTypesMap,
97        expected_input_type: &KVSchemaType,
98        input_provider_name: &str,
99    ) -> Result<(), ValidationFailure> {
100        let mut errors = ValidationFailure::new();
101
102        if let Err(assign_key_error) =
103            self.validate_assign_key(types, expected_input_type, input_provider_name)
104        {
105            errors.concat(&assign_key_error);
106        }
107
108        if let Err(transforms_error) = validate_transforms_steps(
109            &self.transforms.steps,
110            types,
111            expected_input_type.to_owned(),
112            input_provider_name.to_string(),
113        ) {
114            errors.concat_with_context("transforms block is invalid:", &transforms_error);
115        }
116
117        if let Err(update_state_error) = self.validate_update_state(types) {
118            errors.concat(&update_state_error);
119        }
120
121        if errors.any() {
122            Err(errors)
123        } else {
124            Ok(())
125        }
126    }
127
128    fn validate_assign_key(
129        &self,
130        types: &SdfTypesMap,
131        expected_type: &KVSchemaType,
132        input_provider_name: &str,
133    ) -> Result<(), ValidationFailure> {
134        let mut errors = ValidationFailure::new();
135
136        if let Err(assign_key_error) = self.assign_key.validate_assign_key(types) {
137            errors.concat(&assign_key_error);
138        }
139
140        let value_type = if self.assign_key.requires_key_param() {
141            let key = self
142                .assign_key
143                .inputs
144                .first()
145                .map(|input| input.type_.clone());
146
147            if let Some(key) = key {
148                if let Some(expected_key) = expected_type.key.as_ref() {
149                    if key.name != expected_key.name {
150                        errors.push_str(&format!(
151                            "assign-key function `{}` key type should match `{}` provided by `{}` but found `{}`",
152                            self.assign_key.uses,
153                            expected_key.name,
154                            &input_provider_name,
155                            key.name
156                        ));
157                    }
158                } else {
159                    errors.push_str("assign-key type function `assign-key-fn` requires a key type");
160                }
161
162                self.assign_key.inputs.get(1)
163            } else {
164                errors.push_str("assign-key type function `assign-key-fn` requires an input type");
165                None
166            }
167        } else {
168            self.assign_key.inputs.first()
169        };
170
171        //assert assign key first input matches the expected input type
172        if let Some(assign_key_input) = value_type {
173            if assign_key_input.type_.name != expected_type.value.name {
174                errors.push_str(&format!(
175                    "assign-key function `{}` input type should match `{}` provided by `{}` but found `{}`",
176                    self.assign_key.uses,
177                    expected_type.value.name,
178                    &input_provider_name,
179                    assign_key_input.type_.name
180                ));
181            }
182        }
183
184        if errors.any() {
185            Err(errors)
186        } else {
187            Ok(())
188        }
189    }
190
191    fn validate_update_state(&self, types: &SdfTypesMap) -> Result<(), ValidationFailure> {
192        if let Some(update_state) = &self.update_state {
193            update_state.validate_update_state(types)
194        } else {
195            Ok(())
196        }
197    }
198
199    #[cfg(feature = "parser")]
200    pub fn update_inline_operators(&mut self) -> Result<()> {
201        self.assign_key.update_signature_from_code()?;
202
203        for step in &mut self.transforms.steps {
204            step.update_signature_from_code()?;
205        }
206
207        if let Some(update_state) = &mut self.update_state {
208            update_state.update_signature_from_code()?;
209        }
210
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod test {
217    use std::collections::BTreeMap;
218
219    use sdf_common::constants::DATAFLOW_STABLE_VERSION;
220
221    use crate::{
222        metadata::io::topic::KVSchemaType,
223        util::{sdf_types_map::SdfTypesMap, validation_error::ValidationError},
224        wit::{
225            dataflow::{PackageDefinition, PackageImport},
226            io::TypeRef,
227            metadata::{NamedParameter, Parameter, ParameterKind, SdfKeyedStateValue},
228            operator::{
229                PartitionOperator, StepInvocation, StepState, TransformOperator, Transforms,
230            },
231            package_interface::{FunctionImport, Header, OperatorType},
232            states::{SdfKeyedState, State, StateTyped},
233        },
234    };
235
236    fn packages() -> Vec<PackageDefinition> {
237        vec![PackageDefinition {
238            api_version: DATAFLOW_STABLE_VERSION.to_string(),
239            meta: Header {
240                name: "my-pkg".to_string(),
241                namespace: "my-ns".to_string(),
242                version: "0.1.0".to_string(),
243            },
244            functions: vec![map_fn(), assign_key_fn()],
245            imports: vec![],
246            types: vec![],
247            states: vec![],
248            dev: None,
249        }]
250    }
251
252    fn map_fn() -> (StepInvocation, OperatorType) {
253        (
254            StepInvocation {
255                uses: "map-fn".to_string(),
256                inputs: vec![NamedParameter {
257                    name: "map-input".to_string(),
258                    type_: TypeRef {
259                        name: "u8".to_string(),
260                    },
261                    optional: false,
262                    kind: ParameterKind::Value,
263                }],
264                output: Some(Parameter {
265                    type_: TypeRef {
266                        name: "u8".to_string(),
267                    }
268                    .into(),
269                    ..Default::default()
270                }),
271                states: vec![StepState::Resolved(StateTyped {
272                    name: "map-state".to_string(),
273                    type_: SdfKeyedState {
274                        key: TypeRef {
275                            name: "string".to_string(),
276                        },
277                        value: SdfKeyedStateValue::U32,
278                    },
279                })],
280                ..Default::default()
281            },
282            OperatorType::Map,
283        )
284    }
285
286    fn assign_key_fn() -> (StepInvocation, OperatorType) {
287        (
288            StepInvocation {
289                uses: "assign-key-fn".to_string(),
290                inputs: vec![NamedParameter {
291                    name: "word-count".to_string(),
292                    type_: TypeRef {
293                        name: "U8".to_string(),
294                    },
295                    optional: false,
296                    kind: ParameterKind::Value,
297                }],
298                output: Some(Parameter {
299                    type_: TypeRef {
300                        name: "U8".to_string(),
301                    }
302                    .into(),
303                    ..Default::default()
304                }),
305                ..Default::default()
306            },
307            OperatorType::AssignKey,
308        )
309    }
310
311    fn imports() -> Vec<PackageImport> {
312        vec![PackageImport {
313            metadata: Header {
314                name: "my-pkg".to_string(),
315                namespace: "my-ns".to_string(),
316                version: "0.1.0".to_string(),
317            },
318            functions: vec![
319                FunctionImport {
320                    name: "map-fn".to_string(),
321                    alias: None,
322                },
323                FunctionImport {
324                    name: "assign-key-fn".to_string(),
325                    alias: None,
326                },
327            ],
328            path: Some("path/to/my-pkg".to_string()),
329            types: vec![],
330            states: vec![],
331        }]
332    }
333
334    fn partition_operator() -> PartitionOperator {
335        PartitionOperator {
336            assign_key: StepInvocation {
337                uses: "assign-key-fn".to_string(),
338                ..Default::default()
339            },
340            transforms: {
341                Transforms {
342                    steps: vec![
343                        TransformOperator::Map(StepInvocation {
344                            uses: "map-fn".to_string(),
345                            ..Default::default()
346                        }),
347                        TransformOperator::Map(StepInvocation {
348                            uses: "map-fn".to_string(),
349                            ..Default::default()
350                        }),
351                    ],
352                }
353            },
354            update_state: None,
355        }
356    }
357
358    fn expected_type() -> KVSchemaType {
359        (
360            None,
361            TypeRef {
362                name: "s16".to_string(),
363            },
364        )
365            .into()
366    }
367
368    #[test]
369    fn test_import_operator_configs_merges_operator_signatures() {
370        let mut states: BTreeMap<String, State> = Default::default();
371        let mut partition = partition_operator();
372
373        assert!(partition.assign_key.inputs.is_empty());
374        assert!(partition.assign_key.output.is_none());
375
376        let partition_steps = &partition.transforms.steps;
377
378        assert!(partition_steps.first().unwrap().inner().inputs.is_empty());
379        assert!(partition_steps.first().unwrap().inner().output.is_none());
380        assert!(partition_steps.get(1).unwrap().inner().inputs.is_empty());
381        assert!(partition_steps.get(1).unwrap().inner().output.is_none());
382
383        assert!(states.is_empty());
384
385        partition
386            .import_operator_configs(&imports(), &packages(), &mut states)
387            .unwrap();
388
389        assert_eq!(partition.assign_key.inputs.len(), 1);
390        assert!(partition.assign_key.output.is_some());
391
392        let partition_steps = &partition.transforms.steps;
393
394        assert_eq!(partition_steps.first().unwrap().inner().inputs.len(), 1);
395        assert!(partition_steps.first().unwrap().inner().output.is_some());
396        assert_eq!(partition_steps.get(1).unwrap().inner().inputs.len(), 1);
397        assert!(partition_steps.get(1).unwrap().inner().output.is_some());
398
399        assert_eq!(states.len(), 1);
400    }
401
402    #[test]
403    fn test_validate_validates_assign_key_operator() {
404        let types = SdfTypesMap::default();
405        let mut partition = partition_operator();
406        partition.assign_key.output = None;
407
408        let res = partition
409            .validate(&types, &expected_type(), "service transforms block")
410            .expect_err("should fail for invalid assign key operator");
411
412        assert!(res.errors.contains(&ValidationError::new(
413            "assign-key type function `assign-key-fn` requires an output type"
414        )));
415    }
416
417    #[test]
418    fn test_validate_validates_assign_key_input_matches_expected_input() {
419        let types = SdfTypesMap::default();
420        let mut partition = partition_operator();
421        partition.assign_key.inputs = vec![NamedParameter {
422            name: "value".to_string(),
423            type_: TypeRef {
424                name: "u8".to_string(),
425            },
426            optional: false,
427            kind: ParameterKind::Value,
428        }];
429
430        let res = partition
431            .validate(&types, &expected_type(), "service transforms block")
432            .expect_err("should fail for assign key operator with wrong input type");
433
434        assert!(res.errors.contains(&ValidationError::new(
435            "assign-key function `assign-key-fn` input type should match `s16` provided by `service transforms block` but found `u8`"
436        )));
437    }
438
439    #[test]
440    fn test_validate_validates_transforms() {
441        let types = SdfTypesMap::default();
442        let mut partition = partition_operator();
443
444        partition.transforms = Transforms {
445            steps: vec![TransformOperator::Filter(StepInvocation {
446                uses: "filter-fn".to_string(),
447                ..Default::default()
448            })],
449        };
450
451        let res = partition
452            .validate(&types, &expected_type(), "transforms block")
453            .expect_err("should fail for invalid filter function");
454
455        assert!(res.errors.contains(&ValidationError::new(
456            "transforms block is invalid: filter type function `filter-fn` should have exactly 1 input type, found 0"
457        )));
458    }
459
460    #[test]
461    fn test_operators() {
462        let partition = partition_operator();
463
464        let operators = partition.operators();
465
466        assert_eq!(operators.len(), 3);
467        assert_eq!(operators[0].0.uses, "assign-key-fn");
468        assert_eq!(operators[1].0.uses, "map-fn");
469        assert_eq!(operators[2].0.uses, "map-fn");
470    }
471
472    #[test]
473    fn test_add_operator() {
474        let mut partition = PartitionOperator {
475            assign_key: StepInvocation {
476                uses: "assign_key".to_string(),
477                ..Default::default()
478            },
479            transforms: Transforms {
480                steps: vec![TransformOperator::Map(StepInvocation {
481                    uses: "prospect_map_prospect2".to_string(),
482                    ..Default::default()
483                })],
484            },
485            update_state: None,
486        };
487
488        let res = partition.add_operator(
489            Some(1),
490            OperatorType::Map,
491            StepInvocation {
492                uses: "prospect_map_prospect2".to_string(),
493                ..Default::default()
494            },
495        );
496
497        assert!(res.is_ok());
498        assert_eq!(partition.transforms.steps.len(), 2);
499    }
500
501    #[test]
502    fn test_delete_operator() {
503        let mut partition = PartitionOperator {
504            assign_key: StepInvocation {
505                uses: "assign_key".to_string(),
506                ..Default::default()
507            },
508            transforms: Transforms {
509                steps: vec![TransformOperator::Map(StepInvocation {
510                    uses: "prospect_map_prospect2".to_string(),
511                    ..Default::default()
512                })],
513            },
514            update_state: None,
515        };
516
517        let res = partition.delete_operator(Some(0));
518
519        assert!(res.is_ok());
520        assert_eq!(partition.transforms.steps.len(), 0);
521    }
522}