1use crate::ast::Node;
4use crate::ast::{ConstDecl, ConstInit, DataType};
5use crate::onnx::convert::{sanitize_identifier, OnnxError};
6use crate::onnx::ops::{
7 normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler,
8};
9use crate::protos::onnx::NodeProto;
10use serde_json::{json, Map};
11
12pub struct UtilityHandler;
13
14impl OpHandler for UtilityHandler {
15 fn supports(&self, op_type: &str) -> bool {
16 matches!(
17 op_type,
18 "Shape" | "Gather" | "Slice" | "ConstantOfShape" | "Range" | "Trilu"
19 )
20 }
21
22 fn convert(
23 &self,
24 node: &NodeProto,
25 context: &ConversionContext,
26 ) -> Result<ConversionResult, OnnxError> {
27 let op_type = node.op_type.as_str();
28 let node_name = if !node.name.is_empty() {
29 node.name.as_str().to_string()
30 } else {
31 "unnamed".to_string()
32 };
33
34 match op_type {
35 "Shape" => self.convert_shape(node, &node_name, context),
36 "Gather" => self.convert_gather(node, &node_name, context),
37 "Slice" => self.convert_slice(node, &node_name, context),
38 "ConstantOfShape" => self.convert_constant_of_shape(node, &node_name, context),
39 "Range" => self.convert_range(node, &node_name, context),
40 "Trilu" => self.convert_trilu(node, &node_name, context),
41 _ => Err(OnnxError::UnsupportedOp {
42 op: op_type.to_string(),
43 node: node_name,
44 }),
45 }
46 }
47}
48
49impl UtilityHandler {
50 fn convert_shape(
53 &self,
54 node: &NodeProto,
55 node_name: &str,
56 context: &ConversionContext,
57 ) -> Result<ConversionResult, OnnxError> {
58 let inputs = node.input.as_slice();
59 if inputs.len() != 1 {
60 return Err(OnnxError::InvalidShape(format!(
61 "Shape expects 1 input, got {}",
62 inputs.len()
63 )));
64 }
65
66 let output_name = if node.output.as_slice().is_empty() {
67 format!("{}_output", node_name)
68 } else {
69 sanitize_identifier(&node.output.as_slice()[0].to_string())
70 };
71
72 let input0 = context.resolve_input(&inputs[0]);
73
74 let options = Map::new();
75
76 let mut result = ConversionResult::new(vec![Node {
79 id: output_name.clone(),
80 op: "shape".to_string(),
81 inputs: vec![input0],
82 options,
83 outputs: None,
84 }]);
85
86 if let Some(output) = node.output.as_slice().first() {
87 result
88 .output_mappings
89 .insert(output.to_string(), output_name.clone());
90 }
91
92 Ok(result)
93 }
94
95 fn read_scalar_i64(&self, name: &str, context: &ConversionContext) -> Option<i64> {
96 if let Some(vals) = context.const_values.get(name) {
97 return vals.first().copied();
98 }
99 if let Some(t) = context.initializers.get(name) {
100 let raw = t.raw_data.as_slice();
101 if !raw.is_empty() {
102 if t.data_type == crate::protos::onnx::TensorProto_DataType::Int32 as i32 {
103 return Some(i32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]]) as i64);
104 }
105 if raw.len() >= 8 {
106 return Some(i64::from_le_bytes([
107 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
108 ]));
109 }
110 } else if !t.int64_data.as_slice().is_empty() {
111 return t.int64_data.as_slice().first().copied();
112 } else if !t.int32_data.as_slice().is_empty() {
113 return t.int32_data.as_slice().first().map(|v| *v as i64);
114 }
115 }
116 None
117 }
118
119 fn convert_range(
120 &self,
121 node: &NodeProto,
122 node_name: &str,
123 context: &ConversionContext,
124 ) -> Result<ConversionResult, OnnxError> {
125 let inputs = node.input.as_slice();
126 if inputs.len() != 3 {
127 return Err(OnnxError::InvalidShape(format!(
128 "Range expects 3 inputs (start, limit, delta), got {}",
129 inputs.len()
130 )));
131 }
132
133 let output_name = if node.output.as_slice().is_empty() {
134 format!("{}_output", node_name)
135 } else {
136 sanitize_identifier(&node.output.as_slice()[0].to_string())
137 };
138
139 let start = self.read_scalar_i64(&inputs[0], context);
140 let limit = self.read_scalar_i64(&inputs[1], context);
141 let delta = self.read_scalar_i64(&inputs[2], context);
142
143 let start_dim = crate::onnx::convert::dynamic_scalar_dimension_for_value(
144 &inputs[0],
145 context.value_shape_dims,
146 );
147 if let (Some(start), Some(delta), Some(limit_dim)) = (
148 start,
149 delta,
150 crate::onnx::convert::dynamic_scalar_dimension_for_value(
151 &inputs[1],
152 context.value_shape_dims,
153 ),
154 ) {
155 let range_dim = crate::onnx::convert::dynamic_range_length_dimension(
156 start,
157 delta,
158 start_dim.as_ref(),
159 &limit_dim,
160 )
161 .ok_or_else(|| {
162 OnnxError::InvalidShape(format!(
163 "Range {} requires dynamic range length to be representable as <dim> +/- const with delta=1",
164 node_name,
165 ))
166 })?;
167
168 let max_len = usize::try_from(range_dim.max_size).map_err(|_| {
169 OnnxError::InvalidShape(format!(
170 "Range {} max size {} does not fit in usize",
171 node_name, range_dim.max_size
172 ))
173 })?;
174
175 let use_runtime_start = start_dim.is_some();
176 let mut values = Vec::with_capacity(max_len.max(1));
177 let mut current = if use_runtime_start { 0 } else { start };
178 for _ in 0..max_len {
179 values.push(current);
180 current += delta;
181 }
182 if values.is_empty() {
183 values.push(if use_runtime_start { 0 } else { start });
184 }
185
186 let bytes: Vec<u8> = values
187 .iter()
188 .flat_map(|v| v.to_le_bytes().to_vec())
189 .collect();
190
191 let range_const_name = format!("{}_range_const", output_name);
192 let range_const = ConstDecl {
193 data_type: DataType::Int64,
194 shape: vec![values.len() as u32],
195 init: ConstInit::InlineBytes { bytes },
196 };
197
198 let mut options = Map::new();
199 options.insert("starts".to_string(), json!([0]));
200 options.insert(
201 "sizes".to_string(),
202 json!([{
203 "name": range_dim.name,
204 "maxSize": range_dim.max_size
205 }]),
206 );
207 options.insert("strides".to_string(), json!([1]));
208
209 let sliced_name = if use_runtime_start {
210 format!("{}_slice", output_name)
211 } else {
212 output_name.clone()
213 };
214 let mut nodes = vec![Node {
215 id: sliced_name.clone(),
216 op: "slice".to_string(),
217 inputs: vec![range_const_name.clone()],
218 options,
219 outputs: None,
220 }];
221 if use_runtime_start {
222 nodes.push(Node {
223 id: output_name.clone(),
224 op: "add".to_string(),
225 inputs: vec![sliced_name, context.resolve_input(&inputs[0])],
226 options: Map::new(),
227 outputs: None,
228 });
229 }
230
231 let mut result = ConversionResult::new(nodes);
232 result.consts.push((range_const_name, range_const));
233 if let Some(out) = node.output.as_slice().first() {
234 result
235 .output_mappings
236 .insert(out.to_string(), output_name.clone());
237 result.output_types.insert(out.to_string(), DataType::Int64);
238 }
239 return Ok(result);
240 }
241
242 let start = start.ok_or_else(|| {
243 OnnxError::InvalidShape(format!(
244 "Range {} requires a constant scalar start input",
245 node_name
246 ))
247 })?;
248 let limit = limit.ok_or_else(|| {
249 OnnxError::InvalidShape(format!(
250 "Range {} requires a constant scalar or supported dynamic limit input",
251 node_name
252 ))
253 })?;
254 let delta = delta.ok_or_else(|| {
255 OnnxError::InvalidShape(format!(
256 "Range {} requires a constant scalar delta input",
257 node_name
258 ))
259 })?;
260
261 if delta == 0 {
262 return Err(OnnxError::InvalidShape(
263 "Range delta cannot be zero".to_string(),
264 ));
265 }
266
267 let mut values = Vec::new();
268 let mut v = start;
269 if delta > 0 {
270 while v < limit {
271 values.push(v);
272 v += delta;
273 }
274 } else {
275 while v > limit {
276 values.push(v);
277 v += delta;
278 }
279 }
280
281 if values.is_empty() {
282 values.push(0);
283 }
284
285 let bytes: Vec<u8> = values
286 .iter()
287 .flat_map(|v| v.to_le_bytes().to_vec())
288 .collect();
289
290 let const_decl = ConstDecl {
291 data_type: DataType::Int64,
292 shape: vec![values.len() as u32],
293 init: ConstInit::InlineBytes { bytes },
294 };
295
296 let mut result = ConversionResult::new(vec![]);
297 result.consts.push((output_name.clone(), const_decl));
298 if let Some(out) = node.output.as_slice().first() {
299 result
300 .output_mappings
301 .insert(out.to_string(), output_name.clone());
302 result.output_types.insert(out.to_string(), DataType::Int64);
303 }
304
305 Ok(result)
306 }
307
308 fn convert_trilu(
309 &self,
310 node: &NodeProto,
311 node_name: &str,
312 context: &ConversionContext,
313 ) -> Result<ConversionResult, OnnxError> {
314 let inputs = node.input.as_slice();
315 if inputs.is_empty() {
316 return Err(OnnxError::InvalidShape(
317 "Trilu expects at least 1 input (data)".to_string(),
318 ));
319 }
320
321 if inputs.len() > 2 {
322 return Err(OnnxError::InvalidShape(format!(
323 "Trilu expects at most 2 inputs (data, k), got {}",
324 inputs.len()
325 )));
326 }
327
328 let mut upper = true;
329 for attr in node.attribute.as_slice() {
330 if attr.name.as_str() == "upper" {
331 upper = attr.i != 0;
332 }
333 }
334
335 let mut k: i64 = 0;
336 if inputs.len() == 2 {
337 let k_input = inputs[1].as_str();
338 if let Some(offset) = self.read_scalar_i64(k_input, context) {
339 k = offset;
340 } else {
341 return Err(OnnxError::InvalidShape(
342 "Trilu k input must be a constant scalar for WebNN".to_string(),
343 ));
344 }
345 }
346
347 let output_name = if node.output.as_slice().is_empty() {
348 format!("{}_output", node_name)
349 } else {
350 sanitize_identifier(&node.output.as_slice()[0].to_string())
351 };
352
353 let input0 = context.resolve_input(&inputs[0]);
354
355 let mut options = Map::new();
356 options.insert("upper".to_string(), json!(upper));
357 options.insert("k".to_string(), json!(k));
358
359 let mut result = ConversionResult::new(vec![Node {
360 id: output_name.clone(),
361 op: "triangular".to_string(),
362 inputs: vec![input0],
363 options,
364 outputs: None,
365 }]);
366
367 if let Some(output) = node.output.as_slice().first() {
368 result
369 .output_mappings
370 .insert(output.to_string(), output_name.clone());
371 if let Some(dtype) = context.value_types.get(&inputs[0]) {
372 result
373 .output_types
374 .insert(output.to_string(), dtype.clone());
375 }
376 }
377
378 Ok(result)
379 }
380
381 fn convert_constant_of_shape(
383 &self,
384 node: &NodeProto,
385 node_name: &str,
386 context: &ConversionContext,
387 ) -> Result<ConversionResult, OnnxError> {
388 let output_name = if node.output.as_slice().is_empty() {
389 format!("{}_output", node_name)
390 } else {
391 sanitize_identifier(&node.output.as_slice()[0].to_string())
392 };
393
394 let output_dim_shape = node
395 .output
396 .as_slice()
397 .first()
398 .and_then(|out| {
399 let out_s = out.to_string();
400 context
401 .value_shape_dims
402 .get(&out_s)
403 .or_else(|| context.value_shape_dims.get(&sanitize_identifier(&out_s)))
404 .or_else(|| context.value_shape_dims.get(out_s.trim_start_matches('/')))
405 })
406 .cloned();
407
408 let mut shape: Option<Vec<i64>> = None;
410 if let Some(out) = node.output.as_slice().first() {
411 if let Some(s) = context.value_shapes.get(out) {
412 shape = Some(s.clone());
413 } else {
414 let sanitized = sanitize_identifier(out);
415 if let Some(s) = context.value_shapes.get(&sanitized) {
416 shape = Some(s.clone());
417 }
418 }
419 }
420 if shape.is_none() {
421 if let Some(shape_input) = node.input.as_slice().first() {
422 if let Some(vals) = context.const_values.get(shape_input) {
423 shape = Some(vals.clone());
424 } else if let Some(len_shape) = context.value_shapes.get(shape_input) {
425 if len_shape.len() == 1 && len_shape[0] > 0 {
427 shape = Some(vec![1; len_shape[0] as usize]);
428 }
429 }
430 }
431 }
432
433 let mut fill_value_i64: i64 = 0;
435 let mut dtype = DataType::Int64;
436 for attr in node.attribute.as_slice() {
437 if attr.name.as_str() == "value" {
438 if let Some(t) = attr.t.as_ref() {
439 match t.data_type {
440 x if x == crate::protos::onnx::TensorProto_DataType::Float as i32 => {
442 dtype = DataType::Float32;
443 if !t.float_data.as_slice().is_empty() {
444 fill_value_i64 = t.float_data.as_slice()[0].to_bits() as i64;
445 } else if !t.raw_data.as_slice().is_empty()
446 && t.raw_data.as_slice().len() >= 4
447 {
448 let raw = &t.raw_data.as_slice()[..4];
449 let bits = u32::from_le_bytes([raw[0], raw[1], raw[2], raw[3]]);
450 fill_value_i64 = bits as i64;
451 } else {
452 fill_value_i64 = 0f32.to_bits() as i64;
453 }
454 }
455 x if x == crate::protos::onnx::TensorProto_DataType::Int64 as i32 => {
457 dtype = DataType::Int64;
458 if !t.int64_data.as_slice().is_empty() {
459 fill_value_i64 = t.int64_data.as_slice()[0];
460 } else if !t.raw_data.as_slice().is_empty()
461 && t.raw_data.as_slice().len() >= 8
462 {
463 let raw = &t.raw_data.as_slice()[..8];
464 fill_value_i64 = i64::from_le_bytes([
465 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
466 ]);
467 }
468 }
469 _ => {}
470 }
471 }
472 }
473 }
474
475 if let Some(dims) = output_dim_shape.as_ref().filter(|dims| {
476 dims.iter()
477 .any(|d| matches!(d, crate::ast::Dimension::Dynamic(_)))
478 }) {
479 let scalar_name = format!("{}_fill", output_name);
480 let scalar_bytes = match dtype {
481 DataType::Float32 => {
482 let f = f32::from_bits(fill_value_i64 as u32);
483 f.to_le_bytes().to_vec()
484 }
485 _ => fill_value_i64.to_le_bytes().to_vec(),
486 };
487 let scalar_decl = ConstDecl {
488 data_type: dtype.clone(),
489 shape: vec![1],
490 init: ConstInit::InlineBytes {
491 bytes: scalar_bytes,
492 },
493 };
494
495 let new_shape: Vec<serde_json::Value> = dims
496 .iter()
497 .map(|d| match d {
498 crate::ast::Dimension::Static(v) => serde_json::json!(v),
499 crate::ast::Dimension::Dynamic(dd) => serde_json::json!({
500 "name": dd.name,
501 "maxSize": dd.max_size
502 }),
503 })
504 .collect();
505
506 let mut options = Map::new();
507 options.insert("newShape".to_string(), serde_json::json!(new_shape));
508
509 let mut result = ConversionResult::new(vec![Node {
510 id: output_name.clone(),
511 op: "expand".to_string(),
512 inputs: vec![scalar_name.clone()],
513 options,
514 outputs: None,
515 }]);
516 result.consts.push((scalar_name, scalar_decl));
517 if let Some(out) = node.output.as_slice().first() {
518 result
519 .output_mappings
520 .insert(out.to_string(), output_name.clone());
521 result.output_types.insert(out.to_string(), dtype);
522 }
523 return Ok(result);
524 }
525
526 let shape = shape.unwrap_or_else(|| vec![1]);
527
528 let mut numel: usize = 1;
529 for d in &shape {
530 if *d <= 0 {
531 return Err(OnnxError::InvalidShape(format!(
532 "ConstantOfShape '{}' has non-positive dimension {:?}",
533 node_name, shape
534 )));
535 }
536 numel = numel.saturating_mul(*d as usize);
537 }
538
539 let bytes = match dtype {
540 DataType::Float32 => {
541 let f = f32::from_bits(fill_value_i64 as u32);
542 let val = f.to_le_bytes();
543 val.repeat(numel)
544 }
545 _ => {
546 let val = fill_value_i64.to_le_bytes();
547 val.repeat(numel)
548 }
549 };
550
551 let const_decl = ConstDecl {
552 data_type: dtype.clone(),
553 shape: shape.iter().map(|d| *d as u32).collect(),
554 init: ConstInit::InlineBytes { bytes },
555 };
556
557 let mut result = ConversionResult::new(vec![]);
558 result.consts.push((output_name.clone(), const_decl));
559 if let Some(out) = node.output.as_slice().first() {
560 result
561 .output_mappings
562 .insert(out.to_string(), output_name.clone());
563 result.output_types.insert(out.to_string(), dtype);
564 }
565
566 Ok(result)
567 }
568
569 fn convert_gather(
572 &self,
573 node: &NodeProto,
574 node_name: &str,
575 context: &ConversionContext,
576 ) -> Result<ConversionResult, OnnxError> {
577 let inputs = node.input.as_slice();
578 if inputs.len() < 2 {
579 return Err(OnnxError::InvalidShape(format!(
580 "Gather expects 2 inputs (data, indices), got {}",
581 inputs.len()
582 )));
583 }
584
585 let mut axis = 0i64;
587 for attr in node.attribute.as_slice() {
588 if attr.name.as_str() == "axis" && attr.i != 0 {
589 axis = attr.i;
590 }
591 }
592
593 let output_name = if node.output.as_slice().is_empty() {
594 format!("{}_output", node_name)
595 } else {
596 sanitize_identifier(&node.output.as_slice()[0].to_string())
597 };
598
599 let input0 = context.resolve_input(&inputs[0]);
600 let input1 = context.resolve_input(&inputs[1]);
601
602 let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) {
603 normalize_axis_best_effort(axis, rank)
604 } else {
605 axis
606 };
607
608 let mut options = Map::new();
609 options.insert("axis".to_string(), serde_json::json!(axis));
610
611 if let (Some(data_shape), Some(indices_shape)) = (
613 context.value_shapes.get(&inputs[0]),
614 context.value_shapes.get(&inputs[1]),
615 ) {
616 let resolved_axis = axis;
617 if resolved_axis >= 0 && (resolved_axis as usize) < data_shape.len() {
618 let axis_idx = resolved_axis as usize;
619 let mut out_shape = Vec::new();
620 out_shape.extend_from_slice(&data_shape[..axis_idx]);
621 out_shape.extend(indices_shape.iter().cloned());
622 if axis_idx < data_shape.len() {
623 out_shape.extend_from_slice(&data_shape[axis_idx + 1..]);
624 }
625 options.insert("shape".to_string(), serde_json::json!(out_shape));
626 }
627 }
628
629 let mut result = ConversionResult::new(vec![Node {
630 id: output_name.clone(),
631 op: "gather".to_string(),
632 inputs: vec![input0, input1],
633 options,
634 outputs: None,
635 }]);
636
637 if let Some(output) = node.output.as_slice().first() {
638 result
639 .output_mappings
640 .insert(output.to_string(), output_name.clone());
641 if let Some(dtype) = context.value_types.get(&inputs[0]) {
642 result
643 .output_types
644 .insert(output.to_string(), dtype.clone());
645 }
646 }
647
648 Ok(result)
649 }
650
651 fn convert_slice(
654 &self,
655 node: &NodeProto,
656 node_name: &str,
657 context: &ConversionContext,
658 ) -> Result<ConversionResult, OnnxError> {
659 let inputs = node.input.as_slice();
660 if inputs.is_empty() {
661 return Err(OnnxError::InvalidShape(
662 "Slice expects at least 1 input".to_string(),
663 ));
664 }
665
666 let output_name = if node.output.as_slice().is_empty() {
667 format!("{}_output", node_name)
668 } else {
669 sanitize_identifier(&node.output.as_slice()[0].to_string())
670 };
671
672 let input0 = context.resolve_input(&inputs[0]);
673
674 let read_ints = |name: &str, context: &ConversionContext| -> Option<Vec<i64>> {
675 if let Some(vals) = context.const_values.get(name) {
676 return Some(vals.clone());
677 }
678 if let Some(t) = context.initializers.get(name) {
679 let raw = t.raw_data.as_slice();
680 if !raw.is_empty() {
681 if t.data_type == crate::protos::onnx::TensorProto_DataType::Int32 as i32 {
682 return Some(
683 raw.chunks_exact(4)
684 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
685 .collect(),
686 );
687 }
688 return Some(
689 raw.chunks_exact(8)
690 .map(|c| {
691 i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
692 })
693 .collect(),
694 );
695 } else if !t.int64_data.as_slice().is_empty() {
696 return Some(t.int64_data.as_slice().to_vec());
697 } else if !t.int32_data.as_slice().is_empty() {
698 return Some(t.int32_data.as_slice().iter().map(|&v| v as i64).collect());
699 }
700 }
701 None
702 };
703
704 let mut options = Map::new();
705
706 if inputs.len() >= 3 {
709 let starts_name = inputs[1].as_str();
710 let ends_name = inputs[2].as_str();
711 let mut starts = read_ints(starts_name, context);
712 let mut ends = read_ints(ends_name, context);
713
714 if starts.is_none() || ends.is_none() {
715 if let Some(s) = context.const_values.get(starts_name) {
718 starts = Some(s.clone());
719 }
720 if let Some(e) = context.const_values.get(ends_name) {
721 ends = Some(e.clone());
722 }
723
724 let fallback_len = if let Some(axes_name) = inputs.get(3).map(|s| s.as_str()) {
725 read_ints(axes_name, context)
726 .map(|v| v.len())
727 .unwrap_or_else(|| {
728 starts
729 .as_ref()
730 .map(|v| v.len())
731 .or_else(|| {
732 context
733 .value_shapes
734 .get(inputs[0].as_str())
735 .map(|s| s.len())
736 })
737 .unwrap_or(1)
738 })
739 } else {
740 starts
741 .as_ref()
742 .map(|v| v.len())
743 .or_else(|| {
744 context
745 .value_shapes
746 .get(inputs[0].as_str())
747 .map(|s| s.len())
748 })
749 .unwrap_or(1)
750 };
751
752 starts.get_or_insert(vec![0; fallback_len]);
753 ends.get_or_insert(vec![i64::MAX; fallback_len]);
755
756 crate::debug_println!(
757 "[slice] using fallback starts/ends for {}, starts={:?} ends={:?}",
758 node_name,
759 starts,
760 ends
761 );
762 }
763
764 let starts = starts.ok_or_else(|| {
765 OnnxError::InvalidShape("Slice starts must be constant for WebNN".to_string())
766 })?;
767 let ends = ends.ok_or_else(|| {
768 OnnxError::InvalidShape("Slice ends must be constant for WebNN".to_string())
769 })?;
770
771 let mut axes_opt: Option<Vec<i64>> = None;
774 if inputs.len() >= 4 {
775 let axes_name = inputs[3].as_str();
776 if let Some(axes) = read_ints(axes_name, context) {
777 axes_opt = Some(axes);
778 }
779 }
780
781 let desired_len = axes_opt
782 .as_ref()
783 .map(|a| a.len())
784 .unwrap_or_else(|| starts.len().max(ends.len()));
785 let mut starts_norm = starts;
786 let mut ends_norm = ends;
787 if starts_norm.len() > desired_len {
788 starts_norm.truncate(desired_len);
789 } else {
790 starts_norm.resize(desired_len, 0);
791 }
792 if ends_norm.len() > desired_len {
793 ends_norm.truncate(desired_len);
794 } else {
795 let fill = context
797 .value_shapes
798 .get(inputs[0].as_str())
799 .and_then(|s| s.first())
800 .copied()
801 .unwrap_or(i64::MAX);
802 ends_norm.resize(desired_len, fill);
803 }
804
805 if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) {
806 let rank = input_shape.len();
807 let mut axes = if let Some(a) = axes_opt {
808 if a.is_empty() {
809 (0..desired_len as i64).collect::<Vec<_>>()
810 } else {
811 a
812 }
813 } else {
814 (0..desired_len as i64).collect::<Vec<_>>()
815 };
816 if axes.len() != desired_len {
817 axes.resize(desired_len, 0);
818 }
819 let axes: Vec<i64> = axes
820 .iter()
821 .map(|&a| normalize_axis_best_effort(a, rank))
822 .collect();
823
824 let mut steps = if inputs.len() >= 5 {
825 let steps_name = inputs[4].as_str();
826 read_ints(steps_name, context).unwrap_or_default()
827 } else {
828 Vec::new()
829 };
830 if steps.len() > desired_len {
831 steps.truncate(desired_len);
832 } else {
833 steps.resize(desired_len, 1);
834 }
835
836 let mut dense_starts = vec![0i64; rank];
837 let mut dense_sizes: Vec<i64> = input_shape.clone();
838 let mut dense_strides = vec![1i64; rank];
839
840 let ends_dims = context.value_shape_dims.get(ends_name).or_else(|| {
842 context
843 .value_shape_dims
844 .get(&sanitize_identifier(ends_name))
845 });
846
847 let mut dynamic_size_info: Vec<Option<crate::ast::DynamicDimension>> =
849 vec![None; rank];
850
851 for i in 0..desired_len {
852 let axis = axes[i] as usize;
853 let dim = input_shape[axis];
854 let step = steps[i];
855 if step <= 0 {
856 return Err(OnnxError::InvalidShape(
857 "Slice currently requires positive step values".to_string(),
858 ));
859 }
860
861 let mut start = starts_norm[i];
862 let mut end = ends_norm[i];
863 if start < 0 {
864 start += dim;
865 }
866 if end == i64::MAX {
867 end = dim;
868 } else if end < 0 {
869 end += dim;
870 }
871 start = start.clamp(0, dim);
872 end = end.clamp(0, dim);
873
874 let size = if end <= start {
875 0
876 } else {
877 (end - start + step - 1) / step
878 };
879
880 if let Some(dims) = ends_dims {
882 if let Some(crate::ast::Dimension::Dynamic(dd)) = dims.get(i) {
883 dynamic_size_info[axis] = Some(crate::ast::DynamicDimension {
884 name: dd.name.clone(),
885 max_size: size as u32,
886 });
887 }
888 }
889
890 dense_starts[axis] = start;
891 dense_sizes[axis] = size;
892 dense_strides[axis] = step;
893 }
894
895 options.insert("starts".to_string(), serde_json::json!(dense_starts));
896
897 let has_dynamic = dynamic_size_info.iter().any(|d| d.is_some());
899 if has_dynamic {
900 let sizes_json: Vec<serde_json::Value> = dense_sizes
901 .iter()
902 .zip(dynamic_size_info.iter())
903 .map(|(&sz, dyn_info)| match dyn_info {
904 Some(dd) => serde_json::json!({
905 "name": dd.name,
906 "maxSize": dd.max_size
907 }),
908 None => serde_json::json!(sz),
909 })
910 .collect();
911 options.insert("sizes".to_string(), serde_json::json!(sizes_json));
912 } else {
913 options.insert("sizes".to_string(), serde_json::json!(dense_sizes));
914 }
915
916 options.insert("strides".to_string(), serde_json::json!(dense_strides));
917 } else {
918 options.insert("starts".to_string(), serde_json::json!(starts_norm));
920 options.insert("ends".to_string(), serde_json::json!(ends_norm));
921 if let Some(axes) = axes_opt {
922 options.insert("axes".to_string(), serde_json::json!(axes));
923 }
924 if inputs.len() >= 5 {
925 let steps_name = inputs[4].as_str();
926 if let Some(steps) = read_ints(steps_name, context) {
927 options.insert("steps".to_string(), serde_json::json!(steps));
928 }
929 }
930 }
931 } else {
932 for attr in node.attribute.as_slice() {
934 match attr.name.as_str() {
935 "starts" => {
936 options
937 .insert("starts".to_string(), serde_json::json!(&attr.ints.to_vec()));
938 }
939 "ends" => {
940 options.insert("ends".to_string(), serde_json::json!(&attr.ints.to_vec()));
941 }
942 "axes" => {
943 options.insert("axes".to_string(), serde_json::json!(&attr.ints.to_vec()));
944 }
945 "steps" => {
946 options.insert("steps".to_string(), serde_json::json!(&attr.ints.to_vec()));
947 }
948 _ => {}
949 }
950 }
951 if !options.contains_key("starts") || !options.contains_key("ends") {
952 return Err(OnnxError::InvalidShape(
953 "Slice requires static starts/ends".to_string(),
954 ));
955 }
956
957 if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) {
958 let rank = input_shape.len();
959 let starts = options
960 .remove("starts")
961 .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
962 .ok_or_else(|| OnnxError::InvalidShape("Slice starts malformed".to_string()))?;
963 let ends = options
964 .remove("ends")
965 .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
966 .ok_or_else(|| OnnxError::InvalidShape("Slice ends malformed".to_string()))?;
967 let axes = options
968 .remove("axes")
969 .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
970 .unwrap_or_else(|| (0..starts.len() as i64).collect::<Vec<_>>());
971 let mut steps = options
972 .remove("steps")
973 .and_then(|v| serde_json::from_value::<Vec<i64>>(v).ok())
974 .unwrap_or_else(|| vec![1; starts.len()]);
975
976 let desired_len = starts.len().max(ends.len()).max(axes.len());
977 let mut starts = starts;
978 let mut ends = ends;
979 let mut axes = axes;
980 if starts.len() < desired_len {
981 starts.resize(desired_len, 0);
982 }
983 if ends.len() < desired_len {
984 ends.resize(desired_len, i64::MAX);
985 }
986 if axes.len() < desired_len {
987 axes.resize(desired_len, 0);
988 }
989 if steps.len() < desired_len {
990 steps.resize(desired_len, 1);
991 }
992
993 let axes: Vec<i64> = axes
994 .iter()
995 .map(|&a| normalize_axis_best_effort(a, rank))
996 .collect();
997 let mut dense_starts = vec![0i64; rank];
998 let mut dense_sizes: Vec<i64> = input_shape.clone();
999 let mut dense_strides = vec![1i64; rank];
1000
1001 for i in 0..desired_len {
1002 let axis = axes[i] as usize;
1003 let dim = input_shape[axis];
1004 let step = steps[i];
1005 if step <= 0 {
1006 return Err(OnnxError::InvalidShape(
1007 "Slice currently requires positive step values".to_string(),
1008 ));
1009 }
1010
1011 let mut start = starts[i];
1012 let mut end = ends[i];
1013 if start < 0 {
1014 start += dim;
1015 }
1016 if end == i64::MAX {
1017 end = dim;
1018 } else if end < 0 {
1019 end += dim;
1020 }
1021 start = start.clamp(0, dim);
1022 end = end.clamp(0, dim);
1023
1024 let size = if end <= start {
1025 0
1026 } else {
1027 (end - start + step - 1) / step
1028 };
1029
1030 dense_starts[axis] = start;
1031 dense_sizes[axis] = size;
1032 dense_strides[axis] = step;
1033 }
1034
1035 options.insert("starts".to_string(), serde_json::json!(dense_starts));
1036 options.insert("sizes".to_string(), serde_json::json!(dense_sizes));
1037 options.insert("strides".to_string(), serde_json::json!(dense_strides));
1038 }
1039 }
1040
1041 let mut result = ConversionResult::new(vec![Node {
1042 id: output_name.clone(),
1043 op: "slice".to_string(),
1044 inputs: vec![input0],
1045 options,
1046 outputs: None,
1047 }]);
1048
1049 if let Some(output) = node.output.as_slice().first() {
1050 result
1051 .output_mappings
1052 .insert(output.to_string(), output_name.clone());
1053 if let Some(dtype) = context.value_types.get(&inputs[0]) {
1054 result
1055 .output_types
1056 .insert(output.to_string(), dtype.clone());
1057 }
1058 }
1059
1060 Ok(result)
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067 use crate::ast::DataType;
1068 use crate::protos::onnx::{AttributeProto, NodeProto, TensorProto, TensorProto_DataType};
1069 use serde_json::json;
1070
1071 fn create_test_node(op_type: &str, inputs: Vec<&str>, outputs: Vec<&str>) -> NodeProto {
1072 NodeProto {
1073 op_type: op_type.to_string(),
1074 name: format!("test_{}", op_type.to_lowercase()),
1075 input: inputs.iter().map(|s| s.to_string()).collect(),
1076 output: outputs.iter().map(|s| s.to_string()).collect(),
1077 ..Default::default()
1078 }
1079 }
1080
1081 fn add_int_attribute(node: &mut NodeProto, name: &str, value: i64) {
1082 let attr = AttributeProto {
1083 name: name.to_string(),
1084 i: value,
1085 ..Default::default()
1086 };
1087 node.attribute.push(attr);
1088 }
1089
1090 #[test]
1091 fn test_utility_handler_supports() {
1092 let handler = UtilityHandler;
1093 assert!(handler.supports("Shape"));
1094 assert!(handler.supports("Gather"));
1095 assert!(handler.supports("Slice"));
1096 assert!(!handler.supports("Add"));
1097 }
1098
1099 #[test]
1100 fn test_convert_shape() {
1101 let handler = UtilityHandler;
1102 let node = create_test_node("Shape", vec!["x"], vec!["shape"]);
1103 let initializers = std::collections::HashMap::new();
1104 let value_shapes = std::collections::HashMap::new();
1105 let const_values = std::collections::HashMap::new();
1106 let value_ids = std::collections::HashMap::new();
1107 let value_types = std::collections::HashMap::new();
1108 let context = ConversionContext {
1109 initializers: &initializers,
1110 value_shapes: &value_shapes,
1111 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1112 const_values: &const_values,
1113 value_ids: &value_ids,
1114 value_types: &value_types,
1115 };
1116
1117 let result = handler.convert(&node, &context).unwrap();
1118 assert_eq!(result.nodes.len(), 1);
1119 assert_eq!(result.nodes[0].op, "shape");
1120 assert_eq!(result.nodes[0].inputs, vec!["x"]);
1121 }
1122
1123 #[test]
1124 fn test_convert_gather() {
1125 let handler = UtilityHandler;
1126 let mut node = create_test_node("Gather", vec!["data", "indices"], vec!["output"]);
1127 add_int_attribute(&mut node, "axis", -1);
1128 let initializers = std::collections::HashMap::new();
1129 let mut value_shapes = std::collections::HashMap::new();
1130 value_shapes.insert("data".to_string(), vec![2, 3, 4]);
1131 value_shapes.insert("indices".to_string(), vec![2]);
1132 let const_values = std::collections::HashMap::new();
1133 let value_ids = std::collections::HashMap::new();
1134 let value_types = std::collections::HashMap::new();
1135 let context = ConversionContext {
1136 initializers: &initializers,
1137 value_shapes: &value_shapes,
1138 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1139 const_values: &const_values,
1140 value_ids: &value_ids,
1141 value_types: &value_types,
1142 };
1143
1144 let result = handler.convert(&node, &context).unwrap();
1145 assert_eq!(result.nodes.len(), 1);
1146 assert_eq!(result.nodes[0].op, "gather");
1147 assert_eq!(result.nodes[0].inputs.len(), 2);
1148 assert!(result.nodes[0].options.contains_key("axis"));
1149 assert_eq!(
1150 result.nodes[0].options.get("axis"),
1151 Some(&serde_json::json!(2))
1152 );
1153 }
1154
1155 #[test]
1156 fn test_convert_slice() {
1157 let handler = UtilityHandler;
1158 let node = create_test_node(
1159 "Slice",
1160 vec!["x", "starts", "ends", "axes", "steps"],
1161 vec!["output"],
1162 );
1163 let initializers = std::collections::HashMap::new();
1164 let mut value_shapes = std::collections::HashMap::new();
1165 value_shapes.insert("x".to_string(), vec![1, 128]);
1166 let mut const_values = std::collections::HashMap::new();
1167 const_values.insert("starts".to_string(), vec![0]);
1168 const_values.insert("ends".to_string(), vec![128]);
1169 const_values.insert("axes".to_string(), vec![1]);
1170 const_values.insert("steps".to_string(), vec![1]);
1171 let value_ids = std::collections::HashMap::new();
1172 let value_types = std::collections::HashMap::new();
1173 let context = ConversionContext {
1174 initializers: &initializers,
1175 value_shapes: &value_shapes,
1176 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1177 const_values: &const_values,
1178 value_ids: &value_ids,
1179 value_types: &value_types,
1180 };
1181
1182 let result = handler.convert(&node, &context).unwrap();
1183 assert_eq!(result.nodes.len(), 1);
1184 assert_eq!(result.nodes[0].op, "slice");
1185 assert_eq!(result.nodes[0].inputs, vec!["x"]);
1186 assert!(result.nodes[0].options.contains_key("starts"));
1187 assert_eq!(
1188 result.nodes[0].options.get("starts"),
1189 Some(&serde_json::json!([0, 0]))
1190 );
1191 assert_eq!(
1192 result.nodes[0].options.get("sizes"),
1193 Some(&serde_json::json!([1, 128]))
1194 );
1195 assert_eq!(
1196 result.nodes[0].options.get("strides"),
1197 Some(&serde_json::json!([1, 1]))
1198 );
1199 assert!(!result.nodes[0].options.contains_key("ends"));
1200 assert!(!result.nodes[0].options.contains_key("axes"));
1201 assert!(!result.nodes[0].options.contains_key("steps"));
1202 }
1203
1204 #[test]
1205 fn test_convert_constant_of_shape_prefers_dynamic_output_dims() {
1206 let handler = UtilityHandler;
1207 let mut node = create_test_node("ConstantOfShape", vec!["shape"], vec!["output"]);
1208 node.attribute.push(AttributeProto {
1209 name: "value".to_string(),
1210 t: Some(TensorProto {
1211 data_type: TensorProto_DataType::Float as i32,
1212 dims: vec![],
1213 raw_data: 0f32.to_le_bytes().to_vec(),
1214 ..Default::default()
1215 }),
1216 ..Default::default()
1217 });
1218
1219 let initializers = std::collections::HashMap::new();
1220 let mut value_shapes = std::collections::HashMap::new();
1221 value_shapes.insert("output".to_string(), vec![4096, 4096]);
1222 let mut value_shape_dims = std::collections::HashMap::new();
1223 value_shape_dims.insert(
1224 "output".to_string(),
1225 vec![
1226 crate::ast::Dimension::Dynamic(crate::ast::DynamicDimension {
1227 name: "sequence_length".to_string(),
1228 max_size: 4096,
1229 }),
1230 crate::ast::Dimension::Dynamic(crate::ast::DynamicDimension {
1231 name: "past_sequence_length + 1".to_string(),
1232 max_size: 4096,
1233 }),
1234 ],
1235 );
1236 let mut const_values = std::collections::HashMap::new();
1237 const_values.insert("shape".to_string(), vec![4096, 4096]);
1238 let value_ids = std::collections::HashMap::new();
1239 let value_types = std::collections::HashMap::new();
1240 let context = ConversionContext {
1241 initializers: &initializers,
1242 value_shapes: &value_shapes,
1243 value_shape_dims: &value_shape_dims,
1244 const_values: &const_values,
1245 value_ids: &value_ids,
1246 value_types: &value_types,
1247 };
1248
1249 let result = handler.convert(&node, &context).unwrap();
1250 assert_eq!(result.nodes.len(), 1);
1251 assert_eq!(result.nodes[0].op, "expand");
1252 assert_eq!(result.nodes[0].inputs.len(), 1);
1253 assert_eq!(result.consts.len(), 1);
1254 assert_eq!(result.consts[0].1.shape, vec![1]);
1255 assert_eq!(
1256 result.nodes[0].options.get("newShape"),
1257 Some(&json!([
1258 {"name": "sequence_length", "maxSize": 4096},
1259 {"name": "past_sequence_length + 1", "maxSize": 4096}
1260 ]))
1261 );
1262 assert_eq!(result.output_types.get("output"), Some(&DataType::Float32));
1263 }
1264
1265 #[test]
1266 fn test_convert_trilu_defaults() {
1267 let handler = UtilityHandler;
1268 let node = create_test_node("Trilu", vec!["x"], vec!["y"]);
1269 let initializers = std::collections::HashMap::new();
1270 let value_shapes = std::collections::HashMap::new();
1271 let const_values = std::collections::HashMap::new();
1272 let value_ids = std::collections::HashMap::new();
1273 let mut value_types = std::collections::HashMap::new();
1274 value_types.insert("x".to_string(), DataType::Float32);
1275 let context = ConversionContext {
1276 initializers: &initializers,
1277 value_shapes: &value_shapes,
1278 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1279 const_values: &const_values,
1280 value_ids: &value_ids,
1281 value_types: &value_types,
1282 };
1283
1284 let result = handler.convert(&node, &context).unwrap();
1285 assert_eq!(result.nodes.len(), 1);
1286 assert_eq!(result.nodes[0].op, "triangular");
1287 assert_eq!(result.nodes[0].inputs, vec!["x"]);
1288 assert_eq!(result.nodes[0].options.get("upper"), Some(&json!(true)));
1289 assert_eq!(result.nodes[0].options.get("k"), Some(&json!(0)));
1290 assert_eq!(result.output_mappings.get("y"), Some(&"y".to_string()));
1291 assert_eq!(result.output_types.get("y"), Some(&DataType::Float32));
1292 }
1293
1294 #[test]
1295 fn test_convert_trilu_with_k_and_lower() {
1296 let handler = UtilityHandler;
1297 let mut node = create_test_node("Trilu", vec!["x", "k"], vec!["y"]);
1298 add_int_attribute(&mut node, "upper", 0);
1299 let initializers = std::collections::HashMap::new();
1300 let value_shapes = std::collections::HashMap::new();
1301 let mut const_values = std::collections::HashMap::new();
1302 const_values.insert("k".to_string(), vec![2]);
1303 let value_ids = std::collections::HashMap::new();
1304 let mut value_types = std::collections::HashMap::new();
1305 value_types.insert("x".to_string(), DataType::Float16);
1306 let context = ConversionContext {
1307 initializers: &initializers,
1308 value_shapes: &value_shapes,
1309 value_shape_dims: crate::onnx::ops::empty_value_shape_dims(),
1310 const_values: &const_values,
1311 value_ids: &value_ids,
1312 value_types: &value_types,
1313 };
1314
1315 let result = handler.convert(&node, &context).unwrap();
1316 assert_eq!(result.nodes.len(), 1);
1317 assert_eq!(result.nodes[0].op, "triangular");
1318 assert_eq!(result.nodes[0].inputs, vec!["x"]);
1319 assert_eq!(result.nodes[0].options.get("upper"), Some(&json!(false)));
1320 assert_eq!(result.nodes[0].options.get("k"), Some(&json!(2)));
1321 assert_eq!(result.output_mappings.get("y"), Some(&"y".to_string()));
1322 assert_eq!(result.output_types.get("y"), Some(&DataType::Float16));
1323 }
1324}