Skip to main content

webnn_graph/onnx/ops/
conditional.rs

1// Conditional operators: Where
2
3use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ConditionalHandler;
10
11impl OpHandler for ConditionalHandler {
12    fn supports(&self, op_type: &str) -> bool {
13        matches!(op_type, "Where")
14    }
15
16    fn convert(
17        &self,
18        node: &NodeProto,
19        context: &ConversionContext,
20    ) -> Result<ConversionResult, OnnxError> {
21        let op_type = node.op_type.as_str();
22        let node_name = if !node.name.is_empty() {
23            node.name.as_str().to_string()
24        } else {
25            "unnamed".to_string()
26        };
27
28        let inputs = node.input.as_slice();
29        if inputs.len() != 3 {
30            return Err(OnnxError::InvalidShape(format!(
31                "{} expects 3 inputs (condition, x, y), got {}",
32                op_type,
33                inputs.len()
34            )));
35        }
36
37        let output_name = if node.output.as_slice().is_empty() {
38            format!("{}_output", node_name)
39        } else {
40            sanitize_identifier(&node.output.as_slice()[0].to_string())
41        };
42
43        // Resolve input names (respecting prior mappings)
44        let condition = context.resolve_input(&inputs[0]);
45        let true_value = context.resolve_input(&inputs[1]);
46        let false_value = context.resolve_input(&inputs[2]);
47
48        let mut result = ConversionResult::new(vec![Node {
49            id: output_name.clone(),
50            op: "where".to_string(),
51            inputs: vec![condition, true_value, false_value],
52            options: Map::new(),
53            outputs: None,
54        }]);
55
56        if let Some(output) = node.output.as_slice().first() {
57            result
58                .output_mappings
59                .insert(output.to_string(), output_name.clone());
60            // Where output type matches the input data type (x and y), not condition
61            if let Some(dtype) = context.value_types.get(&inputs[1]) {
62                result
63                    .output_types
64                    .insert(output.to_string(), dtype.clone());
65            }
66        }
67
68        Ok(result)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::ast::DataType;
76    use crate::protos::onnx::NodeProto;
77    use std::collections::HashMap;
78
79    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
80        NodeProto {
81            op_type: op_type.to_string(),
82            name: format!("test_{}", op_type.to_lowercase()),
83            input: inputs.iter().map(|s| s.to_string()).collect(),
84            output: outputs.iter().map(|s| s.to_string()).collect(),
85            ..Default::default()
86        }
87    }
88
89    #[test]
90    fn test_conditional_handler_supports() {
91        let handler = ConditionalHandler;
92        assert!(handler.supports("Where"));
93        assert!(!handler.supports("Add"));
94        assert!(!handler.supports("Greater"));
95    }
96
97    #[test]
98    fn test_where_conversion() {
99        let handler = ConditionalHandler;
100        let node = create_test_node("Where", vec!["condition", "x", "y"], vec!["output"]);
101        let initializers = HashMap::new();
102        let value_shapes = HashMap::new();
103        let const_values = HashMap::new();
104        let value_ids = HashMap::new();
105        let mut value_types = HashMap::new();
106        value_types.insert("x".to_string(), DataType::Float32);
107        value_types.insert("y".to_string(), DataType::Float32);
108        let context = ConversionContext {
109            initializers: &initializers,
110            value_shapes: &value_shapes,
111            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
112            const_values: &const_values,
113            value_ids: &value_ids,
114            value_types: &value_types,
115        };
116
117        let result = handler.convert(&node, &context).unwrap();
118
119        assert_eq!(result.nodes.len(), 1);
120        let converted_node = &result.nodes[0];
121        assert_eq!(converted_node.op, "where");
122        assert_eq!(converted_node.inputs.len(), 3);
123        assert_eq!(converted_node.inputs[0], "condition");
124        assert_eq!(converted_node.inputs[1], "x");
125        assert_eq!(converted_node.inputs[2], "y");
126
127        // Check output type matches input data type
128        assert_eq!(result.output_types.get("output"), Some(&DataType::Float32));
129    }
130
131    #[test]
132    fn test_where_invalid_inputs() {
133        let handler = ConditionalHandler;
134        let node = create_test_node("Where", vec!["condition", "x"], vec!["output"]); // Only 2 inputs
135        let initializers = HashMap::new();
136        let value_shapes = HashMap::new();
137        let const_values = HashMap::new();
138        let value_ids = HashMap::new();
139        let value_types = HashMap::new();
140        let context = ConversionContext {
141            initializers: &initializers,
142            value_shapes: &value_shapes,
143            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
144            const_values: &const_values,
145            value_ids: &value_ids,
146            value_types: &value_types,
147        };
148
149        let result = handler.convert(&node, &context);
150        assert!(result.is_err());
151        if let Err(OnnxError::InvalidShape(msg)) = result {
152            assert!(msg.contains("expects 3 inputs"));
153        } else {
154            panic!("Expected InvalidShape error");
155        }
156    }
157}