1use ringkernel_core::runtime::KernelId;
15use std::collections::HashMap;
16use std::collections::hash_map::DefaultHasher;
17use std::hash::{Hash, Hasher};
18
19pub fn kernel_id_to_u64(id: &KernelId) -> u64 {
25 let mut hasher = DefaultHasher::new();
26 id.as_str().hash(&mut hasher);
27 hasher.finish()
28}
29
30#[derive(Debug, Clone)]
38pub struct IterativeState {
39 pub iteration: u64,
41 pub last_delta: f64,
43 pub convergence_threshold: f64,
45 pub max_iterations: u64,
47 pub converged: bool,
49}
50
51impl IterativeState {
52 pub fn new(convergence_threshold: f64, max_iterations: u64) -> Self {
54 Self {
55 iteration: 0,
56 last_delta: f64::MAX,
57 convergence_threshold,
58 max_iterations,
59 converged: false,
60 }
61 }
62
63 pub fn update(&mut self, delta: f64) -> bool {
65 self.iteration += 1;
66 self.last_delta = delta;
67 self.converged =
68 delta < self.convergence_threshold || self.iteration >= self.max_iterations;
69 self.converged
70 }
71
72 pub fn should_continue(&self) -> bool {
74 !self.converged && self.iteration < self.max_iterations
75 }
76
77 pub fn reset(&mut self) {
79 self.iteration = 0;
80 self.last_delta = f64::MAX;
81 self.converged = false;
82 }
83
84 pub fn summary(&self) -> IterativeConvergenceSummary {
86 IterativeConvergenceSummary {
87 iterations: self.iteration,
88 final_delta: self.last_delta,
89 converged: self.converged,
90 reached_max: self.iteration >= self.max_iterations,
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct IterativeConvergenceSummary {
98 pub iterations: u64,
100 pub final_delta: f64,
102 pub converged: bool,
104 pub reached_max: bool,
106}
107
108#[derive(Debug, Clone)]
114pub struct PipelineTracker {
115 stages: Vec<String>,
116 current_stage: usize,
117 stage_timings_us: HashMap<String, u64>,
118 total_items_processed: u64,
119}
120
121impl PipelineTracker {
122 pub fn new(stages: Vec<String>) -> Self {
124 Self {
125 stages,
126 current_stage: 0,
127 stage_timings_us: HashMap::new(),
128 total_items_processed: 0,
129 }
130 }
131
132 pub fn current_stage(&self) -> Option<&str> {
134 self.stages.get(self.current_stage).map(|s| s.as_str())
135 }
136
137 pub fn next_stage(&self) -> Option<&str> {
139 self.stages.get(self.current_stage + 1).map(|s| s.as_str())
140 }
141
142 pub fn advance(&mut self, elapsed_us: u64) -> bool {
144 if let Some(stage) = self.stages.get(self.current_stage) {
145 self.stage_timings_us.insert(stage.clone(), elapsed_us);
146 }
147 if self.current_stage + 1 < self.stages.len() {
148 self.current_stage += 1;
149 true
150 } else {
151 false
152 }
153 }
154
155 pub fn record_items(&mut self, count: u64) {
157 self.total_items_processed += count;
158 }
159
160 pub fn is_complete(&self) -> bool {
162 self.current_stage >= self.stages.len().saturating_sub(1)
163 && self.stage_timings_us.len() >= self.stages.len()
164 }
165
166 pub fn total_time_us(&self) -> u64 {
168 self.stage_timings_us.values().sum()
169 }
170
171 pub fn stage_timing(&self, stage: &str) -> Option<u64> {
173 self.stage_timings_us.get(stage).copied()
174 }
175
176 pub fn reset(&mut self) {
178 self.current_stage = 0;
179 self.stage_timings_us.clear();
180 self.total_items_processed = 0;
181 }
182}
183
184#[derive(Debug)]
190pub struct ScatterGatherState<T> {
191 pub worker_count: usize,
193 pub results: Vec<T>,
195 pub responded_workers: Vec<KernelId>,
197 pub start_time_us: u64,
199}
200
201impl<T> ScatterGatherState<T> {
202 pub fn new(worker_count: usize, start_time_us: u64) -> Self {
204 Self {
205 worker_count,
206 results: Vec::with_capacity(worker_count),
207 responded_workers: Vec::with_capacity(worker_count),
208 start_time_us,
209 }
210 }
211
212 pub fn receive_result(&mut self, worker: KernelId, result: T) {
214 if !self.responded_workers.contains(&worker) {
215 self.responded_workers.push(worker);
216 self.results.push(result);
217 }
218 }
219
220 pub fn is_complete(&self) -> bool {
222 self.responded_workers.len() >= self.worker_count
223 }
224
225 pub fn pending_count(&self) -> usize {
227 self.worker_count
228 .saturating_sub(self.responded_workers.len())
229 }
230
231 pub fn take_results(self) -> Vec<T> {
233 self.results
234 }
235}
236
237#[derive(Debug, Clone)]
243pub struct FanOutTracker {
244 destinations: Vec<KernelId>,
245 delivery_status: HashMap<String, bool>,
246 broadcast_count: u64,
247}
248
249impl FanOutTracker {
250 pub fn new() -> Self {
252 Self {
253 destinations: Vec::new(),
254 delivery_status: HashMap::new(),
255 broadcast_count: 0,
256 }
257 }
258
259 pub fn add_destination(&mut self, dest: KernelId) {
261 if !self
262 .destinations
263 .iter()
264 .any(|d| d.as_str() == dest.as_str())
265 {
266 self.destinations.push(dest);
267 }
268 }
269
270 pub fn remove_destination(&mut self, dest: &KernelId) {
272 self.destinations.retain(|d| d.as_str() != dest.as_str());
273 self.delivery_status.remove(dest.as_str());
274 }
275
276 pub fn destinations(&self) -> &[KernelId] {
278 &self.destinations
279 }
280
281 pub fn record_broadcast(&mut self) {
283 self.broadcast_count += 1;
284 for dest in &self.destinations {
286 self.delivery_status
287 .insert(dest.as_str().to_string(), false);
288 }
289 }
290
291 pub fn mark_delivered(&mut self, dest: &KernelId) {
293 self.delivery_status.insert(dest.as_str().to_string(), true);
294 }
295
296 pub fn delivery_count(&self) -> usize {
298 self.delivery_status.values().filter(|&&v| v).count()
299 }
300
301 pub fn broadcast_count(&self) -> u64 {
303 self.broadcast_count
304 }
305
306 pub fn destination_count(&self) -> usize {
308 self.destinations.len()
309 }
310}
311
312impl Default for FanOutTracker {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[derive(Debug, Clone)]
324pub enum K2KControlMessage {
325 Start {
327 correlation_id: u64,
329 },
330 Stop {
332 reason: String,
334 },
335 GetStatus {
337 correlation_id: u64,
339 },
340 IterationComplete {
342 iteration: u64,
344 delta: f64,
346 worker_id: u64,
348 },
349 Converged {
351 iterations: u64,
353 final_delta: f64,
355 },
356 Error {
358 message: String,
360 code: u32,
362 },
363 Heartbeat {
365 sequence: u64,
367 timestamp_us: u64,
369 },
370 Barrier {
372 barrier_id: u64,
374 worker_id: u64,
376 },
377}
378
379#[derive(Debug, Clone)]
385pub struct K2KWorkerResult<T> {
386 pub worker_id: KernelId,
388 pub correlation_id: u64,
390 pub result: T,
392 pub processing_time_us: u64,
394}
395
396impl<T> K2KWorkerResult<T> {
397 pub fn new(
399 worker_id: KernelId,
400 correlation_id: u64,
401 result: T,
402 processing_time_us: u64,
403 ) -> Self {
404 Self {
405 worker_id,
406 correlation_id,
407 result,
408 processing_time_us,
409 }
410 }
411}
412
413#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
419#[repr(u8)]
420#[derive(Default)]
421pub enum K2KPriority {
422 Low = 0,
424 #[default]
426 Normal = 64,
427 High = 128,
429 Critical = 192,
431 RealTime = 255,
433}
434
435impl From<K2KPriority> for u8 {
436 fn from(p: K2KPriority) -> u8 {
437 p as u8
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_iterative_state_convergence() {
447 let mut state = IterativeState::new(1e-6, 100);
448
449 assert!(state.should_continue());
450 assert!(!state.converged);
451
452 state.update(0.1);
454 assert!(!state.converged);
455 assert_eq!(state.iteration, 1);
456
457 state.update(0.01);
458 assert!(!state.converged);
459
460 state.update(1e-7); assert!(state.converged);
462
463 let summary = state.summary();
464 assert_eq!(summary.iterations, 3);
465 assert!(summary.converged);
466 }
467
468 #[test]
469 fn test_iterative_state_max_iterations() {
470 let mut state = IterativeState::new(1e-6, 3);
471
472 state.update(0.1);
473 state.update(0.05);
474 state.update(0.01); assert!(state.converged);
477 let summary = state.summary();
478 assert!(summary.reached_max);
479 }
480
481 #[test]
482 fn test_pipeline_tracker() {
483 let stages = vec![
484 "ingest".to_string(),
485 "transform".to_string(),
486 "output".to_string(),
487 ];
488 let mut tracker = PipelineTracker::new(stages);
489
490 assert_eq!(tracker.current_stage(), Some("ingest"));
491 assert_eq!(tracker.next_stage(), Some("transform"));
492
493 tracker.advance(1000);
494 assert_eq!(tracker.current_stage(), Some("transform"));
495
496 tracker.advance(2000);
497 assert_eq!(tracker.current_stage(), Some("output"));
498
499 tracker.advance(500);
500 assert!(tracker.is_complete());
501 assert_eq!(tracker.total_time_us(), 3500);
502 }
503
504 #[test]
505 fn test_scatter_gather_state() {
506 let mut state: ScatterGatherState<i32> = ScatterGatherState::new(3, 0);
507
508 assert!(!state.is_complete());
509 assert_eq!(state.pending_count(), 3);
510
511 state.receive_result(KernelId::new("worker1"), 10);
512 state.receive_result(KernelId::new("worker2"), 20);
513 assert_eq!(state.pending_count(), 1);
514
515 state.receive_result(KernelId::new("worker3"), 30);
516 assert!(state.is_complete());
517
518 let results = state.take_results();
519 assert_eq!(results, vec![10, 20, 30]);
520 }
521
522 #[test]
523 fn test_fan_out_tracker() {
524 let mut tracker = FanOutTracker::new();
525
526 tracker.add_destination(KernelId::new("dest1"));
527 tracker.add_destination(KernelId::new("dest2"));
528 tracker.add_destination(KernelId::new("dest1")); assert_eq!(tracker.destination_count(), 2);
531
532 tracker.record_broadcast();
533 assert_eq!(tracker.broadcast_count(), 1);
534 assert_eq!(tracker.delivery_count(), 0);
535
536 tracker.mark_delivered(&KernelId::new("dest1"));
537 assert_eq!(tracker.delivery_count(), 1);
538 }
539
540 #[test]
541 fn test_kernel_id_to_u64() {
542 let id1 = KernelId::new("kernel-a");
543 let id2 = KernelId::new("kernel-b");
544 let id1_copy = KernelId::new("kernel-a");
545
546 let hash1 = kernel_id_to_u64(&id1);
547 let hash2 = kernel_id_to_u64(&id2);
548 let hash1_copy = kernel_id_to_u64(&id1_copy);
549
550 assert_ne!(hash1, hash2);
551 assert_eq!(hash1, hash1_copy);
552 }
553}