Skip to main content

all_eqtests_custom_kb/
all_eqtests_custom_kb.rs

1//! Compare multiple equivalence-test strategies on the same custom ATM-like
2//! system under learning and print runtime/statistics summaries.
3
4use rust_lstar::eqtest::{BDistMethod, MultipleEqtests, RandomWalkMethod, WMethodEQ};
5use rust_lstar::knowledge_base::{KnowledgeBaseStats, KnowledgeBaseTrait};
6use rust_lstar::query::OutputQuery;
7use rust_lstar::*;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11type SharedKb = Arc<Mutex<dyn KnowledgeBaseTrait>>;
12type EqBuilder = dyn Fn(SharedKb, Vec<Letter>, usize) -> Arc<dyn EquivalenceTest>;
13
14fn main() -> Result<(), Box<dyn std::error::Error>> {
15    println!("=== Custom KB: All Equivalence Tests ===");
16    println!("System under learning: ATM protocol");
17
18    let mut reports = Vec::new();
19
20    let strategies: Vec<(&str, Box<EqBuilder>)> = vec![
21        (
22            "WMethodEQ",
23            Box::new(|kb, input_letters, max_states| {
24                Arc::new(WMethodEQ::new(kb, input_letters, max_states))
25            }),
26        ),
27        (
28            "RandomWalkMethod",
29            Box::new(|kb, input_letters, _| {
30                Arc::new(RandomWalkMethod::new(kb, input_letters, 10_000, 0.75))
31            }),
32        ),
33        (
34            "BDistMethod",
35            Box::new(|kb, input_letters, _| Arc::new(BDistMethod::new(kb, input_letters, 2))),
36        ),
37        (
38            "MultipleEqtests",
39            Box::new(|kb, input_letters, max_states| {
40                let eqtests: Vec<Arc<dyn EquivalenceTest>> = vec![
41                    Arc::new(WMethodEQ::new(
42                        kb.clone(),
43                        input_letters.clone(),
44                        max_states,
45                    )),
46                    Arc::new(RandomWalkMethod::new(
47                        kb.clone(),
48                        input_letters.clone(),
49                        5_000,
50                        0.65,
51                    )),
52                    Arc::new(BDistMethod::new(kb, input_letters, 2)),
53                ];
54                Arc::new(MultipleEqtests::new(eqtests))
55            }),
56        ),
57    ];
58
59    for (name, builder) in strategies {
60        reports.push(run_strategy(name, builder.as_ref()));
61    }
62
63    print_summary(&reports);
64
65    Ok(())
66}
67
68fn run_strategy(name: &str, builder: &EqBuilder) -> StrategyReport {
69    let vocabulary = vec![
70        "INSERT_CARD".to_string(),
71        "ENTER_PIN".to_string(),
72        "REQUEST_WITHDRAW".to_string(),
73        "EJECT_CARD".to_string(),
74        "TIMEOUT".to_string(),
75    ];
76    let input_letters = vocabulary
77        .iter()
78        .map(|s| Letter::new(s))
79        .collect::<Vec<_>>();
80    let max_states = 8;
81
82    let kb = Arc::new(Mutex::new(ATMKnowledgeBase::new()));
83    let knowledge_base: SharedKb = kb.clone();
84    let eqtest = builder(knowledge_base.clone(), input_letters, max_states);
85    let mut learner = LSTAR::new(vocabulary, knowledge_base, max_states, None, Some(eqtest));
86
87    println!("\n--- Running {name} ---");
88    let started = Instant::now();
89    let learn_result = learner.learn();
90    let elapsed_ms = started.elapsed().as_millis();
91
92    let stats = {
93        let kb_guard = kb.lock().unwrap();
94        StatsSnapshot::from_stats(&kb_guard.stats)
95    };
96
97    match learn_result {
98        Ok(automata) => {
99            let state_count = automata.get_states().len();
100            let transition_count = automata.transitions.len();
101            println!("{name}: success (states={state_count}, transitions={transition_count})");
102            StrategyReport {
103                name: name.to_string(),
104                elapsed_ms,
105                state_count: Some(state_count),
106                transition_count: Some(transition_count),
107                error: None,
108                stats,
109            }
110        }
111        Err(err) => {
112            println!("{name}: failed ({err})");
113            StrategyReport {
114                name: name.to_string(),
115                elapsed_ms,
116                state_count: None,
117                transition_count: None,
118                error: Some(err),
119                stats,
120            }
121        }
122    }
123}
124
125fn print_summary(reports: &[StrategyReport]) {
126    println!("\n=== Final Statistics Summary ===");
127    for report in reports {
128        println!("\n[{}]", report.name);
129        println!("  runtime_ms: {}", report.elapsed_ms);
130        match (&report.state_count, &report.transition_count) {
131            (Some(states), Some(transitions)) => {
132                println!("  states: {}", states);
133                println!("  transitions: {}", transitions);
134            }
135            _ => println!("  model: n/a"),
136        }
137        match &report.error {
138            Some(err) => println!("  status: failed ({err})"),
139            None => println!("  status: success"),
140        }
141        println!("  kb_nb_query: {}", report.stats.nb_query);
142        println!(
143            "  kb_nb_submitted_query: {}",
144            report.stats.nb_submitted_query
145        );
146        println!("  kb_nb_letter: {}", report.stats.nb_letter);
147        println!(
148            "  kb_nb_submitted_letter: {}",
149            report.stats.nb_submitted_letter
150        );
151    }
152}
153
154/// Minimal ATM simulator used as a custom knowledge base.
155struct ATMKnowledgeBase {
156    state: ATMState,
157    stats: KnowledgeBaseStats,
158}
159
160#[derive(Clone, Copy)]
161enum ATMState {
162    Idle,
163    CardInserted,
164    Authenticated,
165    Ready,
166    Dispensing,
167}
168
169impl ATMKnowledgeBase {
170    fn new() -> Self {
171        Self {
172            state: ATMState::Idle,
173            stats: KnowledgeBaseStats::new(),
174        }
175    }
176
177    fn process_input(&self, command: &str, current_state: ATMState) -> (ATMState, &'static str) {
178        match current_state {
179            ATMState::Idle => match command {
180                "INSERT_CARD" => (ATMState::CardInserted, "CARD_ACCEPTED"),
181                _ => (ATMState::Idle, "INVALID_OP"),
182            },
183            ATMState::CardInserted => match command {
184                "ENTER_PIN" => (ATMState::Authenticated, "PIN_VERIFIED"),
185                "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
186                _ => (ATMState::CardInserted, "RETRY"),
187            },
188            ATMState::Authenticated => match command {
189                "REQUEST_WITHDRAW" => (ATMState::Ready, "ENTER_AMOUNT"),
190                "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
191                "TIMEOUT" => (ATMState::Idle, "SESSION_TIMEOUT"),
192                _ => (ATMState::Authenticated, "INVALID_COMMAND"),
193            },
194            ATMState::Ready => match command {
195                "REQUEST_WITHDRAW" => (ATMState::Dispensing, "DISPENSING"),
196                "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
197                _ => (ATMState::Ready, "WAIT"),
198            },
199            ATMState::Dispensing => match command {
200                "EJECT_CARD" => (ATMState::Idle, "CARD_EJECTED"),
201                _ => (ATMState::Dispensing, "DISPENSING"),
202            },
203        }
204    }
205}
206
207impl KnowledgeBaseTrait for ATMKnowledgeBase {
208    fn resolve_query(&mut self, query: &mut OutputQuery) -> Result<(), String> {
209        self.stats.increment_nb_query();
210        self.stats.add_nb_letter(query.input_word.len());
211        self.stats.increment_nb_submitted_query();
212        self.stats.add_nb_submitted_letter(query.input_word.len());
213
214        self.state = ATMState::Idle;
215        let mut outputs = Vec::new();
216
217        for input_letter in query.input_word.letters() {
218            let command = input_letter.symbols();
219            let (next_state, response) = self.process_input(command.as_str(), self.state);
220            self.state = next_state;
221            outputs.push(Letter::new(response));
222        }
223
224        query.set_result(Word::from_letters(outputs));
225        Ok(())
226    }
227
228    fn add_word(&mut self, _input: &Word, _output: &Word) -> Result<(), String> {
229        Ok(())
230    }
231}
232
233/// Snapshot of key knowledge-base counters for final reporting.
234struct StatsSnapshot {
235    nb_query: usize,
236    nb_submitted_query: usize,
237    nb_letter: usize,
238    nb_submitted_letter: usize,
239}
240
241impl StatsSnapshot {
242    fn from_stats(stats: &KnowledgeBaseStats) -> Self {
243        Self {
244            nb_query: stats.nb_query(),
245            nb_submitted_query: stats.nb_submitted_query(),
246            nb_letter: stats.nb_letter(),
247            nb_submitted_letter: stats.nb_submitted_letter(),
248        }
249    }
250}
251
252/// Summary row for one equivalence-test strategy execution.
253struct StrategyReport {
254    name: String,
255    elapsed_ms: u128,
256    state_count: Option<usize>,
257    transition_count: Option<usize>,
258    error: Option<String>,
259    stats: StatsSnapshot,
260}