1use bit_set::BitSet;
2
3use rand::{rng, seq::IteratorRandom};
4
5use crate::{
6 automata::{TransitionType, NFA},
7 re::{deriv::DerivativeBuilder, ReBuilder, ReOp, Regex},
8 SmtString,
9};
10
11#[derive(Debug, PartialEq, Eq, Clone)]
13pub enum SampleResult {
14 Sampled(SmtString),
16 Empty,
18 MaxDepth,
20}
21
22impl SampleResult {
23 pub fn unwrap(self) -> SmtString {
26 match self {
27 SampleResult::Sampled(s) => s,
28 _ => panic!("called `unwrap` on empty value"),
29 }
30 }
31
32 pub fn success(&self) -> bool {
35 matches!(self, SampleResult::Sampled(_))
36 }
37}
38
39pub fn sample_regex(
44 regex: &Regex,
45 builder: &mut ReBuilder,
46 max_depth: usize,
47 comp: bool,
48) -> SampleResult {
49 fn fast_sample(re: &Regex, d: usize, max: usize) -> SampleResult {
50 if d > max {
51 return SampleResult::MaxDepth;
52 }
53 match re.op() {
54 ReOp::Literal(w) => SampleResult::Sampled(w.clone()),
55 ReOp::Range(r) => {
56 if let Some(r) = r.choose().map(|c| c.into()) {
57 SampleResult::Sampled(r)
58 } else {
59 SampleResult::Empty
60 }
61 }
62 ReOp::None => SampleResult::Empty,
63 ReOp::Any | ReOp::All => SampleResult::Sampled(SmtString::from("a")),
64 ReOp::Concat(rs) => {
65 let mut res = SmtString::empty();
66 for r in rs {
67 match fast_sample(r, d + 1, max) {
68 SampleResult::Sampled(s) => res.append(&s),
69 SampleResult::Empty => return SampleResult::Empty,
70 SampleResult::MaxDepth => return SampleResult::MaxDepth,
71 }
72 }
73 SampleResult::Sampled(res)
74 }
75 ReOp::Comp(comped) => match comped.op() {
76 ReOp::Literal(s) => {
77 if s.is_empty() {
78 SampleResult::Sampled("a".into())
79 } else {
80 SampleResult::Sampled(SmtString::empty())
81 }
82 }
83 ReOp::Range(range) => {
84 for c in range.complement() {
85 if let Some(c) = c.choose() {
86 return SampleResult::Sampled(c.into());
87 }
88 }
89 SampleResult::Empty
90 }
91 ReOp::None => SampleResult::Sampled("a".into()),
92 ReOp::Any => SampleResult::Sampled("aa".into()),
93 ReOp::All => SampleResult::Empty,
94 ReOp::Comp(r) => fast_sample(r, d + 1, max), _ => SampleResult::MaxDepth,
96 },
97 ReOp::Union(rs) => {
98 let mut max_reached = false;
99 for r in rs {
100 match fast_sample(r, d + 1, max) {
101 SampleResult::Sampled(s) => return SampleResult::Sampled(s),
102 SampleResult::Empty => (),
103 SampleResult::MaxDepth => max_reached = true,
104 }
105 }
106 if max_reached {
107 SampleResult::MaxDepth
108 } else {
109 SampleResult::Empty
110 }
111 }
112 ReOp::Star(_) | ReOp::Opt(_) => SampleResult::Sampled(SmtString::empty()),
113 ReOp::Plus(r) => fast_sample(r, d + 1, max),
114 ReOp::Pow(r, e) => match fast_sample(r, d + 1, max) {
115 SampleResult::Sampled(s) => SampleResult::Sampled(s.repeat(*e as usize)),
116 SampleResult::Empty => SampleResult::Empty,
117 SampleResult::MaxDepth => SampleResult::MaxDepth,
118 },
119 ReOp::Loop(r, l, u) if l <= u => match fast_sample(r, d + 1, max) {
120 SampleResult::Sampled(s) => SampleResult::Sampled(s.repeat(*l as usize)),
121 SampleResult::Empty => SampleResult::Empty,
122 SampleResult::MaxDepth => SampleResult::MaxDepth,
123 },
124 ReOp::Loop(_, _, _) => SampleResult::Empty,
125 _ => SampleResult::MaxDepth,
126 }
127 }
128
129 if !comp {
130 match fast_sample(regex, 0, max_depth) {
131 SampleResult::Sampled(s) => return SampleResult::Sampled(s),
132 SampleResult::Empty => return SampleResult::Empty,
133 SampleResult::MaxDepth => (),
134 }
135 }
136
137 let mut w = SmtString::empty();
138 let mut deriver = DerivativeBuilder::default();
139
140 let mut i = 0;
141 let mut re = regex.clone();
142
143 let done = |re: &Regex| {
144 if comp {
145 !re.nullable()
146 } else {
147 re.nullable()
148 }
149 };
150
151 if done(&re) {
152 return SampleResult::Sampled(w);
153 }
154
155 while !done(&re) && i < max_depth {
156 let next = if let Some(c) = re
157 .first()
158 .iter()
159 .choose(&mut rng())
160 .and_then(|c| c.choose())
161 {
162 c
163 } else {
164 return SampleResult::Empty;
165 };
166 w.push(next);
167 re = deriver.deriv(&re, next, builder);
168 i += 1;
169 }
170
171 if done(&re) {
172 SampleResult::Sampled(w)
173 } else {
174 SampleResult::MaxDepth
175 }
176}
177
178pub fn sample_nfa(nfa: &NFA, max: usize, comp: bool) -> SampleResult {
189 let mut w = SmtString::empty();
190 let mut states = BitSet::new();
191 if let Some(q0) = nfa.initial() {
192 states = BitSet::from_iter(nfa.epsilon_closure(q0).unwrap());
193 }
194 let mut i = 0;
195
196 let done = |s: &BitSet| {
197 if comp {
198 !s.iter().any(|q| nfa.is_final(q))
199 } else {
200 s.iter().any(|q| nfa.is_final(q))
201 }
202 };
203
204 while i <= max {
205 i += 1;
206 if done(&states) {
208 return SampleResult::Sampled(w);
209 }
210
211 let mut transitions = Vec::new();
213 for q in states.iter() {
214 transitions.extend(nfa.transitions_from(q).unwrap());
215 }
216 let transition = match transitions.iter().choose(&mut rng()) {
218 Some(t) => t,
219 None => return SampleResult::Empty,
220 };
221 let c = match transition.get_type() {
223 TransitionType::Range(r) => r.choose(),
224 TransitionType::NotRange(nr) => {
225 let r = nr.complement();
226 r.into_iter()
227 .filter(|r| !r.is_empty())
228 .choose(&mut rng())
229 .and_then(|r| r.choose())
230 }
231 TransitionType::Epsilon => None,
232 };
233 match c {
234 Some(c) => {
235 w.push(c);
236 states = BitSet::from_iter(
238 states
239 .iter()
240 .flat_map(|s| nfa.consume(s, c))
241 .flatten()
242 .flat_map(|q| nfa.epsilon_closure(q))
243 .flatten(),
244 );
245 }
246 None => continue,
247 }
248 }
249
250 SampleResult::MaxDepth
251}
252
253#[cfg(test)]
254mod tests {
255
256 use quickcheck_macros::quickcheck;
257 use smallvec::smallvec;
258
259 use crate::alphabet::CharRange;
260
261 use super::*;
262
263 #[test]
264 fn sample_const() {
265 let mut builder = ReBuilder::default();
266 let regex = builder.to_re("foo".into());
267
268 assert_eq!(
269 sample_regex(®ex, &mut builder, 3, false).unwrap(),
270 "foo".into()
271 );
272 assert_eq!(
273 sample_regex(®ex, &mut builder, 10, false).unwrap(),
274 "foo".into()
275 );
276 }
277
278 #[test]
279 fn sample_with_optional_characters() {
280 let mut builder = ReBuilder::default();
281
282 let o = builder.to_re("o".into());
284 let fo = builder.to_re("fo".into());
285 let bar = builder.to_re("bar".into());
286 let o_or_bar = builder.union(smallvec![o, bar]);
287 let regex = builder.concat(smallvec![fo, o_or_bar]);
288
289 assert!(sample_regex(®ex, &mut builder, 5, false).success());
291 }
292
293 #[quickcheck]
294 fn sample_with_character_range(range: CharRange) {
295 let mut builder = ReBuilder::default();
296 let regex = builder.range(range);
297
298 assert!(sample_regex(®ex, &mut builder, 1, false).success());
299 assert!(sample_regex(®ex, &mut builder, 3, false).success());
301 }
302
303 #[quickcheck]
304 fn sample_character_range(range: CharRange) {
305 let mut builder = ReBuilder::default();
306 let regex = builder.range(range);
307
308 assert!(sample_regex(®ex, &mut builder, 1, false).success());
309 assert!(sample_regex(®ex, &mut builder, 3, false).success());
311 }
312
313 #[quickcheck]
314 fn sample_character_range_pow(range: CharRange, n: u32) {
315 let n = n % 100;
316 let mut builder = ReBuilder::default();
317 let regex = builder.range(range);
318 let regex = builder.pow(regex, n);
319
320 assert!(sample_regex(®ex, &mut builder, n as usize, false).success());
321 }
322
323 #[quickcheck]
324 fn sample_alternatives(rs: Vec<CharRange>) {
325 let n = rs.len();
326 let mut builder = ReBuilder::default();
327 let rs = rs.into_iter().map(|r| builder.range(r)).collect();
328 let regex = builder.union(rs);
329
330 if n > 0 {
331 assert!(sample_regex(®ex, &mut builder, 1, false).success());
332 } else {
333 assert!(!sample_regex(®ex, &mut builder, 10, false).success());
334 }
335 }
336
337 #[test]
338 fn sampling_alternatives_bug() {
339 let rs = vec![
340 CharRange::new(2u32, 5u32),
344 CharRange::new(3u32, 6u32),
345 CharRange::new(1u32, 4u32),
346 ];
347
348 let n = rs.len();
350 let mut builder = ReBuilder::default();
351 let rs = rs.into_iter().map(|r| builder.range(r)).collect();
352 let regex = builder.union(rs);
353
354 if n > 0 {
355 assert!(sample_regex(®ex, &mut builder, 1, false).success());
356 } else {
357 assert!(!sample_regex(®ex, &mut builder, 10, false).success());
358 }
359 }
360
361 #[quickcheck]
362 fn sample_opt(r: CharRange) {
363 let mut builder = ReBuilder::default();
364 let r = builder.range(r);
365 let regex = builder.opt(r);
366
367 assert!(sample_regex(®ex, &mut builder, 0, false).success());
368 assert!(sample_regex(®ex, &mut builder, 1, false).success());
369 }
370
371 #[test]
372 fn sample_empty_string() {
373 let mut builder = ReBuilder::default();
374 let regex = builder.epsilon();
375
376 assert!(sample_regex(®ex, &mut builder, 0, false).success());
377 }
378
379 #[test]
380 fn sample_empty_regex() {
381 let mut builder = ReBuilder::default();
382 let regex = builder.none();
383
384 assert_eq!(
385 sample_regex(®ex, &mut builder, 0, false),
386 SampleResult::Empty
387 );
388 assert_eq!(
389 sample_regex(®ex, &mut builder, 20, false),
390 SampleResult::Empty
391 );
392 }
393
394 #[test]
395 fn sample_all() {
396 let mut builder = ReBuilder::default();
397 let regex = builder.all();
398
399 assert!(sample_regex(®ex, &mut builder, 0, false).success());
400 assert!(sample_regex(®ex, &mut builder, 20, false).success());
401 }
402
403 #[test]
404 fn sample_any() {
405 let mut builder = ReBuilder::default();
406 let regex = builder.allchar();
407 assert!(sample_regex(®ex, &mut builder, 20, false).success());
408 }
409
410 #[test]
411 fn test_sample_nfa_accepts_word() {
412 let mut nfa = NFA::new();
413 let q0 = nfa.new_state();
414 let q1 = nfa.new_state();
415
416 nfa.set_initial(q0).unwrap();
417 nfa.add_final(q1).unwrap();
418
419 nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'a')))
420 .unwrap();
421
422 let sample = sample_nfa(&nfa, 10, false);
423 assert_eq!(sample, SampleResult::Sampled(SmtString::from("a")));
424 }
425
426 #[test]
427 fn test_sample_nfa_rejects_unreachable_final_state() {
428 let mut nfa = NFA::new();
429 let q0 = nfa.new_state();
430 let q1 = nfa.new_state(); nfa.set_initial(q0).unwrap();
433 nfa.add_final(q1).unwrap();
434
435 let sample = sample_nfa(&nfa, 10, false);
436 assert_eq!(sample, SampleResult::Empty);
437 }
438
439 #[test]
440 fn test_sample_nfa_handles_epsilon_transitions() {
441 let mut nfa = NFA::new();
442 let q0 = nfa.new_state();
443 let q1 = nfa.new_state();
444 let q2 = nfa.new_state();
445
446 nfa.set_initial(q0).unwrap();
447 nfa.add_final(q2).unwrap();
448
449 nfa.add_transition(q0, q1, TransitionType::Epsilon).unwrap();
450 nfa.add_transition(q1, q2, TransitionType::Range(CharRange::new('b', 'b')))
451 .unwrap();
452
453 let sample = sample_nfa(&nfa, 10, false);
454 assert_eq!(sample, SampleResult::Sampled(SmtString::from("b")));
455 }
456
457 #[test]
458 fn test_sample_nfa_stops_at_max_depth() {
459 let mut nfa = NFA::new();
460 let q0 = nfa.new_state();
461 let q1 = nfa.new_state();
462 let q2 = nfa.new_state();
463
464 nfa.set_initial(q0).unwrap();
465 nfa.add_final(q2).unwrap();
466
467 nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'z')))
469 .unwrap();
470 nfa.add_transition(q1, q2, TransitionType::Range(CharRange::new('a', 'z')))
471 .unwrap();
472
473 let sample = sample_nfa(&nfa, 1, false); assert_eq!(sample, SampleResult::MaxDepth); }
476
477 #[test]
478 fn test_sample_nfa_handles_not_range_transitions() {
479 let mut nfa = NFA::new();
480 let q0 = nfa.new_state();
481 let q1 = nfa.new_state();
482
483 nfa.set_initial(q0).unwrap();
484 nfa.add_final(q1).unwrap();
485
486 nfa.add_transition(q0, q1, TransitionType::NotRange(CharRange::new('x', 'z')))
487 .unwrap();
488
489 let sample = sample_nfa(&nfa, 10, false);
490 assert!(sample.success()); if let SampleResult::Sampled(word) = sample {
492 assert!(
493 !word.contains_char('x') && !word.contains_char('y') && !word.contains_char('z')
494 );
495 }
496 }
497
498 #[test]
499 fn test_sample_nfa_multiple_paths() {
500 let mut nfa = NFA::new();
501 let q0 = nfa.new_state();
502 let q1 = nfa.new_state();
503 let q2 = nfa.new_state();
504 let q3 = nfa.new_state();
505
506 nfa.set_initial(q0).unwrap();
507 nfa.add_final(q3).unwrap();
508
509 nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'a')))
510 .unwrap();
511 nfa.add_transition(q1, q3, TransitionType::Range(CharRange::new('b', 'b')))
512 .unwrap();
513 nfa.add_transition(q0, q2, TransitionType::Range(CharRange::new('x', 'x')))
514 .unwrap();
515 nfa.add_transition(q2, q3, TransitionType::Range(CharRange::new('y', 'y')))
516 .unwrap();
517
518 let sample = sample_nfa(&nfa, 10, false);
519 assert!(
520 sample == SampleResult::Sampled(SmtString::from("ab"))
521 || sample == SampleResult::Sampled(SmtString::from("xy"))
522 );
523 }
524
525 #[test]
526 fn test_sample_nfa_leaves_loops() {
527 let mut nfa = NFA::new();
528 let q0 = nfa.new_state();
529 let q1 = nfa.new_state();
530
531 nfa.set_initial(q0).unwrap();
532 nfa.add_final(q1).unwrap();
533
534 nfa.add_transition(q0, q0, TransitionType::Range(CharRange::singleton('a')))
535 .unwrap();
536 nfa.add_transition(q0, q1, TransitionType::Range(CharRange::singleton('b')))
537 .unwrap();
538
539 match sample_nfa(&nfa, 100, false) {
540 SampleResult::Sampled(w) => {
541 let l = w.len();
542 let mut expected = SmtString::from("a").repeat(l - 1);
543 expected.push('b');
544 assert_eq!(w, expected);
545 }
546 _ => unreachable!("Sample should not return None"),
547 }
548 }
549}