webnn_graph/onnx/ops/
conditional.rs1use crate::ast::Node;
4use crate::onnx::convert::{sanitize_identifier, OnnxError};
5use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler};
6use crate::protos::onnx::NodeProto;
7use serde_json::Map;
8
9pub struct ConditionalHandler;
10
11impl OpHandler for ConditionalHandler {
12 fn supports(&self, op_type: &str) -> bool {
13 matches!(op_type, "Where")
14 }
15
16 fn convert(
17 &self,
18 node: &NodeProto,
19 context: &ConversionContext,
20 ) -> Result<ConversionResult, OnnxError> {
21 let op_type = node.op_type.as_str();
22 let node_name = if !node.name.is_empty() {
23 node.name.as_str().to_string()
24 } else {
25 "unnamed".to_string()
26 };
27
28 let inputs = node.input.as_slice();
29 if inputs.len() != 3 {
30 return Err(OnnxError::InvalidShape(format!(
31 "{} expects 3 inputs (condition, x, y), got {}",
32 op_type,
33 inputs.len()
34 )));
35 }
36
37 let output_name = if node.output.as_slice().is_empty() {
38 format!("{}_output", node_name)
39 } else {
40 sanitize_identifier(&node.output.as_slice()[0].to_string())
41 };
42
43 let condition = context.resolve_input(&inputs[0]);
45 let true_value = context.resolve_input(&inputs[1]);
46 let false_value = context.resolve_input(&inputs[2]);
47
48 let mut result = ConversionResult::new(vec![Node {
49 id: output_name.clone(),
50 op: "where".to_string(),
51 inputs: vec![condition, true_value, false_value],
52 options: Map::new(),
53 outputs: None,
54 }]);
55
56 if let Some(output) = node.output.as_slice().first() {
57 result
58 .output_mappings
59 .insert(output.to_string(), output_name.clone());
60 if let Some(dtype) = context.value_types.get(&inputs[1]) {
62 result
63 .output_types
64 .insert(output.to_string(), dtype.clone());
65 }
66 }
67
68 Ok(result)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::ast::DataType;
76 use crate::protos::onnx::NodeProto;
77 use std::collections::HashMap;
78
79 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
80 NodeProto {
81 op_type: op_type.to_string(),
82 name: format!("test_{}", op_type.to_lowercase()),
83 input: inputs.iter().map(|s| s.to_string()).collect(),
84 output: outputs.iter().map(|s| s.to_string()).collect(),
85 ..Default::default()
86 }
87 }
88
89 #[test]
90 fn test_conditional_handler_supports() {
91 let handler = ConditionalHandler;
92 assert!(handler.supports("Where"));
93 assert!(!handler.supports("Add"));
94 assert!(!handler.supports("Greater"));
95 }
96
97 #[test]
98 fn test_where_conversion() {
99 let handler = ConditionalHandler;
100 let node = create_test_node("Where", vec!["condition", "x", "y"], vec!["output"]);
101 let initializers = HashMap::new();
102 let value_shapes = HashMap::new();
103 let const_values = HashMap::new();
104 let value_ids = HashMap::new();
105 let mut value_types = HashMap::new();
106 value_types.insert("x".to_string(), DataType::Float32);
107 value_types.insert("y".to_string(), DataType::Float32);
108 let context = ConversionContext {
109 initializers: &initializers,
110 value_shapes: &value_shapes,
111 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
112 const_values: &const_values,
113 value_ids: &value_ids,
114 value_types: &value_types,
115 };
116
117 let result = handler.convert(&node, &context).unwrap();
118
119 assert_eq!(result.nodes.len(), 1);
120 let converted_node = &result.nodes[0];
121 assert_eq!(converted_node.op, "where");
122 assert_eq!(converted_node.inputs.len(), 3);
123 assert_eq!(converted_node.inputs[0], "condition");
124 assert_eq!(converted_node.inputs[1], "x");
125 assert_eq!(converted_node.inputs[2], "y");
126
127 assert_eq!(result.output_types.get("output"), Some(&DataType::Float32));
129 }
130
131 #[test]
132 fn test_where_invalid_inputs() {
133 let handler = ConditionalHandler;
134 let node = create_test_node("Where", vec!["condition", "x"], vec!["output"]); let initializers = HashMap::new();
136 let value_shapes = HashMap::new();
137 let const_values = HashMap::new();
138 let value_ids = HashMap::new();
139 let value_types = HashMap::new();
140 let context = ConversionContext {
141 initializers: &initializers,
142 value_shapes: &value_shapes,
143 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
144 const_values: &const_values,
145 value_ids: &value_ids,
146 value_types: &value_types,
147 };
148
149 let result = handler.convert(&node, &context);
150 assert!(result.is_err());
151 if let Err(OnnxError::InvalidShape(msg)) = result {
152 assert!(msg.contains("expects 3 inputs"));
153 } else {
154 panic!("Expected InvalidShape error");
155 }
156 }
157}