rust_lstar/knowledge_base/
tree.rs1use crate::letter::Letter;
6use crate::word::Word;
7use core::fmt;
8use std::collections::HashMap;
9
10#[derive(Clone, Debug)]
12pub struct KnowledgeNode {
13 input_letter: Letter,
14 output_letter: Letter,
15 children: HashMap<Letter, KnowledgeNode>,
16}
17
18impl KnowledgeNode {
19 pub fn new(input_letter: Letter, output_letter: Letter) -> Self {
21 KnowledgeNode {
22 input_letter,
23 output_letter,
24 children: HashMap::new(),
25 }
26 }
27
28 pub fn input_letter(&self) -> &Letter {
30 &self.input_letter
31 }
32
33 pub fn output_letter(&self) -> &Letter {
35 &self.output_letter
36 }
37
38 pub fn children(&self) -> &HashMap<Letter, KnowledgeNode> {
40 &self.children
41 }
42
43 pub fn serialize(&self) -> HashMap<String, String> {
44 let mut node = HashMap::new();
45 node.insert("input_letter".to_string(), self.input_letter.symbols());
46 node.insert("output_letter".to_string(), self.output_letter.symbols());
47 let children: Vec<_> = self.children.iter().map(|(_k, v)| v.serialize()).collect();
48 node.insert("children".to_string(), format!("{:?}", children));
49 node
50 }
51
52 pub fn deserialize(
53 dict_data: &HashMap<String, String>,
54 possible_letters: &[Letter],
55 ) -> Result<KnowledgeNode, String> {
56 let input_letter = Letter::deserialize(
57 dict_data
58 .get("input_letter")
59 .ok_or("Missing input_letter")?,
60 possible_letters,
61 )?;
62 let output_letter = Letter::deserialize(
63 dict_data
64 .get("output_letter")
65 .ok_or("Missing output_letter")?,
66 possible_letters,
67 )?;
68 let mut node = KnowledgeNode::new(input_letter, output_letter);
69
70 if let Some(children_str) = dict_data.get("children") {
71 if let Ok(children) = serde_json::from_str::<Vec<HashMap<String, String>>>(children_str)
72 {
73 for child_map in children {
74 let child_node = KnowledgeNode::deserialize(&child_map, possible_letters)?;
75 node.children
76 .insert(child_node.input_letter.clone(), child_node);
77 }
78 }
79 }
80
81 Ok(node)
82 }
83
84 pub fn traverse(
85 &mut self,
86 input_letters: &[Letter],
87 output_letters: Option<&[Letter]>,
88 ) -> Result<Vec<Letter>, String> {
89 if input_letters[0] != self.input_letter {
90 return Err(format!(
91 "Node cannot be traversed with input letter '{}'",
92 input_letters[0]
93 ));
94 }
95 if let Some(output_letters) = output_letters {
96 if output_letters[0] != self.output_letter {
97 return Err(format!(
98 "Node '{}' cannot be traversed with output letter '{}'",
99 self.input_letter, output_letters[0]
100 ));
101 }
102 if input_letters.len() != output_letters.len() {
103 return Err(
104 "Specified input and output letters do not have the same length".to_string(),
105 );
106 }
107 }
108
109 if input_letters.len() < 2 {
110 return Ok(vec![self.output_letter.clone()]);
111 }
112
113 let current_input_letter = &input_letters[1];
114 let current_output_letter = output_letters.map(|ol| &ol[1]);
115
116 if let Some(child) = self.children.get_mut(current_input_letter) {
117 if let Some(current_output) = current_output_letter {
118 if child.output_letter != *current_output {
119 return Err(format!(
120 "Incompatible path found, expected '{}' found '{}'",
121 child.output_letter.symbols(),
122 current_output.symbols()
123 ));
124 }
125 }
126
127 let new_output_letters = output_letters.map(|ol| &ol[1..]);
128 let new_input_letters = &input_letters[1..];
129
130 let mut result = vec![self.output_letter.clone()];
131 result.extend(child.traverse(new_input_letters, new_output_letters)?);
132 Ok(result)
133 } else if output_letters.is_some() {
134 let mut new_child =
135 KnowledgeNode::new(input_letters[1].clone(), output_letters.unwrap()[1].clone());
136 let new_input_letters = &input_letters[1..];
137 let new_output_letters = &output_letters.unwrap()[1..];
138
139 let mut result = vec![self.output_letter.clone()];
140 result.extend(new_child.traverse(new_input_letters, Some(new_output_letters))?);
141
142 self.children
143 .insert(new_child.input_letter.clone(), new_child);
144 Ok(result)
145 } else {
146 let letters_str: Vec<String> = input_letters.iter().map(|l| l.to_string()).collect();
147 Err(format!(
148 "Cannot traverse node '{}' with subsequences '{}'",
149 self.input_letter,
150 letters_str.join(", ")
151 ))
152 }
153 }
154}
155
156impl fmt::Display for KnowledgeNode {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 let serialized = self.serialize();
159 match serde_json::to_string_pretty(&serialized) {
160 Ok(json) => write!(f, "{}", json),
161 Err(_) => write!(f, "{:?}", serialized),
162 }
163 }
164}
165
166#[derive(Clone, Debug)]
168pub struct KnowledgeTree {
169 roots: Vec<KnowledgeNode>,
170 nb_added_words: usize,
171}
172
173impl KnowledgeTree {
174 pub fn new() -> Self {
176 KnowledgeTree {
177 roots: Vec::new(),
178 nb_added_words: 0,
179 }
180 }
181
182 pub fn roots(&self) -> &Vec<KnowledgeNode> {
184 &self.roots
185 }
186
187 pub fn num_added_words(&self) -> usize {
189 self.nb_added_words
190 }
191
192 pub fn get_output_word(&mut self, input_word: &Word) -> Result<Word, String> {
196 for root in &mut self.roots {
197 if let Ok(output_letters) = root.traverse(input_word.letters(), None) {
198 return Ok(Word::from_letters(output_letters));
199 }
200 }
201 Err("No path found".to_string())
202 }
203
204 pub fn add_word(&mut self, input_word: &Word, output_word: &Word) -> Result<(), String> {
209 if input_word.len() != output_word.len() {
210 return Err("Input and output words do not have the same size".to_string());
211 }
212 self.add_letters(input_word.letters(), output_word.letters())?;
213 self.nb_added_words += 1;
214 Ok(())
215 }
216
217 fn add_letters(
219 &mut self,
220 input_letters: &[Letter],
221 output_letters: &[Letter],
222 ) -> Result<(), String> {
223 let mut retained_root: Option<&mut KnowledgeNode> = None;
224
225 for root in &mut self.roots {
226 if root.input_letter == input_letters[0] {
227 if root.output_letter != output_letters[0] {
228 return Err(format!(
229 "Incompatible path found, expected '{}' found '{}'",
230 root.output_letter.symbols(),
231 output_letters[0].symbols()
232 ));
233 }
234 retained_root = Some(root);
235 break;
236 }
237 }
238
239 let root = if let Some(root) = retained_root {
240 root
241 } else {
242 let new_root = KnowledgeNode::new(input_letters[0].clone(), output_letters[0].clone());
243 self.roots.push(new_root);
244 self.roots.last_mut().unwrap()
245 };
246
247 root.traverse(input_letters, Some(output_letters))?;
248 Ok(())
249 }
250}
251
252impl Default for KnowledgeTree {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_add_and_retrieve_word() {
264 let mut tree = KnowledgeTree::new();
265 let input = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
266 let output = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
267
268 tree.add_word(&input, &output).unwrap();
269
270 let retrieved = tree.get_output_word(&input).unwrap();
271 assert_eq!(retrieved, output);
272 }
273
274 #[test]
275 fn test_multiple_words() {
276 let mut tree = KnowledgeTree::new();
277
278 let input1 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
279 let output1 = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
280
281 let input2 = Word::from_letters(vec![Letter::new("a"), Letter::new("c")]);
282 let output2 = Word::from_letters(vec![Letter::new(1), Letter::new(3)]);
283
284 tree.add_word(&input1, &output1).unwrap();
285 tree.add_word(&input2, &output2).unwrap();
286
287 assert_eq!(tree.get_output_word(&input1).unwrap(), output1);
288 assert_eq!(tree.get_output_word(&input2).unwrap(), output2);
289 }
290
291 #[test]
292 fn test_incompatible_path_error() {
293 let mut tree = KnowledgeTree::new();
294
295 let input1 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
296 let output1 = Word::from_letters(vec![Letter::new(1), Letter::new(2)]);
297
298 let input2 = Word::from_letters(vec![Letter::new("a"), Letter::new("b")]);
299 let output2 = Word::from_letters(vec![Letter::new(1), Letter::new(3)]);
300
301 tree.add_word(&input1, &output1).unwrap();
302
303 let result = tree.add_word(&input2, &output2);
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_retrieve_nonexistent_word() {
310 let mut tree = KnowledgeTree::new();
311 let input = Word::from_letters(vec![Letter::new("x"), Letter::new("y")]);
312
313 let result = tree.get_output_word(&input);
314 assert!(result.is_err());
315 }
316}