rustkernel_ml/
ring_messages.rs

1//! Ring message types for Statistical ML kernels.
2//!
3//! This module defines zero-copy Ring messages for GPU-native persistent actors.
4//! Type IDs 700-799 are reserved for Statistical ML domain.
5//!
6//! ## Type ID Allocation
7//!
8//! - 700-719: KMeans messages
9//! - 720-739: DBSCAN messages
10//! - 740-759: Anomaly detection messages
11//! - 760-779: K2K parallel coordination messages
12
13use ringkernel_derive::RingMessage;
14use rkyv::{Archive, Deserialize, Serialize};
15use rustkernel_core::messages::MessageId;
16
17// ============================================================================
18// KMeans Ring Messages (700-719)
19// ============================================================================
20
21/// Initialize KMeans with centroids.
22#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
23#[archive(check_bytes)]
24#[message(type_id = 700)]
25pub struct KMeansInitRing {
26    /// Message ID.
27    pub id: MessageId,
28    /// Number of clusters (K).
29    pub k: u32,
30    /// Number of features per point.
31    pub n_features: u32,
32    /// Initial centroids (packed: k * n_features values, fixed-point).
33    pub centroids_packed: [i64; 32], // Support up to k=8, n_features=4
34}
35
36/// KMeans initialization response.
37#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
38#[archive(check_bytes)]
39#[message(type_id = 701)]
40pub struct KMeansInitResponse {
41    /// Original message ID.
42    pub request_id: u64,
43    /// Whether initialization succeeded.
44    pub success: bool,
45    /// Number of clusters configured.
46    pub k: u32,
47}
48
49/// Assign points to clusters (E-step).
50#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
51#[archive(check_bytes)]
52#[message(type_id = 702)]
53pub struct KMeansAssignRing {
54    /// Message ID.
55    pub id: MessageId,
56    /// Iteration number.
57    pub iteration: u32,
58}
59
60/// Assignment response.
61#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
62#[archive(check_bytes)]
63#[message(type_id = 703)]
64pub struct KMeansAssignResponse {
65    /// Original message ID.
66    pub request_id: u64,
67    /// Iteration number.
68    pub iteration: u32,
69    /// Total inertia (sum of squared distances, fixed-point).
70    pub inertia_fp: i64,
71    /// Number of points assigned.
72    pub points_assigned: u32,
73}
74
75/// Update centroids (M-step).
76#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
77#[archive(check_bytes)]
78#[message(type_id = 704)]
79pub struct KMeansUpdateRing {
80    /// Message ID.
81    pub id: MessageId,
82    /// Iteration number.
83    pub iteration: u32,
84}
85
86/// Update response with new centroids.
87#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
88#[archive(check_bytes)]
89#[message(type_id = 705)]
90pub struct KMeansUpdateResponse {
91    /// Original message ID.
92    pub request_id: u64,
93    /// Iteration number.
94    pub iteration: u32,
95    /// Maximum centroid shift (fixed-point).
96    pub max_shift_fp: i64,
97    /// Whether converged.
98    pub converged: bool,
99}
100
101/// Query cluster assignment for a point.
102#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
103#[archive(check_bytes)]
104#[message(type_id = 706)]
105pub struct KMeansQueryRing {
106    /// Message ID.
107    pub id: MessageId,
108    /// Point coordinates (fixed-point).
109    pub point: [i64; 8], // Up to 8 dimensions
110    /// Number of dimensions.
111    pub n_dims: u8,
112}
113
114/// Query response.
115#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
116#[archive(check_bytes)]
117#[message(type_id = 707)]
118pub struct KMeansQueryResponse {
119    /// Original message ID.
120    pub request_id: u64,
121    /// Assigned cluster.
122    pub cluster: u32,
123    /// Distance to centroid (fixed-point).
124    pub distance_fp: i64,
125}
126
127// ============================================================================
128// K2K Parallel Centroid Update Messages (760-779)
129// ============================================================================
130
131/// K2K partial centroid update from a worker.
132///
133/// In distributed KMeans, each worker computes partial sums for centroids
134/// from its data partition.
135#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
136#[archive(check_bytes)]
137#[message(type_id = 760)]
138pub struct K2KPartialCentroid {
139    /// Message ID.
140    pub id: MessageId,
141    /// Worker ID.
142    pub worker_id: u64,
143    /// Iteration number.
144    pub iteration: u64,
145    /// Cluster ID this update is for.
146    pub cluster_id: u32,
147    /// Number of points assigned to this cluster on this worker.
148    pub point_count: u32,
149    /// Partial sum of coordinates (fixed-point, up to 8 dimensions).
150    pub coord_sum_fp: [i64; 8],
151    /// Number of dimensions.
152    pub n_dims: u8,
153}
154
155/// K2K centroid aggregation response.
156#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
157#[archive(check_bytes)]
158#[message(type_id = 761)]
159pub struct K2KCentroidAggregation {
160    /// Original request ID.
161    pub request_id: u64,
162    /// Cluster ID.
163    pub cluster_id: u32,
164    /// Iteration number.
165    pub iteration: u64,
166    /// New centroid coordinates (fixed-point).
167    pub new_centroid_fp: [i64; 8],
168    /// Total points in cluster.
169    pub total_points: u32,
170    /// Centroid shift from previous iteration (fixed-point).
171    pub shift_fp: i64,
172}
173
174/// K2K iteration sync for distributed KMeans.
175#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
176#[archive(check_bytes)]
177#[message(type_id = 762)]
178pub struct K2KKMeansSync {
179    /// Message ID.
180    pub id: MessageId,
181    /// Worker ID.
182    pub worker_id: u64,
183    /// Iteration number.
184    pub iteration: u64,
185    /// Local inertia (fixed-point).
186    pub local_inertia_fp: i64,
187    /// Points processed on this worker.
188    pub points_processed: u32,
189    /// Maximum local centroid shift (fixed-point).
190    pub max_shift_fp: i64,
191}
192
193/// K2K sync response.
194#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
195#[archive(check_bytes)]
196#[message(type_id = 763)]
197pub struct K2KKMeansSyncResponse {
198    /// Original request ID.
199    pub request_id: u64,
200    /// Iteration number.
201    pub iteration: u64,
202    /// All workers synced.
203    pub all_synced: bool,
204    /// Global inertia (fixed-point).
205    pub global_inertia_fp: i64,
206    /// Global maximum shift (fixed-point).
207    pub global_max_shift_fp: i64,
208    /// Global converged.
209    pub converged: bool,
210}
211
212/// K2K broadcast new centroids to workers.
213#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
214#[archive(check_bytes)]
215#[message(type_id = 764)]
216pub struct K2KCentroidBroadcast {
217    /// Message ID.
218    pub id: MessageId,
219    /// Iteration number.
220    pub iteration: u64,
221    /// Number of clusters.
222    pub k: u32,
223    /// Number of dimensions.
224    pub n_dims: u8,
225    /// Packed centroids (up to k=4 clusters, 8 dims each).
226    pub centroids_packed: [i64; 32],
227}
228
229/// K2K broadcast acknowledgment.
230#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
231#[archive(check_bytes)]
232#[message(type_id = 765)]
233pub struct K2KCentroidBroadcastAck {
234    /// Original message ID.
235    pub request_id: u64,
236    /// Worker ID.
237    pub worker_id: u64,
238    /// Iteration received.
239    pub iteration: u64,
240    /// Centroids applied.
241    pub applied: bool,
242}
243
244// ============================================================================
245// Helper Functions
246// ============================================================================
247
248/// Convert f64 to fixed-point i64 (8 decimal places).
249#[inline]
250pub fn to_fixed_point(value: f64) -> i64 {
251    (value * 100_000_000.0) as i64
252}
253
254/// Convert fixed-point i64 to f64.
255#[inline]
256pub fn from_fixed_point(fp: i64) -> f64 {
257    fp as f64 / 100_000_000.0
258}
259
260/// Pack coordinates into fixed-point array.
261pub fn pack_coordinates(coords: &[f64], output: &mut [i64; 8]) {
262    for (i, &c) in coords.iter().take(8).enumerate() {
263        output[i] = to_fixed_point(c);
264    }
265}
266
267/// Unpack coordinates from fixed-point array.
268pub fn unpack_coordinates(input: &[i64; 8], n_dims: usize) -> Vec<f64> {
269    input
270        .iter()
271        .take(n_dims)
272        .map(|&fp| from_fixed_point(fp))
273        .collect()
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_fixed_point_conversion() {
282        let value = 2.5;
283        let fp = to_fixed_point(value);
284        let back = from_fixed_point(fp);
285        assert!((value - back).abs() < 1e-8);
286    }
287
288    #[test]
289    fn test_pack_unpack_coordinates() {
290        let coords = vec![1.5, 2.5, 3.5];
291        let mut packed = [0i64; 8];
292        pack_coordinates(&coords, &mut packed);
293
294        let unpacked = unpack_coordinates(&packed, 3);
295        assert_eq!(unpacked.len(), 3);
296        for (a, b) in coords.iter().zip(unpacked.iter()) {
297            assert!((a - b).abs() < 1e-7);
298        }
299    }
300
301    #[test]
302    fn test_kmeans_init_ring() {
303        let msg = KMeansInitRing {
304            id: MessageId(1),
305            k: 3,
306            n_features: 2,
307            centroids_packed: [0; 32],
308        };
309        assert_eq!(msg.k, 3);
310    }
311
312    #[test]
313    fn test_k2k_partial_centroid() {
314        let mut coord_sum = [0i64; 8];
315        pack_coordinates(&[10.0, 20.0], &mut coord_sum);
316
317        let msg = K2KPartialCentroid {
318            id: MessageId(2),
319            worker_id: 1,
320            iteration: 5,
321            cluster_id: 0,
322            point_count: 100,
323            coord_sum_fp: coord_sum,
324            n_dims: 2,
325        };
326        assert_eq!(msg.point_count, 100);
327        assert_eq!(msg.iteration, 5);
328    }
329
330    #[test]
331    fn test_k2k_kmeans_sync() {
332        let msg = K2KKMeansSync {
333            id: MessageId(3),
334            worker_id: 2,
335            iteration: 10,
336            local_inertia_fp: to_fixed_point(1234.5),
337            points_processed: 5000,
338            max_shift_fp: to_fixed_point(0.001),
339        };
340        assert_eq!(msg.iteration, 10);
341        assert_eq!(msg.points_processed, 5000);
342    }
343}