Skip to main content

webnn_graph/onnx/
ir.rs

1// Minimal ONNX IR structures used for static shape/type inference.
2use crate::ast::DataType;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum Dim {
7    Known(i64),
8    Unknown(Option<String>),
9}
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct TensorShape {
13    pub dims: Vec<Dim>,
14}
15
16impl TensorShape {
17    pub fn from_known(dims: Vec<i64>) -> Self {
18        Self {
19            dims: dims.into_iter().map(Dim::Known).collect(),
20        }
21    }
22
23    pub fn is_static(&self) -> bool {
24        self.dims.iter().all(|d| matches!(d, Dim::Known(_)))
25    }
26
27    pub fn to_i64(&self) -> Option<Vec<i64>> {
28        self.dims
29            .iter()
30            .map(|d| match d {
31                Dim::Known(v) => Some(*v),
32                Dim::Unknown(_) => None,
33            })
34            .collect()
35    }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct TensorType {
40    pub data_type: DataType,
41    pub shape: TensorShape,
42}
43
44#[derive(Debug, Clone)]
45pub struct ValueInfo {
46    pub id: String,
47    pub ty: Option<TensorType>,
48    pub producer: Option<String>,
49    pub consumers: Vec<String>,
50}
51
52impl ValueInfo {
53    pub fn new(id: String) -> Self {
54        Self {
55            id,
56            ty: None,
57            producer: None,
58            consumers: Vec::new(),
59        }
60    }
61}
62
63#[derive(Debug, Default, Clone)]
64pub struct OnnxIrGraph {
65    pub values: HashMap<String, ValueInfo>,
66}
67
68impl OnnxIrGraph {
69    pub fn value_or_insert(&mut self, id: &str) -> &mut ValueInfo {
70        self.values
71            .entry(id.to_string())
72            .or_insert_with(|| ValueInfo::new(id.to_string()))
73    }
74}