rustkernel_graph/
centrality.rs

1//! Centrality measure kernels.
2//!
3//! This module provides GPU-accelerated centrality algorithms:
4//! - Degree centrality
5//! - Betweenness centrality (Brandes algorithm)
6//! - Closeness centrality (BFS-based)
7//! - Eigenvector centrality (power iteration)
8//! - PageRank (power iteration with teleport)
9//! - Katz centrality (attenuated paths)
10
11use crate::messages::{CentralityInput, CentralityOutput, CentralityParams};
12use crate::ring_messages::{
13    K2KBarrier, K2KBarrierRelease, K2KIterationSync, K2KIterationSyncResponse,
14    PageRankConvergeResponse, PageRankConvergeRing, PageRankIterateResponse, PageRankIterateRing,
15    PageRankQueryResponse, PageRankQueryRing, from_fixed_point, to_fixed_point,
16};
17use crate::types::{CentralityResult, CsrGraph, NodeScore};
18use async_trait::async_trait;
19use ringkernel_core::RingContext;
20use rustkernel_core::{
21    domain::Domain,
22    error::Result,
23    k2k::IterativeState,
24    kernel::KernelMetadata,
25    traits::{BatchKernel, GpuKernel, RingKernelHandler},
26};
27use std::collections::VecDeque;
28use std::time::Instant;
29
30// ============================================================================
31// PageRank Kernel
32// ============================================================================
33
34/// PageRank kernel state.
35#[derive(Debug, Clone, Default)]
36pub struct PageRankState {
37    /// Current scores.
38    pub scores: Vec<f64>,
39    /// Previous scores (for convergence check).
40    pub prev_scores: Vec<f64>,
41    /// Graph in CSR format.
42    pub graph: Option<CsrGraph>,
43    /// Damping factor.
44    pub damping: f32,
45    /// Current iteration.
46    pub iteration: u32,
47    /// Whether converged.
48    pub converged: bool,
49}
50
51/// PageRank centrality kernel.
52///
53/// Calculates PageRank centrality using power iteration with teleportation.
54/// This is a Ring kernel for low-latency queries after graph is loaded.
55#[derive(Debug)]
56pub struct PageRank {
57    metadata: KernelMetadata,
58    /// Internal state for Ring mode operations.
59    state: std::sync::RwLock<PageRankState>,
60}
61
62impl Clone for PageRank {
63    fn clone(&self) -> Self {
64        Self {
65            metadata: self.metadata.clone(),
66            state: std::sync::RwLock::new(self.state.read().unwrap().clone()),
67        }
68    }
69}
70
71impl PageRank {
72    /// Create a new PageRank kernel.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            metadata: KernelMetadata::ring("graph/pagerank", Domain::GraphAnalytics)
77                .with_description("PageRank centrality via power iteration")
78                .with_throughput(100_000)
79                .with_latency_us(1.0)
80                .with_gpu_native(true),
81            state: std::sync::RwLock::new(PageRankState::default()),
82        }
83    }
84
85    /// Initialize the kernel with a graph for Ring mode operations.
86    pub fn initialize(&self, graph: CsrGraph, damping: f32) {
87        let mut state = self.state.write().unwrap();
88        *state = Self::initialize_state(graph, damping);
89    }
90
91    /// Query the score for a specific node.
92    pub fn query_score(&self, node_id: u64) -> Option<f64> {
93        let state = self.state.read().unwrap();
94        state.scores.get(node_id as usize).copied()
95    }
96
97    /// Get current iteration count.
98    pub fn current_iteration(&self) -> u32 {
99        self.state.read().unwrap().iteration
100    }
101
102    /// Check if converged.
103    pub fn is_converged(&self) -> bool {
104        self.state.read().unwrap().converged
105    }
106
107    /// Perform one iteration step using internal state.
108    pub fn iterate(&self) -> f64 {
109        let mut state = self.state.write().unwrap();
110        Self::iterate_step(&mut state)
111    }
112
113    /// Perform one iteration of PageRank on the given state.
114    pub fn iterate_step(state: &mut PageRankState) -> f64 {
115        let Some(ref graph) = state.graph else {
116            return 0.0;
117        };
118
119        let n = graph.num_nodes;
120        if n == 0 {
121            return 0.0;
122        }
123
124        let d = state.damping as f64;
125        let teleport = (1.0 - d) / n as f64;
126
127        // Swap buffers
128        std::mem::swap(&mut state.scores, &mut state.prev_scores);
129
130        // Calculate new scores
131        let mut max_diff = 0.0f64;
132
133        for i in 0..n {
134            let mut rank_sum = 0.0f64;
135
136            // Sum contributions from incoming edges
137            for &neighbor in graph.neighbors(i as u64) {
138                let out_degree = graph.out_degree(neighbor) as f64;
139                if out_degree > 0.0 {
140                    rank_sum += state.prev_scores[neighbor as usize] / out_degree;
141                }
142            }
143
144            let new_score = teleport + d * rank_sum;
145            state.scores[i] = new_score;
146
147            let diff = (new_score - state.prev_scores[i]).abs();
148            if diff > max_diff {
149                max_diff = diff;
150            }
151        }
152
153        state.iteration += 1;
154        max_diff
155    }
156
157    /// Initialize state for a graph.
158    pub fn initialize_state(graph: CsrGraph, damping: f32) -> PageRankState {
159        let n = graph.num_nodes;
160        PageRankState {
161            scores: vec![1.0 / n as f64; n],
162            prev_scores: vec![0.0; n],
163            graph: Some(graph),
164            damping,
165            iteration: 0,
166            converged: false,
167        }
168    }
169
170    /// Run PageRank to convergence.
171    pub fn run_to_convergence(
172        graph: CsrGraph,
173        damping: f32,
174        max_iterations: u32,
175        threshold: f64,
176    ) -> Result<CentralityResult> {
177        let mut state = Self::initialize_state(graph, damping);
178
179        for _ in 0..max_iterations {
180            let diff = Self::iterate_step(&mut state);
181            if diff < threshold {
182                state.converged = true;
183                break;
184            }
185        }
186
187        Ok(CentralityResult {
188            scores: state
189                .scores
190                .iter()
191                .enumerate()
192                .map(|(i, &score)| NodeScore {
193                    node_id: i as u64,
194                    score,
195                })
196                .collect(),
197            iterations: Some(state.iteration),
198            converged: state.converged,
199        })
200    }
201}
202
203impl Default for PageRank {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl GpuKernel for PageRank {
210    fn metadata(&self) -> &KernelMetadata {
211        &self.metadata
212    }
213}
214
215// ============================================================================
216// PageRank RingKernelHandler Implementations
217// ============================================================================
218
219/// RingKernelHandler for PageRank queries.
220///
221/// Enables low-latency score queries for individual nodes in Ring mode.
222#[async_trait]
223impl RingKernelHandler<PageRankQueryRing, PageRankQueryResponse> for PageRank {
224    async fn handle(
225        &self,
226        _ctx: &mut RingContext,
227        msg: PageRankQueryRing,
228    ) -> Result<PageRankQueryResponse> {
229        let state = self.state.read().unwrap();
230        let score = state
231            .scores
232            .get(msg.node_id as usize)
233            .copied()
234            .unwrap_or(0.0);
235
236        Ok(PageRankQueryResponse {
237            request_id: msg.id.0,
238            node_id: msg.node_id,
239            score_fp: to_fixed_point(score),
240            iteration: state.iteration,
241            converged: state.converged,
242        })
243    }
244}
245
246/// RingKernelHandler for PageRank single iteration.
247///
248/// Performs one power iteration step in Ring mode.
249#[async_trait]
250impl RingKernelHandler<PageRankIterateRing, PageRankIterateResponse> for PageRank {
251    async fn handle(
252        &self,
253        _ctx: &mut RingContext,
254        msg: PageRankIterateRing,
255    ) -> Result<PageRankIterateResponse> {
256        // Perform one iteration on internal state
257        let max_delta = self.iterate();
258
259        // Check convergence using default threshold
260        let state = self.state.read().unwrap();
261        let converged = max_delta < 1e-6;
262
263        Ok(PageRankIterateResponse {
264            request_id: msg.id.0,
265            iteration: state.iteration,
266            max_delta_fp: to_fixed_point(max_delta),
267            converged,
268        })
269    }
270}
271
272/// RingKernelHandler for PageRank convergence.
273///
274/// Runs PageRank to convergence using K2K coordination for iterative state.
275#[async_trait]
276impl RingKernelHandler<PageRankConvergeRing, PageRankConvergeResponse> for PageRank {
277    async fn handle(
278        &self,
279        _ctx: &mut RingContext,
280        msg: PageRankConvergeRing,
281    ) -> Result<PageRankConvergeResponse> {
282        let threshold = from_fixed_point(msg.threshold_fp);
283        let max_iterations = msg.max_iterations as u64;
284
285        // Use K2K IterativeState for convergence tracking
286        let mut iterative_state = IterativeState::new(threshold, max_iterations);
287
288        // Run actual iterations on internal state
289        while iterative_state.should_continue() {
290            let max_delta = self.iterate();
291            iterative_state.update(max_delta);
292        }
293
294        // Update convergence status in internal state
295        {
296            let mut state = self.state.write().unwrap();
297            state.converged = iterative_state.summary().converged;
298        }
299
300        let summary = iterative_state.summary();
301
302        Ok(PageRankConvergeResponse {
303            request_id: msg.id.0,
304            iterations: summary.iterations as u32,
305            final_delta_fp: to_fixed_point(summary.final_delta),
306            converged: summary.converged,
307        })
308    }
309}
310
311/// RingKernelHandler for K2K iteration synchronization.
312///
313/// Used in distributed PageRank to synchronize iterations across partitions.
314/// In a single-instance setting, this validates the worker's iteration state
315/// and returns convergence status based on the reported delta.
316#[async_trait]
317impl RingKernelHandler<K2KIterationSync, K2KIterationSyncResponse> for PageRank {
318    async fn handle(
319        &self,
320        _ctx: &mut RingContext,
321        msg: K2KIterationSync,
322    ) -> Result<K2KIterationSyncResponse> {
323        let state = self.state.read().unwrap();
324
325        // For single-instance, verify iteration matches internal state
326        // In distributed setting, would aggregate deltas from all workers
327        let current_iteration = state.iteration as u64;
328        let all_synced = msg.iteration <= current_iteration;
329
330        // Use reported local delta as global delta (single worker case)
331        // In distributed setting, would compute max across all workers
332        let local_delta = from_fixed_point(msg.local_delta_fp);
333        let global_converged = local_delta < 1e-6 || state.converged;
334
335        Ok(K2KIterationSyncResponse {
336            request_id: msg.id.0,
337            iteration: msg.iteration,
338            all_synced,
339            global_delta_fp: msg.local_delta_fp,
340            global_converged,
341        })
342    }
343}
344
345/// RingKernelHandler for K2K barrier synchronization.
346///
347/// Implements barrier synchronization for distributed PageRank iterations.
348#[async_trait]
349impl RingKernelHandler<K2KBarrier, K2KBarrierRelease> for PageRank {
350    async fn handle(&self, _ctx: &mut RingContext, msg: K2KBarrier) -> Result<K2KBarrierRelease> {
351        // In a distributed setting, this would:
352        // 1. Record this worker as ready
353        // 2. Check if all workers are ready
354        // 3. Release barrier when all ready
355        let all_ready = msg.ready_count >= msg.total_workers;
356
357        Ok(K2KBarrierRelease {
358            barrier_id: msg.barrier_id,
359            all_ready,
360            next_iteration: msg.barrier_id + 1,
361        })
362    }
363}
364
365// ============================================================================
366// Degree Centrality Kernel
367// ============================================================================
368
369/// Degree centrality kernel.
370///
371/// Simple O(1) lookup of node degrees after graph is loaded.
372#[derive(Debug, Clone)]
373pub struct DegreeCentrality {
374    metadata: KernelMetadata,
375}
376
377impl DegreeCentrality {
378    /// Create a new degree centrality kernel.
379    #[must_use]
380    pub fn new() -> Self {
381        Self {
382            metadata: KernelMetadata::ring("graph/degree-centrality", Domain::GraphAnalytics)
383                .with_description("Degree centrality (O(1) lookup)")
384                .with_throughput(1_000_000)
385                .with_latency_us(0.1),
386        }
387    }
388
389    /// Calculate degree centrality for all nodes.
390    ///
391    /// Returns normalized degree centrality (degree / (n-1)).
392    pub fn compute(graph: &CsrGraph) -> CentralityResult {
393        let n = graph.num_nodes;
394        let normalizer = if n > 1 { (n - 1) as f64 } else { 1.0 };
395
396        let scores: Vec<NodeScore> = (0..n)
397            .map(|i| NodeScore {
398                node_id: i as u64,
399                score: graph.out_degree(i as u64) as f64 / normalizer,
400            })
401            .collect();
402
403        CentralityResult {
404            scores,
405            iterations: None,
406            converged: true,
407        }
408    }
409}
410
411impl Default for DegreeCentrality {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417impl GpuKernel for DegreeCentrality {
418    fn metadata(&self) -> &KernelMetadata {
419        &self.metadata
420    }
421}
422
423// ============================================================================
424// Betweenness Centrality Kernel (Brandes Algorithm)
425// ============================================================================
426
427/// Betweenness centrality kernel.
428///
429/// Uses Brandes algorithm for efficient computation in O(VE) time.
430#[derive(Debug, Clone)]
431pub struct BetweennessCentrality {
432    metadata: KernelMetadata,
433}
434
435impl BetweennessCentrality {
436    /// Create a new betweenness centrality kernel.
437    #[must_use]
438    pub fn new() -> Self {
439        Self {
440            metadata: KernelMetadata::batch("graph/betweenness-centrality", Domain::GraphAnalytics)
441                .with_description("Betweenness centrality (Brandes algorithm)")
442                .with_throughput(10_000)
443                .with_latency_us(100.0),
444        }
445    }
446
447    /// Compute betweenness centrality using Brandes algorithm.
448    ///
449    /// The algorithm runs BFS from each vertex and accumulates
450    /// dependency scores in a single backward pass.
451    pub fn compute(graph: &CsrGraph, normalized: bool) -> CentralityResult {
452        let n = graph.num_nodes;
453        let mut centrality = vec![0.0f64; n];
454
455        // Run Brandes algorithm from each source
456        for s in 0..n {
457            // BFS structures
458            let mut stack: Vec<usize> = Vec::with_capacity(n);
459            let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
460            let mut sigma = vec![0.0f64; n]; // Number of shortest paths
461            let mut dist = vec![-1i64; n]; // Distance from source
462
463            sigma[s] = 1.0;
464            dist[s] = 0;
465
466            let mut queue = VecDeque::new();
467            queue.push_back(s);
468
469            // Forward BFS
470            while let Some(v) = queue.pop_front() {
471                stack.push(v);
472
473                for &w in graph.neighbors(v as u64) {
474                    let w = w as usize;
475
476                    // First time visiting w?
477                    if dist[w] < 0 {
478                        dist[w] = dist[v] + 1;
479                        queue.push_back(w);
480                    }
481
482                    // Is this a shortest path to w via v?
483                    if dist[w] == dist[v] + 1 {
484                        sigma[w] += sigma[v];
485                        predecessors[w].push(v);
486                    }
487                }
488            }
489
490            // Backward pass - accumulate dependencies
491            let mut delta = vec![0.0f64; n];
492
493            while let Some(w) = stack.pop() {
494                for &v in &predecessors[w] {
495                    let contribution = (sigma[v] / sigma[w]) * (1.0 + delta[w]);
496                    delta[v] += contribution;
497                }
498
499                if w != s {
500                    centrality[w] += delta[w];
501                }
502            }
503        }
504
505        // Normalize if requested
506        if normalized && n > 2 {
507            let scale = 1.0 / ((n - 1) * (n - 2)) as f64;
508            for c in &mut centrality {
509                *c *= scale;
510            }
511        }
512
513        CentralityResult {
514            scores: centrality
515                .into_iter()
516                .enumerate()
517                .map(|(i, score)| NodeScore {
518                    node_id: i as u64,
519                    score,
520                })
521                .collect(),
522            iterations: None,
523            converged: true,
524        }
525    }
526}
527
528impl Default for BetweennessCentrality {
529    fn default() -> Self {
530        Self::new()
531    }
532}
533
534impl GpuKernel for BetweennessCentrality {
535    fn metadata(&self) -> &KernelMetadata {
536        &self.metadata
537    }
538}
539
540// ============================================================================
541// Closeness Centrality Kernel
542// ============================================================================
543
544/// Closeness centrality kernel.
545///
546/// BFS-based closeness centrality calculation.
547/// Closeness = (n-1) / sum(shortest_path_distances)
548#[derive(Debug, Clone)]
549pub struct ClosenessCentrality {
550    metadata: KernelMetadata,
551}
552
553impl ClosenessCentrality {
554    /// Create a new closeness centrality kernel.
555    #[must_use]
556    pub fn new() -> Self {
557        Self {
558            metadata: KernelMetadata::batch("graph/closeness-centrality", Domain::GraphAnalytics)
559                .with_description("Closeness centrality (BFS-based)")
560                .with_throughput(10_000)
561                .with_latency_us(100.0),
562        }
563    }
564
565    /// Compute closeness centrality using BFS from each node.
566    ///
567    /// For disconnected graphs, uses harmonic mean variant.
568    #[allow(clippy::needless_range_loop)]
569    pub fn compute(graph: &CsrGraph, harmonic: bool) -> CentralityResult {
570        let n = graph.num_nodes;
571        let mut centrality = vec![0.0f64; n];
572
573        for source in 0..n {
574            let distances = Self::bfs_distances(graph, source);
575
576            if harmonic {
577                // Harmonic centrality: sum(1/d) for all reachable nodes
578                let sum: f64 = distances
579                    .iter()
580                    .enumerate()
581                    .filter(|(i, d)| *i != source && **d > 0)
582                    .map(|(_, d)| 1.0 / *d as f64)
583                    .sum();
584                centrality[source] = sum / (n - 1) as f64;
585            } else {
586                // Classic closeness: (n-1) / sum(distances)
587                let sum: i64 = distances.iter().sum();
588                let reachable: usize = distances.iter().filter(|&&d| d > 0).count();
589
590                if sum > 0 && reachable > 0 {
591                    centrality[source] = reachable as f64 / sum as f64;
592                }
593            }
594        }
595
596        CentralityResult {
597            scores: centrality
598                .into_iter()
599                .enumerate()
600                .map(|(i, score)| NodeScore {
601                    node_id: i as u64,
602                    score,
603                })
604                .collect(),
605            iterations: None,
606            converged: true,
607        }
608    }
609
610    /// BFS to compute distances from source to all other nodes.
611    fn bfs_distances(graph: &CsrGraph, source: usize) -> Vec<i64> {
612        let n = graph.num_nodes;
613        let mut distances = vec![0i64; n];
614        let mut visited = vec![false; n];
615
616        let mut queue = VecDeque::new();
617        queue.push_back(source);
618        visited[source] = true;
619
620        while let Some(v) = queue.pop_front() {
621            for &w in graph.neighbors(v as u64) {
622                let w = w as usize;
623                if !visited[w] {
624                    visited[w] = true;
625                    distances[w] = distances[v] + 1;
626                    queue.push_back(w);
627                }
628            }
629        }
630
631        distances
632    }
633}
634
635impl Default for ClosenessCentrality {
636    fn default() -> Self {
637        Self::new()
638    }
639}
640
641impl GpuKernel for ClosenessCentrality {
642    fn metadata(&self) -> &KernelMetadata {
643        &self.metadata
644    }
645}
646
647// ============================================================================
648// Eigenvector Centrality Kernel
649// ============================================================================
650
651/// Eigenvector centrality kernel.
652///
653/// Power iteration method for eigenvector centrality.
654/// A node's score is proportional to the sum of its neighbors' scores.
655#[derive(Debug, Clone)]
656pub struct EigenvectorCentrality {
657    metadata: KernelMetadata,
658}
659
660impl EigenvectorCentrality {
661    /// Create a new eigenvector centrality kernel.
662    #[must_use]
663    pub fn new() -> Self {
664        Self {
665            metadata: KernelMetadata::batch("graph/eigenvector-centrality", Domain::GraphAnalytics)
666                .with_description("Eigenvector centrality (power iteration)")
667                .with_throughput(50_000)
668                .with_latency_us(10.0),
669        }
670    }
671
672    /// Compute eigenvector centrality using power iteration.
673    #[allow(clippy::needless_range_loop)]
674    pub fn compute(graph: &CsrGraph, max_iterations: u32, tolerance: f64) -> CentralityResult {
675        let n = graph.num_nodes;
676        if n == 0 {
677            return CentralityResult {
678                scores: Vec::new(),
679                iterations: Some(0),
680                converged: true,
681            };
682        }
683
684        // Initialize with uniform scores
685        let mut scores = vec![1.0 / (n as f64).sqrt(); n];
686        let mut new_scores = vec![0.0f64; n];
687        let mut converged = false;
688        let mut iterations = 0u32;
689
690        for iter in 0..max_iterations {
691            iterations = iter + 1;
692
693            // Compute new scores: x_i = sum(A_ij * x_j)
694            for i in 0..n {
695                let mut sum = 0.0f64;
696                for &j in graph.neighbors(i as u64) {
697                    sum += scores[j as usize];
698                }
699                new_scores[i] = sum;
700            }
701
702            // Normalize
703            let norm: f64 = new_scores.iter().map(|x| x * x).sum::<f64>().sqrt();
704            if norm > 0.0 {
705                for x in &mut new_scores {
706                    *x /= norm;
707                }
708            }
709
710            // Check convergence
711            let diff: f64 = scores
712                .iter()
713                .zip(new_scores.iter())
714                .map(|(a, b)| (a - b).abs())
715                .fold(0.0f64, |acc, x| acc.max(x));
716
717            std::mem::swap(&mut scores, &mut new_scores);
718
719            if diff < tolerance {
720                converged = true;
721                break;
722            }
723        }
724
725        CentralityResult {
726            scores: scores
727                .into_iter()
728                .enumerate()
729                .map(|(i, score)| NodeScore {
730                    node_id: i as u64,
731                    score,
732                })
733                .collect(),
734            iterations: Some(iterations),
735            converged,
736        }
737    }
738}
739
740impl Default for EigenvectorCentrality {
741    fn default() -> Self {
742        Self::new()
743    }
744}
745
746impl GpuKernel for EigenvectorCentrality {
747    fn metadata(&self) -> &KernelMetadata {
748        &self.metadata
749    }
750}
751
752// ============================================================================
753// Katz Centrality Kernel
754// ============================================================================
755
756/// Katz centrality kernel.
757///
758/// Measures influence through attenuated paths.
759/// Katz(i) = sum over all paths from j to i of alpha^(path_length)
760#[derive(Debug, Clone)]
761pub struct KatzCentrality {
762    metadata: KernelMetadata,
763}
764
765impl KatzCentrality {
766    /// Create a new Katz centrality kernel.
767    #[must_use]
768    pub fn new() -> Self {
769        Self {
770            metadata: KernelMetadata::batch("graph/katz-centrality", Domain::GraphAnalytics)
771                .with_description("Katz centrality (attenuated paths)")
772                .with_throughput(50_000)
773                .with_latency_us(10.0),
774        }
775    }
776
777    /// Compute Katz centrality.
778    ///
779    /// # Arguments
780    /// * `graph` - The input graph
781    /// * `alpha` - Attenuation factor (should be < 1/lambda_max)
782    /// * `beta` - Base score for each node (default 1.0)
783    /// * `max_iterations` - Maximum iterations for power iteration
784    /// * `tolerance` - Convergence threshold
785    #[allow(clippy::needless_range_loop)]
786    pub fn compute(
787        graph: &CsrGraph,
788        alpha: f64,
789        beta: f64,
790        max_iterations: u32,
791        tolerance: f64,
792    ) -> CentralityResult {
793        let n = graph.num_nodes;
794        if n == 0 {
795            return CentralityResult {
796                scores: Vec::new(),
797                iterations: Some(0),
798                converged: true,
799            };
800        }
801
802        // Initialize scores
803        let mut scores = vec![0.0f64; n];
804        let mut new_scores = vec![0.0f64; n];
805        let mut converged = false;
806        let mut iterations = 0u32;
807
808        // Power iteration: x = alpha * A * x + beta
809        for iter in 0..max_iterations {
810            iterations = iter + 1;
811
812            for i in 0..n {
813                let mut sum = 0.0f64;
814                for &j in graph.neighbors(i as u64) {
815                    sum += scores[j as usize];
816                }
817                new_scores[i] = alpha * sum + beta;
818            }
819
820            // Check convergence
821            let diff: f64 = scores
822                .iter()
823                .zip(new_scores.iter())
824                .map(|(a, b)| (a - b).abs())
825                .fold(0.0f64, |acc, x| acc.max(x));
826
827            std::mem::swap(&mut scores, &mut new_scores);
828
829            if diff < tolerance {
830                converged = true;
831                break;
832            }
833        }
834
835        // Normalize by maximum score
836        let max_score = scores.iter().cloned().fold(0.0f64, f64::max);
837        if max_score > 0.0 {
838            for s in &mut scores {
839                *s /= max_score;
840            }
841        }
842
843        CentralityResult {
844            scores: scores
845                .into_iter()
846                .enumerate()
847                .map(|(i, score)| NodeScore {
848                    node_id: i as u64,
849                    score,
850                })
851                .collect(),
852            iterations: Some(iterations),
853            converged,
854        }
855    }
856}
857
858impl Default for KatzCentrality {
859    fn default() -> Self {
860        Self::new()
861    }
862}
863
864impl GpuKernel for KatzCentrality {
865    fn metadata(&self) -> &KernelMetadata {
866        &self.metadata
867    }
868}
869
870// ============================================================================
871// BatchKernel Implementations
872// ============================================================================
873
874/// Batch execution wrapper for all centrality kernels.
875///
876/// Since centrality algorithms are computationally intensive,
877/// they benefit from batch execution with CPU orchestration.
878
879#[async_trait]
880impl BatchKernel<CentralityInput, CentralityOutput> for BetweennessCentrality {
881    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
882        let start = Instant::now();
883        let normalized = input.normalize;
884        let result = Self::compute(&input.graph, normalized);
885        let compute_time_us = start.elapsed().as_micros() as u64;
886
887        Ok(CentralityOutput {
888            result,
889            compute_time_us,
890        })
891    }
892}
893
894#[async_trait]
895impl BatchKernel<CentralityInput, CentralityOutput> for ClosenessCentrality {
896    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
897        let start = Instant::now();
898        let harmonic = match input.params {
899            CentralityParams::Closeness { harmonic } => harmonic,
900            _ => false,
901        };
902        let result = Self::compute(&input.graph, harmonic);
903        let compute_time_us = start.elapsed().as_micros() as u64;
904
905        Ok(CentralityOutput {
906            result,
907            compute_time_us,
908        })
909    }
910}
911
912#[async_trait]
913impl BatchKernel<CentralityInput, CentralityOutput> for EigenvectorCentrality {
914    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
915        let start = Instant::now();
916        let max_iterations = input.max_iterations.unwrap_or(1000);
917        let tolerance = input.tolerance.unwrap_or(1e-6);
918        let result = Self::compute(&input.graph, max_iterations, tolerance);
919        let compute_time_us = start.elapsed().as_micros() as u64;
920
921        Ok(CentralityOutput {
922            result,
923            compute_time_us,
924        })
925    }
926}
927
928#[async_trait]
929impl BatchKernel<CentralityInput, CentralityOutput> for KatzCentrality {
930    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
931        let start = Instant::now();
932        let (alpha, beta) = match input.params {
933            CentralityParams::Katz { alpha, beta } => (alpha, beta),
934            _ => (0.1, 1.0),
935        };
936        let max_iterations = input.max_iterations.unwrap_or(100);
937        let tolerance = input.tolerance.unwrap_or(1e-6);
938        let result = Self::compute(&input.graph, alpha, beta, max_iterations, tolerance);
939        let compute_time_us = start.elapsed().as_micros() as u64;
940
941        Ok(CentralityOutput {
942            result,
943            compute_time_us,
944        })
945    }
946}
947
948/// PageRank can be used in both batch and ring modes.
949/// This is the batch mode implementation.
950impl PageRank {
951    /// Execute PageRank as a batch operation.
952    ///
953    /// Convenience method that runs the algorithm to convergence.
954    pub async fn compute_batch(
955        &self,
956        graph: CsrGraph,
957        damping: f32,
958        max_iterations: u32,
959        threshold: f64,
960    ) -> Result<CentralityResult> {
961        Self::run_to_convergence(graph, damping, max_iterations, threshold)
962    }
963}
964
965#[async_trait]
966impl BatchKernel<CentralityInput, CentralityOutput> for PageRank {
967    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
968        let start = Instant::now();
969        let damping = match input.params {
970            CentralityParams::PageRank { damping } => damping,
971            _ => 0.85,
972        };
973        let max_iterations = input.max_iterations.unwrap_or(100);
974        let tolerance = input.tolerance.unwrap_or(1e-6);
975        let result = Self::run_to_convergence(input.graph, damping, max_iterations, tolerance)?;
976        let compute_time_us = start.elapsed().as_micros() as u64;
977
978        Ok(CentralityOutput {
979            result,
980            compute_time_us,
981        })
982    }
983}
984
985/// Degree centrality batch implementation.
986#[async_trait]
987impl BatchKernel<CentralityInput, CentralityOutput> for DegreeCentrality {
988    async fn execute(&self, input: CentralityInput) -> Result<CentralityOutput> {
989        let start = Instant::now();
990        let result = Self::compute(&input.graph);
991        let compute_time_us = start.elapsed().as_micros() as u64;
992
993        Ok(CentralityOutput {
994            result,
995            compute_time_us,
996        })
997    }
998}
999
1000#[cfg(test)]
1001mod tests {
1002    use super::*;
1003
1004    fn create_test_graph() -> CsrGraph {
1005        // Simple graph: 0 -> 1 -> 2 -> 3 -> 0 (cycle)
1006        CsrGraph::from_edges(4, &[(0, 1), (1, 2), (2, 3), (3, 0)])
1007    }
1008
1009    fn create_star_graph() -> CsrGraph {
1010        // Star graph: center node 0 connected to all others
1011        CsrGraph::from_edges(
1012            5,
1013            &[
1014                (0, 1),
1015                (0, 2),
1016                (0, 3),
1017                (0, 4),
1018                (1, 0),
1019                (2, 0),
1020                (3, 0),
1021                (4, 0),
1022            ],
1023        )
1024    }
1025
1026    #[test]
1027    fn test_pagerank_metadata() {
1028        let kernel = PageRank::new();
1029        assert_eq!(kernel.metadata().id, "graph/pagerank");
1030        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
1031    }
1032
1033    #[test]
1034    fn test_pagerank_iteration() {
1035        let graph = create_test_graph();
1036        let mut state = PageRank::initialize_state(graph, 0.85);
1037
1038        let diff = PageRank::iterate_step(&mut state);
1039        assert!(diff >= 0.0);
1040        assert_eq!(state.iteration, 1);
1041    }
1042
1043    #[test]
1044    fn test_pagerank_convergence() {
1045        let graph = create_test_graph();
1046        let result = PageRank::run_to_convergence(graph, 0.85, 100, 1e-6).unwrap();
1047
1048        assert!(result.converged);
1049        assert_eq!(result.scores.len(), 4);
1050
1051        // In a cycle, all nodes should have equal PageRank
1052        let first_score = result.scores[0].score;
1053        for score in &result.scores {
1054            assert!((score.score - first_score).abs() < 0.01);
1055        }
1056    }
1057
1058    #[test]
1059    fn test_degree_centrality() {
1060        let graph = create_star_graph();
1061        let result = DegreeCentrality::compute(&graph);
1062
1063        assert_eq!(result.scores.len(), 5);
1064
1065        // Center node (0) should have highest degree
1066        let center_score = result.scores[0].score;
1067        for score in &result.scores[1..] {
1068            assert!(center_score > score.score);
1069        }
1070    }
1071
1072    #[test]
1073    fn test_betweenness_centrality() {
1074        // Line graph: 0 - 1 - 2 - 3
1075        let graph = CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)]);
1076
1077        let result = BetweennessCentrality::compute(&graph, false);
1078
1079        assert_eq!(result.scores.len(), 4);
1080
1081        // Middle nodes (1, 2) should have highest betweenness
1082        let node_1_score = result.scores[1].score;
1083        let node_0_score = result.scores[0].score;
1084        assert!(node_1_score > node_0_score);
1085    }
1086
1087    #[test]
1088    fn test_closeness_centrality() {
1089        let graph = create_star_graph();
1090        let result = ClosenessCentrality::compute(&graph, false);
1091
1092        assert_eq!(result.scores.len(), 5);
1093
1094        // Center node should have highest closeness
1095        let center_score = result.scores[0].score;
1096        for score in &result.scores[1..] {
1097            assert!(center_score >= score.score);
1098        }
1099    }
1100
1101    #[test]
1102    fn test_eigenvector_centrality() {
1103        let graph = create_star_graph();
1104        let result = EigenvectorCentrality::compute(&graph, 1000, 1e-4);
1105
1106        // May or may not converge depending on graph structure
1107        assert_eq!(result.scores.len(), 5);
1108
1109        // Center node should have high eigenvector centrality
1110        // (may not be highest due to star graph properties)
1111        let center_score = result.scores[0].score;
1112        assert!(center_score > 0.0);
1113    }
1114
1115    #[test]
1116    fn test_katz_centrality() {
1117        let graph = create_star_graph();
1118        let result = KatzCentrality::compute(&graph, 0.1, 1.0, 100, 1e-6);
1119
1120        assert!(result.converged);
1121        assert_eq!(result.scores.len(), 5);
1122    }
1123}