Skip to main content

rust_lstar/knowledge_base/
tree.rs

1/// Knowledge Tree Implementation
2///
3/// A tree-based structure that stores the relationship between input words and output words.
4/// This implementation is based on the pylstar KnowledgeTree used for caching query results.
5use crate::letter::Letter;
6use crate::word::Word;
7use core::fmt;
8use std::collections::HashMap;
9
10/// A node in the knowledge tree
11#[derive(Clone, Debug)]
12pub struct KnowledgeNode {
13    input_letter: Letter,
14    output_letter: Letter,
15    children: HashMap<Letter, KnowledgeNode>,
16}
17
18impl KnowledgeNode {
19    /// Create a new knowledge node
20    pub fn new(input_letter: Letter, output_letter: Letter) -> Self {
21        KnowledgeNode {
22            input_letter,
23            output_letter,
24            children: HashMap::new(),
25        }
26    }
27
28    /// Get the input letter of this node
29    pub fn input_letter(&self) -> &Letter {
30        &self.input_letter
31    }
32
33    /// Get the output letter of this node
34    pub fn output_letter(&self) -> &Letter {
35        &self.output_letter
36    }
37
38    /// Get the children of this node
39    pub fn children(&self) -> &HashMap<Letter, KnowledgeNode> {
40        &self.children
41    }
42
43    pub fn serialize(&self) -> HashMap<String, String> {
44        let mut node = HashMap::new();
45        node.insert("input_letter".to_string(), self.input_letter.symbols());
46        node.insert("output_letter".to_string(), self.output_letter.symbols());
47        let children: Vec<_> = self.children.iter().map(|(_k, v)| v.serialize()).collect();
48        node.insert("children".to_string(), format!("{:?}", children));
49        node
50    }
51
52    pub fn deserialize(
53        dict_data: &HashMap<String, String>,
54        possible_letters: &[Letter],
55    ) -> Result<KnowledgeNode, String> {
56        let input_letter = Letter::deserialize(
57            dict_data
58                .get("input_letter")
59                .ok_or("Missing input_letter")?,
60            possible_letters,
61        )?;
62        let output_letter = Letter::deserialize(
63            dict_data
64                .get("output_letter")
65                .ok_or("Missing output_letter")?,
66            possible_letters,
67        )?;
68        let mut node = KnowledgeNode::new(input_letter, output_letter);
69
70        if let Some(children_str) = dict_data.get("children") {
71            if let Ok(children) = serde_json::from_str::<Vec<HashMap<String, String>>>(children_str)
72            {
73                for child_map in children {
74                    let child_node = KnowledgeNode::deserialize(&child_map, possible_letters)?;
75                    node.children
76                        .insert(child_node.input_letter.clone(), child_node);
77                }
78            }
79        }
80
81        Ok(node)
82    }
83
84    pub fn traverse(
85        &mut self,
86        input_letters: &[Letter],
87        output_letters: Option<&[Letter]>,
88    ) -> Result<Vec<Letter>, String> {
89        if input_letters[0] != self.input_letter {
90            return Err(format!(
91                "Node cannot be traversed with input letter '{}'",
92                input_letters[0]
93            ));
94        }
95        if let Some(output_letters) = output_letters {
96            if output_letters[0] != self.output_letter {
97                return Err(format!(
98                    "Node '{}' cannot be traversed with output letter '{}'",
99                    self.input_letter, output_letters[0]
100                ));
101            }
102            if input_letters.len() != output_letters.len() {
103                return Err(
104                    "Specified input and output letters do not have the same length".to_string(),
105                );
106            }
107        }
108
109        if input_letters.len() < 2 {
110            return Ok(vec![self.output_letter.clone()]);
111        }
112
113        let current_input_letter = &input_letters[1];
114        let current_output_letter = output_letters.map(|ol| &ol[1]);
115
116        if let Some(child) = self.children.get_mut(current_input_letter) {
117            if let Some(current_output) = current_output_letter {
118                if child.output_letter != *current_output {
119                    return Err(format!(
120                        "Incompatible path found, expected '{}' found '{}'",
121                        child.output_letter.symbols(),
122                        current_output.symbols()
123                    ));
124                }
125            }
126
127            let new_output_letters = output_letters.map(|ol| &ol[1..]);
128            let new_input_letters = &input_letters[1..];
129
130            let mut result = vec![self.output_letter.clone()];
131            result.extend(child.traverse(new_input_letters, new_output_letters)?);
132            Ok(result)
133        } else if output_letters.is_some() {
134            let mut new_child =
135                KnowledgeNode::new(input_letters[1].clone(), output_letters.unwrap()[1].clone());
136            let new_input_letters = &input_letters[1..];
137            let new_output_letters = &output_letters.unwrap()[1..];
138
139            let mut result = vec![self.output_letter.clone()];
140            result.extend(new_child.traverse(new_input_letters, Some(new_output_letters))?);
141
142            self.children
143                .insert(new_child.input_letter.clone(), new_child);
144            Ok(result)
145        } else {
146            let letters_str: Vec<String> = input_letters.iter().map(|l| l.to_string()).collect();
147            Err(format!(
148                "Cannot traverse node '{}' with subsequences '{}'",
149                self.input_letter,
150                letters_str.join(", ")
151            ))
152        }
153    }
154}
155
156impl fmt::Display for KnowledgeNode {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        let serialized = self.serialize();
159        match serde_json::to_string_pretty(&serialized) {
160            Ok(json) => write!(f, "{}", json),
161            Err(_) => write!(f, "{:?}", serialized),
162        }
163    }
164}
165
166/// A tree that stores the relationship between input and output words
167#[derive(Clone, Debug)]
168pub struct KnowledgeTree {
169    roots: Vec<KnowledgeNode>,
170    nb_added_words: usize,
171}
172
173impl KnowledgeTree {
174    /// Create a new empty knowledge tree
175    pub fn new() -> Self {
176        KnowledgeTree {
177            roots: Vec::new(),
178            nb_added_words: 0,
179        }
180    }
181
182    /// Get the roots of the tree
183    pub fn roots(&self) -> &Vec<KnowledgeNode> {
184        &self.roots
185    }
186
187    /// Get the number of words added
188    pub fn num_added_words(&self) -> usize {
189        self.nb_added_words
190    }
191
192    /// Get the output word for a given input word
193    ///
194    /// Returns an error if no path exists in the tree for the input.
195    pub fn get_output_word(&mut self, input_word: &Word) -> Result<Word, String> {
196        for root in &mut self.roots {
197            if let Ok(output_letters) = root.traverse(input_word.letters(), None) {
198                return Ok(Word::from_letters(output_letters));
199            }
200        }
201        Err("No path found".to_string())
202    }
203
204    /// Add a word mapping to the tree
205    ///
206    /// Creates or traverses the tree to establish the relationship between
207    /// the input and output words.
208    pub fn add_word(&mut self, input_word: &Word, output_word: &Word) -> Result<(), String> {
209        if input_word.len() != output_word.len() {
210            return Err("Input and output words do not have the same size".to_string());
211        }
212        self.add_letters(input_word.letters(), output_word.letters())?;
213        self.nb_added_words += 1;
214        Ok(())
215    }
216
217    /// Internal method to add letters to the tree
218    fn add_letters(
219        &mut self,
220        input_letters: &[Letter],
221        output_letters: &[Letter],
222    ) -> Result<(), String> {
223        let mut retained_root: Option<&mut KnowledgeNode> = None;
224
225        for root in &mut self.roots {
226            if root.input_letter == input_letters[0] {
227                if root.output_letter != output_letters[0] {
228                    return Err(format!(
229                        "Incompatible path found, expected '{}' found '{}'",
230                        root.output_letter.symbols(),
231                        output_letters[0].symbols()
232                    ));
233                }
234                retained_root = Some(root);
235                break;
236            }
237        }
238
239        let root = if let Some(root) = retained_root {
240            root
241        } else {
242            let new_root = KnowledgeNode::new(input_letters[0].clone(), output_letters[0].clone());
243            self.roots.push(new_root);
244            self.roots.last_mut().unwrap()
245        };
246
247        root.traverse(input_letters, Some(output_letters))?;
248        Ok(())
249    }
250}
251
252impl Default for KnowledgeTree {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_add_and_retrieve_word() {
264        let mut tree = KnowledgeTree::new();
265        let input = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
266        let output = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
267
268        tree.add_word(&input, &output).unwrap();
269
270        let retrieved = tree.get_output_word(&input).unwrap();
271        assert_eq!(retrieved, output);
272    }
273
274    #[test]
275    fn test_multiple_words() {
276        let mut tree = KnowledgeTree::new();
277
278        let input1 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
279        let output1 = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
280
281        let input2 = Word::from_letters(vec![Letter::new("a"), Letter::new("c")]);
282        let output2 = Word::from_letters(vec![Letter::new(1), Letter::new(3)]);
283
284        tree.add_word(&input1, &output1).unwrap();
285        tree.add_word(&input2, &output2).unwrap();
286
287        assert_eq!(tree.get_output_word(&input1).unwrap(), output1);
288        assert_eq!(tree.get_output_word(&input2).unwrap(), output2);
289    }
290
291    #[test]
292    fn test_incompatible_path_error() {
293        let mut tree = KnowledgeTree::new();
294
295        let input1 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
296        let output1 = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
297
298        let input2 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
299        let output2 = Word::from_letters(vec![Letter::new(1), Letter::new(3)]);
300
301        tree.add_word(&input1, &output1).unwrap();
302
303        // This should fail because the path already exists with different output
304        let result = tree.add_word(&input2, &output2);
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_retrieve_nonexistent_word() {
310        let mut tree = KnowledgeTree::new();
311        let input = Word::from_letters(vec![Letter::new("x"), Letter::new("y")]);
312
313        let result = tree.get_output_word(&input);
314        assert!(result.is_err());
315    }
316}