state_tree/
tree_diff.rs

1use crate::patch::CopyFromPatch;
2use crate::tree::StateTree;
3
4pub fn take_diff(old_tree: &StateTree, new_tree: &StateTree) -> Vec<CopyFromPatch> {
5    let mut patches = Vec::new();
6    build_patches_recursive(old_tree, new_tree, Vec::new(), Vec::new(), &mut patches);
7    patches
8}
9
10/// 2つのノードが「同じ種類」であるかを判定する
11fn nodes_match(old: &StateTree, new: &StateTree) -> bool {
12    match (old, new) {
13        (StateTree::Delay { data: d1, .. }, StateTree::Delay { data: d2, .. }) => {
14            d1.len() == d2.len()
15        }
16        (StateTree::Mem { data: d1 }, StateTree::Mem { data: d2 }) => d1.len() == d2.len(),
17        (StateTree::Feed { data: d1 }, StateTree::Feed { data: d2 }) => d1.len() == d2.len(),
18        (StateTree::FnCall(_), StateTree::FnCall(_)) => true,
19        _ => false,
20    }
21}
22
23/// LCSアルゴリズムの結果を表すEnum
24#[derive(Debug)]
25pub enum DiffResult {
26    /// 両方のシーケンスに共通して存在する要素
27    Common { old_index: usize, new_index: usize },
28    /// 古いシーケンスにのみ存在する要素(削除された)
29    Delete { old_index: usize },
30    /// 新しいシーケンスにのみ存在する要素(挿入された)
31    Insert { new_index: usize },
32}
33
34/// 2つのスライスを比較し、LCSの結果を返す
35/// `compare`クロージャで要素の比較方法を指定する
36pub fn lcs_by<T>(old: &[T], new: &[T], compare: impl Fn(&T, &T) -> bool) -> Vec<DiffResult> {
37    let old_len = old.len();
38    let new_len = new.len();
39
40    // DPテーブルを作成
41    let mut dp = vec![vec![0; new_len + 1]; old_len + 1];
42    for i in 1..=old_len {
43        for j in 1..=new_len {
44            if compare(&old[i - 1], &new[j - 1]) {
45                dp[i][j] = dp[i - 1][j - 1] + 1;
46            } else {
47                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
48            }
49        }
50    }
51
52    // DPテーブルをバックトラックして差分情報を復元
53    let mut results = Vec::new();
54    let (mut i, mut j) = (old_len, new_len);
55    while i > 0 || j > 0 {
56        if i > 0 && j > 0 && compare(&old[i - 1], &new[j - 1]) {
57            // Common
58            results.push(DiffResult::Common {
59                old_index: i - 1,
60                new_index: j - 1,
61            });
62            i -= 1;
63            j -= 1;
64        } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
65            // Insert
66            results.push(DiffResult::Insert { new_index: j - 1 });
67            j -= 1;
68        } else if i > 0 {
69            // Delete
70            results.push(DiffResult::Delete { old_index: i - 1 });
71            i -= 1;
72        }
73    }
74
75    results.reverse(); // 結果を正しい順序にする
76    results
77}
78
79fn build_patches_recursive(
80    old_node: &StateTree,
81    new_node: &StateTree,
82    old_path: Vec<usize>,
83    new_path: Vec<usize>,
84    patches: &mut Vec<CopyFromPatch>,
85) {
86    if !nodes_match(old_node, new_node) {
87        return;
88    }
89
90    if !matches!(new_node, StateTree::FnCall(_)) {
91        patches.push(CopyFromPatch { old_path, new_path });
92        return;
93    }
94
95    if let (StateTree::FnCall(old_children), StateTree::FnCall(new_children)) = (old_node, new_node)
96    {
97        // 自作のlcs_by関数を使用
98        let lcs_results = lcs_by(old_children, new_children, nodes_match);
99
100        let mut unmatched_old = Vec::new();
101        let mut unmatched_new = Vec::new();
102
103        for result in &lcs_results {
104            match *result {
105                DiffResult::Common {
106                    old_index,
107                    new_index,
108                } => {
109                    let old_child = &old_children[old_index];
110                    let new_child = &new_children[new_index];
111
112                    let mut next_old_path = old_path.clone();
113                    next_old_path.push(old_index);
114                    let mut next_new_path = new_path.clone();
115                    next_new_path.push(new_index);
116
117                    build_patches_recursive(
118                        old_child,
119                        new_child,
120                        next_old_path,
121                        next_new_path,
122                        patches,
123                    );
124                }
125                DiffResult::Delete { old_index } => {
126                    unmatched_old.push(old_index);
127                }
128                DiffResult::Insert { new_index } => {
129                    unmatched_new.push(new_index);
130                }
131            }
132        }
133
134        for new_index in unmatched_new {
135            if let Some(position) = unmatched_old.iter().position(|&old_index| {
136                nodes_match(&old_children[old_index], &new_children[new_index])
137            }) {
138                let old_index = unmatched_old.remove(position);
139
140                let mut next_old_path = old_path.clone();
141                next_old_path.push(old_index);
142                let mut next_new_path = new_path.clone();
143                next_new_path.push(new_index);
144
145                build_patches_recursive(
146                    &old_children[old_index],
147                    &new_children[new_index],
148                    next_old_path,
149                    next_new_path,
150                    patches,
151                );
152            }
153        }
154    }
155}