Skip to main content

random_walk_eq/
random_walk_eq.rs

1/// Example: Using Random Walk Equivalence Test
2use rust_lstar::eqtest::RandomWalkMethod;
3use rust_lstar::knowledge_base::{KnowledgeBaseStats, KnowledgeBaseTrait};
4use rust_lstar::query::OutputQuery;
5use rust_lstar::*;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9fn main() -> Result<(), Box<dyn std::error::Error>> {
10    println!("=== Random Walk Equivalence Test Example ===\n");
11
12    // Create a simple system
13    let mut kb = DemoKnowledgeBase::new();
14
15    // Define a simple state machine:
16    // S0 -a/0-> S1   (initial state)
17    // S0 -b/1-> S0
18    // S1 -a/0-> S2
19    // S1 -b/1-> S0
20    // S2 -a/0-> S2
21    // S2 -b/1-> S1
22
23    // Make outputs state-dependent so learner must distinguish states
24    kb.add_transition("s0", "a", "0", "s1");
25    kb.add_transition("s0", "b", "1", "s0");
26    kb.add_transition("s1", "a", "1", "s2");
27    kb.add_transition("s1", "b", "0", "s0");
28    kb.add_transition("s2", "a", "1", "s2");
29    kb.add_transition("s2", "b", "0", "s1");
30
31    let knowledge_base = Arc::new(Mutex::new(kb));
32
33    // Create vocabulary
34    let vocabulary = vec!["a".to_string(), "b".to_string()];
35
36    // Create learner
37    let mut lstar = LSTAR::new(vocabulary.clone(), knowledge_base.clone(), 5, None, None);
38
39    // Use random walk equivalence test instead of W-method
40    let input_letters = vocabulary
41        .iter()
42        .map(|s| Letter::new(s))
43        .collect::<Vec<_>>();
44    let random_walk = RandomWalkMethod::new(
45        knowledge_base.clone(),
46        input_letters,
47        10000, // max_steps: 10000 steps
48        0.75,  // restart_probability: 75%
49    );
50
51    lstar = lstar.with_equivalence_test(Arc::new(random_walk));
52
53    // Run learning
54    match lstar.learn() {
55        Ok(automata) => {
56            println!("\n=== Learned Automaton (Random Walk Test) ===\n");
57            println!("{}", automata.build_dot_code());
58            println!("\nLearning completed successfully!");
59
60            // Print statistics
61            println!("\nStatistics:");
62            println!("  Number of states: {}", automata.get_states().len());
63            println!("  Number of transitions: {}", automata.transitions.len());
64        }
65        Err(e) => eprintln!("Error during learning: {}", e),
66    }
67
68    let kb_guard = knowledge_base.lock().unwrap();
69    println!("\nKnowledge Base Statistics:\n{}", kb_guard.stats);
70
71    Ok(())
72}
73
74struct DemoKnowledgeBase {
75    transitions: HashMap<(String, String), (String, String)>,
76    current_state: String,
77    stats: KnowledgeBaseStats,
78}
79
80impl DemoKnowledgeBase {
81    fn new() -> Self {
82        Self {
83            transitions: std::collections::HashMap::new(),
84            current_state: "s0".to_string(),
85            stats: KnowledgeBaseStats::new(),
86        }
87    }
88
89    fn add_transition(&mut self, from_state: &str, input: &str, output: &str, to_state: &str) {
90        self.transitions.insert(
91            (from_state.to_string(), input.to_string()),
92            (output.to_string(), to_state.to_string()),
93        );
94    }
95}
96
97impl KnowledgeBaseTrait for DemoKnowledgeBase {
98    fn resolve_query(&mut self, query: &mut OutputQuery) -> Result<(), String> {
99        self.stats.increment_nb_query();
100        self.stats.add_nb_letter(query.input_word.len());
101        self.stats.increment_nb_submitted_query();
102        self.stats.add_nb_submitted_letter(query.input_word.len());
103
104        self.current_state = "s0".to_string();
105        let mut output_letters = Vec::new();
106
107        for input_letter in query.input_word.letters() {
108            let input = input_letter.symbols();
109            let key = (self.current_state.clone(), input);
110            let (output, next_state) = self
111                .transitions
112                .get(&key)
113                .cloned()
114                .ok_or_else(|| format!("No transition for ({}, {})", key.0, key.1))?;
115
116            output_letters.push(Letter::new(output));
117            self.current_state = next_state;
118        }
119
120        query.set_result(Word::from_letters(output_letters));
121        Ok(())
122    }
123
124    fn add_word(&mut self, _input_word: &Word, _output_word: &Word) -> Result<(), String> {
125        Ok(())
126    }
127}