rustkov/
stats.rs

1use std::{collections::HashSet, vec};
2
3use super::{brain_prelude::StateElement, prelude::Brain};
4
5/// This struct let you compute some statistics of a [`Brain`].
6///
7/// Beware that it is not very optimised, nor complete for now.
8///
9/// # Example
10///
11/// ```
12/// use rustkov::prelude::Brain;
13///
14/// let brain = Brain::from_file("path/to/brain.yml").unwrap()
15///                         .get();
16///
17/// let stats = brain.stats();
18///
19/// println!("{}", stats.get_total_states());
20///
21/// ```
22///
23/// [`Brain`]: crate::brain::Brain
24pub struct BrainStats<'a> {
25    brain: &'a Brain,
26}
27
28impl<'a> BrainStats<'a> {
29    pub(crate) fn new(brain: &'a Brain) -> Self {
30        Self { brain }
31    }
32
33    /// Returns the length of states that the brain have.
34    pub fn get_total_states(&self) -> usize {
35        self.brain.state_transitions.len()
36    }
37
38    /// Returns the number of transitions that the brain have.
39    ///
40    /// A single state has multiple transitions
41    pub fn get_total_transitions(&self) -> usize {
42        self.brain
43            .state_transitions
44            .iter()
45            .map(|(_, transition)| transition.prev.len() + transition.next.len())
46            .sum()
47    }
48
49    /// Returns the average of the last two metrics.
50    ///
51    /// It is useful to see if your chatbot will be able to
52    /// construct unique sentences
53    pub fn avg_transition_per_state(&self) -> f32 {
54        return self.get_total_transitions() as f32 / self.get_total_states() as f32;
55    }
56
57    /// Retruns the total number of single words
58    /// known to the brain.
59    pub fn get_total_words(&self) -> usize {
60        let mut words: Vec<&str> = vec![];
61
62        self.brain.state_transitions.iter().for_each(|(state, _)| {
63            state
64                .0
65                .iter()
66                .filter_map(|elem| {
67                    if let StateElement::Word(e) = elem {
68                        Some(e.as_str())
69                    } else {
70                        None
71                    }
72                })
73                .for_each(|ref word| {
74                    words.push(word);
75                });
76        });
77
78        let set: HashSet<_> = words.drain(..).collect();
79        set.len()
80    }
81}