1use error::Error;
2use rand::{seq::SliceRandom, thread_rng, Rng};
3use std::collections::HashMap;
4
5mod error;
6
7pub type MarkovChainRule<'a> = HashMap<Vec<&'a str>, Vec<&'a str>>;
8
9pub fn generate_rule_from_data(content: &str, key_size: usize) -> Result<MarkovChainRule, Error> {
10 if key_size < 1 {
11 return Err(Error::InvalidKeySize);
12 }
13
14 let words: Vec<&str> = content.split_whitespace().collect();
15
16 let mut dict: MarkovChainRule = HashMap::new();
17
18 for slice in words.windows(key_size + 1) {
19 let (key, value) = slice.split_at(key_size);
20 let value = value[0];
21 match dict.get_mut(key) {
22 Some(e) => {
23 e.push(value);
24 }
25 None => {
26 dict.insert(key.to_vec(), vec![value]);
27 }
28 }
29 }
30
31 Ok(dict)
32}
33
34pub fn generate_text(rule: &MarkovChainRule, length: usize) -> String {
35 let mut rng = thread_rng();
36
37 let start = rule.keys().nth(rng.gen_range(0..rule.len())).unwrap();
38
39 let mut chain = start.clone();
40 let key_size = chain.len();
41
42 for _ in 0..length {
43 let key = &chain[chain.len() - key_size..];
44 let nexts = match rule.get(key) {
45 None => break,
46 Some(e) => e,
47 };
48 let next = nexts.choose(&mut rng).unwrap();
49 chain.push(next);
50 }
51
52 chain.join(" ")
53}