tree_edit_distance/
diff.rs

1use crate::{memoize, Cost, Edit, Tree};
2use arrayvec::ArrayVec;
3use derive_more::{Add, From};
4use itertools::Itertools;
5use pathfinding::{num_traits::Zero, prelude::*};
6use std::{collections::HashMap, ops::Add};
7
8#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, From, Add)]
9struct WholeNumber<T>(T);
10
11impl<T: Default + Eq + Add<Output = T>> Zero for WholeNumber<T> {
12    fn zero() -> Self {
13        Self::default()
14    }
15
16    fn is_zero(&self) -> bool {
17        *self == Self::zero()
18    }
19}
20
21fn levenshtein<'c, T>(a: &'c [T], b: &'c [T]) -> (Box<[Edit]>, T::Weight)
22where
23    T: Tree<Children<'c> = &'c [T]> + Cost<Output = T::Weight>,
24{
25    let mut edges = HashMap::new();
26
27    let (path, WholeNumber(cost)) = astar(
28        &(0, 0),
29        |&(i, j)| {
30            let x = a.get(i);
31            let y = b.get(j);
32
33            let mut successors = ArrayVec::<_, 3>::new();
34
35            if let Some(x) = x {
36                let next = (i + 1, j);
37                let none = edges.insert(((i, j), next), Edit::Remove);
38                debug_assert!(none.is_none());
39                successors.push((next, x.cost().into()));
40            }
41
42            if let Some(y) = y {
43                let next = (i, j + 1);
44                let none = edges.insert(((i, j), next), Edit::Insert);
45                debug_assert!(none.is_none());
46                successors.push((next, y.cost().into()));
47            }
48
49            if let (Some(x), Some(y)) = (x, y) {
50                if x.kind() == y.kind() {
51                    let next = (i + 1, j + 1);
52                    let (inner, cost) = levenshtein(x.children(), y.children());
53                    let none = edges.insert(((i, j), next), Edit::Replace(inner));
54                    debug_assert!(none.is_none());
55                    successors.push((next, cost.into()));
56                }
57            }
58
59            successors
60        },
61        |&(i, j)| match (&a[i..], &b[j..]) {
62            (&[], rest) | (rest, &[]) => rest.cost().into(),
63
64            (a, b) if a.len() != b.len() => {
65                let rest = if a.len() > b.len() { a } else { b };
66                let nth = a.len().max(b.len()) - a.len().min(b.len());
67                let mut costs: Box<[_]> = rest.iter().map(T::cost).collect();
68                let (cheapest, _, _) = costs.select_nth_unstable(nth);
69                cheapest.cost().into()
70            }
71
72            _ => WholeNumber::default(),
73        },
74        |&p| p == (a.len(), b.len()),
75    )
76    .unwrap();
77
78    let patches = path
79        .into_iter()
80        .tuple_windows()
81        .flat_map(move |e| edges.remove(&e))
82        .collect();
83
84    (patches, cost)
85}
86
87/// Finds the lowest cost sequence of [Edit]s that transforms one [Tree] into the other.
88///
89/// The sequence of [Edit]s is understood to apply to the left-hand side so it becomes the
90/// right-hand side.
91pub fn diff<T: Tree>(a: &T, b: &T) -> (Box<[Edit]>, T::Weight) {
92    levenshtein(&[memoize(a)], &[memoize(b)])
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::{Fold, MockTree, Tree};
99    use assert_matches::assert_matches;
100    use proptest::collection::size_range;
101    use test_strategy::{proptest, Arbitrary};
102
103    #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Arbitrary)]
104    struct Eq;
105
106    #[derive(Debug, Default, Copy, Clone, Arbitrary)]
107    struct NotEq;
108
109    impl PartialEq for NotEq {
110        fn eq(&self, _: &Self) -> bool {
111            false
112        }
113    }
114
115    #[proptest]
116    fn the_number_of_edits_is_at_most_equal_to_the_total_number_of_nodes(
117        a: MockTree<u8>,
118        b: MockTree<u8>,
119    ) {
120        let (e, _) = diff(&a, &b);
121        assert_matches!((e.count(), a.count() + b.count()), (x, y) if x <= y);
122    }
123
124    #[proptest]
125    fn the_cost_is_at_most_equal_to_the_sum_of_costs(a: MockTree<u8>, b: MockTree<u8>) {
126        let (_, c) = diff(&a, &b);
127        assert_matches!((c, a.cost() + b.cost()), (x, y) if x <= y);
128    }
129
130    #[proptest]
131    fn the_cost_between_identical_trees_is_zero(a: MockTree<u8>) {
132        let (e, c) = diff(&a, &a);
133        assert_eq!(e.count(), a.count());
134        assert_eq!(c, 0);
135    }
136
137    #[proptest]
138    fn nodes_of_different_kinds_cannot_be_replaced(a: MockTree<NotEq>, b: MockTree<NotEq>) {
139        use Edit::*;
140        let (e, _) = diff(&a, &b);
141        assert_matches!(&e[..], [Remove, Insert] | [Insert, Remove]);
142    }
143
144    #[proptest]
145    fn nodes_of_equal_kinds_can_be_replaced(a: MockTree<Eq>, b: MockTree<Eq>) {
146        let (e, _) = diff(&a, &b);
147        let (i, _) = levenshtein(a.children(), b.children());
148
149        assert_matches!(&e[..], [Edit::Replace(x)] => {
150            assert_eq!(x, &i);
151        });
152    }
153
154    #[proptest]
155    fn the_cost_of_swapping_nodes_is_equal_to_the_sum_of_their_costs(
156        a: MockTree<NotEq>,
157        b: MockTree<NotEq>,
158    ) {
159        let (_, c) = diff(&a, &b);
160        assert_eq!(c, a.cost() + b.cost());
161    }
162
163    #[proptest]
164    fn the_cost_of_replacing_nodes_does_not_depend_on_their_weights(
165        a: MockTree<Eq>,
166        b: MockTree<Eq>,
167    ) {
168        let (_, c) = diff(&a, &b);
169        let (_, d) = levenshtein(a.children(), b.children());
170        assert_eq!(c, d);
171    }
172
173    #[proptest]
174    fn the_cost_is_always_minimized(
175        #[any(size_range(1..8).lift())] a: Vec<MockTree<u8>>,
176        #[any(size_range(1..8).lift())] b: Vec<MockTree<u8>>,
177        #[strategy(0..#a.len())] i: usize,
178        #[strategy(0..#b.len())] j: usize,
179    ) {
180        let mut x = a.clone();
181        let mut y = b.clone();
182
183        let m = x.remove(i);
184        let n = y.remove(j);
185
186        let (_, c) = levenshtein(&a, &b);
187        let (_, d) = levenshtein(&x, &y);
188
189        assert_matches!((c, d + m.cost() + n.cost()), (x, y) if x <= y);
190    }
191}