rustkernel_core/
k2k.rs

1//! K2K (Kernel-to-Kernel) coordination utilities.
2//!
3//! This module provides higher-level abstractions for K2K communication patterns
4//! commonly used in financial analytics pipelines.
5//!
6//! ## Coordination Patterns
7//!
8//! - **Fan-out**: One kernel broadcasting to multiple downstream kernels
9//! - **Fan-in**: Multiple kernels sending to one aggregator kernel
10//! - **Pipeline**: Sequential multi-stage processing
11//! - **Scatter-Gather**: Parallel processing with result aggregation
12//! - **Iterative**: Convergence-based algorithms (PageRank, K-Means)
13
14use ringkernel_core::runtime::KernelId;
15use std::collections::HashMap;
16use std::collections::hash_map::DefaultHasher;
17use std::hash::{Hash, Hasher};
18
19// ============================================================================
20// Kernel ID Utilities
21// ============================================================================
22
23/// Convert a KernelId to a u64 hash for message envelope addressing.
24pub fn kernel_id_to_u64(id: &KernelId) -> u64 {
25    let mut hasher = DefaultHasher::new();
26    id.as_str().hash(&mut hasher);
27    hasher.finish()
28}
29
30// ============================================================================
31// Iterative Convergence Coordinator
32// ============================================================================
33
34/// State for tracking iterative algorithm convergence.
35///
36/// Used for algorithms like PageRank, K-Means, GARCH that iterate until convergence.
37#[derive(Debug, Clone)]
38pub struct IterativeState {
39    /// Current iteration number.
40    pub iteration: u64,
41    /// Last computed delta/error value.
42    pub last_delta: f64,
43    /// Convergence threshold.
44    pub convergence_threshold: f64,
45    /// Maximum allowed iterations.
46    pub max_iterations: u64,
47    /// Whether algorithm has converged.
48    pub converged: bool,
49}
50
51impl IterativeState {
52    /// Create a new iterative state.
53    pub fn new(convergence_threshold: f64, max_iterations: u64) -> Self {
54        Self {
55            iteration: 0,
56            last_delta: f64::MAX,
57            convergence_threshold,
58            max_iterations,
59            converged: false,
60        }
61    }
62
63    /// Update state with new delta from an iteration.
64    pub fn update(&mut self, delta: f64) -> bool {
65        self.iteration += 1;
66        self.last_delta = delta;
67        self.converged =
68            delta < self.convergence_threshold || self.iteration >= self.max_iterations;
69        self.converged
70    }
71
72    /// Check if should continue iterating.
73    pub fn should_continue(&self) -> bool {
74        !self.converged && self.iteration < self.max_iterations
75    }
76
77    /// Reset state for a new run.
78    pub fn reset(&mut self) {
79        self.iteration = 0;
80        self.last_delta = f64::MAX;
81        self.converged = false;
82    }
83
84    /// Get convergence summary.
85    pub fn summary(&self) -> IterativeConvergenceSummary {
86        IterativeConvergenceSummary {
87            iterations: self.iteration,
88            final_delta: self.last_delta,
89            converged: self.converged,
90            reached_max: self.iteration >= self.max_iterations,
91        }
92    }
93}
94
95/// Summary of iterative algorithm convergence.
96#[derive(Debug, Clone)]
97pub struct IterativeConvergenceSummary {
98    /// Number of iterations executed.
99    pub iterations: u64,
100    /// Final delta/error value.
101    pub final_delta: f64,
102    /// Whether convergence was achieved.
103    pub converged: bool,
104    /// Whether max iterations was reached.
105    pub reached_max: bool,
106}
107
108// ============================================================================
109// Pipeline Stage Tracker
110// ============================================================================
111
112/// Tracks progress through a multi-stage pipeline.
113#[derive(Debug, Clone)]
114pub struct PipelineTracker {
115    stages: Vec<String>,
116    current_stage: usize,
117    stage_timings_us: HashMap<String, u64>,
118    total_items_processed: u64,
119}
120
121impl PipelineTracker {
122    /// Create a new pipeline tracker with the given stages.
123    pub fn new(stages: Vec<String>) -> Self {
124        Self {
125            stages,
126            current_stage: 0,
127            stage_timings_us: HashMap::new(),
128            total_items_processed: 0,
129        }
130    }
131
132    /// Get the current stage name.
133    pub fn current_stage(&self) -> Option<&str> {
134        self.stages.get(self.current_stage).map(|s| s.as_str())
135    }
136
137    /// Get the next stage name.
138    pub fn next_stage(&self) -> Option<&str> {
139        self.stages.get(self.current_stage + 1).map(|s| s.as_str())
140    }
141
142    /// Advance to the next stage, recording timing for the completed stage.
143    pub fn advance(&mut self, elapsed_us: u64) -> bool {
144        if let Some(stage) = self.stages.get(self.current_stage) {
145            self.stage_timings_us.insert(stage.clone(), elapsed_us);
146        }
147        if self.current_stage + 1 < self.stages.len() {
148            self.current_stage += 1;
149            true
150        } else {
151            false
152        }
153    }
154
155    /// Record items processed in current stage.
156    pub fn record_items(&mut self, count: u64) {
157        self.total_items_processed += count;
158    }
159
160    /// Check if pipeline is complete.
161    pub fn is_complete(&self) -> bool {
162        self.current_stage >= self.stages.len().saturating_sub(1)
163            && self.stage_timings_us.len() >= self.stages.len()
164    }
165
166    /// Get total pipeline timing.
167    pub fn total_time_us(&self) -> u64 {
168        self.stage_timings_us.values().sum()
169    }
170
171    /// Get timing for a specific stage.
172    pub fn stage_timing(&self, stage: &str) -> Option<u64> {
173        self.stage_timings_us.get(stage).copied()
174    }
175
176    /// Reset pipeline for new processing.
177    pub fn reset(&mut self) {
178        self.current_stage = 0;
179        self.stage_timings_us.clear();
180        self.total_items_processed = 0;
181    }
182}
183
184// ============================================================================
185// Scatter-Gather State
186// ============================================================================
187
188/// Tracks scatter-gather operation state.
189#[derive(Debug)]
190pub struct ScatterGatherState<T> {
191    /// Number of workers to scatter to.
192    pub worker_count: usize,
193    /// Results received so far.
194    pub results: Vec<T>,
195    /// Workers that have responded.
196    pub responded_workers: Vec<KernelId>,
197    /// Start timestamp (microseconds).
198    pub start_time_us: u64,
199}
200
201impl<T> ScatterGatherState<T> {
202    /// Create new scatter-gather state.
203    pub fn new(worker_count: usize, start_time_us: u64) -> Self {
204        Self {
205            worker_count,
206            results: Vec::with_capacity(worker_count),
207            responded_workers: Vec::with_capacity(worker_count),
208            start_time_us,
209        }
210    }
211
212    /// Record a result from a worker.
213    pub fn receive_result(&mut self, worker: KernelId, result: T) {
214        if !self.responded_workers.contains(&worker) {
215            self.responded_workers.push(worker);
216            self.results.push(result);
217        }
218    }
219
220    /// Check if all workers have responded.
221    pub fn is_complete(&self) -> bool {
222        self.responded_workers.len() >= self.worker_count
223    }
224
225    /// Get count of pending responses.
226    pub fn pending_count(&self) -> usize {
227        self.worker_count
228            .saturating_sub(self.responded_workers.len())
229    }
230
231    /// Get the results (consumes the state).
232    pub fn take_results(self) -> Vec<T> {
233        self.results
234    }
235}
236
237// ============================================================================
238// Fan-Out Destination Tracker
239// ============================================================================
240
241/// Tracks fan-out broadcast destinations and delivery status.
242#[derive(Debug, Clone)]
243pub struct FanOutTracker {
244    destinations: Vec<KernelId>,
245    delivery_status: HashMap<String, bool>,
246    broadcast_count: u64,
247}
248
249impl FanOutTracker {
250    /// Create new fan-out tracker.
251    pub fn new() -> Self {
252        Self {
253            destinations: Vec::new(),
254            delivery_status: HashMap::new(),
255            broadcast_count: 0,
256        }
257    }
258
259    /// Add a destination kernel.
260    pub fn add_destination(&mut self, dest: KernelId) {
261        if !self
262            .destinations
263            .iter()
264            .any(|d| d.as_str() == dest.as_str())
265        {
266            self.destinations.push(dest);
267        }
268    }
269
270    /// Remove a destination kernel.
271    pub fn remove_destination(&mut self, dest: &KernelId) {
272        self.destinations.retain(|d| d.as_str() != dest.as_str());
273        self.delivery_status.remove(dest.as_str());
274    }
275
276    /// Get all destination IDs.
277    pub fn destinations(&self) -> &[KernelId] {
278        &self.destinations
279    }
280
281    /// Record broadcast attempt.
282    pub fn record_broadcast(&mut self) {
283        self.broadcast_count += 1;
284        // Reset delivery status for new broadcast
285        for dest in &self.destinations {
286            self.delivery_status
287                .insert(dest.as_str().to_string(), false);
288        }
289    }
290
291    /// Mark delivery to a destination as successful.
292    pub fn mark_delivered(&mut self, dest: &KernelId) {
293        self.delivery_status.insert(dest.as_str().to_string(), true);
294    }
295
296    /// Get delivery success count for last broadcast.
297    pub fn delivery_count(&self) -> usize {
298        self.delivery_status.values().filter(|&&v| v).count()
299    }
300
301    /// Get total broadcast count.
302    pub fn broadcast_count(&self) -> u64 {
303        self.broadcast_count
304    }
305
306    /// Get destination count.
307    pub fn destination_count(&self) -> usize {
308        self.destinations.len()
309    }
310}
311
312impl Default for FanOutTracker {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318// ============================================================================
319// K2K Control Messages
320// ============================================================================
321
322/// Control messages for K2K coordination between kernels.
323#[derive(Debug, Clone)]
324pub enum K2KControlMessage {
325    /// Signal to start processing.
326    Start {
327        /// Correlation ID for tracking.
328        correlation_id: u64,
329    },
330    /// Signal to stop processing.
331    Stop {
332        /// Reason for stopping.
333        reason: String,
334    },
335    /// Request current state/status.
336    GetStatus {
337        /// Correlation ID for response.
338        correlation_id: u64,
339    },
340    /// Signal iteration complete.
341    IterationComplete {
342        /// Iteration number.
343        iteration: u64,
344        /// Delta/error from this iteration.
345        delta: f64,
346        /// Worker ID that completed.
347        worker_id: u64,
348    },
349    /// Signal convergence reached.
350    Converged {
351        /// Total iterations.
352        iterations: u64,
353        /// Final delta/error.
354        final_delta: f64,
355    },
356    /// Signal processing error.
357    Error {
358        /// Error message.
359        message: String,
360        /// Error code.
361        code: u32,
362    },
363    /// Heartbeat for liveness checking.
364    Heartbeat {
365        /// Sequence number.
366        sequence: u64,
367        /// Timestamp (microseconds).
368        timestamp_us: u64,
369    },
370    /// Barrier synchronization.
371    Barrier {
372        /// Barrier ID.
373        barrier_id: u64,
374        /// Worker that reached barrier.
375        worker_id: u64,
376    },
377}
378
379// ============================================================================
380// K2K Aggregation Result
381// ============================================================================
382
383/// Result from a worker in a scatter-gather operation.
384#[derive(Debug, Clone)]
385pub struct K2KWorkerResult<T> {
386    /// Worker that produced this result.
387    pub worker_id: KernelId,
388    /// Correlation ID linking to original request.
389    pub correlation_id: u64,
390    /// The result data.
391    pub result: T,
392    /// Processing time in microseconds.
393    pub processing_time_us: u64,
394}
395
396impl<T> K2KWorkerResult<T> {
397    /// Create a new worker result.
398    pub fn new(
399        worker_id: KernelId,
400        correlation_id: u64,
401        result: T,
402        processing_time_us: u64,
403    ) -> Self {
404        Self {
405            worker_id,
406            correlation_id,
407            result,
408            processing_time_us,
409        }
410    }
411}
412
413// ============================================================================
414// K2K Message Priority
415// ============================================================================
416
417/// Priority levels for K2K messages.
418#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
419#[repr(u8)]
420#[derive(Default)]
421pub enum K2KPriority {
422    /// Low priority - background processing.
423    Low = 0,
424    /// Normal priority - default.
425    #[default]
426    Normal = 64,
427    /// High priority - time-sensitive operations.
428    High = 128,
429    /// Critical priority - must process immediately.
430    Critical = 192,
431    /// Real-time priority - latency-critical paths.
432    RealTime = 255,
433}
434
435impl From<K2KPriority> for u8 {
436    fn from(p: K2KPriority) -> u8 {
437        p as u8
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_iterative_state_convergence() {
447        let mut state = IterativeState::new(1e-6, 100);
448
449        assert!(state.should_continue());
450        assert!(!state.converged);
451
452        // Simulate iterations
453        state.update(0.1);
454        assert!(!state.converged);
455        assert_eq!(state.iteration, 1);
456
457        state.update(0.01);
458        assert!(!state.converged);
459
460        state.update(1e-7); // Below threshold
461        assert!(state.converged);
462
463        let summary = state.summary();
464        assert_eq!(summary.iterations, 3);
465        assert!(summary.converged);
466    }
467
468    #[test]
469    fn test_iterative_state_max_iterations() {
470        let mut state = IterativeState::new(1e-6, 3);
471
472        state.update(0.1);
473        state.update(0.05);
474        state.update(0.01); // Reaches max iterations
475
476        assert!(state.converged);
477        let summary = state.summary();
478        assert!(summary.reached_max);
479    }
480
481    #[test]
482    fn test_pipeline_tracker() {
483        let stages = vec![
484            "ingest".to_string(),
485            "transform".to_string(),
486            "output".to_string(),
487        ];
488        let mut tracker = PipelineTracker::new(stages);
489
490        assert_eq!(tracker.current_stage(), Some("ingest"));
491        assert_eq!(tracker.next_stage(), Some("transform"));
492
493        tracker.advance(1000);
494        assert_eq!(tracker.current_stage(), Some("transform"));
495
496        tracker.advance(2000);
497        assert_eq!(tracker.current_stage(), Some("output"));
498
499        tracker.advance(500);
500        assert!(tracker.is_complete());
501        assert_eq!(tracker.total_time_us(), 3500);
502    }
503
504    #[test]
505    fn test_scatter_gather_state() {
506        let mut state: ScatterGatherState<i32> = ScatterGatherState::new(3, 0);
507
508        assert!(!state.is_complete());
509        assert_eq!(state.pending_count(), 3);
510
511        state.receive_result(KernelId::new("worker1"), 10);
512        state.receive_result(KernelId::new("worker2"), 20);
513        assert_eq!(state.pending_count(), 1);
514
515        state.receive_result(KernelId::new("worker3"), 30);
516        assert!(state.is_complete());
517
518        let results = state.take_results();
519        assert_eq!(results, vec![10, 20, 30]);
520    }
521
522    #[test]
523    fn test_fan_out_tracker() {
524        let mut tracker = FanOutTracker::new();
525
526        tracker.add_destination(KernelId::new("dest1"));
527        tracker.add_destination(KernelId::new("dest2"));
528        tracker.add_destination(KernelId::new("dest1")); // Duplicate
529
530        assert_eq!(tracker.destination_count(), 2);
531
532        tracker.record_broadcast();
533        assert_eq!(tracker.broadcast_count(), 1);
534        assert_eq!(tracker.delivery_count(), 0);
535
536        tracker.mark_delivered(&KernelId::new("dest1"));
537        assert_eq!(tracker.delivery_count(), 1);
538    }
539
540    #[test]
541    fn test_kernel_id_to_u64() {
542        let id1 = KernelId::new("kernel-a");
543        let id2 = KernelId::new("kernel-b");
544        let id1_copy = KernelId::new("kernel-a");
545
546        let hash1 = kernel_id_to_u64(&id1);
547        let hash2 = kernel_id_to_u64(&id2);
548        let hash1_copy = kernel_id_to_u64(&id1_copy);
549
550        assert_ne!(hash1, hash2);
551        assert_eq!(hash1, hash1_copy);
552    }
553}