1use crate::ast::DataType;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum Dim {
7 Known(i64),
8 Unknown(Option<String>),
9}
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct TensorShape {
13 pub dims: Vec<Dim>,
14}
15
16impl TensorShape {
17 pub fn from_known(dims: Vec<i64>) -> Self {
18 Self {
19 dims: dims.into_iter().map(Dim::Known).collect(),
20 }
21 }
22
23 pub fn is_static(&self) -> bool {
24 self.dims.iter().all(|d| matches!(d, Dim::Known(_)))
25 }
26
27 pub fn to_i64(&self) -> Option<Vec<i64>> {
28 self.dims
29 .iter()
30 .map(|d| match d {
31 Dim::Known(v) => Some(*v),
32 Dim::Unknown(_) => None,
33 })
34 .collect()
35 }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct TensorType {
40 pub data_type: DataType,
41 pub shape: TensorShape,
42}
43
44#[derive(Debug, Clone)]
45pub struct ValueInfo {
46 pub id: String,
47 pub ty: Option<TensorType>,
48 pub producer: Option<String>,
49 pub consumers: Vec<String>,
50}
51
52impl ValueInfo {
53 pub fn new(id: String) -> Self {
54 Self {
55 id,
56 ty: None,
57 producer: None,
58 consumers: Vec::new(),
59 }
60 }
61}
62
63#[derive(Debug, Default, Clone)]
64pub struct OnnxIrGraph {
65 pub values: HashMap<String, ValueInfo>,
66}
67
68impl OnnxIrGraph {
69 pub fn value_or_insert(&mut self, id: &str) -> &mut ValueInfo {
70 self.values
71 .entry(id.to_string())
72 .or_insert_with(|| ValueInfo::new(id.to_string()))
73 }
74}