1use std::collections::HashMap;
4
5use ruv_neural_core::topology::{CognitiveState, TopologyMetrics};
6use serde::{Deserialize, Serialize};
7
8pub struct TransitionDecoder {
13 current_state: CognitiveState,
14 transition_patterns: HashMap<(CognitiveState, CognitiveState), TransitionPattern>,
15 history: Vec<TopologyMetrics>,
16 window_size: usize,
17}
18
19#[derive(Debug, Clone)]
21pub struct TransitionPattern {
22 pub mincut_delta: f64,
24 pub modularity_delta: f64,
26 pub duration_s: f64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct StateTransition {
33 pub from: CognitiveState,
35 pub to: CognitiveState,
37 pub confidence: f64,
39 pub timestamp: f64,
41}
42
43impl TransitionDecoder {
44 pub fn new(window_size: usize) -> Self {
49 let window_size = if window_size < 2 { 2 } else { window_size };
50 Self {
51 current_state: CognitiveState::Unknown,
52 transition_patterns: HashMap::new(),
53 history: Vec::new(),
54 window_size,
55 }
56 }
57
58 pub fn register_pattern(
60 &mut self,
61 from: CognitiveState,
62 to: CognitiveState,
63 pattern: TransitionPattern,
64 ) {
65 self.transition_patterns.insert((from, to), pattern);
66 }
67
68 pub fn current_state(&self) -> CognitiveState {
70 self.current_state
71 }
72
73 pub fn set_current_state(&mut self, state: CognitiveState) {
75 self.current_state = state;
76 }
77
78 pub fn update(&mut self, metrics: TopologyMetrics) -> Option<StateTransition> {
83 self.history.push(metrics);
84
85 if self.history.len() > self.window_size {
87 let excess = self.history.len() - self.window_size;
88 self.history.drain(..excess);
89 }
90
91 if self.history.len() < 2 {
93 return None;
94 }
95
96 let oldest = &self.history[0];
97 let newest = self.history.last().unwrap();
98
99 let observed_mincut_delta = newest.global_mincut - oldest.global_mincut;
100 let observed_modularity_delta = newest.modularity - oldest.modularity;
101 let observed_duration = newest.timestamp - oldest.timestamp;
102
103 let mut best_match: Option<(CognitiveState, f64)> = None;
105
106 for (&(from, to), pattern) in &self.transition_patterns {
107 if from != self.current_state {
109 continue;
110 }
111
112 let score = pattern_match_score(
113 observed_mincut_delta,
114 observed_modularity_delta,
115 observed_duration,
116 pattern,
117 );
118
119 if score > 0.5 {
120 if let Some((_, best_score)) = &best_match {
121 if score > *best_score {
122 best_match = Some((to, score));
123 }
124 } else {
125 best_match = Some((to, score));
126 }
127 }
128 }
129
130 if let Some((to_state, confidence)) = best_match {
131 let transition = StateTransition {
132 from: self.current_state,
133 to: to_state,
134 confidence: confidence.clamp(0.0, 1.0),
135 timestamp: newest.timestamp,
136 };
137 self.current_state = to_state;
138 Some(transition)
139 } else {
140 None
141 }
142 }
143
144 pub fn num_patterns(&self) -> usize {
146 self.transition_patterns.len()
147 }
148
149 pub fn history_len(&self) -> usize {
151 self.history.len()
152 }
153}
154
155fn pattern_match_score(
159 observed_mincut_delta: f64,
160 observed_modularity_delta: f64,
161 observed_duration: f64,
162 pattern: &TransitionPattern,
163) -> f64 {
164 let mincut_score = if pattern.mincut_delta.abs() < 1e-10 {
165 if observed_mincut_delta.abs() < 0.5 {
166 1.0
167 } else {
168 0.5
169 }
170 } else {
171 let ratio = observed_mincut_delta / pattern.mincut_delta;
172 gaussian_score(ratio, 1.0, 0.5)
173 };
174
175 let modularity_score = if pattern.modularity_delta.abs() < 1e-10 {
176 if observed_modularity_delta.abs() < 0.05 {
177 1.0
178 } else {
179 0.5
180 }
181 } else {
182 let ratio = observed_modularity_delta / pattern.modularity_delta;
183 gaussian_score(ratio, 1.0, 0.5)
184 };
185
186 let duration_score = if pattern.duration_s.abs() < 1e-10 {
187 1.0
188 } else {
189 let ratio = observed_duration / pattern.duration_s;
190 gaussian_score(ratio, 1.0, 0.5)
191 };
192
193 (mincut_score + modularity_score + duration_score) / 3.0
194}
195
196fn gaussian_score(value: f64, center: f64, sigma: f64) -> f64 {
198 let diff = value - center;
199 (-0.5 * (diff / sigma).powi(2)).exp()
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 fn make_metrics(
207 mincut: f64,
208 modularity: f64,
209 timestamp: f64,
210 ) -> TopologyMetrics {
211 TopologyMetrics {
212 global_mincut: mincut,
213 modularity,
214 global_efficiency: 0.3,
215 local_efficiency: 0.0,
216 graph_entropy: 2.0,
217 fiedler_value: 0.0,
218 num_modules: 4,
219 timestamp,
220 }
221 }
222
223 #[test]
224 fn test_detect_state_transition() {
225 let mut decoder = TransitionDecoder::new(5);
226 decoder.set_current_state(CognitiveState::Rest);
227
228 decoder.register_pattern(
230 CognitiveState::Rest,
231 CognitiveState::Focused,
232 TransitionPattern {
233 mincut_delta: 3.0,
234 modularity_delta: 0.2,
235 duration_s: 2.0,
236 },
237 );
238
239 let updates = vec![
242 make_metrics(5.0, 0.4, 0.0),
243 make_metrics(6.0, 0.45, 0.5),
244 make_metrics(7.0, 0.5, 1.0),
245 make_metrics(8.0, 0.6, 2.0),
246 ];
247
248 let mut detected: Option<StateTransition> = None;
249 for m in updates {
250 if let Some(t) = decoder.update(m) {
251 detected = Some(t);
252 }
253 }
254
255 assert!(detected.is_some(), "Expected a transition to be detected");
256 let transition = detected.unwrap();
257 assert_eq!(transition.from, CognitiveState::Rest);
258 assert_eq!(transition.to, CognitiveState::Focused);
259 assert!(transition.confidence > 0.0 && transition.confidence <= 1.0);
260 }
261
262 #[test]
263 fn test_no_transition_without_pattern() {
264 let mut decoder = TransitionDecoder::new(3);
265 decoder.set_current_state(CognitiveState::Rest);
266
267 let result = decoder.update(make_metrics(5.0, 0.4, 0.0));
268 assert!(result.is_none());
269 let result = decoder.update(make_metrics(8.0, 0.6, 2.0));
270 assert!(result.is_none());
271 }
272
273 #[test]
274 fn test_window_trimming() {
275 let mut decoder = TransitionDecoder::new(3);
276 for i in 0..10 {
277 decoder.update(make_metrics(5.0, 0.4, i as f64));
278 }
279 assert_eq!(decoder.history_len(), 3);
280 }
281
282 #[test]
283 fn test_single_sample_no_transition() {
284 let mut decoder = TransitionDecoder::new(5);
285 decoder.register_pattern(
286 CognitiveState::Rest,
287 CognitiveState::Focused,
288 TransitionPattern {
289 mincut_delta: 3.0,
290 modularity_delta: 0.2,
291 duration_s: 2.0,
292 },
293 );
294 decoder.set_current_state(CognitiveState::Rest);
295 let result = decoder.update(make_metrics(5.0, 0.4, 0.0));
296 assert!(result.is_none());
297 }
298}