1use ringkernel_derive::RingMessage;
14use rkyv::{Archive, Deserialize, Serialize};
15use rustkernel_core::messages::MessageId;
16
17#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
23#[archive(check_bytes)]
24#[message(type_id = 700)]
25pub struct KMeansInitRing {
26 pub id: MessageId,
28 pub k: u32,
30 pub n_features: u32,
32 pub centroids_packed: [i64; 32], }
35
36#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
38#[archive(check_bytes)]
39#[message(type_id = 701)]
40pub struct KMeansInitResponse {
41 pub request_id: u64,
43 pub success: bool,
45 pub k: u32,
47}
48
49#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
51#[archive(check_bytes)]
52#[message(type_id = 702)]
53pub struct KMeansAssignRing {
54 pub id: MessageId,
56 pub iteration: u32,
58}
59
60#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
62#[archive(check_bytes)]
63#[message(type_id = 703)]
64pub struct KMeansAssignResponse {
65 pub request_id: u64,
67 pub iteration: u32,
69 pub inertia_fp: i64,
71 pub points_assigned: u32,
73}
74
75#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
77#[archive(check_bytes)]
78#[message(type_id = 704)]
79pub struct KMeansUpdateRing {
80 pub id: MessageId,
82 pub iteration: u32,
84}
85
86#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
88#[archive(check_bytes)]
89#[message(type_id = 705)]
90pub struct KMeansUpdateResponse {
91 pub request_id: u64,
93 pub iteration: u32,
95 pub max_shift_fp: i64,
97 pub converged: bool,
99}
100
101#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
103#[archive(check_bytes)]
104#[message(type_id = 706)]
105pub struct KMeansQueryRing {
106 pub id: MessageId,
108 pub point: [i64; 8], pub n_dims: u8,
112}
113
114#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
116#[archive(check_bytes)]
117#[message(type_id = 707)]
118pub struct KMeansQueryResponse {
119 pub request_id: u64,
121 pub cluster: u32,
123 pub distance_fp: i64,
125}
126
127#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
136#[archive(check_bytes)]
137#[message(type_id = 760)]
138pub struct K2KPartialCentroid {
139 pub id: MessageId,
141 pub worker_id: u64,
143 pub iteration: u64,
145 pub cluster_id: u32,
147 pub point_count: u32,
149 pub coord_sum_fp: [i64; 8],
151 pub n_dims: u8,
153}
154
155#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
157#[archive(check_bytes)]
158#[message(type_id = 761)]
159pub struct K2KCentroidAggregation {
160 pub request_id: u64,
162 pub cluster_id: u32,
164 pub iteration: u64,
166 pub new_centroid_fp: [i64; 8],
168 pub total_points: u32,
170 pub shift_fp: i64,
172}
173
174#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
176#[archive(check_bytes)]
177#[message(type_id = 762)]
178pub struct K2KKMeansSync {
179 pub id: MessageId,
181 pub worker_id: u64,
183 pub iteration: u64,
185 pub local_inertia_fp: i64,
187 pub points_processed: u32,
189 pub max_shift_fp: i64,
191}
192
193#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
195#[archive(check_bytes)]
196#[message(type_id = 763)]
197pub struct K2KKMeansSyncResponse {
198 pub request_id: u64,
200 pub iteration: u64,
202 pub all_synced: bool,
204 pub global_inertia_fp: i64,
206 pub global_max_shift_fp: i64,
208 pub converged: bool,
210}
211
212#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
214#[archive(check_bytes)]
215#[message(type_id = 764)]
216pub struct K2KCentroidBroadcast {
217 pub id: MessageId,
219 pub iteration: u64,
221 pub k: u32,
223 pub n_dims: u8,
225 pub centroids_packed: [i64; 32],
227}
228
229#[derive(Debug, Clone, Archive, Serialize, Deserialize, RingMessage)]
231#[archive(check_bytes)]
232#[message(type_id = 765)]
233pub struct K2KCentroidBroadcastAck {
234 pub request_id: u64,
236 pub worker_id: u64,
238 pub iteration: u64,
240 pub applied: bool,
242}
243
244#[inline]
250pub fn to_fixed_point(value: f64) -> i64 {
251 (value * 100_000_000.0) as i64
252}
253
254#[inline]
256pub fn from_fixed_point(fp: i64) -> f64 {
257 fp as f64 / 100_000_000.0
258}
259
260pub fn pack_coordinates(coords: &[f64], output: &mut [i64; 8]) {
262 for (i, &c) in coords.iter().take(8).enumerate() {
263 output[i] = to_fixed_point(c);
264 }
265}
266
267pub fn unpack_coordinates(input: &[i64; 8], n_dims: usize) -> Vec<f64> {
269 input
270 .iter()
271 .take(n_dims)
272 .map(|&fp| from_fixed_point(fp))
273 .collect()
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_fixed_point_conversion() {
282 let value = 2.5;
283 let fp = to_fixed_point(value);
284 let back = from_fixed_point(fp);
285 assert!((value - back).abs() < 1e-8);
286 }
287
288 #[test]
289 fn test_pack_unpack_coordinates() {
290 let coords = vec![1.5, 2.5, 3.5];
291 let mut packed = [0i64; 8];
292 pack_coordinates(&coords, &mut packed);
293
294 let unpacked = unpack_coordinates(&packed, 3);
295 assert_eq!(unpacked.len(), 3);
296 for (a, b) in coords.iter().zip(unpacked.iter()) {
297 assert!((a - b).abs() < 1e-7);
298 }
299 }
300
301 #[test]
302 fn test_kmeans_init_ring() {
303 let msg = KMeansInitRing {
304 id: MessageId(1),
305 k: 3,
306 n_features: 2,
307 centroids_packed: [0; 32],
308 };
309 assert_eq!(msg.k, 3);
310 }
311
312 #[test]
313 fn test_k2k_partial_centroid() {
314 let mut coord_sum = [0i64; 8];
315 pack_coordinates(&[10.0, 20.0], &mut coord_sum);
316
317 let msg = K2KPartialCentroid {
318 id: MessageId(2),
319 worker_id: 1,
320 iteration: 5,
321 cluster_id: 0,
322 point_count: 100,
323 coord_sum_fp: coord_sum,
324 n_dims: 2,
325 };
326 assert_eq!(msg.point_count, 100);
327 assert_eq!(msg.iteration, 5);
328 }
329
330 #[test]
331 fn test_k2k_kmeans_sync() {
332 let msg = K2KKMeansSync {
333 id: MessageId(3),
334 worker_id: 2,
335 iteration: 10,
336 local_inertia_fp: to_fixed_point(1234.5),
337 points_processed: 5000,
338 max_shift_fp: to_fixed_point(0.001),
339 };
340 assert_eq!(msg.iteration, 10);
341 assert_eq!(msg.points_processed, 5000);
342 }
343}