Skip to main content

ruv_neural_decoder/
transition_decoder.rs

1//! Transition decoder for detecting cognitive state changes from topology dynamics.
2
3use std::collections::HashMap;
4
5use ruv_neural_core::topology::{CognitiveState, TopologyMetrics};
6use serde::{Deserialize, Serialize};
7
8/// Detect cognitive state transitions from topology change patterns.
9///
10/// Monitors a sliding window of topology metrics and compares observed
11/// deltas against registered transition patterns to detect state changes.
12pub struct TransitionDecoder {
13    current_state: CognitiveState,
14    transition_patterns: HashMap<(CognitiveState, CognitiveState), TransitionPattern>,
15    history: Vec<TopologyMetrics>,
16    window_size: usize,
17}
18
19/// A pattern describing the expected topology change during a state transition.
20#[derive(Debug, Clone)]
21pub struct TransitionPattern {
22    /// Expected change in global minimum cut value.
23    pub mincut_delta: f64,
24    /// Expected change in modularity.
25    pub modularity_delta: f64,
26    /// Expected duration of the transition in seconds.
27    pub duration_s: f64,
28}
29
30/// A detected state transition.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct StateTransition {
33    /// State before the transition.
34    pub from: CognitiveState,
35    /// State after the transition.
36    pub to: CognitiveState,
37    /// Confidence of the detection in `[0, 1]`.
38    pub confidence: f64,
39    /// Timestamp when the transition was detected.
40    pub timestamp: f64,
41}
42
43impl TransitionDecoder {
44    /// Create a new transition decoder with a given sliding window size.
45    ///
46    /// The window size determines how many recent topology snapshots are
47    /// retained for computing deltas.
48    pub fn new(window_size: usize) -> Self {
49        let window_size = if window_size < 2 { 2 } else { window_size };
50        Self {
51            current_state: CognitiveState::Unknown,
52            transition_patterns: HashMap::new(),
53            history: Vec::new(),
54            window_size,
55        }
56    }
57
58    /// Register a transition pattern between two states.
59    pub fn register_pattern(
60        &mut self,
61        from: CognitiveState,
62        to: CognitiveState,
63        pattern: TransitionPattern,
64    ) {
65        self.transition_patterns.insert((from, to), pattern);
66    }
67
68    /// Get the current estimated cognitive state.
69    pub fn current_state(&self) -> CognitiveState {
70        self.current_state
71    }
72
73    /// Set the current state explicitly (e.g., from an external decoder).
74    pub fn set_current_state(&mut self, state: CognitiveState) {
75        self.current_state = state;
76    }
77
78    /// Push a new topology snapshot and check for state transitions.
79    ///
80    /// Returns `Some(StateTransition)` if a transition is detected,
81    /// `None` otherwise.
82    pub fn update(&mut self, metrics: TopologyMetrics) -> Option<StateTransition> {
83        self.history.push(metrics);
84
85        // Trim history to window size.
86        if self.history.len() > self.window_size {
87            let excess = self.history.len() - self.window_size;
88            self.history.drain(..excess);
89        }
90
91        // Need at least 2 samples to compute deltas.
92        if self.history.len() < 2 {
93            return None;
94        }
95
96        let oldest = &self.history[0];
97        let newest = self.history.last().unwrap();
98
99        let observed_mincut_delta = newest.global_mincut - oldest.global_mincut;
100        let observed_modularity_delta = newest.modularity - oldest.modularity;
101        let observed_duration = newest.timestamp - oldest.timestamp;
102
103        // Score each registered pattern.
104        let mut best_match: Option<(CognitiveState, f64)> = None;
105
106        for (&(from, to), pattern) in &self.transition_patterns {
107            // Only consider patterns starting from the current state.
108            if from != self.current_state {
109                continue;
110            }
111
112            let score = pattern_match_score(
113                observed_mincut_delta,
114                observed_modularity_delta,
115                observed_duration,
116                pattern,
117            );
118
119            if score > 0.5 {
120                if let Some((_, best_score)) = &best_match {
121                    if score > *best_score {
122                        best_match = Some((to, score));
123                    }
124                } else {
125                    best_match = Some((to, score));
126                }
127            }
128        }
129
130        if let Some((to_state, confidence)) = best_match {
131            let transition = StateTransition {
132                from: self.current_state,
133                to: to_state,
134                confidence: confidence.clamp(0.0, 1.0),
135                timestamp: newest.timestamp,
136            };
137            self.current_state = to_state;
138            Some(transition)
139        } else {
140            None
141        }
142    }
143
144    /// Number of registered transition patterns.
145    pub fn num_patterns(&self) -> usize {
146        self.transition_patterns.len()
147    }
148
149    /// Number of topology snapshots in the history buffer.
150    pub fn history_len(&self) -> usize {
151        self.history.len()
152    }
153}
154
155/// Compute a similarity score between observed deltas and a transition pattern.
156///
157/// Returns a value in `[0, 1]` where 1.0 means a perfect match.
158fn pattern_match_score(
159    observed_mincut_delta: f64,
160    observed_modularity_delta: f64,
161    observed_duration: f64,
162    pattern: &TransitionPattern,
163) -> f64 {
164    let mincut_score = if pattern.mincut_delta.abs() < 1e-10 {
165        if observed_mincut_delta.abs() < 0.5 {
166            1.0
167        } else {
168            0.5
169        }
170    } else {
171        let ratio = observed_mincut_delta / pattern.mincut_delta;
172        gaussian_score(ratio, 1.0, 0.5)
173    };
174
175    let modularity_score = if pattern.modularity_delta.abs() < 1e-10 {
176        if observed_modularity_delta.abs() < 0.05 {
177            1.0
178        } else {
179            0.5
180        }
181    } else {
182        let ratio = observed_modularity_delta / pattern.modularity_delta;
183        gaussian_score(ratio, 1.0, 0.5)
184    };
185
186    let duration_score = if pattern.duration_s.abs() < 1e-10 {
187        1.0
188    } else {
189        let ratio = observed_duration / pattern.duration_s;
190        gaussian_score(ratio, 1.0, 0.5)
191    };
192
193    (mincut_score + modularity_score + duration_score) / 3.0
194}
195
196/// Gaussian-shaped score centered at `center` with width `sigma`.
197fn gaussian_score(value: f64, center: f64, sigma: f64) -> f64 {
198    let diff = value - center;
199    (-0.5 * (diff / sigma).powi(2)).exp()
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    fn make_metrics(
207        mincut: f64,
208        modularity: f64,
209        timestamp: f64,
210    ) -> TopologyMetrics {
211        TopologyMetrics {
212            global_mincut: mincut,
213            modularity,
214            global_efficiency: 0.3,
215            local_efficiency: 0.0,
216            graph_entropy: 2.0,
217            fiedler_value: 0.0,
218            num_modules: 4,
219            timestamp,
220        }
221    }
222
223    #[test]
224    fn test_detect_state_transition() {
225        let mut decoder = TransitionDecoder::new(5);
226        decoder.set_current_state(CognitiveState::Rest);
227
228        // Register a pattern: Rest -> Focused causes mincut increase and modularity increase.
229        decoder.register_pattern(
230            CognitiveState::Rest,
231            CognitiveState::Focused,
232            TransitionPattern {
233                mincut_delta: 3.0,
234                modularity_delta: 0.2,
235                duration_s: 2.0,
236            },
237        );
238
239        // Feed metrics that progressively match the pattern.
240        // The transition may fire on any update once deltas are large enough.
241        let updates = vec![
242            make_metrics(5.0, 0.4, 0.0),
243            make_metrics(6.0, 0.45, 0.5),
244            make_metrics(7.0, 0.5, 1.0),
245            make_metrics(8.0, 0.6, 2.0),
246        ];
247
248        let mut detected: Option<StateTransition> = None;
249        for m in updates {
250            if let Some(t) = decoder.update(m) {
251                detected = Some(t);
252            }
253        }
254
255        assert!(detected.is_some(), "Expected a transition to be detected");
256        let transition = detected.unwrap();
257        assert_eq!(transition.from, CognitiveState::Rest);
258        assert_eq!(transition.to, CognitiveState::Focused);
259        assert!(transition.confidence > 0.0 && transition.confidence <= 1.0);
260    }
261
262    #[test]
263    fn test_no_transition_without_pattern() {
264        let mut decoder = TransitionDecoder::new(3);
265        decoder.set_current_state(CognitiveState::Rest);
266
267        let result = decoder.update(make_metrics(5.0, 0.4, 0.0));
268        assert!(result.is_none());
269        let result = decoder.update(make_metrics(8.0, 0.6, 2.0));
270        assert!(result.is_none());
271    }
272
273    #[test]
274    fn test_window_trimming() {
275        let mut decoder = TransitionDecoder::new(3);
276        for i in 0..10 {
277            decoder.update(make_metrics(5.0, 0.4, i as f64));
278        }
279        assert_eq!(decoder.history_len(), 3);
280    }
281
282    #[test]
283    fn test_single_sample_no_transition() {
284        let mut decoder = TransitionDecoder::new(5);
285        decoder.register_pattern(
286            CognitiveState::Rest,
287            CognitiveState::Focused,
288            TransitionPattern {
289                mincut_delta: 3.0,
290                modularity_delta: 0.2,
291                duration_s: 2.0,
292            },
293        );
294        decoder.set_current_state(CognitiveState::Rest);
295        let result = decoder.update(make_metrics(5.0, 0.4, 0.0));
296        assert!(result.is_none());
297    }
298}