Skip to main content

rust_lstar/
lstar.rs

1use crate::automata::Automata;
2use crate::eqtest::{Counterexample, EquivalenceTest, WMethodEQ};
3use crate::knowledge_base::KnowledgeBaseTrait;
4use crate::letter::Letter;
5use crate::observation_table::ObservationTable;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9/// The L* algorithm implementation
10pub struct LSTAR {
11    pub input_vocabulary: Vec<Letter>,
12    pub knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
13    pub tmp_dir: Option<String>,
14    pub observation_table: Option<ObservationTable>,
15    pub max_states: usize,
16    pub equivalence_test: Arc<dyn EquivalenceTest>,
17    stop_flag: bool,
18}
19
20impl LSTAR {
21    /// Create a new LSTAR learner
22    pub fn new(
23        input_vocabulary: Vec<String>,
24        knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
25        max_states: usize,
26        tmp_dir: Option<String>,
27        eq_test: Option<Arc<dyn EquivalenceTest>>,
28    ) -> Self {
29        let input_letters = input_vocabulary
30            .into_iter()
31            .map(|s| Letter::new(s))
32            .collect::<Vec<_>>();
33
34        let observation_table =
35            ObservationTable::new(input_letters.clone(), knowledge_base.clone());
36
37        let equivalence_test = eq_test.unwrap_or_else(|| {
38            Arc::new(WMethodEQ::new(
39                knowledge_base.clone(),
40                input_letters.clone(),
41                max_states,
42            ))
43        });
44
45        LSTAR {
46            input_vocabulary: input_letters,
47            knowledge_base,
48            tmp_dir,
49            observation_table: Some(observation_table),
50            max_states,
51            equivalence_test,
52            stop_flag: false,
53        }
54    }
55
56    /// Set a custom equivalence test
57    pub fn with_equivalence_test(mut self, eq_test: Arc<dyn EquivalenceTest>) -> Self {
58        self.equivalence_test = eq_test;
59        self
60    }
61
62    /// Stop the learning process
63    pub fn stop(&mut self) {
64        self.stop_flag = true;
65    }
66
67    /// Run the L* algorithm
68    pub fn learn(&mut self) -> Result<Automata, String> {
69        let start_time = Instant::now();
70
71        println!("Starting L* learning process");
72        println!("Input alphabet size: {}", self.input_vocabulary.len());
73        println!("Max states: {}", self.max_states);
74
75        self.initialize()?;
76
77        let mut round = 1;
78        let mut hypothesis_valid = false;
79        let mut hypothesis: Option<Automata> = None;
80
81        while !hypothesis_valid && !self.stop_flag {
82            println!("\n--- Round {} ---", round);
83
84            // Build hypothesis
85            hypothesis = Some(self.build_hypothesis(round)?);
86            println!("Hypothesis built");
87
88            let _ = self.serialize_hypothesis(round, hypothesis.as_ref().unwrap());
89
90            let counter_example = self
91                .equivalence_test
92                .find_counterexample(hypothesis.as_mut().unwrap());
93            if let Some(ce) = counter_example {
94                println!("Counterexample found: {:?}", ce);
95                self.fix_hypothesis(ce)?;
96            } else {
97                println!("No counterexample found, hypothesis is correct!");
98                hypothesis_valid = true;
99            }
100
101            round += 1;
102        }
103
104        self.serialize_observation_table(round)?;
105        let duration = start_time.elapsed();
106        println!("\nLearning completed in {:.2?}", duration);
107
108        if let Some(hyp) = hypothesis {
109            Ok(hyp)
110        } else {
111            Err("Failed to build a valid hypothesis".into())
112        }
113    }
114
115    fn serialize_hypothesis(&self, round: usize, hypothesis: &Automata) -> Result<(), String> {
116        let dot_code = hypothesis.build_dot_code();
117        let filepath = if let Some(tmp_dir) = &self.tmp_dir {
118            format!("{}/hypothesis_round_{}.dot", tmp_dir, round)
119        } else {
120            format!("hypothesis_round_{}.dot", round)
121        };
122
123        if let Some(parent) = std::path::Path::new(&filepath).parent() {
124            std::fs::create_dir_all(parent)
125                .map_err(|e| format!("Failed to create directory: {}", e))?;
126        }
127        std::fs::write(&filepath, dot_code)
128            .map_err(|e| format!("Failed to write hypothesis DOT file: {}", e))?;
129        println!("Hypothesis for round {} serialized to {}", round, filepath);
130        Ok(())
131    }
132
133    fn serialize_observation_table(&self, round: usize) -> Result<(), String> {
134        let serialized_table = if let Some(ot) = &self.observation_table {
135            ot.serialize()
136        } else {
137            return Err("No observation table available".into());
138        };
139        let str_date = chrono::Local::now().format("%Y%m%d_%H%M%S").to_string();
140        let filepath = if let Some(tmp_dir) = &self.tmp_dir {
141            format!(
142                "{}/observation_table_round_{}_{}.raw",
143                tmp_dir, round, str_date
144            )
145        } else {
146            format!("tmp/observation_table_round_{}_{}.raw", round, str_date)
147        };
148
149        if let Some(parent) = std::path::Path::new(&filepath).parent() {
150            std::fs::create_dir_all(parent)
151                .map_err(|e| format!("Failed to create directory: {}", e))?;
152        }
153        std::fs::write(&filepath, serialized_table)
154            .map_err(|e| format!("Failed to write observation table file: {}", e))?;
155        println!(
156            "Observation table for round {} serialized to {}",
157            round, filepath
158        );
159        Ok(())
160    }
161
162    fn fix_hypothesis(&mut self, counter_example: Counterexample) -> Result<(), String> {
163        println!(
164            "Refining observation table with counterexample: {:?}",
165            counter_example
166        );
167
168        let input_word = counter_example.input_word;
169        let output_word = counter_example.output_word;
170        if let Some(ot) = &mut self.observation_table {
171            ot.add_counterexample(&input_word, &output_word)?;
172        } else {
173            return Err("No observation table available".into());
174        }
175        Ok(())
176    }
177
178    /// Build a hypothesis from the current observation table
179    fn build_hypothesis(&mut self, round: usize) -> Result<Automata, String> {
180        let mut f_consistent = false;
181        let mut f_closed = false;
182        while !f_consistent || !f_closed {
183            if let Some(ot) = &mut self.observation_table {
184                // Make table closed
185                if !ot.is_closed() {
186                    println!("  Closing table...");
187                    ot.close_table()?;
188
189                    f_closed = false;
190                } else {
191                    println!("  Table is closed");
192                    f_closed = true;
193                }
194
195                // Check consistency
196                if let Some(inconsistency) = ot.find_inconsistency() {
197                    // println!("  Making table consistent...");
198                    ot.make_consistent(inconsistency)?;
199                    f_consistent = false;
200                } else {
201                    f_consistent = true;
202                }
203            } else {
204                return Err("No observation table available".into());
205            }
206        }
207        self.serialize_observation_table(round)?;
208
209        if let Some(ot) = &mut self.observation_table {
210            ot.build_hypothesis()
211        } else {
212            Err("No observation table available".into())
213        }
214    }
215
216    fn initialize(&mut self) -> Result<(), String> {
217        if let Some(ot) = &mut self.observation_table {
218            ot.initialize()?;
219            Ok(())
220        } else {
221            Err("Failed to initialize observation table".into())
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::knowledge_base::KnowledgeBase;
230    use std::sync::{Arc, Mutex};
231
232    #[test]
233    fn test_lstar_creation() {
234        let kb: Arc<Mutex<dyn crate::knowledge_base::KnowledgeBaseTrait>> =
235            Arc::new(Mutex::new(KnowledgeBase::new()));
236        let vocabulary = vec!["a".to_string(), "b".to_string()];
237        let lstar = LSTAR::new(vocabulary, kb, 5, None, None);
238        assert_eq!(lstar.input_vocabulary.len(), 2);
239    }
240}