snarky_parrot/
lib.rs

1use error::Error;
2use rand::{seq::SliceRandom, thread_rng, Rng};
3use std::collections::HashMap;
4
5mod error;
6
7pub type MarkovChainRule<'a> = HashMap<Vec<&'a str>, Vec<&'a str>>;
8
9pub fn generate_rule_from_data(content: &str, key_size: usize) -> Result<MarkovChainRule, Error> {
10    if key_size < 1 {
11        return Err(Error::InvalidKeySize);
12    }
13
14    let words: Vec<&str> = content.split_whitespace().collect();
15
16    let mut dict: MarkovChainRule = HashMap::new();
17
18    for slice in words.windows(key_size + 1) {
19        let (key, value) = slice.split_at(key_size);
20        let value = value[0];
21        match dict.get_mut(key) {
22            Some(e) => {
23                e.push(value);
24            }
25            None => {
26                dict.insert(key.to_vec(), vec![value]);
27            }
28        }
29    }
30
31    Ok(dict)
32}
33
34pub fn generate_text(rule: &MarkovChainRule, length: usize) -> String {
35    let mut rng = thread_rng();
36
37    let start = rule.keys().nth(rng.gen_range(0..rule.len())).unwrap();
38
39    let mut chain = start.clone();
40    let key_size = chain.len();
41
42    for _ in 0..length {
43        let key = &chain[chain.len() - key_size..];
44        let nexts = match rule.get(key) {
45            None => break,
46            Some(e) => e,
47        };
48        let next = nexts.choose(&mut rng).unwrap();
49        chain.push(next);
50    }
51
52    chain.join(" ")
53}