1use crate::{memoize, Cost, Edit, Tree};
2use arrayvec::ArrayVec;
3use derive_more::{Add, From};
4use itertools::Itertools;
5use pathfinding::{num_traits::Zero, prelude::*};
6use std::{collections::HashMap, ops::Add};
7
8#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, From, Add)]
9struct WholeNumber<T>(T);
10
11impl<T: Default + Eq + Add<Output = T>> Zero for WholeNumber<T> {
12 fn zero() -> Self {
13 Self::default()
14 }
15
16 fn is_zero(&self) -> bool {
17 *self == Self::zero()
18 }
19}
20
21fn levenshtein<'c, T>(a: &'c [T], b: &'c [T]) -> (Box<[Edit]>, T::Weight)
22where
23 T: Tree<Children<'c> = &'c [T]> + Cost<Output = T::Weight>,
24{
25 let mut edges = HashMap::new();
26
27 let (path, WholeNumber(cost)) = astar(
28 &(0, 0),
29 |&(i, j)| {
30 let x = a.get(i);
31 let y = b.get(j);
32
33 let mut successors = ArrayVec::<_, 3>::new();
34
35 if let Some(x) = x {
36 let next = (i + 1, j);
37 let none = edges.insert(((i, j), next), Edit::Remove);
38 debug_assert!(none.is_none());
39 successors.push((next, x.cost().into()));
40 }
41
42 if let Some(y) = y {
43 let next = (i, j + 1);
44 let none = edges.insert(((i, j), next), Edit::Insert);
45 debug_assert!(none.is_none());
46 successors.push((next, y.cost().into()));
47 }
48
49 if let (Some(x), Some(y)) = (x, y) {
50 if x.kind() == y.kind() {
51 let next = (i + 1, j + 1);
52 let (inner, cost) = levenshtein(x.children(), y.children());
53 let none = edges.insert(((i, j), next), Edit::Replace(inner));
54 debug_assert!(none.is_none());
55 successors.push((next, cost.into()));
56 }
57 }
58
59 successors
60 },
61 |&(i, j)| match (&a[i..], &b[j..]) {
62 (&[], rest) | (rest, &[]) => rest.cost().into(),
63
64 (a, b) if a.len() != b.len() => {
65 let rest = if a.len() > b.len() { a } else { b };
66 let nth = a.len().max(b.len()) - a.len().min(b.len());
67 let mut costs: Box<[_]> = rest.iter().map(T::cost).collect();
68 let (cheapest, _, _) = costs.select_nth_unstable(nth);
69 cheapest.cost().into()
70 }
71
72 _ => WholeNumber::default(),
73 },
74 |&p| p == (a.len(), b.len()),
75 )
76 .unwrap();
77
78 let patches = path
79 .into_iter()
80 .tuple_windows()
81 .flat_map(move |e| edges.remove(&e))
82 .collect();
83
84 (patches, cost)
85}
86
87pub fn diff<T: Tree>(a: &T, b: &T) -> (Box<[Edit]>, T::Weight) {
92 levenshtein(&[memoize(a)], &[memoize(b)])
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::{Fold, MockTree, Tree};
99 use assert_matches::assert_matches;
100 use proptest::collection::size_range;
101 use test_strategy::{proptest, Arbitrary};
102
103 #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Arbitrary)]
104 struct Eq;
105
106 #[derive(Debug, Default, Copy, Clone, Arbitrary)]
107 struct NotEq;
108
109 impl PartialEq for NotEq {
110 fn eq(&self, _: &Self) -> bool {
111 false
112 }
113 }
114
115 #[proptest]
116 fn the_number_of_edits_is_at_most_equal_to_the_total_number_of_nodes(
117 a: MockTree<u8>,
118 b: MockTree<u8>,
119 ) {
120 let (e, _) = diff(&a, &b);
121 assert_matches!((e.count(), a.count() + b.count()), (x, y) if x <= y);
122 }
123
124 #[proptest]
125 fn the_cost_is_at_most_equal_to_the_sum_of_costs(a: MockTree<u8>, b: MockTree<u8>) {
126 let (_, c) = diff(&a, &b);
127 assert_matches!((c, a.cost() + b.cost()), (x, y) if x <= y);
128 }
129
130 #[proptest]
131 fn the_cost_between_identical_trees_is_zero(a: MockTree<u8>) {
132 let (e, c) = diff(&a, &a);
133 assert_eq!(e.count(), a.count());
134 assert_eq!(c, 0);
135 }
136
137 #[proptest]
138 fn nodes_of_different_kinds_cannot_be_replaced(a: MockTree<NotEq>, b: MockTree<NotEq>) {
139 use Edit::*;
140 let (e, _) = diff(&a, &b);
141 assert_matches!(&e[..], [Remove, Insert] | [Insert, Remove]);
142 }
143
144 #[proptest]
145 fn nodes_of_equal_kinds_can_be_replaced(a: MockTree<Eq>, b: MockTree<Eq>) {
146 let (e, _) = diff(&a, &b);
147 let (i, _) = levenshtein(a.children(), b.children());
148
149 assert_matches!(&e[..], [Edit::Replace(x)] => {
150 assert_eq!(x, &i);
151 });
152 }
153
154 #[proptest]
155 fn the_cost_of_swapping_nodes_is_equal_to_the_sum_of_their_costs(
156 a: MockTree<NotEq>,
157 b: MockTree<NotEq>,
158 ) {
159 let (_, c) = diff(&a, &b);
160 assert_eq!(c, a.cost() + b.cost());
161 }
162
163 #[proptest]
164 fn the_cost_of_replacing_nodes_does_not_depend_on_their_weights(
165 a: MockTree<Eq>,
166 b: MockTree<Eq>,
167 ) {
168 let (_, c) = diff(&a, &b);
169 let (_, d) = levenshtein(a.children(), b.children());
170 assert_eq!(c, d);
171 }
172
173 #[proptest]
174 fn the_cost_is_always_minimized(
175 #[any(size_range(1..8).lift())] a: Vec<MockTree<u8>>,
176 #[any(size_range(1..8).lift())] b: Vec<MockTree<u8>>,
177 #[strategy(0..#a.len())] i: usize,
178 #[strategy(0..#b.len())] j: usize,
179 ) {
180 let mut x = a.clone();
181 let mut y = b.clone();
182
183 let m = x.remove(i);
184 let n = y.remove(j);
185
186 let (_, c) = levenshtein(&a, &b);
187 let (_, d) = levenshtein(&x, &y);
188
189 assert_matches!((c, d + m.cost() + n.cost()), (x, y) if x <= y);
190 }
191}