1#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[must_use]
14pub struct Config {
15 pub n_clauses: usize,
16 pub n_features: usize,
17 pub n_states: i16,
18 pub s: f32
19}
20
21impl Config {
22 #[inline]
26 #[must_use]
27 pub fn builder() -> ConfigBuilder {
28 ConfigBuilder::default()
29 }
30
31 pub fn validate(&self) -> Result<()> {
35 if self.n_clauses == 0 {
36 return Err(Error::MissingClauses);
37 }
38 if !self.n_clauses.is_multiple_of(2) {
39 return Err(Error::OddClauses);
40 }
41 if self.n_features == 0 {
42 return Err(Error::MissingFeatures);
43 }
44 if self.s <= 1.0 {
45 return Err(Error::InvalidSpecificity);
46 }
47 Ok(())
48 }
49
50 #[inline]
54 #[must_use]
55 pub fn prob_strengthen(&self) -> f32 {
56 (self.s - 1.0) / self.s
57 }
58
59 #[inline]
63 #[must_use]
64 pub fn prob_weaken(&self) -> f32 {
65 1.0 / self.s
66 }
67
68 #[inline]
75 #[must_use]
76 pub fn threshold_strengthen(&self) -> u32 {
77 prob_to_threshold(self.prob_strengthen())
78 }
79
80 #[inline]
84 #[must_use]
85 pub fn threshold_weaken(&self) -> u32 {
86 prob_to_threshold(self.prob_weaken())
87 }
88}
89
90#[inline]
95#[must_use]
96pub fn prob_to_threshold(prob: f32) -> u32 {
97 (prob as f64 * u32::MAX as f64) as u32
98}
99
100#[derive(Debug, Default)]
104pub struct ConfigBuilder {
105 n_clauses: Option<usize>,
106 n_features: Option<usize>,
107 n_states: Option<i16>,
108 s: Option<f32>
109}
110
111impl ConfigBuilder {
112 pub fn clauses(mut self, n: usize) -> Self {
116 self.n_clauses = Some(n);
117 self
118 }
119
120 pub fn features(mut self, n: usize) -> Self {
124 self.n_features = Some(n);
125 self
126 }
127
128 pub fn states(mut self, n: i16) -> Self {
132 self.n_states = Some(n);
133 self
134 }
135
136 pub fn specificity(mut self, s: f32) -> Self {
140 self.s = Some(s);
141 self
142 }
143
144 pub fn build(self) -> Result<Config> {
148 let config = Config {
149 n_clauses: self.n_clauses.ok_or(Error::MissingClauses)?,
150 n_features: self.n_features.ok_or(Error::MissingFeatures)?,
151 n_states: self.n_states.unwrap_or(100),
152 s: self.s.unwrap_or(3.9)
153 };
154 config.validate()?;
155 Ok(config)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn builder_with_defaults() {
165 let config = Config::builder().clauses(20).features(4).build().unwrap();
166
167 assert_eq!(config.n_clauses, 20);
168 assert_eq!(config.n_features, 4);
169 assert_eq!(config.n_states, 100);
170 assert!((config.s - 3.9).abs() < 0.01);
171 }
172
173 #[test]
174 fn builder_rejects_odd_clauses() {
175 let result = Config::builder().clauses(21).features(4).build();
176
177 assert_eq!(result, Err(Error::OddClauses));
178 }
179
180 #[test]
181 fn prob_precomputed() {
182 let config = Config::builder()
183 .clauses(20)
184 .features(4)
185 .specificity(4.0)
186 .build()
187 .unwrap();
188
189 assert!((config.prob_strengthen() - 0.75).abs() < 0.001);
190 assert!((config.prob_weaken() - 0.25).abs() < 0.001);
191 }
192
193 #[test]
194 fn integer_thresholds() {
195 let config = Config::builder()
196 .clauses(20)
197 .features(4)
198 .specificity(4.0)
199 .build()
200 .unwrap();
201
202 assert!(config.threshold_strengthen() > 3_000_000_000);
204 assert!(config.threshold_strengthen() < 3_500_000_000);
205
206 assert!(config.threshold_weaken() > 1_000_000_000);
208 assert!(config.threshold_weaken() < 1_200_000_000);
209 }
210
211 #[test]
212 fn prob_to_threshold_boundaries() {
213 assert_eq!(prob_to_threshold(0.0), 0);
214 assert_eq!(prob_to_threshold(1.0), u32::MAX);
215 assert_eq!(prob_to_threshold(0.5), u32::MAX / 2);
216 }
217}