1use crate::automata::Automata;
2use crate::eqtest::{Counterexample, EquivalenceTest, WMethodEQ};
3use crate::knowledge_base::KnowledgeBaseTrait;
4use crate::letter::Letter;
5use crate::observation_table::ObservationTable;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9pub struct LSTAR {
11 pub input_vocabulary: Vec<Letter>,
12 pub knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
13 pub tmp_dir: Option<String>,
14 pub observation_table: Option<ObservationTable>,
15 pub max_states: usize,
16 pub equivalence_test: Arc<dyn EquivalenceTest>,
17 stop_flag: bool,
18}
19
20impl LSTAR {
21 pub fn new(
23 input_vocabulary: Vec<String>,
24 knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
25 max_states: usize,
26 tmp_dir: Option<String>,
27 eq_test: Option<Arc<dyn EquivalenceTest>>,
28 ) -> Self {
29 let input_letters = input_vocabulary
30 .into_iter()
31 .map(|s| Letter::new(s))
32 .collect::<Vec<_>>();
33
34 let observation_table =
35 ObservationTable::new(input_letters.clone(), knowledge_base.clone());
36
37 let equivalence_test = eq_test.unwrap_or_else(|| {
38 Arc::new(WMethodEQ::new(
39 knowledge_base.clone(),
40 input_letters.clone(),
41 max_states,
42 ))
43 });
44
45 LSTAR {
46 input_vocabulary: input_letters,
47 knowledge_base,
48 tmp_dir,
49 observation_table: Some(observation_table),
50 max_states,
51 equivalence_test,
52 stop_flag: false,
53 }
54 }
55
56 pub fn with_equivalence_test(mut self, eq_test: Arc<dyn EquivalenceTest>) -> Self {
58 self.equivalence_test = eq_test;
59 self
60 }
61
62 pub fn stop(&mut self) {
64 self.stop_flag = true;
65 }
66
67 pub fn learn(&mut self) -> Result<Automata, String> {
69 let start_time = Instant::now();
70
71 println!("Starting L* learning process");
72 println!("Input alphabet size: {}", self.input_vocabulary.len());
73 println!("Max states: {}", self.max_states);
74
75 self.initialize()?;
76
77 let mut round = 1;
78 let mut hypothesis_valid = false;
79 let mut hypothesis: Option<Automata> = None;
80
81 while !hypothesis_valid && !self.stop_flag {
82 println!("\n--- Round {} ---", round);
83
84 hypothesis = Some(self.build_hypothesis(round)?);
86 println!("Hypothesis built");
87
88 let _ = self.serialize_hypothesis(round, hypothesis.as_ref().unwrap());
89
90 let counter_example = self
91 .equivalence_test
92 .find_counterexample(hypothesis.as_mut().unwrap());
93 if let Some(ce) = counter_example {
94 println!("Counterexample found: {:?}", ce);
95 self.fix_hypothesis(ce)?;
96 } else {
97 println!("No counterexample found, hypothesis is correct!");
98 hypothesis_valid = true;
99 }
100
101 round += 1;
102 }
103
104 self.serialize_observation_table(round)?;
105 let duration = start_time.elapsed();
106 println!("\nLearning completed in {:.2?}", duration);
107
108 if let Some(hyp) = hypothesis {
109 Ok(hyp)
110 } else {
111 Err("Failed to build a valid hypothesis".into())
112 }
113 }
114
115 fn serialize_hypothesis(&self, round: usize, hypothesis: &Automata) -> Result<(), String> {
116 let dot_code = hypothesis.build_dot_code();
117 let filepath = if let Some(tmp_dir) = &self.tmp_dir {
118 format!("{}/hypothesis_round_{}.dot", tmp_dir, round)
119 } else {
120 format!("hypothesis_round_{}.dot", round)
121 };
122
123 if let Some(parent) = std::path::Path::new(&filepath).parent() {
124 std::fs::create_dir_all(parent)
125 .map_err(|e| format!("Failed to create directory: {}", e))?;
126 }
127 std::fs::write(&filepath, dot_code)
128 .map_err(|e| format!("Failed to write hypothesis DOT file: {}", e))?;
129 println!("Hypothesis for round {} serialized to {}", round, filepath);
130 Ok(())
131 }
132
133 fn serialize_observation_table(&self, round: usize) -> Result<(), String> {
134 let serialized_table = if let Some(ot) = &self.observation_table {
135 ot.serialize()
136 } else {
137 return Err("No observation table available".into());
138 };
139 let str_date = chrono::Local::now().format("%Y%m%d_%H%M%S").to_string();
140 let filepath = if let Some(tmp_dir) = &self.tmp_dir {
141 format!(
142 "{}/observation_table_round_{}_{}.raw",
143 tmp_dir, round, str_date
144 )
145 } else {
146 format!("tmp/observation_table_round_{}_{}.raw", round, str_date)
147 };
148
149 if let Some(parent) = std::path::Path::new(&filepath).parent() {
150 std::fs::create_dir_all(parent)
151 .map_err(|e| format!("Failed to create directory: {}", e))?;
152 }
153 std::fs::write(&filepath, serialized_table)
154 .map_err(|e| format!("Failed to write observation table file: {}", e))?;
155 println!(
156 "Observation table for round {} serialized to {}",
157 round, filepath
158 );
159 Ok(())
160 }
161
162 fn fix_hypothesis(&mut self, counter_example: Counterexample) -> Result<(), String> {
163 println!(
164 "Refining observation table with counterexample: {:?}",
165 counter_example
166 );
167
168 let input_word = counter_example.input_word;
169 let output_word = counter_example.output_word;
170 if let Some(ot) = &mut self.observation_table {
171 ot.add_counterexample(&input_word, &output_word)?;
172 } else {
173 return Err("No observation table available".into());
174 }
175 Ok(())
176 }
177
178 fn build_hypothesis(&mut self, round: usize) -> Result<Automata, String> {
180 let mut f_consistent = false;
181 let mut f_closed = false;
182 while !f_consistent || !f_closed {
183 if let Some(ot) = &mut self.observation_table {
184 if !ot.is_closed() {
186 println!(" Closing table...");
187 ot.close_table()?;
188
189 f_closed = false;
190 } else {
191 println!(" Table is closed");
192 f_closed = true;
193 }
194
195 if let Some(inconsistency) = ot.find_inconsistency() {
197 ot.make_consistent(inconsistency)?;
199 f_consistent = false;
200 } else {
201 f_consistent = true;
202 }
203 } else {
204 return Err("No observation table available".into());
205 }
206 }
207 self.serialize_observation_table(round)?;
208
209 if let Some(ot) = &mut self.observation_table {
210 ot.build_hypothesis()
211 } else {
212 Err("No observation table available".into())
213 }
214 }
215
216 fn initialize(&mut self) -> Result<(), String> {
217 if let Some(ot) = &mut self.observation_table {
218 ot.initialize()?;
219 Ok(())
220 } else {
221 Err("Failed to initialize observation table".into())
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::knowledge_base::KnowledgeBase;
230 use std::sync::{Arc, Mutex};
231
232 #[test]
233 fn test_lstar_creation() {
234 let kb: Arc<Mutex<dyn crate::knowledge_base::KnowledgeBaseTrait>> =
235 Arc::new(Mutex::new(KnowledgeBase::new()));
236 let vocabulary = vec!["a".to_string(), "b".to_string()];
237 let lstar = LSTAR::new(vocabulary, kb, 5, None, None);
238 assert_eq!(lstar.input_vocabulary.len(), 2);
239 }
240}