syntaxdot_encoders/categorical/
number.rs

1use std::cell::RefCell;
2use std::hash::Hash;
3
4use numberer::Numberer;
5use serde_derive::{Deserialize, Serialize};
6
7/// Number a categorical variable.
8#[allow(clippy::len_without_is_empty)]
9pub trait Number<V>
10where
11    V: Clone + Eq + Hash,
12{
13    /// Construct a numberer for categorical variables.
14    fn new(numberer: Numberer<V>) -> Self;
15
16    /// Get the number of possible values in the categorical variable.
17    ///
18    /// This includes reserved numerical representations that do
19    /// not correspond to values in the categorial variable.
20    fn len(&self) -> usize;
21
22    /// Get the number of a value from a categorical variable.
23    ///
24    /// Mutable implementations of this trait must add the value if it
25    /// is unknown and always return [`Option::Some`].
26    fn number(&self, value: V) -> Option<usize>;
27
28    /// Get the value corresponding of a number.
29    ///
30    /// Returns [`Option::None`] if the number is unknown *or* a
31    /// reserved number.
32    fn value(&self, number: usize) -> Option<V>;
33}
34
35/// An immutable categorical variable numberer.
36#[derive(Deserialize, Serialize)]
37pub struct ImmutableNumberer<V>(Numberer<V>)
38where
39    V: Clone + Eq + Hash;
40
41impl<V> Number<V> for ImmutableNumberer<V>
42where
43    V: Clone + Eq + Hash,
44{
45    fn new(numberer: Numberer<V>) -> Self {
46        ImmutableNumberer(numberer)
47    }
48
49    fn len(&self) -> usize {
50        self.0.len()
51    }
52
53    fn number(&self, value: V) -> Option<usize> {
54        self.0.number(&value)
55    }
56
57    fn value(&self, number: usize) -> Option<V> {
58        self.0.value(number).cloned()
59    }
60}
61
62/// A mutable categorical variable numberer using interior mutability.
63#[derive(Deserialize, Serialize)]
64pub struct MutableNumberer<V>(RefCell<Numberer<V>>)
65where
66    V: Clone + Eq + Hash;
67
68impl<V> Number<V> for MutableNumberer<V>
69where
70    V: Clone + Eq + Hash,
71{
72    fn new(numberer: Numberer<V>) -> Self {
73        MutableNumberer(RefCell::new(numberer))
74    }
75
76    fn len(&self) -> usize {
77        self.0.borrow().len()
78    }
79
80    fn number(&self, value: V) -> Option<usize> {
81        Some(self.0.borrow_mut().add(value))
82    }
83
84    fn value(&self, number: usize) -> Option<V> {
85        self.0.borrow().value(number).cloned()
86    }
87}