Skip to main content

webnn_graph/onnx/ops/
utility.rs

1// Utility operators: Shape, Gather, Slice
2
3use crate::ast::Node;
4use crate::ast::{ConstDecl, ConstInit, DataType};
5use crate::onnx::convert::{sanitize_identifier, OnnxError};
6use crate::onnx::ops::{
7    normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler,
8};
9use crate::protos::onnx::NodeProto;
10use serde_json::{json, Map};
11
12pub struct UtilityHandler;
13
14impl OpHandler for UtilityHandler {
15    fn supports(&self, op_type: &str) -> bool {
16        matches!(
17            op_type,
18            "Shape" | "Gather" | "Slice" | "ConstantOfShape" | "Range" | "Trilu"
19        )
20    }
21
22    fn convert(
23        &self,
24        node: &NodeProto,
25        context: &ConversionContext,
26    ) -> Result<ConversionResult, OnnxError> {
27        let op_type = node.op_type.as_str();
28        let node_name = if !node.name.is_empty() {
29            node.name.as_str().to_string()
30        } else {
31            "unnamed".to_string()
32        };
33
34        match op_type {
35            "Shape" => self.convert_shape(node, &node_name, context),
36            "Gather" => self.convert_gather(node, &node_name, context),
37            "Slice" => self.convert_slice(node, &node_name, context),
38            "ConstantOfShape" => self.convert_constant_of_shape(node, &node_name, context),
39            "Range" => self.convert_range(node, &node_name, context),
40            "Trilu" => self.convert_trilu(node, &node_name, context),
41            _ => Err(OnnxError::UnsupportedOp {
42                op: op_type.to_string(),
43                node: node_name,
44            }),
45        }
46    }
47}
48
49impl UtilityHandler {
50    /// Convert ONNX Shape to WebNN shape operation
51    /// Returns a 1D tensor containing the dimensions of the input
52    fn convert_shape(
53        &self,
54        node: &NodeProto,
55        node_name: &str,
56        context: &ConversionContext,
57    ) -> Result<ConversionResult, OnnxError> {
58        let inputs = node.input.as_slice();
59        if inputs.len() != 1 {
60            return Err(OnnxError::InvalidShape(format!(
61                "Shape expects 1 input, got {}",
62                inputs.len()
63            )));
64        }
65
66        let output_name = if node.output.as_slice().is_empty() {
67            format!("{}_output", node_name)
68        } else {
69            sanitize_identifier(&node.output.as_slice()[0].to_string())
70        };
71
72        let input0 = context.resolve_input(&inputs[0]);
73
74        let options = Map::new();
75
76        // WebNN doesn't have a direct shape operation, but we can use identity
77        // and mark it with metadata that this is a shape operation
78        let mut result = ConversionResult::new(vec![Node {
79            id: output_name.clone(),
80            op: "shape".to_string(),
81            inputs: vec![input0],
82            options,
83            outputs: None,
84        }]);
85
86        if let Some(output) = node.output.as_slice().first() {
87            result
88                .output_mappings
89                .insert(output.to_string(), output_name.clone());
90        }
91
92        Ok(result)
93    }
94
95    fn read_scalar_i64(&self, name: &str, context: &ConversionContext) -> Option<i64> {
96        if let Some(vals) = context.const_values.get(name) {
97            return vals.first().copied();
98        }
99        if let Some(t) = context.initializers.get(name) {
100            let raw = t.raw_data.as_slice();
101            if !raw.is_empty() {
102                if t.data_type == crate::protos::onnx::TensorProto_DataType::Int32 as i32 {
103                    return Some(i32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]]) as i64);
104                }
105                if raw.len() >= 8 {
106                    return Some(i64::from_le_bytes([
107                        raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
108                    ]));
109                }
110            } else if !t.int64_data.as_slice().is_empty() {
111                return t.int64_data.as_slice().first().copied();
112            } else if !t.int32_data.as_slice().is_empty() {
113                return t.int32_data.as_slice().first().map(|v| *v as i64);
114            }
115        }
116        None
117    }
118
119    fn convert_range(
120        &self,
121        node: &NodeProto,
122        node_name: &str,
123        context: &ConversionContext,
124    ) -> Result<ConversionResult, OnnxError> {
125        let inputs = node.input.as_slice();
126        if inputs.len() != 3 {
127            return Err(OnnxError::InvalidShape(format!(
128                "Range expects 3 inputs (start, limit, delta), got {}",
129                inputs.len()
130            )));
131        }
132
133        let output_name = if node.output.as_slice().is_empty() {
134            format!("{}_output", node_name)
135        } else {
136            sanitize_identifier(&node.output.as_slice()[0].to_string())
137        };
138
139        let start = self.read_scalar_i64(&inputs[0], context);
140        let limit = self.read_scalar_i64(&inputs[1], context);
141        let delta = self.read_scalar_i64(&inputs[2], context);
142
143        let start_dim = crate::onnx::convert::dynamic_scalar_dimension_for_value(
144            &inputs[0],
145            context.value_shape_dims,
146        );
147        if let (Some(start), Some(delta), Some(limit_dim)) = (
148            start,
149            delta,
150            crate::onnx::convert::dynamic_scalar_dimension_for_value(
151                &inputs[1],
152                context.value_shape_dims,
153            ),
154        ) {
155            let range_dim = crate::onnx::convert::dynamic_range_length_dimension(
156                start,
157                delta,
158                start_dim.as_ref(),
159                &limit_dim,
160            )
161            .ok_or_else(|| {
162                OnnxError::InvalidShape(format!(
163                    "Range {} requires dynamic range length to be representable as <dim> +/- const with delta=1",
164                    node_name,
165                ))
166            })?;
167
168            let max_len = usize::try_from(range_dim.max_size).map_err(|_| {
169                OnnxError::InvalidShape(format!(
170                    "Range {} max size {} does not fit in usize",
171                    node_name, range_dim.max_size
172                ))
173            })?;
174
175            let use_runtime_start = start_dim.is_some();
176            let mut values = Vec::with_capacity(max_len.max(1));
177            let mut current = if use_runtime_start { 0 } else { start };
178            for _ in 0..max_len {
179                values.push(current);
180                current += delta;
181            }
182            if values.is_empty() {
183                values.push(if use_runtime_start { 0 } else { start });
184            }
185
186            let bytes: Vec<u8> = values
187                .iter()
188                .flat_map(|v| v.to_le_bytes().to_vec())
189                .collect();
190
191            let range_const_name = format!("{}_range_const", output_name);
192            let range_const = ConstDecl {
193                data_type: DataType::Int64,
194                shape: vec![values.len() as u32],
195                init: ConstInit::InlineBytes { bytes },
196            };
197
198            let mut options = Map::new();
199            options.insert("starts".to_string(), json!([0]));
200            options.insert(
201                "sizes".to_string(),
202                json!([{
203                    "name": range_dim.name,
204                    "maxSize": range_dim.max_size
205                }]),
206            );
207            options.insert("strides".to_string(), json!([1]));
208
209            let sliced_name = if use_runtime_start {
210                format!("{}_slice", output_name)
211            } else {
212                output_name.clone()
213            };
214            let mut nodes = vec![Node {
215                id: sliced_name.clone(),
216                op: "slice".to_string(),
217                inputs: vec![range_const_name.clone()],
218                options,
219                outputs: None,
220            }];
221            if use_runtime_start {
222                nodes.push(Node {
223                    id: output_name.clone(),
224                    op: "add".to_string(),
225                    inputs: vec![sliced_name, context.resolve_input(&inputs[0])],
226                    options: Map::new(),
227                    outputs: None,
228                });
229            }
230
231            let mut result = ConversionResult::new(nodes);
232            result.consts.push((range_const_name, range_const));
233            if let Some(out) = node.output.as_slice().first() {
234                result
235                    .output_mappings
236                    .insert(out.to_string(), output_name.clone());
237                result.output_types.insert(out.to_string(), DataType::Int64);
238            }
239            return Ok(result);
240        }
241
242        let start = start.ok_or_else(|| {
243            OnnxError::InvalidShape(format!(
244                "Range {} requires a constant scalar start input",
245                node_name
246            ))
247        })?;
248        let limit = limit.ok_or_else(|| {
249            OnnxError::InvalidShape(format!(
250                "Range {} requires a constant scalar or supported dynamic limit input",
251                node_name
252            ))
253        })?;
254        let delta = delta.ok_or_else(|| {
255            OnnxError::InvalidShape(format!(
256                "Range {} requires a constant scalar delta input",
257                node_name
258            ))
259        })?;
260
261        if delta == 0 {
262            return Err(OnnxError::InvalidShape(
263                "Range delta cannot be zero".to_string(),
264            ));
265        }
266
267        let mut values = Vec::new();
268        let mut v = start;
269        if delta > 0 {
270            while v < limit {
271                values.push(v);
272                v += delta;
273            }
274        } else {
275            while v > limit {
276                values.push(v);
277                v += delta;
278            }
279        }
280
281        if values.is_empty() {
282            values.push(0);
283        }
284
285        let bytes: Vec<u8> = values
286            .iter()
287            .flat_map(|v| v.to_le_bytes().to_vec())
288            .collect();
289
290        let const_decl = ConstDecl {
291            data_type: DataType::Int64,
292            shape: vec![values.len() as u32],
293            init: ConstInit::InlineBytes { bytes },
294        };
295
296        let mut result = ConversionResult::new(vec![]);
297        result.consts.push((output_name.clone(), const_decl));
298        if let Some(out) = node.output.as_slice().first() {
299            result
300                .output_mappings
301                .insert(out.to_string(), output_name.clone());
302            result.output_types.insert(out.to_string(), DataType::Int64);
303        }
304
305        Ok(result)
306    }
307
308    fn convert_trilu(
309        &self,
310        node: &NodeProto,
311        node_name: &str,
312        context: &ConversionContext,
313    ) -> Result<ConversionResult, OnnxError> {
314        let inputs = node.input.as_slice();
315        if inputs.is_empty() {
316            return Err(OnnxError::InvalidShape(
317                "Trilu expects at least 1 input (data)".to_string(),
318            ));
319        }
320
321        if inputs.len() > 2 {
322            return Err(OnnxError::InvalidShape(format!(
323                "Trilu expects at most 2 inputs (data, k), got {}",
324                inputs.len()
325            )));
326        }
327
328        let mut upper = true;
329        for attr in node.attribute.as_slice() {
330            if attr.name.as_str() == "upper" {
331                upper = attr.i != 0;
332            }
333        }
334
335        let mut k: i64 = 0;
336        if inputs.len() == 2 {
337            let k_input = inputs[1].as_str();
338            if let Some(offset) = self.read_scalar_i64(k_input, context) {
339                k = offset;
340            } else {
341                return Err(OnnxError::InvalidShape(
342                    "Trilu k input must be a constant scalar for WebNN".to_string(),
343                ));
344            }
345        }
346
347        let output_name = if node.output.as_slice().is_empty() {
348            format!("{}_output", node_name)
349        } else {
350            sanitize_identifier(&node.output.as_slice()[0].to_string())
351        };
352
353        let input0 = context.resolve_input(&inputs[0]);
354
355        let mut options = Map::new();
356        options.insert("upper".to_string(), json!(upper));
357        options.insert("k".to_string(), json!(k));
358
359        let mut result = ConversionResult::new(vec![Node {
360            id: output_name.clone(),
361            op: "triangular".to_string(),
362            inputs: vec![input0],
363            options,
364            outputs: None,
365        }]);
366
367        if let Some(output) = node.output.as_slice().first() {
368            result
369                .output_mappings
370                .insert(output.to_string(), output_name.clone());
371            if let Some(dtype) = context.value_types.get(&inputs[0]) {
372                result
373                    .output_types
374                    .insert(output.to_string(), dtype.clone());
375            }
376        }
377
378        Ok(result)
379    }
380
381    /// Convert ConstantOfShape into an inline constant when the output shape is statically known.
382    fn convert_constant_of_shape(
383        &self,
384        node: &NodeProto,
385        node_name: &str,
386        context: &ConversionContext,
387    ) -> Result<ConversionResult, OnnxError> {
388        let output_name = if node.output.as_slice().is_empty() {
389            format!("{}_output", node_name)
390        } else {
391            sanitize_identifier(&node.output.as_slice()[0].to_string())
392        };
393
394        let output_dim_shape = node
395            .output
396            .as_slice()
397            .first()
398            .and_then(|out| {
399                let out_s = out.to_string();
400                context
401                    .value_shape_dims
402                    .get(&out_s)
403                    .or_else(|| context.value_shape_dims.get(&sanitize_identifier(&out_s)))
404                    .or_else(|| context.value_shape_dims.get(out_s.trim_start_matches('/')))
405            })
406            .cloned();
407
408        // Determine the target shape: prefer inferred output shape, otherwise try the shape input const.
409        let mut shape: Option<Vec<i64>> = None;
410        if let Some(out) = node.output.as_slice().first() {
411            if let Some(s) = context.value_shapes.get(out) {
412                shape = Some(s.clone());
413            } else {
414                let sanitized = sanitize_identifier(out);
415                if let Some(s) = context.value_shapes.get(&sanitized) {
416                    shape = Some(s.clone());
417                }
418            }
419        }
420        if shape.is_none() {
421            if let Some(shape_input) = node.input.as_slice().first() {
422                if let Some(vals) = context.const_values.get(shape_input) {
423                    shape = Some(vals.clone());
424                } else if let Some(len_shape) = context.value_shapes.get(shape_input) {
425                    // If we only know the length of the shape tensor, default the dims to 1s.
426                    if len_shape.len() == 1 && len_shape[0] > 0 {
427                        shape = Some(vec![1; len_shape[0] as usize]);
428                    }
429                }
430            }
431        }
432
433        // Determine fill value and data type (default int64 zero)
434        let mut fill_value_i64: i64 = 0;
435        let mut dtype = DataType::Int64;
436        for attr in node.attribute.as_slice() {
437            if attr.name.as_str() == "value" {
438                if let Some(t) = attr.t.as_ref() {
439                    match t.data_type {
440                        // FLOAT
441                        x if x == crate::protos::onnx::TensorProto_DataType::Float as i32 => {
442                            dtype = DataType::Float32;
443                            if !t.float_data.as_slice().is_empty() {
444                                fill_value_i64 = t.float_data.as_slice()[0].to_bits() as i64;
445                            } else if !t.raw_data.as_slice().is_empty()
446                                && t.raw_data.as_slice().len() >= 4
447                            {
448                                let raw = &t.raw_data.as_slice()[..4];
449                                let bits = u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]]);
450                                fill_value_i64 = bits as i64;
451                            } else {
452                                fill_value_i64 = 0f32.to_bits() as i64;
453                            }
454                        }
455                        // INT64
456                        x if x == crate::protos::onnx::TensorProto_DataType::Int64 as i32 => {
457                            dtype = DataType::Int64;
458                            if !t.int64_data.as_slice().is_empty() {
459                                fill_value_i64 = t.int64_data.as_slice()[0];
460                            } else if !t.raw_data.as_slice().is_empty()
461                                && t.raw_data.as_slice().len() >= 8
462                            {
463                                let raw = &t.raw_data.as_slice()[..8];
464                                fill_value_i64 = i64::from_le_bytes([
465                                    raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
466                                ]);
467                            }
468                        }
469                        _ => {}
470                    }
471                }
472            }
473        }
474
475        if let Some(dims) = output_dim_shape.as_ref().filter(|dims| {
476            dims.iter()
477                .any(|d| matches!(d, crate::ast::Dimension::Dynamic(_)))
478        }) {
479            let scalar_name = format!("{}_fill", output_name);
480            let scalar_bytes = match dtype {
481                DataType::Float32 => {
482                    let f = f32::from_bits(fill_value_i64 as u32);
483                    f.to_le_bytes().to_vec()
484                }
485                _ => fill_value_i64.to_le_bytes().to_vec(),
486            };
487            let scalar_decl = ConstDecl {
488                data_type: dtype.clone(),
489                shape: vec![1],
490                init: ConstInit::InlineBytes {
491                    bytes: scalar_bytes,
492                },
493            };
494
495            let new_shape: Vec<serde_json::Value> = dims
496                .iter()
497                .map(|d| match d {
498                    crate::ast::Dimension::Static(v) => serde_json::json!(v),
499                    crate::ast::Dimension::Dynamic(dd) => serde_json::json!({
500                        "name": dd.name,
501                        "maxSize": dd.max_size
502                    }),
503                })
504                .collect();
505
506            let mut options = Map::new();
507            options.insert("newShape".to_string(), serde_json::json!(new_shape));
508
509            let mut result = ConversionResult::new(vec![Node {
510                id: output_name.clone(),
511                op: "expand".to_string(),
512                inputs: vec![scalar_name.clone()],
513                options,
514                outputs: None,
515            }]);
516            result.consts.push((scalar_name, scalar_decl));
517            if let Some(out) = node.output.as_slice().first() {
518                result
519                    .output_mappings
520                    .insert(out.to_string(), output_name.clone());
521                result.output_types.insert(out.to_string(), dtype);
522            }
523            return Ok(result);
524        }
525
526        let shape = shape.unwrap_or_else(|| vec![1]);
527
528        let mut numel: usize = 1;
529        for d in &shape {
530            if *d <= 0 {
531                return Err(OnnxError::InvalidShape(format!(
532                    "ConstantOfShape '{}' has non-positive dimension {:?}",
533                    node_name, shape
534                )));
535            }
536            numel = numel.saturating_mul(*d as usize);
537        }
538
539        let bytes = match dtype {
540            DataType::Float32 => {
541                let f = f32::from_bits(fill_value_i64 as u32);
542                let val = f.to_le_bytes();
543                val.repeat(numel)
544            }
545            _ => {
546                let val = fill_value_i64.to_le_bytes();
547                val.repeat(numel)
548            }
549        };
550
551        let const_decl = ConstDecl {
552            data_type: dtype.clone(),
553            shape: shape.iter().map(|d| *d as u32).collect(),
554            init: ConstInit::InlineBytes { bytes },
555        };
556
557        let mut result = ConversionResult::new(vec![]);
558        result.consts.push((output_name.clone(), const_decl));
559        if let Some(out) = node.output.as_slice().first() {
560            result
561                .output_mappings
562                .insert(out.to_string(), output_name.clone());
563            result.output_types.insert(out.to_string(), dtype);
564        }
565
566        Ok(result)
567    }
568
569    /// Convert ONNX Gather to WebNN gather
570    /// Gathers elements along a specified axis using indices
571    fn convert_gather(
572        &self,
573        node: &NodeProto,
574        node_name: &str,
575        context: &ConversionContext,
576    ) -> Result<ConversionResult, OnnxError> {
577        let inputs = node.input.as_slice();
578        if inputs.len() < 2 {
579            return Err(OnnxError::InvalidShape(format!(
580                "Gather expects 2 inputs (data, indices), got {}",
581                inputs.len()
582            )));
583        }
584
585        // Extract axis attribute (default: 0)
586        let mut axis = 0i64;
587        for attr in node.attribute.as_slice() {
588            if attr.name.as_str() == "axis" && attr.i != 0 {
589                axis = attr.i;
590            }
591        }
592
593        let output_name = if node.output.as_slice().is_empty() {
594            format!("{}_output", node_name)
595        } else {
596            sanitize_identifier(&node.output.as_slice()[0].to_string())
597        };
598
599        let input0 = context.resolve_input(&inputs[0]);
600        let input1 = context.resolve_input(&inputs[1]);
601
602        let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
603            normalize_axis_best_effort(axis, rank)
604        } else {
605            axis
606        };
607
608        let mut options = Map::new();
609        options.insert("axis".to_string(), serde_json::json!(axis));
610
611        // Propagate output shape metadata when available so downstream ops see correct ranks
612        if let (Some(data_shape), Some(indices_shape)) = (
613            context.value_shapes.get(&inputs[0]),
614            context.value_shapes.get(&inputs[1]),
615        ) {
616            let resolved_axis = axis;
617            if resolved_axis >= 0 && (resolved_axis as usize) < data_shape.len() {
618                let axis_idx = resolved_axis as usize;
619                let mut out_shape = Vec::new();
620                out_shape.extend_from_slice(&data_shape[..axis_idx]);
621                out_shape.extend(indices_shape.iter().cloned());
622                if axis_idx < data_shape.len() {
623                    out_shape.extend_from_slice(&data_shape[axis_idx + 1..]);
624                }
625                options.insert("shape".to_string(), serde_json::json!(out_shape));
626            }
627        }
628
629        let mut result = ConversionResult::new(vec![Node {
630            id: output_name.clone(),
631            op: "gather".to_string(),
632            inputs: vec![input0, input1],
633            options,
634            outputs: None,
635        }]);
636
637        if let Some(output) = node.output.as_slice().first() {
638            result
639                .output_mappings
640                .insert(output.to_string(), output_name.clone());
641            if let Some(dtype) = context.value_types.get(&inputs[0]) {
642                result
643                    .output_types
644                    .insert(output.to_string(), dtype.clone());
645            }
646        }
647
648        Ok(result)
649    }
650
651    /// Convert ONNX Slice to WebNN slice
652    /// Extracts a slice from the input tensor
653    fn convert_slice(
654        &self,
655        node: &NodeProto,
656        node_name: &str,
657        context: &ConversionContext,
658    ) -> Result<ConversionResult, OnnxError> {
659        let inputs = node.input.as_slice();
660        if inputs.is_empty() {
661            return Err(OnnxError::InvalidShape(
662                "Slice expects at least 1 input".to_string(),
663            ));
664        }
665
666        let output_name = if node.output.as_slice().is_empty() {
667            format!("{}_output", node_name)
668        } else {
669            sanitize_identifier(&node.output.as_slice()[0].to_string())
670        };
671
672        let input0 = context.resolve_input(&inputs[0]);
673
674        let read_ints = |name: &str, context: &ConversionContext| -> Option<Vec<i64>> {
675            if let Some(vals) = context.const_values.get(name) {
676                return Some(vals.clone());
677            }
678            if let Some(t) = context.initializers.get(name) {
679                let raw = t.raw_data.as_slice();
680                if !raw.is_empty() {
681                    if t.data_type == crate::protos::onnx::TensorProto_DataType::Int32 as i32 {
682                        return Some(
683                            raw.chunks_exact(4)
684                                .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
685                                .collect(),
686                        );
687                    }
688                    return Some(
689                        raw.chunks_exact(8)
690                            .map(|c| {
691                                i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
692                            })
693                            .collect(),
694                    );
695                } else if !t.int64_data.as_slice().is_empty() {
696                    return Some(t.int64_data.as_slice().to_vec());
697                } else if !t.int32_data.as_slice().is_empty() {
698                    return Some(t.int32_data.as_slice().iter().map(|&v| v as i64).collect());
699                }
700            }
701            None
702        };
703
704        let mut options = Map::new();
705
706        // In opset >= 10, starts/ends/axes/steps are inputs
707        // WebNN requires static values, so we enforce const-ness here.
708        if inputs.len() >= 3 {
709            let starts_name = inputs[1].as_str();
710            let ends_name = inputs[2].as_str();
711            let mut starts = read_ints(starts_name, context);
712            let mut ends = read_ints(ends_name, context);
713
714            if starts.is_none() || ends.is_none() {
715                // As a last resort, try to pull starts/ends from sibling consts
716                // produced by earlier shape inference passes.
717                if let Some(s) = context.const_values.get(starts_name) {
718                    starts = Some(s.clone());
719                }
720                if let Some(e) = context.const_values.get(ends_name) {
721                    ends = Some(e.clone());
722                }
723
724                let fallback_len = if let Some(axes_name) = inputs.get(3).map(|s| s.as_str()) {
725                    read_ints(axes_name, context)
726                        .map(|v| v.len())
727                        .unwrap_or_else(|| {
728                            starts
729                                .as_ref()
730                                .map(|v| v.len())
731                                .or_else(|| {
732                                    context
733                                        .value_shapes
734                                        .get(inputs[0].as_str())
735                                        .map(|s| s.len())
736                                })
737                                .unwrap_or(1)
738                        })
739                } else {
740                    starts
741                        .as_ref()
742                        .map(|v| v.len())
743                        .or_else(|| {
744                            context
745                                .value_shapes
746                                .get(inputs[0].as_str())
747                                .map(|s| s.len())
748                        })
749                        .unwrap_or(1)
750                };
751
752                starts.get_or_insert(vec![0; fallback_len]);
753                // Keep Slice dynamic when ONNX ends input is non-const.
754                ends.get_or_insert(vec![i64::MAX; fallback_len]);
755
756                crate::debug_println!(
757                    "[slice] using fallback starts/ends for {}, starts={:?} ends={:?}",
758                    node_name,
759                    starts,
760                    ends
761                );
762            }
763
764            let starts = starts.ok_or_else(|| {
765                OnnxError::InvalidShape("Slice starts must be constant for WebNN".to_string())
766            })?;
767            let ends = ends.ok_or_else(|| {
768                OnnxError::InvalidShape("Slice ends must be constant for WebNN".to_string())
769            })?;
770
771            // Normalize lengths: starts/ends must match axes length if provided,
772            // otherwise match each other.
773            let mut axes_opt: Option<Vec<i64>> = None;
774            if inputs.len() >= 4 {
775                let axes_name = inputs[3].as_str();
776                if let Some(axes) = read_ints(axes_name, context) {
777                    axes_opt = Some(axes);
778                }
779            }
780
781            let desired_len = axes_opt
782                .as_ref()
783                .map(|a| a.len())
784                .unwrap_or_else(|| starts.len().max(ends.len()));
785            let mut starts_norm = starts;
786            let mut ends_norm = ends;
787            if starts_norm.len() > desired_len {
788                starts_norm.truncate(desired_len);
789            } else {
790                starts_norm.resize(desired_len, 0);
791            }
792            if ends_norm.len() > desired_len {
793                ends_norm.truncate(desired_len);
794            } else {
795                // If we know data shape, use its dims; otherwise use max i64.
796                let fill = context
797                    .value_shapes
798                    .get(inputs[0].as_str())
799                    .and_then(|s| s.first())
800                    .copied()
801                    .unwrap_or(i64::MAX);
802                ends_norm.resize(desired_len, fill);
803            }
804
805            if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) {
806                let rank = input_shape.len();
807                let mut axes = if let Some(a) = axes_opt {
808                    if a.is_empty() {
809                        (0..desired_len as i64).collect::<Vec<_>>()
810                    } else {
811                        a
812                    }
813                } else {
814                    (0..desired_len as i64).collect::<Vec<_>>()
815                };
816                if axes.len() != desired_len {
817                    axes.resize(desired_len, 0);
818                }
819                let axes: Vec<i64> = axes
820                    .iter()
821                    .map(|&a| normalize_axis_best_effort(a, rank))
822                    .collect();
823
824                let mut steps = if inputs.len() >= 5 {
825                    let steps_name = inputs[4].as_str();
826                    read_ints(steps_name, context).unwrap_or_default()
827                } else {
828                    Vec::new()
829                };
830                if steps.len() > desired_len {
831                    steps.truncate(desired_len);
832                } else {
833                    steps.resize(desired_len, 1);
834                }
835
836                let mut dense_starts = vec![0i64; rank];
837                let mut dense_sizes: Vec<i64> = input_shape.clone();
838                let mut dense_strides = vec![1i64; rank];
839
840                // Check if ends input has dynamic dimension metadata
841                let ends_dims = context.value_shape_dims.get(ends_name).or_else(|| {
842                    context
843                        .value_shape_dims
844                        .get(&sanitize_identifier(ends_name))
845                });
846
847                // Track which dense axes have dynamic sizes
848                let mut dynamic_size_info: Vec<Option<crate::ast::DynamicDimension>> =
849                    vec![None; rank];
850
851                for i in 0..desired_len {
852                    let axis = axes[i] as usize;
853                    let dim = input_shape[axis];
854                    let step = steps[i];
855                    if step <= 0 {
856                        return Err(OnnxError::InvalidShape(
857                            "Slice currently requires positive step values".to_string(),
858                        ));
859                    }
860
861                    let mut start = starts_norm[i];
862                    let mut end = ends_norm[i];
863                    if start < 0 {
864                        start += dim;
865                    }
866                    if end == i64::MAX {
867                        end = dim;
868                    } else if end < 0 {
869                        end += dim;
870                    }
871                    start = start.clamp(0, dim);
872                    end = end.clamp(0, dim);
873
874                    let size = if end <= start {
875                        0
876                    } else {
877                        (end - start + step - 1) / step
878                    };
879
880                    // If this end value came from a dynamic dimension, mark the size as dynamic
881                    if let Some(dims) = ends_dims {
882                        if let Some(crate::ast::Dimension::Dynamic(dd)) = dims.get(i) {
883                            dynamic_size_info[axis] = Some(crate::ast::DynamicDimension {
884                                name: dd.name.clone(),
885                                max_size: size as u32,
886                            });
887                        }
888                    }
889
890                    dense_starts[axis] = start;
891                    dense_sizes[axis] = size;
892                    dense_strides[axis] = step;
893                }
894
895                options.insert("starts".to_string(), serde_json::json!(dense_starts));
896
897                // Emit sizes with dynamic dimension metadata when present
898                let has_dynamic = dynamic_size_info.iter().any(|d| d.is_some());
899                if has_dynamic {
900                    let sizes_json: Vec<serde_json::Value> = dense_sizes
901                        .iter()
902                        .zip(dynamic_size_info.iter())
903                        .map(|(&sz, dyn_info)| match dyn_info {
904                            Some(dd) => serde_json::json!({
905                                "name": dd.name,
906                                "maxSize": dd.max_size
907                            }),
908                            None => serde_json::json!(sz),
909                        })
910                        .collect();
911                    options.insert("sizes".to_string(), serde_json::json!(sizes_json));
912                } else {
913                    options.insert("sizes".to_string(), serde_json::json!(dense_sizes));
914                }
915
916                options.insert("strides".to_string(), serde_json::json!(dense_strides));
917            } else {
918                // Fallback for unknown-rank tensors: keep ONNX-style static slice options.
919                options.insert("starts".to_string(), serde_json::json!(starts_norm));
920                options.insert("ends".to_string(), serde_json::json!(ends_norm));
921                if let Some(axes) = axes_opt {
922                    options.insert("axes".to_string(), serde_json::json!(axes));
923                }
924                if inputs.len() >= 5 {
925                    let steps_name = inputs[4].as_str();
926                    if let Some(steps) = read_ints(steps_name, context) {
927                        options.insert("steps".to_string(), serde_json::json!(steps));
928                    }
929                }
930            }
931        } else {
932            // Extract from attributes (older opset)
933            for attr in node.attribute.as_slice() {
934                match attr.name.as_str() {
935                    "starts" => {
936                        options
937                            .insert("starts".to_string(), serde_json::json!(&attr.ints.to_vec()));
938                    }
939                    "ends" => {
940                        options.insert("ends".to_string(), serde_json::json!(&attr.ints.to_vec()));
941                    }
942                    "axes" => {
943                        options.insert("axes".to_string(), serde_json::json!(&attr.ints.to_vec()));
944                    }
945                    "steps" => {
946                        options.insert("steps".to_string(), serde_json::json!(&attr.ints.to_vec()));
947                    }
948                    _ => {}
949                }
950            }
951            if !options.contains_key("starts") || !options.contains_key("ends") {
952                return Err(OnnxError::InvalidShape(
953                    "Slice requires static starts/ends".to_string(),
954                ));
955            }
956
957            if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) {
958                let rank = input_shape.len();
959                let starts = options
960                    .remove("starts")
961                    .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
962                    .ok_or_else(|| OnnxError::InvalidShape("Slice starts malformed".to_string()))?;
963                let ends = options
964                    .remove("ends")
965                    .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
966                    .ok_or_else(|| OnnxError::InvalidShape("Slice ends malformed".to_string()))?;
967                let axes = options
968                    .remove("axes")
969                    .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
970                    .unwrap_or_else(|| (0..starts.len() as i64).collect::<Vec<_>>());
971                let mut steps = options
972                    .remove("steps")
973                    .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
974                    .unwrap_or_else(|| vec![1; starts.len()]);
975
976                let desired_len = starts.len().max(ends.len()).max(axes.len());
977                let mut starts = starts;
978                let mut ends = ends;
979                let mut axes = axes;
980                if starts.len() < desired_len {
981                    starts.resize(desired_len, 0);
982                }
983                if ends.len() < desired_len {
984                    ends.resize(desired_len, i64::MAX);
985                }
986                if axes.len() < desired_len {
987                    axes.resize(desired_len, 0);
988                }
989                if steps.len() < desired_len {
990                    steps.resize(desired_len, 1);
991                }
992
993                let axes: Vec<i64> = axes
994                    .iter()
995                    .map(|&a| normalize_axis_best_effort(a, rank))
996                    .collect();
997                let mut dense_starts = vec![0i64; rank];
998                let mut dense_sizes: Vec<i64> = input_shape.clone();
999                let mut dense_strides = vec![1i64; rank];
1000
1001                for i in 0..desired_len {
1002                    let axis = axes[i] as usize;
1003                    let dim = input_shape[axis];
1004                    let step = steps[i];
1005                    if step <= 0 {
1006                        return Err(OnnxError::InvalidShape(
1007                            "Slice currently requires positive step values".to_string(),
1008                        ));
1009                    }
1010
1011                    let mut start = starts[i];
1012                    let mut end = ends[i];
1013                    if start < 0 {
1014                        start += dim;
1015                    }
1016                    if end == i64::MAX {
1017                        end = dim;
1018                    } else if end < 0 {
1019                        end += dim;
1020                    }
1021                    start = start.clamp(0, dim);
1022                    end = end.clamp(0, dim);
1023
1024                    let size = if end <= start {
1025                        0
1026                    } else {
1027                        (end - start + step - 1) / step
1028                    };
1029
1030                    dense_starts[axis] = start;
1031                    dense_sizes[axis] = size;
1032                    dense_strides[axis] = step;
1033                }
1034
1035                options.insert("starts".to_string(), serde_json::json!(dense_starts));
1036                options.insert("sizes".to_string(), serde_json::json!(dense_sizes));
1037                options.insert("strides".to_string(), serde_json::json!(dense_strides));
1038            }
1039        }
1040
1041        let mut result = ConversionResult::new(vec![Node {
1042            id: output_name.clone(),
1043            op: "slice".to_string(),
1044            inputs: vec![input0],
1045            options,
1046            outputs: None,
1047        }]);
1048
1049        if let Some(output) = node.output.as_slice().first() {
1050            result
1051                .output_mappings
1052                .insert(output.to_string(), output_name.clone());
1053            if let Some(dtype) = context.value_types.get(&inputs[0]) {
1054                result
1055                    .output_types
1056                    .insert(output.to_string(), dtype.clone());
1057            }
1058        }
1059
1060        Ok(result)
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067    use crate::ast::DataType;
1068    use crate::protos::onnx::{AttributeProto, NodeProto, TensorProto, TensorProto_DataType};
1069    use serde_json::json;
1070
1071    fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
1072        NodeProto {
1073            op_type: op_type.to_string(),
1074            name: format!("test_{}", op_type.to_lowercase()),
1075            input: inputs.iter().map(|s| s.to_string()).collect(),
1076            output: outputs.iter().map(|s| s.to_string()).collect(),
1077            ..Default::default()
1078        }
1079    }
1080
1081    fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
1082        let attr = AttributeProto {
1083            name: name.to_string(),
1084            i: value,
1085            ..Default::default()
1086        };
1087        node.attribute.push(attr);
1088    }
1089
1090    #[test]
1091    fn test_utility_handler_supports() {
1092        let handler = UtilityHandler;
1093        assert!(handler.supports("Shape"));
1094        assert!(handler.supports("Gather"));
1095        assert!(handler.supports("Slice"));
1096        assert!(!handler.supports("Add"));
1097    }
1098
1099    #[test]
1100    fn test_convert_shape() {
1101        let handler = UtilityHandler;
1102        let node = create_test_node("Shape", vec!["x"], vec!["shape"]);
1103        let initializers = std::collections::HashMap::new();
1104        let value_shapes = std::collections::HashMap::new();
1105        let const_values = std::collections::HashMap::new();
1106        let value_ids = std::collections::HashMap::new();
1107        let value_types = std::collections::HashMap::new();
1108        let context = ConversionContext {
1109            initializers: &initializers,
1110            value_shapes: &value_shapes,
1111            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1112            const_values: &const_values,
1113            value_ids: &value_ids,
1114            value_types: &value_types,
1115        };
1116
1117        let result = handler.convert(&node, &context).unwrap();
1118        assert_eq!(result.nodes.len(), 1);
1119        assert_eq!(result.nodes[0].op, "shape");
1120        assert_eq!(result.nodes[0].inputs, vec!["x"]);
1121    }
1122
1123    #[test]
1124    fn test_convert_gather() {
1125        let handler = UtilityHandler;
1126        let mut node = create_test_node("Gather", vec!["data", "indices"], vec!["output"]);
1127        add_int_attribute(&mut node, "axis", -1);
1128        let initializers = std::collections::HashMap::new();
1129        let mut value_shapes = std::collections::HashMap::new();
1130        value_shapes.insert("data".to_string(), vec![2, 3, 4]);
1131        value_shapes.insert("indices".to_string(), vec![2]);
1132        let const_values = std::collections::HashMap::new();
1133        let value_ids = std::collections::HashMap::new();
1134        let value_types = std::collections::HashMap::new();
1135        let context = ConversionContext {
1136            initializers: &initializers,
1137            value_shapes: &value_shapes,
1138            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1139            const_values: &const_values,
1140            value_ids: &value_ids,
1141            value_types: &value_types,
1142        };
1143
1144        let result = handler.convert(&node, &context).unwrap();
1145        assert_eq!(result.nodes.len(), 1);
1146        assert_eq!(result.nodes[0].op, "gather");
1147        assert_eq!(result.nodes[0].inputs.len(), 2);
1148        assert!(result.nodes[0].options.contains_key("axis"));
1149        assert_eq!(
1150            result.nodes[0].options.get("axis"),
1151            Some(&serde_json::json!(2))
1152        );
1153    }
1154
1155    #[test]
1156    fn test_convert_slice() {
1157        let handler = UtilityHandler;
1158        let node = create_test_node(
1159            "Slice",
1160            vec!["x", "starts", "ends", "axes", "steps"],
1161            vec!["output"],
1162        );
1163        let initializers = std::collections::HashMap::new();
1164        let mut value_shapes = std::collections::HashMap::new();
1165        value_shapes.insert("x".to_string(), vec![1, 128]);
1166        let mut const_values = std::collections::HashMap::new();
1167        const_values.insert("starts".to_string(), vec![0]);
1168        const_values.insert("ends".to_string(), vec![128]);
1169        const_values.insert("axes".to_string(), vec![1]);
1170        const_values.insert("steps".to_string(), vec![1]);
1171        let value_ids = std::collections::HashMap::new();
1172        let value_types = std::collections::HashMap::new();
1173        let context = ConversionContext {
1174            initializers: &initializers,
1175            value_shapes: &value_shapes,
1176            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1177            const_values: &const_values,
1178            value_ids: &value_ids,
1179            value_types: &value_types,
1180        };
1181
1182        let result = handler.convert(&node, &context).unwrap();
1183        assert_eq!(result.nodes.len(), 1);
1184        assert_eq!(result.nodes[0].op, "slice");
1185        assert_eq!(result.nodes[0].inputs, vec!["x"]);
1186        assert!(result.nodes[0].options.contains_key("starts"));
1187        assert_eq!(
1188            result.nodes[0].options.get("starts"),
1189            Some(&serde_json::json!([0, 0]))
1190        );
1191        assert_eq!(
1192            result.nodes[0].options.get("sizes"),
1193            Some(&serde_json::json!([1, 128]))
1194        );
1195        assert_eq!(
1196            result.nodes[0].options.get("strides"),
1197            Some(&serde_json::json!([1, 1]))
1198        );
1199        assert!(!result.nodes[0].options.contains_key("ends"));
1200        assert!(!result.nodes[0].options.contains_key("axes"));
1201        assert!(!result.nodes[0].options.contains_key("steps"));
1202    }
1203
1204    #[test]
1205    fn test_convert_constant_of_shape_prefers_dynamic_output_dims() {
1206        let handler = UtilityHandler;
1207        let mut node = create_test_node("ConstantOfShape", vec!["shape"], vec!["output"]);
1208        node.attribute.push(AttributeProto {
1209            name: "value".to_string(),
1210            t: Some(TensorProto {
1211                data_type: TensorProto_DataType::Float as i32,
1212                dims: vec![],
1213                raw_data: 0f32.to_le_bytes().to_vec(),
1214                ..Default::default()
1215            }),
1216            ..Default::default()
1217        });
1218
1219        let initializers = std::collections::HashMap::new();
1220        let mut value_shapes = std::collections::HashMap::new();
1221        value_shapes.insert("output".to_string(), vec![4096, 4096]);
1222        let mut value_shape_dims = std::collections::HashMap::new();
1223        value_shape_dims.insert(
1224            "output".to_string(),
1225            vec![
1226                crate::ast::Dimension::Dynamic(crate::ast::DynamicDimension {
1227                    name: "sequence_length".to_string(),
1228                    max_size: 4096,
1229                }),
1230                crate::ast::Dimension::Dynamic(crate::ast::DynamicDimension {
1231                    name: "past_sequence_length + 1".to_string(),
1232                    max_size: 4096,
1233                }),
1234            ],
1235        );
1236        let mut const_values = std::collections::HashMap::new();
1237        const_values.insert("shape".to_string(), vec![4096, 4096]);
1238        let value_ids = std::collections::HashMap::new();
1239        let value_types = std::collections::HashMap::new();
1240        let context = ConversionContext {
1241            initializers: &initializers,
1242            value_shapes: &value_shapes,
1243            value_shape_dims: &value_shape_dims,
1244            const_values: &const_values,
1245            value_ids: &value_ids,
1246            value_types: &value_types,
1247        };
1248
1249        let result = handler.convert(&node, &context).unwrap();
1250        assert_eq!(result.nodes.len(), 1);
1251        assert_eq!(result.nodes[0].op, "expand");
1252        assert_eq!(result.nodes[0].inputs.len(), 1);
1253        assert_eq!(result.consts.len(), 1);
1254        assert_eq!(result.consts[0].1.shape, vec![1]);
1255        assert_eq!(
1256            result.nodes[0].options.get("newShape"),
1257            Some(&json!([
1258                {"name": "sequence_length", "maxSize": 4096},
1259                {"name": "past_sequence_length + 1", "maxSize": 4096}
1260            ]))
1261        );
1262        assert_eq!(result.output_types.get("output"), Some(&DataType::Float32));
1263    }
1264
1265    #[test]
1266    fn test_convert_trilu_defaults() {
1267        let handler = UtilityHandler;
1268        let node = create_test_node("Trilu", vec!["x"], vec!["y"]);
1269        let initializers = std::collections::HashMap::new();
1270        let value_shapes = std::collections::HashMap::new();
1271        let const_values = std::collections::HashMap::new();
1272        let value_ids = std::collections::HashMap::new();
1273        let mut value_types = std::collections::HashMap::new();
1274        value_types.insert("x".to_string(), DataType::Float32);
1275        let context = ConversionContext {
1276            initializers: &initializers,
1277            value_shapes: &value_shapes,
1278            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1279            const_values: &const_values,
1280            value_ids: &value_ids,
1281            value_types: &value_types,
1282        };
1283
1284        let result = handler.convert(&node, &context).unwrap();
1285        assert_eq!(result.nodes.len(), 1);
1286        assert_eq!(result.nodes[0].op, "triangular");
1287        assert_eq!(result.nodes[0].inputs, vec!["x"]);
1288        assert_eq!(result.nodes[0].options.get("upper"), Some(&json!(true)));
1289        assert_eq!(result.nodes[0].options.get("k"), Some(&json!(0)));
1290        assert_eq!(result.output_mappings.get("y"), Some(&"y".to_string()));
1291        assert_eq!(result.output_types.get("y"), Some(&DataType::Float32));
1292    }
1293
1294    #[test]
1295    fn test_convert_trilu_with_k_and_lower() {
1296        let handler = UtilityHandler;
1297        let mut node = create_test_node("Trilu", vec!["x", "k"], vec!["y"]);
1298        add_int_attribute(&mut node, "upper", 0);
1299        let initializers = std::collections::HashMap::new();
1300        let value_shapes = std::collections::HashMap::new();
1301        let mut const_values = std::collections::HashMap::new();
1302        const_values.insert("k".to_string(), vec![2]);
1303        let value_ids = std::collections::HashMap::new();
1304        let mut value_types = std::collections::HashMap::new();
1305        value_types.insert("x".to_string(), DataType::Float16);
1306        let context = ConversionContext {
1307            initializers: &initializers,
1308            value_shapes: &value_shapes,
1309            value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1310            const_values: &const_values,
1311            value_ids: &value_ids,
1312            value_types: &value_types,
1313        };
1314
1315        let result = handler.convert(&node, &context).unwrap();
1316        assert_eq!(result.nodes.len(), 1);
1317        assert_eq!(result.nodes[0].op, "triangular");
1318        assert_eq!(result.nodes[0].inputs, vec!["x"]);
1319        assert_eq!(result.nodes[0].options.get("upper"), Some(&json!(false)));
1320        assert_eq!(result.nodes[0].options.get("k"), Some(&json!(2)));
1321        assert_eq!(result.output_mappings.get("y"), Some(&"y".to_string()));
1322        assert_eq!(result.output_types.get("y"), Some(&DataType::Float16));
1323    }
1324}