Skip to main content

rust_lstar/eqtest/
random_walk.rs

1//! Random-walk equivalence-test implementation.
2//!
3//! This strategy samples hypothesis paths and compares observed outputs against
4//! the system under learning.
5
6use crate::automata::{Automata, State};
7use crate::eqtest::{Counterexample, EquivalenceTest};
8use crate::knowledge_base::KnowledgeBaseTrait;
9use crate::letter::Letter;
10use crate::query::OutputQuery;
11use crate::word::Word;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16/// Random-walk based equivalence-test strategy.
17pub struct RandomWalkMethod {
18    knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
19    #[allow(dead_code)]
20    input_letters: Vec<Letter>,
21    max_steps: usize,
22    restart_probability: f64,
23}
24
25struct SimpleRng {
26    state: u64,
27}
28
29impl SimpleRng {
30    fn new() -> Self {
31        let seed = SystemTime::now()
32            .duration_since(UNIX_EPOCH)
33            .unwrap_or_default()
34            .as_nanos() as u64;
35        SimpleRng { state: seed }
36    }
37
38    fn next(&mut self) -> u64 {
39        self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
40        self.state
41    }
42
43    fn gen_f64(&mut self) -> f64 {
44        (self.next() as f64) / (u64::MAX as f64)
45    }
46
47    fn gen_range(&mut self, min: usize, max: usize) -> usize {
48        if max <= min {
49            return min;
50        }
51        min + (self.next() as usize) % (max - min)
52    }
53}
54
55impl RandomWalkMethod {
56    /// Create a new random-walk equivalence tester.
57    ///
58    /// `max_steps` bounds the number of transitions explored, and
59    /// `restart_probability` controls how often the walk restarts at the
60    /// initial state.
61    pub fn new(
62        knowledge_base: Arc<Mutex<dyn KnowledgeBaseTrait>>,
63        input_letters: Vec<Letter>,
64        max_steps: usize,
65        restart_probability: f64,
66    ) -> Self {
67        RandomWalkMethod {
68            knowledge_base,
69            input_letters,
70            max_steps,
71            restart_probability,
72        }
73    }
74
75    fn build_flat_outgoing(
76        hypothesis: &Automata,
77    ) -> Option<HashMap<String, Vec<(String, Letter, Letter)>>> {
78        let can_use_flat_transitions = !hypothesis.transitions.is_empty()
79            && hypothesis
80                .transitions
81                .iter()
82                .all(|transition| !transition.source_state.is_empty());
83        if !can_use_flat_transitions {
84            return None;
85        }
86
87        let mut outgoing: HashMap<String, Vec<(String, Letter, Letter)>> = HashMap::new();
88        for transition in &hypothesis.transitions {
89            outgoing
90                .entry(transition.source_state.clone())
91                .or_default()
92                .push((
93                    transition.output_state.name.clone(),
94                    transition.input_letter.clone(),
95                    transition.output_letter.clone(),
96                ));
97        }
98        Some(outgoing)
99    }
100
101    fn check_equivalence(
102        &self,
103        input_word: &Word,
104        expected_output: &Word,
105    ) -> Option<Counterexample> {
106        let mut query = OutputQuery::new(input_word.clone());
107        if self
108            .knowledge_base
109            .lock()
110            .unwrap()
111            .resolve_query(&mut query)
112            .is_err()
113        {
114            return None;
115        }
116
117        if let Some(observed) = query.output_word() {
118            if observed != expected_output {
119                return Some(Counterexample {
120                    input_word: input_word.clone(),
121                    output_word: observed.clone(),
122                });
123            }
124        }
125
126        None
127    }
128
129    fn walk(
130        &self,
131        current_state: &State,
132        rng: &mut SimpleRng,
133        flat_outgoing: Option<&HashMap<String, Vec<(String, Letter, Letter)>>>,
134    ) -> Option<(State, Letter, Letter)> {
135        if let Some(flat_outgoing) = flat_outgoing {
136            let outgoing = flat_outgoing.get(&current_state.name)?;
137
138            if outgoing.is_empty() {
139                return None;
140            }
141
142            let idx = rng.gen_range(0, outgoing.len());
143            let picked_transition = &outgoing[idx];
144            return Some((
145                State::new(picked_transition.0.clone()),
146                picked_transition.1.clone(),
147                picked_transition.2.clone(),
148            ));
149        }
150
151        if current_state.transitions.is_empty() {
152            return None;
153        }
154
155        let idx = rng.gen_range(0, current_state.transitions.len());
156        let picked_transition = &current_state.transitions[idx];
157
158        Some((
159            picked_transition.output_state.clone(),
160            picked_transition.input_letter.clone(),
161            picked_transition.output_letter.clone(),
162        ))
163    }
164}
165
166impl EquivalenceTest for RandomWalkMethod {
167    fn find_counterexample(&self, hypothesis: &mut Automata) -> Option<Counterexample> {
168        let mut rng = SimpleRng::new();
169        let flat_outgoing = Self::build_flat_outgoing(hypothesis);
170        let mut i_step = 0;
171        let mut first_step_after_restart = true;
172        let mut current_state = hypothesis.initial_state.clone();
173        let mut input_word = Word::new();
174        let mut hypothesis_output_word = Word::new();
175        let mut force_restart = false;
176
177        while i_step < self.max_steps {
178            if !first_step_after_restart {
179                if force_restart || rng.gen_f64() < self.restart_probability {
180                    current_state = hypothesis.initial_state.clone();
181                    first_step_after_restart = true;
182
183                    if let Some(ce) = self.check_equivalence(&input_word, &hypothesis_output_word) {
184                        return Some(ce);
185                    }
186
187                    input_word = Word::new();
188                    hypothesis_output_word = Word::new();
189                    force_restart = false;
190                }
191            } else {
192                first_step_after_restart = false;
193            }
194
195            match self.walk(&current_state, &mut rng, flat_outgoing.as_ref()) {
196                Some((new_state, input_letter, output_letter)) => {
197                    current_state = new_state;
198                    input_word = input_word.append_letter(input_letter);
199                    hypothesis_output_word = hypothesis_output_word.append_letter(output_letter);
200                }
201                None => {
202                    force_restart = true;
203                }
204            }
205
206            i_step += 1;
207        }
208
209        None
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::automata::{Automata, State, Transition};
217    use crate::knowledge_base::KnowledgeBaseTrait;
218    use std::sync::{Arc, Mutex};
219
220    struct ConstantOutputKb {
221        output_symbol: String,
222    }
223
224    impl ConstantOutputKb {
225        fn new(output_symbol: &str) -> Self {
226            Self {
227                output_symbol: output_symbol.to_string(),
228            }
229        }
230    }
231
232    impl KnowledgeBaseTrait for ConstantOutputKb {
233        fn resolve_query(&mut self, query: &mut OutputQuery) -> Result<(), String> {
234            let outputs = query
235                .input_word
236                .letters()
237                .iter()
238                .map(|_| Letter::new(self.output_symbol.clone()))
239                .collect::<Vec<_>>();
240            query.set_result(Word::from_letters(outputs));
241            Ok(())
242        }
243
244        fn add_word(&mut self, _input_word: &Word, _output_word: &Word) -> Result<(), String> {
245            Ok(())
246        }
247    }
248
249    fn build_flat_single_state_automata(output_symbol: &str) -> Automata {
250        let initial = State::new("0".to_string());
251        let mut automata = Automata::new(initial, "A".to_string());
252        automata.transitions = vec![Transition::new_with_source(
253            "t0".to_string(),
254            "0".to_string(),
255            State::new("0".to_string()),
256            Letter::new("a"),
257            Letter::new(output_symbol),
258        )];
259        automata
260    }
261
262    #[test]
263    fn random_walk_finds_counterexample_on_flat_transition_automata() {
264        let kb: Arc<Mutex<dyn KnowledgeBaseTrait>> =
265            Arc::new(Mutex::new(ConstantOutputKb::new("1")));
266        let eq = RandomWalkMethod::new(kb, vec![Letter::new("a")], 16, 1.0);
267        let mut hypothesis = build_flat_single_state_automata("0");
268
269        let ce = eq.find_counterexample(&mut hypothesis);
270        assert!(ce.is_some());
271    }
272
273    #[test]
274    fn random_walk_returns_none_when_outputs_match() {
275        let kb: Arc<Mutex<dyn KnowledgeBaseTrait>> =
276            Arc::new(Mutex::new(ConstantOutputKb::new("0")));
277        let eq = RandomWalkMethod::new(kb, vec![Letter::new("a")], 16, 1.0);
278        let mut hypothesis = build_flat_single_state_automata("0");
279
280        let ce = eq.find_counterexample(&mut hypothesis);
281        assert!(ce.is_none());
282    }
283}