syntaxdot_encoders/lemma/
edit_tree.rs

1// Copyright 2019 The edit_tree contributors
2// Copyright 2020 TensorDot
3//
4// Licensed under the Apache License, Version 2.0 or the MIT license,
5// at your option.
6//
7// Contributors:
8//
9// Tobias Pütz <tobias.puetz@uni-tuebingen.de>
10// Daniël de Kok <me@danieldk.eu>
11
12//! Edit trees
13
14use std::cmp::Eq;
15use std::cmp::Ordering;
16use std::fmt::Debug;
17
18use lazy_static::lazy_static;
19use seqalign::measures::LCSOp;
20use seqalign::measures::LCS;
21use seqalign::op::IndexedOperation;
22use seqalign::Align;
23use serde::{Deserialize, Serialize};
24
25lazy_static! {
26    static ref MEASURE: LCS = LCS::new(1, 1);
27}
28
29/// Enum representing a `TreeNode` of an `Graph<TreeNode<T>,Place>`.
30#[derive(Debug, PartialEq, Hash, Eq, Clone, Serialize, Deserialize)]
31pub enum EditTree {
32    MatchNode {
33        pre: usize,
34        suf: usize,
35        left: Option<Box<EditTree>>,
36        right: Option<Box<EditTree>>,
37    },
38
39    ReplaceNode {
40        replacee: Vec<char>,
41        replacement: Vec<char>,
42    },
43}
44
45impl EditTree {
46    /// Returns a edit tree specifying how to derive `b` from `a`.
47    ///
48    /// **Caution:** when using with stringy types. UTF-8 multi byte
49    /// chars will not be treated well. Consider passing in &[char]
50    /// instead.
51    pub fn create_tree(a: &[char], b: &[char]) -> Option<Self> {
52        build_tree(a, b).map(|tree| *tree)
53    }
54
55    /// Recursively applies the nodes stored in the edit tree. Returns `None` if the tree is not applicable to
56    /// `form`.
57    pub fn apply(&self, form: &[char]) -> Option<Vec<char>> {
58        let form_len = form.len();
59        match self {
60            EditTree::MatchNode {
61                pre,
62                suf,
63                left,
64                right,
65            } => {
66                if pre + suf >= form_len {
67                    return None;
68                }
69
70                let mut left = match left {
71                    Some(left) => left.apply(&form[..*pre])?,
72                    None => vec![],
73                };
74
75                left.extend(form[*pre..form_len - *suf].iter().cloned());
76
77                if let Some(right) = right {
78                    left.extend(right.apply(&form[form_len - *suf..])?)
79                }
80
81                Some(left)
82            }
83
84            EditTree::ReplaceNode {
85                ref replacee,
86                ref replacement,
87            } => {
88                if form == &replacee[..] || replacee.is_empty() {
89                    Some(replacement.clone())
90                } else {
91                    None
92                }
93            }
94        }
95    }
96}
97
98/// Struct representing a continuous match between two sequences.
99#[derive(Debug, PartialEq, Eq, Hash)]
100struct LcsMatch {
101    start_src: usize,
102    start_targ: usize,
103    length: usize,
104}
105
106impl LcsMatch {
107    fn new(start_src: usize, start_targ: usize, length: usize) -> Self {
108        LcsMatch {
109            start_src,
110            start_targ,
111            length,
112        }
113    }
114    fn empty() -> Self {
115        LcsMatch::new(0, 0, 0)
116    }
117}
118
119impl PartialOrd for LcsMatch {
120    fn partial_cmp(&self, other: &LcsMatch) -> Option<Ordering> {
121        Some(self.length.cmp(&other.length))
122    }
123}
124
125/// Returns the start and end index of the longest match. Returns none if no match is found.
126fn longest_match(script: &[IndexedOperation<LCSOp>]) -> Option<LcsMatch> {
127    let mut longest = LcsMatch::empty();
128
129    let mut script_slice = script;
130    while !script_slice.is_empty() {
131        let op = &script_slice[0];
132
133        match op.operation() {
134            LCSOp::Match => {
135                let in_start = op.source_idx();
136                let o_start = op.target_idx();
137                let end = match script_slice
138                    .iter()
139                    .position(|x| !matches!(x.operation(), LCSOp::Match))
140                {
141                    Some(idx) => idx,
142                    None => script_slice.len(),
143                };
144                if end > longest.length {
145                    longest = LcsMatch::new(in_start, o_start, end);
146                };
147
148                script_slice = &script_slice[end..];
149            }
150            _ => {
151                script_slice = &script_slice[1..];
152            }
153        }
154    }
155
156    if longest.length != 0 {
157        Some(longest)
158    } else {
159        None
160    }
161}
162
163/// Recursively builds an edit tree by applying itself to pre and suffix of the longest common substring.
164fn build_tree(form_ch: &[char], lem_ch: &[char]) -> Option<Box<EditTree>> {
165    if form_ch.is_empty() && lem_ch.is_empty() {
166        return None;
167    }
168
169    let alignment = MEASURE.align(form_ch, lem_ch);
170    let root = match longest_match(&alignment.edit_script()[..]) {
171        Some(m) => EditTree::MatchNode {
172            pre: m.start_src,
173            suf: (form_ch.len() - m.start_src) - m.length,
174            left: build_tree(&form_ch[..m.start_src], &lem_ch[..m.start_targ]),
175            right: build_tree(
176                &form_ch[m.start_src + m.length..],
177                &lem_ch[m.start_targ + m.length..],
178            ),
179        },
180        None => EditTree::ReplaceNode {
181            replacee: form_ch.to_vec(),
182            replacement: lem_ch.to_vec(),
183        },
184    };
185    Some(Box::new(root))
186}
187
188#[cfg(test)]
189mod tests {
190    use std::collections::HashSet;
191
192    use super::EditTree;
193
194    /// Utility trait to retrieve a lower-cased `Vec<char>`.
195    pub trait ToLowerCharVec {
196        fn to_lower_char_vec(&self) -> Vec<char>;
197    }
198
199    impl<'a> ToLowerCharVec for &'a str {
200        fn to_lower_char_vec(&self) -> Vec<char> {
201            self.to_lowercase().chars().collect()
202        }
203    }
204
205    #[test]
206    fn test_graph_equality_outcome() {
207        let a = "hates".to_lower_char_vec();
208        let b = "hate".to_lower_char_vec();
209        let g = EditTree::create_tree(&a, &b).unwrap();
210
211        let a = "loves".to_lower_char_vec();
212        let b = "love".to_lower_char_vec();
213        let g1 = EditTree::create_tree(&a, &b).unwrap();
214
215        let f = "loves".to_lower_char_vec();
216        let f1 = "hates".to_lower_char_vec();
217        let exp = "love".to_lower_char_vec();
218        let exp1 = "hate".to_lower_char_vec();
219
220        assert_eq!(g.apply(&f1).unwrap(), exp1);
221        assert_eq!(g1.apply(&f).unwrap(), exp);
222        assert_eq!(g, g1);
223    }
224
225    #[test]
226    fn test_graph_equality_outcome_2() {
227        let g = EditTree::create_tree(
228            &"machen".to_lower_char_vec(),
229            &"gemacht".to_lower_char_vec(),
230        )
231        .unwrap();
232        let g1 = EditTree::create_tree(
233            &"lachen".to_lower_char_vec(),
234            &"gelacht".to_lower_char_vec(),
235        )
236        .unwrap();
237
238        let f = "machen".to_lower_char_vec();
239        let f1 = "lachen".to_lower_char_vec();
240        let exp = "gemacht".to_lower_char_vec();
241        let exp1 = "gelacht".to_lower_char_vec();
242
243        assert_eq!(g.apply(&f1).unwrap(), exp1);
244        assert_eq!(g1.apply(&f).unwrap(), exp);
245        assert_eq!(g, g1);
246    }
247
248    #[test]
249    fn test_graph_equality_outcome_3() {
250        let a = "aaaaaaaaen".to_lower_char_vec();
251        let b = "geaaaaaaaat".to_lower_char_vec();
252        let g = EditTree::create_tree(&a, &b).unwrap();
253
254        let a = "lachen".to_lower_char_vec();
255        let b = "gelacht".to_lower_char_vec();
256        let g1 = EditTree::create_tree(&a, &b).unwrap();
257
258        let f = "lachen".to_lower_char_vec();
259        let f1 = "aaaaaaaaen".to_lower_char_vec();
260        let exp = "gelacht".to_lower_char_vec();
261        let exp1 = "geaaaaaaaat".to_lower_char_vec();
262
263        assert_eq!(g.apply(&f).unwrap(), exp);
264        assert_eq!(g1.apply(&f1).unwrap(), exp1);
265        assert_eq!(g, g1);
266    }
267
268    #[test]
269    fn test_graph_equality_and_applicability() {
270        let mut set: HashSet<EditTree> = HashSet::default();
271        let a = "abc".to_lower_char_vec();
272        let b = "ab".to_lower_char_vec();
273        let g1 = EditTree::create_tree(&a, &b).unwrap();
274
275        let a = "aaa".to_lower_char_vec();
276        let b = "aa".to_lower_char_vec();
277        let g2 = EditTree::create_tree(&a, &b).unwrap();
278
279        let a = "cba".to_lower_char_vec();
280        let b = "ba".to_lower_char_vec();
281        let g3 = EditTree::create_tree(&a, &b).unwrap();
282        let g4 = EditTree::create_tree(&a, &b).unwrap();
283
284        let a = "aaa".to_lower_char_vec();
285        let b = "aac".to_lower_char_vec();
286        let g5 = EditTree::create_tree(&a, &b).unwrap();
287
288        let a = "dec".to_lower_char_vec();
289        let b = "decc".to_lower_char_vec();
290        let g6 = EditTree::create_tree(&a, &a).unwrap();
291        let g7 = EditTree::create_tree(&a, &b).unwrap();
292
293        set.insert(g1);
294        assert_eq!(set.len(), 1);
295        set.insert(g2);
296        assert_eq!(set.len(), 2);
297        set.insert(g3);
298        assert_eq!(set.len(), 3);
299        set.insert(g4);
300        assert_eq!(set.len(), 3);
301        set.insert(g5);
302        assert_eq!(set.len(), 4);
303        set.insert(g6);
304        set.insert(g7);
305        assert_eq!(set.len(), 6);
306
307        let v = "yyyy".to_lower_char_vec();
308        let res: HashSet<String> = set
309            .iter()
310            .map(|x| x.apply(&v))
311            .filter(|x| x.is_some())
312            .map(|x| x.unwrap().iter().collect::<String>())
313            .collect();
314
315        assert_eq!(res.len(), 2);
316
317        let v = "yyy".to_lower_char_vec();
318        let res: HashSet<String> = set
319            .iter()
320            .map(|x| x.apply(&v))
321            .filter(|x| x.is_some())
322            .map(|x| x.unwrap().iter().collect::<String>())
323            .collect();
324        assert!(res.contains("yyyc"));
325        assert!(res.contains("yyy"));
326        assert_eq!(res.len(), 2);
327
328        let v = "bba".to_lower_char_vec();
329        let res: HashSet<String> = set
330            .iter()
331            .map(|x| x.apply(&v))
332            .filter(|x| x.is_some())
333            .map(|x| x.unwrap().iter().collect::<String>())
334            .collect();
335
336        assert!(res.contains("bbac"));
337        assert!(res.contains("bba"));
338        assert!(res.contains("bb"));
339        assert!(res.contains("bbc"));
340        assert_eq!(res.len(), 4);
341
342        let res: HashSet<String> = set
343            .iter()
344            .map(|x| x.apply(&a))
345            .filter(|x| x.is_some())
346            .map(|x| x.unwrap().iter().collect::<String>())
347            .collect();
348        assert!(res.contains("dec"));
349        assert!(res.contains("decc"));
350        assert!(res.contains("de"));
351        assert_eq!(res.len(), 3);
352
353        let a = "die".to_lower_char_vec();
354        let b = "das".to_lower_char_vec();
355        let c = "die".to_lower_char_vec();
356        let g = EditTree::create_tree(&a, &b).unwrap();
357        assert!(g.apply(&c).is_some());
358    }
359    #[test]
360    fn test_graphs_inapplicable() {
361        let g = EditTree::create_tree(&"abcdefg".to_lower_char_vec(), &"abc".to_lower_char_vec())
362            .unwrap();
363        assert!(g.apply(&"abc".to_lower_char_vec()).is_none());
364
365        let g = EditTree::create_tree(&"abcdefg".to_lower_char_vec(), &"efg".to_lower_char_vec())
366            .unwrap();
367        assert!(g.apply(&"efg".to_lower_char_vec()).is_none());
368    }
369}