state_tree/
tree.rs

1pub const DELAY_ADDITIONAL_OFFSET: usize = 2;
2
3/// State Tree structure.
4
5//on attributes, see https://github.com/rkyv/rkyv/blob/main/rkyv/examples/json_like_schema.rs
6#[derive(Clone, PartialEq, Eq)]
7pub enum StateTree {
8    Delay {
9        readidx: u64,
10        writeidx: u64,
11        data: Vec<u64>, //assume we are using only mono f64 data
12    },
13    Mem {
14        data: Vec<u64>, //assume we are using only mono f64 data
15    },
16    Feed {
17        data: Vec<u64>, //assume we are using generic data, might be tuple of float
18    },
19    FnCall(Vec<StateTree>),
20}
21
22impl std::fmt::Display for StateTree {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            StateTree::Delay {
26                readidx,
27                writeidx,
28                data,
29            } => write!(
30                f,
31                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
32                readidx,
33                writeidx,
34                data.iter().take(10).collect::<Vec<&u64>>()
35            ),
36            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
37            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
38            StateTree::FnCall(children) => {
39                let children_str: Vec<String> = children.iter().map(|c| format!("{c}")).collect();
40                write!(f, "FnCall([{}])", children_str.join(", "))
41            }
42        }
43    }
44}
45impl std::fmt::Debug for StateTree {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            StateTree::Delay {
49                readidx,
50                writeidx,
51                data,
52            } => write!(
53                f,
54                "Delay(readidx: {}, writeidx: {}, data: {:?} ...)",
55                readidx,
56                writeidx,
57                data.iter().take(10).collect::<Vec<&u64>>()
58            ),
59            StateTree::Mem { data } => write!(f, "Mem(data: {data:?})"),
60            StateTree::Feed { data } => write!(f, "Feed(data: {data:?})"),
61            StateTree::FnCall(children) => {
62                let children_str: Vec<String> = children.iter().map(|c| format!("{c:?}")).collect();
63                write!(f, "FnCall([{}])", children_str.join(", "))
64            }
65        }
66    }
67}
68impl<T: SizedType> From<StateTreeSkeleton<T>> for StateTree {
69    //create empty StateTree from StateTreeSkeleton
70    fn from(skeleton: StateTreeSkeleton<T>) -> Self {
71        match skeleton {
72            StateTreeSkeleton::Delay { len } => StateTree::Delay {
73                readidx: 0,
74                writeidx: 0,
75                data: vec![0; len as usize],
76            },
77            StateTreeSkeleton::Mem(t) => StateTree::Mem {
78                data: vec![0; t.word_size() as usize],
79            },
80            StateTreeSkeleton::Feed(t) => StateTree::Feed {
81                data: vec![0; t.word_size() as usize],
82            },
83            StateTreeSkeleton::FnCall(children_layout) => StateTree::FnCall(
84                children_layout
85                    .into_iter()
86                    .map(|child_layout| StateTree::from(*child_layout))
87                    .collect(),
88            ),
89        }
90    }
91}
92
93impl StateTree {
94    /// パスを指定して、イミュータブルなノードへの参照を取得する
95    pub fn get_node(&self, path: &[usize]) -> Option<&StateTree> {
96        let mut current = self;
97        for &index in path {
98            if let StateTree::FnCall(children) = current {
99                current = children.get(index)?;
100            } else {
101                // パスが深すぎるか、FnCallではないノードを指している
102                return None;
103            }
104        }
105        Some(current)
106    }
107
108    /// パスを指定して、ミュータブルなノードへの参照を取得する
109    pub fn get_node_mut(&mut self, path: &[usize]) -> Option<&mut StateTree> {
110        let mut current = self;
111        for &index in path {
112            if let StateTree::FnCall(children) = current {
113                current = children.get_mut(index)?;
114            } else {
115                // パスが深すぎるか、FnCallではないノードを指している
116                return None;
117            }
118        }
119        Some(current)
120    }
121
122    /// StateTree から StateTreeSkeleton への変換(データを除いた構造のみ)
123    pub fn to_skeleton(&self) -> StateTreeSkeleton<u64> {
124        match self {
125            StateTree::Delay { data, .. } => StateTreeSkeleton::Delay {
126                len: data.len() as u64,
127            },
128            StateTree::Mem { data } => StateTreeSkeleton::Mem(data.len() as u64),
129            StateTree::Feed { data } => StateTreeSkeleton::Feed(data.len() as u64),
130            StateTree::FnCall(children) => StateTreeSkeleton::FnCall(
131                children
132                    .iter()
133                    .map(|child| Box::new(child.to_skeleton()))
134                    .collect(),
135            ),
136        }
137    }
138}
139
140pub fn serialize_tree_untagged(tree: StateTree) -> Vec<u64> {
141    match tree {
142        StateTree::Delay {
143            readidx,
144            writeidx,
145            data,
146        } => itertools::concat([vec![readidx, writeidx], data]),
147        StateTree::Mem { data } | StateTree::Feed { data } => data,
148        StateTree::FnCall(state_trees) => {
149            itertools::concat(state_trees.into_iter().map(serialize_tree_untagged))
150        }
151    }
152}
153
154pub trait SizedType: std::fmt::Debug {
155    fn word_size(&self) -> u64;
156}
157
158impl SizedType for u64 {
159    fn word_size(&self) -> u64 {
160        *self
161    }
162}
163
164impl SizedType for usize {
165    fn word_size(&self) -> u64 {
166        *self as u64
167    }
168}
169
170/// This data represents just a memory layout on a flat array, do not own actual data.
171#[derive(Debug, Clone)]
172pub enum StateTreeSkeleton<T: SizedType> {
173    Delay {
174        len: u64, //assume we are using only mono f64 data
175    },
176    Mem(T),
177    Feed(T),
178    FnCall(Vec<Box<StateTreeSkeleton<T>>>),
179}
180impl<T: SizedType> StateTreeSkeleton<T> {
181    pub fn total_size(&self) -> u64 {
182        match self {
183            StateTreeSkeleton::Delay { len } => DELAY_ADDITIONAL_OFFSET as u64 + *len,
184            StateTreeSkeleton::Mem(t) | StateTreeSkeleton::Feed(t) => t.word_size(),
185            StateTreeSkeleton::FnCall(children_layout) => children_layout
186                .iter()
187                .map(|child_layout| child_layout.total_size())
188                .sum(),
189        }
190    }
191}
192impl<T: SizedType> PartialEq for StateTreeSkeleton<T> {
193    fn eq(&self, other: &Self) -> bool {
194        match (self, other) {
195            (Self::Delay { len: l_len }, Self::Delay { len: r_len }) => l_len == r_len,
196            (Self::Mem(l0), Self::Mem(r0)) => l0.word_size() == r0.word_size(),
197            (Self::Feed(l0), Self::Feed(r0)) => l0.word_size() == r0.word_size(),
198            (Self::FnCall(l0), Self::FnCall(r0)) => l0 == r0,
199            _ => false,
200        }
201    }
202}
203
204fn deserialize_tree_untagged_rec<T: SizedType>(
205    data: &[u64],
206    data_layout: &StateTreeSkeleton<T>,
207) -> Option<(StateTree, usize)> {
208    match data_layout {
209        StateTreeSkeleton::Delay { len } => {
210            let readidx = data.first().copied()?;
211            let writeidx = data.get(1).copied()?;
212            let d = data
213                .get(DELAY_ADDITIONAL_OFFSET..DELAY_ADDITIONAL_OFFSET + (*len as usize))?
214                .to_vec();
215            Some((
216                StateTree::Delay {
217                    readidx,
218                    writeidx,
219                    data: d,
220                },
221                DELAY_ADDITIONAL_OFFSET + (*len as usize),
222            ))
223        }
224        StateTreeSkeleton::Mem(t) => {
225            let size = t.word_size() as usize;
226            let data = data.get(0..size)?.to_vec();
227            Some((StateTree::Mem { data }, size))
228        }
229        StateTreeSkeleton::Feed(t) => {
230            let size = t.word_size() as usize;
231            let data = data.get(0..size)?.to_vec();
232            Some((StateTree::Feed { data }, size))
233        }
234        StateTreeSkeleton::FnCall(children_layout) => {
235            let (children, used) =
236                children_layout
237                    .iter()
238                    .try_fold((vec![], 0), |(v, last_used), child_layout| {
239                        let (child, used) =
240                            deserialize_tree_untagged_rec(&data[last_used..], child_layout)?;
241
242                        Some(([v, vec![child]].concat(), last_used + used))
243                    })?;
244
245            Some((StateTree::FnCall(children), used))
246        }
247    }
248}
249
250pub fn deserialize_tree_untagged<T: SizedType>(
251    data: &[u64],
252    data_layout: &StateTreeSkeleton<T>,
253) -> Option<StateTree> {
254    log::trace!("Deserializing  with layout: {data_layout:?}");
255    if let Some((tree, used)) = deserialize_tree_untagged_rec(data, data_layout) {
256        if used == data.len() { Some(tree) } else { None }
257    } else {
258        None
259    }
260}