1use rust_lstar::eqtest::{BDistMethod, MultipleEqtests, RandomWalkMethod, WMethodEQ};
5use rust_lstar::knowledge_base::{KnowledgeBaseStats, KnowledgeBaseTrait};
6use rust_lstar::query::OutputQuery;
7use rust_lstar::*;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11type SharedKb = Arc<Mutex<dyn KnowledgeBaseTrait>>;
12type EqBuilder = dyn Fn(SharedKb, Vec<Letter>, usize) -> Arc<dyn EquivalenceTest>;
13
14fn main() -> Result<(), Box<dyn std::error::Error>> {
15 println!("=== Custom KB: All Equivalence Tests ===");
16 println!("System under learning: ATM protocol");
17
18 let mut reports = Vec::new();
19
20 let strategies: Vec<(&str, Box<EqBuilder>)> = vec![
21 (
22 "WMethodEQ",
23 Box::new(|kb, input_letters, max_states| {
24 Arc::new(WMethodEQ::new(kb, input_letters, max_states))
25 }),
26 ),
27 (
28 "RandomWalkMethod",
29 Box::new(|kb, input_letters, _| {
30 Arc::new(RandomWalkMethod::new(kb, input_letters, 10_000, 0.75))
31 }),
32 ),
33 (
34 "BDistMethod",
35 Box::new(|kb, input_letters, _| Arc::new(BDistMethod::new(kb, input_letters, 2))),
36 ),
37 (
38 "MultipleEqtests",
39 Box::new(|kb, input_letters, max_states| {
40 let eqtests: Vec<Arc<dyn EquivalenceTest>> = vec![
41 Arc::new(WMethodEQ::new(
42 kb.clone(),
43 input_letters.clone(),
44 max_states,
45 )),
46 Arc::new(RandomWalkMethod::new(
47 kb.clone(),
48 input_letters.clone(),
49 5_000,
50 0.65,
51 )),
52 Arc::new(BDistMethod::new(kb, input_letters, 2)),
53 ];
54 Arc::new(MultipleEqtests::new(eqtests))
55 }),
56 ),
57 ];
58
59 for (name, builder) in strategies {
60 reports.push(run_strategy(name, builder.as_ref()));
61 }
62
63 print_summary(&reports);
64
65 Ok(())
66}
67
68fn run_strategy(name: &str, builder: &EqBuilder) -> StrategyReport {
69 let vocabulary = vec![
70 "INSERT_CARD".to_string(),
71 "ENTER_PIN".to_string(),
72 "REQUEST_WITHDRAW".to_string(),
73 "EJECT_CARD".to_string(),
74 "TIMEOUT".to_string(),
75 ];
76 let input_letters = vocabulary
77 .iter()
78 .map(|s| Letter::new(s))
79 .collect::<Vec<_>>();
80 let max_states = 8;
81
82 let kb = Arc::new(Mutex::new(ATMKnowledgeBase::new()));
83 let knowledge_base: SharedKb = kb.clone();
84 let eqtest = builder(knowledge_base.clone(), input_letters, max_states);
85 let mut learner = LSTAR::new(vocabulary, knowledge_base, max_states, None, Some(eqtest));
86
87 println!("\n--- Running {name} ---");
88 let started = Instant::now();
89 let learn_result = learner.learn();
90 let elapsed_ms = started.elapsed().as_millis();
91
92 let stats = {
93 let kb_guard = kb.lock().unwrap();
94 StatsSnapshot::from_stats(&kb_guard.stats)
95 };
96
97 match learn_result {
98 Ok(automata) => {
99 let state_count = automata.get_states().len();
100 let transition_count = automata.transitions.len();
101 println!("{name}: success (states={state_count}, transitions={transition_count})");
102 StrategyReport {
103 name: name.to_string(),
104 elapsed_ms,
105 state_count: Some(state_count),
106 transition_count: Some(transition_count),
107 error: None,
108 stats,
109 }
110 }
111 Err(err) => {
112 println!("{name}: failed ({err})");
113 StrategyReport {
114 name: name.to_string(),
115 elapsed_ms,
116 state_count: None,
117 transition_count: None,
118 error: Some(err),
119 stats,
120 }
121 }
122 }
123}
124
125fn print_summary(reports: &[StrategyReport]) {
126 println!("\n=== Final Statistics Summary ===");
127 for report in reports {
128 println!("\n[{}]", report.name);
129 println!(" runtime_ms: {}", report.elapsed_ms);
130 match (&report.state_count, &report.transition_count) {
131 (Some(states), Some(transitions)) => {
132 println!(" states: {}", states);
133 println!(" transitions: {}", transitions);
134 }
135 _ => println!(" model: n/a"),
136 }
137 match &report.error {
138 Some(err) => println!(" status: failed ({err})"),
139 None => println!(" status: success"),
140 }
141 println!(" kb_nb_query: {}", report.stats.nb_query);
142 println!(
143 " kb_nb_submitted_query: {}",
144 report.stats.nb_submitted_query
145 );
146 println!(" kb_nb_letter: {}", report.stats.nb_letter);
147 println!(
148 " kb_nb_submitted_letter: {}",
149 report.stats.nb_submitted_letter
150 );
151 }
152}
153
154struct ATMKnowledgeBase {
156 state: ATMState,
157 stats: KnowledgeBaseStats,
158}
159
160#[derive(Clone, Copy)]
161enum ATMState {
162 Idle,
163 CardInserted,
164 Authenticated,
165 Ready,
166 Dispensing,
167}
168
169impl ATMKnowledgeBase {
170 fn new() -> Self {
171 Self {
172 state: ATMState::Idle,
173 stats: KnowledgeBaseStats::new(),
174 }
175 }
176
177 fn process_input(&self, command: &str, current_state: ATMState) -> (ATMState, &'static str) {
178 match current_state {
179 ATMState::Idle => match command {
180 "INSERT_CARD" => (ATMState::CardInserted, "CARD_ACCEPTED"),
181 _ => (ATMState::Idle, "INVALID_OP"),
182 },
183 ATMState::CardInserted => match command {
184 "ENTER_PIN" => (ATMState::Authenticated, "PIN_VERIFIED"),
185 "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
186 _ => (ATMState::CardInserted, "RETRY"),
187 },
188 ATMState::Authenticated => match command {
189 "REQUEST_WITHDRAW" => (ATMState::Ready, "ENTER_AMOUNT"),
190 "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
191 "TIMEOUT" => (ATMState::Idle, "SESSION_TIMEOUT"),
192 _ => (ATMState::Authenticated, "INVALID_COMMAND"),
193 },
194 ATMState::Ready => match command {
195 "REQUEST_WITHDRAW" => (ATMState::Dispensing, "DISPENSING"),
196 "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
197 _ => (ATMState::Ready, "WAIT"),
198 },
199 ATMState::Dispensing => match command {
200 "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
201 _ => (ATMState::Dispensing, "DISPENSING"),
202 },
203 }
204 }
205}
206
207impl KnowledgeBaseTrait for ATMKnowledgeBase {
208 fn resolve_query(&mut self, query: &mut OutputQuery) -> Result<(), String> {
209 self.stats.increment_nb_query();
210 self.stats.add_nb_letter(query.input_word.len());
211 self.stats.increment_nb_submitted_query();
212 self.stats.add_nb_submitted_letter(query.input_word.len());
213
214 self.state = ATMState::Idle;
215 let mut outputs = Vec::new();
216
217 for input_letter in query.input_word.letters() {
218 let command = input_letter.symbols();
219 let (next_state, response) = self.process_input(command.as_str(), self.state);
220 self.state = next_state;
221 outputs.push(Letter::new(response));
222 }
223
224 query.set_result(Word::from_letters(outputs));
225 Ok(())
226 }
227
228 fn add_word(&mut self, _input: &Word, _output: &Word) -> Result<(), String> {
229 Ok(())
230 }
231}
232
233struct StatsSnapshot {
235 nb_query: usize,
236 nb_submitted_query: usize,
237 nb_letter: usize,
238 nb_submitted_letter: usize,
239}
240
241impl StatsSnapshot {
242 fn from_stats(stats: &KnowledgeBaseStats) -> Self {
243 Self {
244 nb_query: stats.nb_query(),
245 nb_submitted_query: stats.nb_submitted_query(),
246 nb_letter: stats.nb_letter(),
247 nb_submitted_letter: stats.nb_submitted_letter(),
248 }
249 }
250}
251
252struct StrategyReport {
254 name: String,
255 elapsed_ms: u128,
256 state_count: Option<usize>,
257 transition_count: Option<usize>,
258 error: Option<String>,
259 stats: StatsSnapshot,
260}