state_tree/
tree.rs

1pub const DELAY_ADDITIONAL_OFFSET: usize = 2;
2
3/// State Tree structure.
4
5//on attributes, see https://github.com/rkyv/rkyv/blob/main/rkyv/examples/json_like_schema.rs
6#[derive(Clone, PartialEq, Eq)]
7pub enum StateTree {
8    Delay {
9        readidx: u64,
10        writeidx: u64,
11        data: Vec<u64>, //assume we are using only mono f64 data
12    },
13    Mem {
14        data: Vec<u64>, //assume we are using only mono f64 data
15    },
16    Feed {
17        data: Vec<u64>, //assume we are using generic data, might be tuple of float
18    },
19    FnCall(Vec<StateTree>),
20}
21
22impl std::fmt::Display for StateTree {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            StateTree::Delay {
26                readidx,
27                writeidx,
28                data,
29            } => write!(
30                f,
31                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
32                readidx,
33                writeidx,
34                data.iter().take(10).collect::<Vec<&u64>>()
35            ),
36            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
37            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
38            StateTree::FnCall(children) => {
39                let children_str: Vec<String> = children.iter().map(|c| format!("{c}")).collect();
40                write!(f, "FnCall([{}])", children_str.join(", "))
41            }
42        }
43    }
44}
45impl std::fmt::Debug for StateTree {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            StateTree::Delay {
49                readidx,
50                writeidx,
51                data,
52            } => write!(
53                f,
54                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
55                readidx,
56                writeidx,
57                data.iter().take(10).collect::<Vec<&u64>>()
58            ),
59            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
60            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
61            StateTree::FnCall(children) => {
62                let children_str: Vec<String> = children.iter().map(|c| format!("{c:?}")).collect();
63                write!(f, "FnCall([{}])", children_str.join(", "))
64            }
65        }
66    }
67}
68impl<T: SizedType> From<StateTreeSkeleton<T>> for StateTree {
69    //create empty StateTree from StateTreeSkeleton
70    fn from(skeleton: StateTreeSkeleton<T>) -> Self {
71        match skeleton {
72            StateTreeSkeleton::Delay { len } => StateTree::Delay {
73                readidx: 0,
74                writeidx: 0,
75                data: vec![0; len as usize],
76            },
77            StateTreeSkeleton::Mem(t) => StateTree::Mem {
78                data: vec![0; t.word_size() as usize],
79            },
80            StateTreeSkeleton::Feed(t) => StateTree::Feed {
81                data: vec![0; t.word_size() as usize],
82            },
83            StateTreeSkeleton::FnCall(children_layout) => StateTree::FnCall(
84                children_layout
85                    .into_iter()
86                    .map(|child_layout| StateTree::from(*child_layout))
87                    .collect(),
88            ),
89        }
90    }
91}
92
93impl StateTree {
94    /// パスを指定して、イミュータブルなノードへの参照を取得する
95    pub fn get_node(&self, path: &[usize]) -> Option<&StateTree> {
96        let mut current = self;
97        for &index in path {
98            if let StateTree::FnCall(children) = current {
99                current = children.get(index)?;
100            } else {
101                // パスが深すぎるか、FnCallではないノードを指している
102                return None;
103            }
104        }
105        Some(current)
106    }
107
108    /// パスを指定して、ミュータブルなノードへの参照を取得する
109    pub fn get_node_mut(&mut self, path: &[usize]) -> Option<&mut StateTree> {
110        let mut current = self;
111        for &index in path {
112            if let StateTree::FnCall(children) = current {
113                current = children.get_mut(index)?;
114            } else {
115                // パスが深すぎるか、FnCallではないノードを指している
116                return None;
117            }
118        }
119        Some(current)
120    }
121}
122
123pub fn serialize_tree_untagged(tree: StateTree) -> Vec<u64> {
124    match tree {
125        StateTree::Delay {
126            readidx,
127            writeidx,
128            data,
129        } => itertools::concat([vec![readidx, writeidx], data]),
130        StateTree::Mem { data } | StateTree::Feed { data } => data,
131        StateTree::FnCall(state_trees) => {
132            itertools::concat(state_trees.into_iter().map(serialize_tree_untagged))
133        }
134    }
135}
136
137pub trait SizedType: std::fmt::Debug {
138    fn word_size(&self) -> u64;
139}
140
141/// This data represents just a memory layout on a flat array, do not own actual data.
142#[derive(Debug, Clone)]
143pub enum StateTreeSkeleton<T: SizedType> {
144    Delay {
145        len: u64, //assume we are using only mono f64 data
146    },
147    Mem(T),
148    Feed(T),
149    FnCall(Vec<Box<StateTreeSkeleton<T>>>),
150}
151impl<T: SizedType> StateTreeSkeleton<T> {
152    pub fn total_size(&self) -> u64 {
153        match self {
154            StateTreeSkeleton::Delay { len } => DELAY_ADDITIONAL_OFFSET as u64 + *len,
155            StateTreeSkeleton::Mem(t) | StateTreeSkeleton::Feed(t) => t.word_size(),
156            StateTreeSkeleton::FnCall(children_layout) => children_layout
157                .iter()
158                .map(|child_layout| child_layout.total_size())
159                .sum(),
160        }
161    }
162}
163impl<T: SizedType> PartialEq for StateTreeSkeleton<T> {
164    fn eq(&self, other: &Self) -> bool {
165        match (self, other) {
166            (Self::Delay { len: l_len }, Self::Delay { len: r_len }) => l_len == r_len,
167            (Self::Mem(l0), Self::Mem(r0)) => l0.word_size() == r0.word_size(),
168            (Self::Feed(l0), Self::Feed(r0)) => l0.word_size() == r0.word_size(),
169            (Self::FnCall(l0), Self::FnCall(r0)) => l0 == r0,
170            _ => false,
171        }
172    }
173}
174
175fn deserialize_tree_untagged_rec<T: SizedType>(
176    data: &[u64],
177    data_layout: &StateTreeSkeleton<T>,
178) -> Option<(StateTree, usize)> {
179    match data_layout {
180        StateTreeSkeleton::Delay { len } => {
181            let readidx = data.first().copied()?;
182            let writeidx = data.get(1).copied()?;
183            let d = data
184                .get(DELAY_ADDITIONAL_OFFSET..DELAY_ADDITIONAL_OFFSET + (*len as usize))?
185                .to_vec();
186            Some((
187                StateTree::Delay {
188                    readidx,
189                    writeidx,
190                    data: d,
191                },
192                DELAY_ADDITIONAL_OFFSET + (*len as usize),
193            ))
194        }
195        StateTreeSkeleton::Mem(t) => {
196            let size = t.word_size() as usize;
197            let data = data.get(0..size)?.to_vec();
198            Some((StateTree::Mem { data }, size))
199        }
200        StateTreeSkeleton::Feed(t) => {
201            let size = t.word_size() as usize;
202            let data = data.get(0..size)?.to_vec();
203            Some((StateTree::Feed { data }, size))
204        }
205        StateTreeSkeleton::FnCall(children_layout) => {
206            let (children, used) =
207                children_layout
208                    .iter()
209                    .try_fold((vec![], 0), |(v, last_used), child_layout| {
210                        let (child, used) =
211                            deserialize_tree_untagged_rec(&data[last_used..], child_layout)?;
212
213                        Some(([v, vec![child]].concat(), last_used + used))
214                    })?;
215
216            Some((StateTree::FnCall(children), used))
217        }
218    }
219}
220
221pub fn deserialize_tree_untagged<T: SizedType>(
222    data: &[u64],
223    data_layout: &StateTreeSkeleton<T>,
224) -> Option<StateTree> {
225    log::trace!("Deserializing  with layout: {data_layout:?}");
226    if let Some((tree, used)) = deserialize_tree_untagged_rec(data, data_layout) {
227        if used == data.len() { Some(tree) } else { None }
228    } else {
229        None
230    }
231}