Skip to main content

ruv_neural_graph/
constructor.rs

1//! Graph construction from connectivity matrices and multi-channel time series.
2//!
3//! The [`BrainGraphConstructor`] converts pairwise connectivity values into
4//! [`BrainGraph`] instances, with optional thresholding to remove weak edges.
5//! It also supports sliding-window construction from raw time series via the
6//! signal crate's connectivity metrics.
7
8use ruv_neural_core::brain::Parcellation;
9use ruv_neural_core::error::{Result, RuvNeuralError};
10use ruv_neural_core::graph::{BrainEdge, BrainGraph, BrainGraphSequence, ConnectivityMetric};
11use ruv_neural_core::signal::{FrequencyBand, MultiChannelTimeSeries};
12use ruv_neural_core::traits::GraphConstructor;
13
14use crate::atlas::{AtlasType, load_atlas};
15
16/// Constructs brain connectivity graphs from matrices or time series data.
17pub struct BrainGraphConstructor {
18    parcellation: Parcellation,
19    metric: ConnectivityMetric,
20    band: FrequencyBand,
21    /// Edge weight threshold: edges below this value are dropped.
22    threshold: f64,
23    /// Sliding window duration in seconds.
24    window_duration_s: f64,
25    /// Sliding window step in seconds.
26    window_step_s: f64,
27}
28
29impl BrainGraphConstructor {
30    /// Create a new constructor with default window parameters.
31    pub fn new(atlas: AtlasType, metric: ConnectivityMetric, band: FrequencyBand) -> Self {
32        Self {
33            parcellation: load_atlas(atlas),
34            metric,
35            band,
36            threshold: 0.0,
37            window_duration_s: 1.0,
38            window_step_s: 0.5,
39        }
40    }
41
42    /// Set the edge weight threshold. Edges with weight below this are excluded.
43    pub fn with_threshold(mut self, threshold: f64) -> Self {
44        self.threshold = threshold;
45        self
46    }
47
48    /// Set the sliding window duration in seconds.
49    pub fn with_window_duration(mut self, duration_s: f64) -> Self {
50        self.window_duration_s = duration_s;
51        self
52    }
53
54    /// Set the sliding window step in seconds.
55    pub fn with_window_step(mut self, step_s: f64) -> Self {
56        self.window_step_s = step_s;
57        self
58    }
59
60    /// Construct a brain graph from a pre-computed connectivity matrix.
61    ///
62    /// The matrix should be `n x n` where `n` matches the number of atlas regions.
63    /// The matrix is treated as symmetric; only the upper triangle is read.
64    pub fn construct_from_matrix(
65        &self,
66        connectivity: &[Vec<f64>],
67        timestamp: f64,
68    ) -> BrainGraph {
69        let n = self.parcellation.num_regions();
70        let mut edges = Vec::new();
71
72        for i in 0..n.min(connectivity.len()) {
73            for j in (i + 1)..n.min(connectivity[i].len()) {
74                let weight = connectivity[i][j];
75                if weight.abs() > self.threshold {
76                    edges.push(BrainEdge {
77                        source: i,
78                        target: j,
79                        weight,
80                        metric: self.metric,
81                        frequency_band: self.band,
82                    });
83                }
84            }
85        }
86
87        BrainGraph {
88            num_nodes: n,
89            edges,
90            timestamp,
91            window_duration_s: self.window_duration_s,
92            atlas: self.parcellation.atlas,
93        }
94    }
95
96    /// Construct a sequence of brain graphs from multi-channel time series
97    /// using a sliding window approach.
98    ///
99    /// For each window, computes pairwise Pearson correlation as connectivity,
100    /// then builds a graph with thresholding applied.
101    pub fn construct_sequence(
102        &self,
103        data: &MultiChannelTimeSeries,
104    ) -> BrainGraphSequence {
105        let n_samples = data.num_samples;
106        let sr = data.sample_rate_hz;
107
108        let window_samples = (self.window_duration_s * sr) as usize;
109        let step_samples = (self.window_step_s * sr) as usize;
110
111        if window_samples == 0 || step_samples == 0 || n_samples < window_samples {
112            return BrainGraphSequence {
113                graphs: Vec::new(),
114                window_step_s: self.window_step_s,
115            };
116        }
117
118        let mut graphs = Vec::new();
119        let mut offset = 0;
120
121        while offset + window_samples <= n_samples {
122            let timestamp = data.timestamp_start + offset as f64 / sr;
123
124            // Extract windowed data for each channel
125            let windowed: Vec<&[f64]> = data
126                .data
127                .iter()
128                .map(|ch| &ch[offset..offset + window_samples])
129                .collect();
130
131            // Compute pairwise Pearson correlation matrix
132            let connectivity = compute_correlation_matrix(&windowed);
133
134            let graph = self.construct_from_matrix(&connectivity, timestamp);
135            graphs.push(graph);
136
137            offset += step_samples;
138        }
139
140        BrainGraphSequence {
141            graphs,
142            window_step_s: self.window_step_s,
143        }
144    }
145}
146
147impl GraphConstructor for BrainGraphConstructor {
148    fn construct(&self, signals: &MultiChannelTimeSeries) -> Result<BrainGraph> {
149        let n_channels = signals.num_channels;
150        let expected = self.parcellation.num_regions();
151        if n_channels != expected {
152            return Err(RuvNeuralError::DimensionMismatch {
153                expected,
154                got: n_channels,
155            });
156        }
157
158        let windowed: Vec<&[f64]> = signals.data.iter().map(|ch| ch.as_slice()).collect();
159        let connectivity = compute_correlation_matrix(&windowed);
160        Ok(self.construct_from_matrix(&connectivity, signals.timestamp_start))
161    }
162}
163
164/// Compute pairwise Pearson correlation matrix for a set of channels.
165fn compute_correlation_matrix(channels: &[&[f64]]) -> Vec<Vec<f64>> {
166    let n = channels.len();
167    let mut matrix = vec![vec![0.0; n]; n];
168
169    // Pre-compute means and standard deviations
170    let stats: Vec<(f64, f64)> = channels
171        .iter()
172        .map(|ch| {
173            let len = ch.len() as f64;
174            if len == 0.0 {
175                return (0.0, 0.0);
176            }
177            let mean = ch.iter().sum::<f64>() / len;
178            let var = ch.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / len;
179            (mean, var.sqrt())
180        })
181        .collect();
182
183    for i in 0..n {
184        matrix[i][i] = 1.0;
185        for j in (i + 1)..n {
186            let (mean_i, std_i) = stats[i];
187            let (mean_j, std_j) = stats[j];
188
189            if std_i == 0.0 || std_j == 0.0 {
190                matrix[i][j] = 0.0;
191                matrix[j][i] = 0.0;
192                continue;
193            }
194
195            let len = channels[i].len().min(channels[j].len());
196            let cov: f64 = channels[i][..len]
197                .iter()
198                .zip(channels[j][..len].iter())
199                .map(|(a, b)| (a - mean_i) * (b - mean_j))
200                .sum::<f64>()
201                / len as f64;
202
203            let r = cov / (std_i * std_j);
204            matrix[i][j] = r;
205            matrix[j][i] = r;
206        }
207    }
208
209    matrix
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use ruv_neural_core::graph::ConnectivityMetric;
216    use ruv_neural_core::signal::FrequencyBand;
217
218    fn make_constructor() -> BrainGraphConstructor {
219        BrainGraphConstructor::new(
220            AtlasType::DesikanKilliany,
221            ConnectivityMetric::PhaseLockingValue,
222            FrequencyBand::Alpha,
223        )
224    }
225
226    #[test]
227    fn identity_matrix_fully_disconnected() {
228        let ctor = make_constructor().with_threshold(0.01);
229        let n = 68;
230        // Identity matrix: diagonal = 1, off-diagonal = 0
231        let identity: Vec<Vec<f64>> = (0..n)
232            .map(|i| {
233                let mut row = vec![0.0; n];
234                row[i] = 1.0;
235                row
236            })
237            .collect();
238
239        let graph = ctor.construct_from_matrix(&identity, 0.0);
240        assert_eq!(graph.num_nodes, 68);
241        assert_eq!(graph.edges.len(), 0, "Identity matrix should produce no edges");
242    }
243
244    #[test]
245    fn ones_matrix_fully_connected() {
246        let ctor = make_constructor().with_threshold(0.01);
247        let n = 68;
248        let ones: Vec<Vec<f64>> = vec![vec![1.0; n]; n];
249
250        let graph = ctor.construct_from_matrix(&ones, 0.0);
251        let expected_edges = n * (n - 1) / 2;
252        assert_eq!(graph.edges.len(), expected_edges);
253    }
254
255    #[test]
256    fn threshold_filters_weak_edges() {
257        let ctor = make_constructor().with_threshold(0.5);
258        let n = 68;
259        let mut matrix = vec![vec![0.0; n]; n];
260        // Set a few strong edges
261        matrix[0][1] = 0.8;
262        matrix[1][0] = 0.8;
263        // Set a weak edge
264        matrix[2][3] = 0.3;
265        matrix[3][2] = 0.3;
266
267        let graph = ctor.construct_from_matrix(&matrix, 0.0);
268        assert_eq!(graph.edges.len(), 1, "Only edge above threshold should survive");
269        assert_eq!(graph.edges[0].source, 0);
270        assert_eq!(graph.edges[0].target, 1);
271    }
272
273    #[test]
274    fn construct_sequence_produces_graphs() {
275        let ctor = BrainGraphConstructor::new(
276            AtlasType::DesikanKilliany,
277            ConnectivityMetric::PhaseLockingValue,
278            FrequencyBand::Alpha,
279        )
280        .with_window_duration(0.5)
281        .with_window_step(0.25);
282
283        // 68 channels, 256 samples at 256 Hz = 1 second of data
284        let n_ch = 68;
285        let n_samples = 256;
286        let data: Vec<Vec<f64>> = (0..n_ch)
287            .map(|i| {
288                (0..n_samples)
289                    .map(|j| ((j as f64 + i as f64) * 0.1).sin())
290                    .collect()
291            })
292            .collect();
293
294        let ts = MultiChannelTimeSeries::new(data, 256.0, 0.0).unwrap();
295        let seq = ctor.construct_sequence(&ts);
296
297        // 1.0s data, 0.5s window, 0.25s step => 3 windows: [0,0.5], [0.25,0.75], [0.5,1.0]
298        assert!(seq.len() >= 2, "Should produce at least 2 graphs, got {}", seq.len());
299    }
300}