1use std::collections::HashSet;
2
3use crate::patch::CopyFromPatch;
4use crate::tree::{SizedType, StateTreeSkeleton};
5
6pub fn take_diff<T: SizedType>(
7 old_skeleton: &StateTreeSkeleton<T>,
8 new_skeleton: &StateTreeSkeleton<T>,
9) -> HashSet<CopyFromPatch> {
10 build_patches_recursive(old_skeleton, new_skeleton, vec![], vec![])
11}
12
13#[derive(Debug)]
15pub enum DiffResult {
16 Common { old_index: usize, new_index: usize },
18 Delete { old_index: usize },
20 Insert { new_index: usize },
22}
23
24pub fn lcs_by_score<T>(
27 old: &[T],
28 new: &[T],
29 mut score_fn: impl FnMut(&T, &T) -> f64,
30) -> Vec<DiffResult> {
31 let old_len = old.len();
32 let new_len = new.len();
33
34 let mut dp = vec![vec![0.0; new_len + 1]; old_len + 1];
36
37 for i in 1..=old_len {
38 for j in 1..=new_len {
39 let score = score_fn(&old[i - 1], &new[j - 1]);
40
41 if score > 0.0 {
42 dp[i][j] = (dp[i - 1][j - 1] + score).max(dp[i - 1][j].max(dp[i][j - 1]));
44 } else {
45 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
47 }
48 }
49 }
50 let mut results = Vec::new();
52 let (mut i, mut j) = (old_len, new_len);
53
54 while i > 0 || j > 0 {
55 if i > 0 && j > 0 {
56 let score = score_fn(&old[i - 1], &new[j - 1]);
57
58 if score > 0.0 {
59 results.push(DiffResult::Common {
61 old_index: i - 1,
62 new_index: j - 1,
63 });
64 i -= 1;
65 j -= 1;
66 } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
67 results.push(DiffResult::Insert { new_index: j - 1 });
68 j -= 1;
69 } else if i > 0 {
70 results.push(DiffResult::Delete { old_index: i - 1 });
71 i -= 1;
72 }
73 } else if j > 0 {
74 results.push(DiffResult::Insert { new_index: j - 1 });
75 j -= 1;
76 } else if i > 0 {
77 results.push(DiffResult::Delete { old_index: i - 1 });
78 i -= 1;
79 }
80 }
81
82 results.reverse(); results
84}
85
86fn nodes_match<T: SizedType>(old: &StateTreeSkeleton<T>, new: &StateTreeSkeleton<T>) -> bool {
87 match (old, new) {
88 (StateTreeSkeleton::Delay { len: len1 }, StateTreeSkeleton::Delay { len: len2 }) => {
89 len1 == len2
90 }
91 (StateTreeSkeleton::Mem(t1), StateTreeSkeleton::Mem(t2)) => {
92 t1.word_size() == t2.word_size()
93 }
94 (StateTreeSkeleton::Feed(t1), StateTreeSkeleton::Feed(t2)) => {
95 t1.word_size() == t2.word_size()
96 }
97 (StateTreeSkeleton::FnCall(c1), StateTreeSkeleton::FnCall(c2)) => {
98 c1.len() == c2.len() && c1.iter().zip(c2.iter()).all(|(a, b)| nodes_match(a, b))
99 }
100 _ => false,
101 }
102}
103
104fn get_node_at_path<'a, T: SizedType>(
106 skeleton: &'a StateTreeSkeleton<T>,
107 path: &[usize],
108) -> Option<&'a StateTreeSkeleton<T>> {
109 if path.is_empty() {
110 return Some(skeleton);
111 }
112
113 match skeleton {
114 StateTreeSkeleton::FnCall(children) => {
115 let child = children.get(path[0])?;
116 get_node_at_path(child, &path[1..])
117 }
118 _ => None,
119 }
120}
121
122fn build_patches_recursive<T: SizedType>(
123 old_skeleton: &StateTreeSkeleton<T>,
124 new_skeleton: &StateTreeSkeleton<T>,
125 old_path: Vec<usize>,
126 new_path: Vec<usize>,
127) -> HashSet<CopyFromPatch> {
128 let old_node = get_node_at_path(old_skeleton, &old_path).expect("Invalid old_path");
130 let new_node = get_node_at_path(new_skeleton, &new_path).expect("Invalid new_path");
131
132 if nodes_match(old_node, new_node) {
134 let (src_addr, size) = old_skeleton
136 .path_to_address(&old_path)
137 .expect("Invalid old_path");
138 let (dst_addr, dst_size) = new_skeleton
139 .path_to_address(&new_path)
140 .expect("Invalid new_path");
141
142 debug_assert_eq!(
143 size, dst_size,
144 "Size mismatch between matched nodes at old_path {old_path:?} and new_path {new_path:?}"
145 );
146
147 return [CopyFromPatch {
148 src_addr,
149 dst_addr,
150 size,
151 }]
152 .into_iter()
153 .collect();
154 }
155
156 match (old_node, new_node) {
157 (StateTreeSkeleton::FnCall(old_children), StateTreeSkeleton::FnCall(new_children)) => {
158 let mut child_patches_map = Vec::new();
160 for old_idx in 0..old_children.len() {
161 for new_idx in 0..new_children.len() {
162 let child_old_path = [old_path.clone(), vec![old_idx]].concat();
163 let child_new_path = [new_path.clone(), vec![new_idx]].concat();
164 let patches = build_patches_recursive(
165 old_skeleton,
166 new_skeleton,
167 child_old_path,
168 child_new_path,
169 );
170 let score = if patches.is_empty() {
171 0.0
172 } else {
173 patches.len() as f64
174 };
175 child_patches_map.push(((old_idx, new_idx), patches, score));
176 }
177 }
178
179 let old_c_with_id: Vec<_> = old_children.iter().enumerate().collect();
181 let new_c_with_id: Vec<_> = new_children.iter().enumerate().collect();
182
183 let lcs_results = lcs_by_score(
184 &old_c_with_id,
185 &new_c_with_id,
186 |(oid, _old), (nid, _new)| {
187 child_patches_map
188 .iter()
189 .find(|((o, n), _, _)| o == oid && n == nid)
190 .map(|(_, _, score)| *score)
191 .unwrap_or(0.0)
192 },
193 );
194
195 let mut c_patches = HashSet::new();
197 for result in &lcs_results {
198 if let DiffResult::Common {
199 old_index,
200 new_index,
201 } = result
202 && let Some((_, patches, _)) = child_patches_map
203 .iter()
204 .find(|((o, n), _, _)| o == old_index && n == new_index)
205 {
206 c_patches.extend(patches.iter().cloned());
207 }
208 }
209
210 c_patches
211 }
212 _ => HashSet::new(),
213 }
214}