1use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
8pub enum TreeKey {
9 Text(String),
10 Tokens(Vec<u32>),
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
15pub struct TreeInsertOp {
16 pub key: TreeKey,
17 pub tenant: String, }
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
22pub struct TreeRemoveOp {
23 pub tenant: String, }
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
28pub enum TreeOperation {
29 Insert(TreeInsertOp),
30 Remove(TreeRemoveOp),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub struct TreeStateDelta {
37 pub model_id: String,
38 pub operations: Vec<TreeOperation>,
39 pub base_version: u64,
41 pub new_version: u64,
43}
44
45impl TreeStateDelta {
46 pub fn to_bytes(&self) -> Result<Vec<u8>, String> {
48 bincode::serialize(self).map_err(|e| format!("Failed to serialize TreeStateDelta: {e}"))
49 }
50
51 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
53 bincode::deserialize(bytes)
54 .map_err(|e| format!("Failed to deserialize TreeStateDelta: {e}"))
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64pub struct TenantDelta {
65 pub model_id: String,
66 pub version: u64,
67 pub inserts: Vec<TenantInsert>,
68 pub evictions: Vec<TenantEvict>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73pub struct TenantInsert {
74 pub node_path_hash: u64,
78 pub worker_url: String,
80 pub epoch: u64,
82}
83
84pub const GLOBAL_EVICTION_HASH: u64 = 0;
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
90pub struct TenantEvict {
91 pub node_path_hash: u64,
94 pub worker_url: String,
96}
97
98#[expect(
101 clippy::unwrap_used,
102 reason = "blake3 always returns 32 bytes; [..8] into [u8; 8] cannot fail"
103)]
104pub fn hash_node_path(path: &str) -> u64 {
105 let hash = blake3::hash(path.as_bytes());
106 let h = u64::from_le_bytes(hash.as_bytes()[..8].try_into().unwrap());
107 if h == GLOBAL_EVICTION_HASH {
108 1
109 } else {
110 h
111 }
112}
113
114#[expect(
117 clippy::unwrap_used,
118 reason = "blake3 always returns 32 bytes; [..8] into [u8; 8] cannot fail"
119)]
120pub fn hash_token_path(tokens: &[u32]) -> u64 {
121 let bytes: Vec<u8> = tokens.iter().flat_map(|t| t.to_le_bytes()).collect();
122 let hash = blake3::hash(&bytes);
123 let h = u64::from_le_bytes(hash.as_bytes()[..8].try_into().unwrap());
124 if h == GLOBAL_EVICTION_HASH {
125 1
126 } else {
127 h
128 }
129}
130
131impl TenantDelta {
132 pub fn new(model_id: String, version: u64) -> Self {
133 Self {
134 model_id,
135 version,
136 inserts: Vec::new(),
137 evictions: Vec::new(),
138 }
139 }
140
141 pub fn is_empty(&self) -> bool {
142 self.inserts.is_empty() && self.evictions.is_empty()
143 }
144
145 pub fn to_bytes(&self) -> Result<Vec<u8>, String> {
146 bincode::serialize(self).map_err(|e| format!("Failed to serialize TenantDelta: {e}"))
147 }
148
149 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
150 bincode::deserialize(bytes).map_err(|e| format!("Failed to deserialize TenantDelta: {e}"))
151 }
152}
153
154pub fn lz4_compress(data: &[u8]) -> Vec<u8> {
159 lz4_flex::compress_prepend_size(data)
160}
161
162pub fn lz4_decompress(data: &[u8]) -> Result<Vec<u8>, String> {
166 const MAX_DECOMPRESSED_SIZE: usize = 256 * 1024 * 1024; if data.len() >= 4 {
168 let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
169 if claimed_size > MAX_DECOMPRESSED_SIZE {
170 return Err(format!(
171 "LZ4 claimed decompressed size {claimed_size} exceeds limit {MAX_DECOMPRESSED_SIZE}"
172 ));
173 }
174 }
175 lz4_flex::decompress_size_prepended(data).map_err(|e| format!("LZ4 decompression failed: {e}"))
176}
177
178const MAX_TREE_OPERATIONS: usize = 2048;
183
184#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
187pub struct TreeState {
188 pub model_id: String,
189 pub operations: Vec<TreeOperation>,
190 pub version: u64,
191}
192
193impl TreeState {
194 pub fn new(model_id: String) -> Self {
195 Self {
196 model_id,
197 operations: Vec::new(),
198 version: 0,
199 }
200 }
201
202 pub fn add_operation(&mut self, operation: TreeOperation) {
203 self.operations.push(operation);
204 self.version += 1;
205 if self.operations.len() > MAX_TREE_OPERATIONS {
206 let drain_count = self.operations.len() - MAX_TREE_OPERATIONS / 2;
208 self.operations.drain(..drain_count);
209 }
210 }
211
212 pub fn to_bytes(&self) -> Result<Vec<u8>, String> {
215 bincode::serialize(self).map_err(|e| format!("Failed to serialize TreeState: {e}"))
216 }
217
218 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
220 bincode::deserialize(bytes).map_err(|e| format!("Failed to deserialize TreeState: {e}"))
221 }
222
223 #[expect(
231 clippy::unwrap_used,
232 reason = "pop() after last_mut().is_some() is infallible"
233 )]
234 pub fn from_snapshot(
235 model_id: String,
236 snapshot: &kv_index::snapshot::TreeSnapshot,
237 version: u64,
238 ) -> Self {
239 let mut tree_state = Self::new(model_id);
240 let mut path_stack: Vec<(String, u32)> = Vec::new();
241 let mut current_prefix = String::new();
242
243 for node in &snapshot.nodes {
244 while let Some((_, remaining)) = path_stack.last_mut() {
246 if *remaining == 0 {
247 let (parent_prefix, _) = path_stack.pop().unwrap();
248 current_prefix = parent_prefix;
249 } else {
250 *remaining -= 1;
251 break;
252 }
253 }
254
255 let node_prefix = format!("{}{}", current_prefix, node.edge);
257
258 for (tenant_url, _epoch) in &node.tenants {
260 if !node_prefix.is_empty() {
261 tree_state.add_operation(TreeOperation::Insert(TreeInsertOp {
262 key: TreeKey::Text(node_prefix.clone()),
263 tenant: tenant_url.clone(),
264 }));
265 }
266 }
267
268 if node.child_count > 0 {
270 path_stack.push((current_prefix.clone(), node.child_count));
271 current_prefix = node_prefix;
272 }
273 }
274
275 tree_state.version = version;
276 tree_state
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_tree_insert_op_creation() {
286 let op = TreeInsertOp {
287 key: TreeKey::Text("test_text".to_string()),
288 tenant: "http://worker1:8000".to_string(),
289 };
290 assert_eq!(op.key, TreeKey::Text("test_text".to_string()));
291 assert_eq!(op.tenant, "http://worker1:8000");
292 }
293
294 #[test]
295 fn test_tree_remove_op_creation() {
296 let op = TreeRemoveOp {
297 tenant: "http://worker1:8000".to_string(),
298 };
299 assert_eq!(op.tenant, "http://worker1:8000");
300 }
301
302 #[test]
303 fn test_tree_operation_insert() {
304 let insert_op = TreeInsertOp {
305 key: TreeKey::Text("test_text".to_string()),
306 tenant: "http://worker1:8000".to_string(),
307 };
308 let operation = TreeOperation::Insert(insert_op.clone());
309
310 match &operation {
311 TreeOperation::Insert(op) => {
312 assert_eq!(op.key, TreeKey::Text("test_text".to_string()));
313 assert_eq!(op.tenant, "http://worker1:8000");
314 }
315 TreeOperation::Remove(_) => panic!("Expected Insert operation"),
316 }
317 }
318
319 #[test]
320 fn test_tree_operation_remove() {
321 let remove_op = TreeRemoveOp {
322 tenant: "http://worker1:8000".to_string(),
323 };
324 let operation = TreeOperation::Remove(remove_op.clone());
325
326 match &operation {
327 TreeOperation::Insert(_) => panic!("Expected Remove operation"),
328 TreeOperation::Remove(op) => {
329 assert_eq!(op.tenant, "http://worker1:8000");
330 }
331 }
332 }
333
334 #[test]
335 fn test_tree_operation_serialization() {
336 let insert_op = TreeInsertOp {
337 key: TreeKey::Text("test_text".to_string()),
338 tenant: "http://worker1:8000".to_string(),
339 };
340 let operation = TreeOperation::Insert(insert_op);
341
342 let serialized = serde_json::to_string(&operation).unwrap();
343 let deserialized: TreeOperation = serde_json::from_str(&serialized).unwrap();
344
345 match (&operation, &deserialized) {
346 (TreeOperation::Insert(a), TreeOperation::Insert(b)) => {
347 assert_eq!(a.key, b.key);
348 assert_eq!(a.tenant, b.tenant);
349 }
350 _ => panic!("Operations should match"),
351 }
352 }
353
354 #[test]
355 fn test_tree_operation_token_serialization() {
356 let insert_op = TreeInsertOp {
357 key: TreeKey::Tokens(vec![1, 2, 3, 4]),
358 tenant: "http://worker1:8000".to_string(),
359 };
360 let operation = TreeOperation::Insert(insert_op);
361
362 let serialized = serde_json::to_string(&operation).unwrap();
363 let deserialized: TreeOperation = serde_json::from_str(&serialized).unwrap();
364
365 match (&operation, &deserialized) {
366 (TreeOperation::Insert(a), TreeOperation::Insert(b)) => {
367 assert_eq!(a.key, b.key);
368 assert_eq!(a.tenant, b.tenant);
369 }
370 _ => panic!("Operations should match"),
371 }
372 }
373
374 #[test]
375 fn test_tree_state_bincode_round_trip_with_tokens() {
376 let tokens = vec![12345u32, 67890, 0, u32::MAX, 42];
377 let mut state = TreeState::new("test-model".to_string());
378 state.add_operation(TreeOperation::Insert(TreeInsertOp {
379 key: TreeKey::Tokens(tokens.clone()),
380 tenant: "http://worker1:8000".to_string(),
381 }));
382 state.add_operation(TreeOperation::Insert(TreeInsertOp {
383 key: TreeKey::Text("text_key".to_string()),
384 tenant: "http://worker2:8000".to_string(),
385 }));
386 state.add_operation(TreeOperation::Remove(TreeRemoveOp {
387 tenant: "http://worker3:8000".to_string(),
388 }));
389
390 let bytes = state.to_bytes().unwrap();
391 let restored = TreeState::from_bytes(&bytes).unwrap();
392
393 assert_eq!(restored.model_id, "test-model");
394 assert_eq!(restored.version, state.version);
395 assert_eq!(restored.operations.len(), 3);
396
397 match &restored.operations[0] {
398 TreeOperation::Insert(op) => {
399 assert_eq!(op.key, TreeKey::Tokens(tokens));
400 assert_eq!(op.tenant, "http://worker1:8000");
401 }
402 TreeOperation::Remove(_) => panic!("Expected Insert"),
403 }
404 match &restored.operations[1] {
405 TreeOperation::Insert(op) => {
406 assert_eq!(op.key, TreeKey::Text("text_key".to_string()));
407 }
408 TreeOperation::Remove(_) => panic!("Expected Insert"),
409 }
410 match &restored.operations[2] {
411 TreeOperation::Remove(op) => {
412 assert_eq!(op.tenant, "http://worker3:8000");
413 }
414 TreeOperation::Insert(_) => panic!("Expected Remove"),
415 }
416 }
417
418 #[test]
419 fn test_tree_state_bincode_round_trip_large_tokens() {
420 let mut state = TreeState::new("large-model".to_string());
421 for i in 0..100 {
422 let tokens: Vec<u32> = (0..1000).map(|j| (i * 1000 + j) as u32).collect();
423 state.add_operation(TreeOperation::Insert(TreeInsertOp {
424 key: TreeKey::Tokens(tokens),
425 tenant: format!("http://worker-{i}:8000"),
426 }));
427 }
428
429 let bytes = state.to_bytes().unwrap();
430 let restored = TreeState::from_bytes(&bytes).unwrap();
431
432 assert_eq!(restored.operations.len(), 100);
433 assert_eq!(restored.version, state.version);
434
435 match &restored.operations[0] {
437 TreeOperation::Insert(op) => {
438 if let TreeKey::Tokens(tokens) = &op.key {
439 assert_eq!(tokens.len(), 1000);
440 assert_eq!(tokens[0], 0);
441 assert_eq!(tokens[999], 999);
442 } else {
443 panic!("Expected Tokens key");
444 }
445 }
446 TreeOperation::Remove(_) => panic!("Expected Insert"),
447 }
448 match &restored.operations[99] {
449 TreeOperation::Insert(op) => {
450 if let TreeKey::Tokens(tokens) = &op.key {
451 assert_eq!(tokens[0], 99000);
452 assert_eq!(tokens[999], 99999);
453 } else {
454 panic!("Expected Tokens key");
455 }
456 }
457 TreeOperation::Remove(_) => panic!("Expected Insert"),
458 }
459 }
460
461 #[test]
462 fn test_tree_operation_remove_serialization() {
463 let remove_op = TreeRemoveOp {
464 tenant: "http://worker1:8000".to_string(),
465 };
466 let operation = TreeOperation::Remove(remove_op);
467
468 let serialized = serde_json::to_string(&operation).unwrap();
469 let deserialized: TreeOperation = serde_json::from_str(&serialized).unwrap();
470
471 match (&operation, &deserialized) {
472 (TreeOperation::Remove(a), TreeOperation::Remove(b)) => {
473 assert_eq!(a.tenant, b.tenant);
474 }
475 _ => panic!("Operations should match"),
476 }
477 }
478
479 #[test]
480 fn test_tree_state_new() {
481 let state = TreeState::new("model1".to_string());
482 assert_eq!(state.model_id, "model1");
483 assert_eq!(state.operations.len(), 0);
484 assert_eq!(state.version, 0);
485 }
486
487 #[test]
488 fn test_tree_state_default() {
489 let state = TreeState::default();
490 assert_eq!(state.model_id, "");
491 assert_eq!(state.operations.len(), 0);
492 assert_eq!(state.version, 0);
493 }
494
495 #[test]
496 fn test_tree_state_add_operation() {
497 let mut state = TreeState::new("model1".to_string());
498
499 let insert_op = TreeInsertOp {
500 key: TreeKey::Text("text1".to_string()),
501 tenant: "http://worker1:8000".to_string(),
502 };
503 state.add_operation(TreeOperation::Insert(insert_op));
504
505 assert_eq!(state.operations.len(), 1);
506 assert_eq!(state.version, 1);
507
508 let remove_op = TreeRemoveOp {
509 tenant: "http://worker1:8000".to_string(),
510 };
511 state.add_operation(TreeOperation::Remove(remove_op));
512
513 assert_eq!(state.operations.len(), 2);
514 assert_eq!(state.version, 2);
515 }
516
517 #[test]
518 fn test_tree_state_add_multiple_operations() {
519 let mut state = TreeState::new("model1".to_string());
520
521 for i in 0..5 {
522 let insert_op = TreeInsertOp {
523 key: TreeKey::Text(format!("text_{i}")),
524 tenant: format!("http://worker{i}:8000"),
525 };
526 state.add_operation(TreeOperation::Insert(insert_op));
527 }
528
529 assert_eq!(state.operations.len(), 5);
530 assert_eq!(state.version, 5);
531 }
532
533 #[test]
534 fn test_tree_state_serialization() {
535 let mut state = TreeState::new("model1".to_string());
536
537 let insert_op = TreeInsertOp {
538 key: TreeKey::Text("test_text".to_string()),
539 tenant: "http://worker1:8000".to_string(),
540 };
541 state.add_operation(TreeOperation::Insert(insert_op));
542
543 let remove_op = TreeRemoveOp {
544 tenant: "http://worker1:8000".to_string(),
545 };
546 state.add_operation(TreeOperation::Remove(remove_op));
547
548 let serialized = serde_json::to_string(&state).unwrap();
549 let deserialized: TreeState = serde_json::from_str(&serialized).unwrap();
550
551 assert_eq!(state.model_id, deserialized.model_id);
552 assert_eq!(state.operations.len(), deserialized.operations.len());
553 assert_eq!(state.version, deserialized.version);
554 }
555
556 #[test]
557 fn test_tree_state_clone() {
558 let mut state = TreeState::new("model1".to_string());
559
560 let insert_op = TreeInsertOp {
561 key: TreeKey::Text("test_text".to_string()),
562 tenant: "http://worker1:8000".to_string(),
563 };
564 state.add_operation(TreeOperation::Insert(insert_op));
565
566 let cloned = state.clone();
567 assert_eq!(state.model_id, cloned.model_id);
568 assert_eq!(state.operations.len(), cloned.operations.len());
569 assert_eq!(state.version, cloned.version);
570 }
571
572 #[test]
573 fn test_tree_state_equality() {
574 let mut state1 = TreeState::new("model1".to_string());
575 let mut state2 = TreeState::new("model1".to_string());
576
577 let insert_op = TreeInsertOp {
578 key: TreeKey::Text("test_text".to_string()),
579 tenant: "http://worker1:8000".to_string(),
580 };
581 state1.add_operation(TreeOperation::Insert(insert_op.clone()));
582 state2.add_operation(TreeOperation::Insert(insert_op));
583
584 assert_eq!(state1, state2);
585 }
586
587 #[test]
588 fn test_tree_operation_hash() {
589 use std::collections::HashSet;
590
591 let insert_op1 = TreeInsertOp {
592 key: TreeKey::Text("text1".to_string()),
593 tenant: "http://worker1:8000".to_string(),
594 };
595 let insert_op2 = TreeInsertOp {
596 key: TreeKey::Text("text1".to_string()),
597 tenant: "http://worker1:8000".to_string(),
598 };
599
600 let op1 = TreeOperation::Insert(insert_op1);
601 let op2 = TreeOperation::Insert(insert_op2);
602
603 let mut set = HashSet::new();
604 set.insert(op1.clone());
605 set.insert(op2.clone());
606
607 assert_eq!(set.len(), 1);
609 }
610
611 #[test]
612 fn test_tenant_delta_round_trip() {
613 let path_hash = hash_node_path("Hello world, how are");
614 let mut delta = TenantDelta::new("model1".to_string(), 42);
615 delta.inserts.push(TenantInsert {
616 node_path_hash: path_hash,
617 worker_url: "grpc://w1:8000".to_string(),
618 epoch: 1000,
619 });
620 delta.evictions.push(TenantEvict {
621 node_path_hash: path_hash,
622 worker_url: "grpc://w2:8000".to_string(),
623 });
624
625 assert!(!delta.is_empty());
626
627 let bytes = delta.to_bytes().unwrap();
628 let restored = TenantDelta::from_bytes(&bytes).unwrap();
629
630 assert_eq!(restored.model_id, "model1");
631 assert_eq!(restored.version, 42);
632 assert_eq!(restored.inserts.len(), 1);
633 assert_eq!(restored.inserts[0].worker_url, "grpc://w1:8000");
634 assert_eq!(restored.inserts[0].node_path_hash, path_hash);
635 assert_eq!(restored.inserts[0].epoch, 1000);
636 assert_eq!(restored.evictions.len(), 1);
637 assert_eq!(restored.evictions[0].worker_url, "grpc://w2:8000");
638 }
639
640 #[test]
641 fn test_tenant_delta_empty() {
642 let delta = TenantDelta::new("model1".to_string(), 0);
643 assert!(delta.is_empty());
644 }
645
646 #[test]
647 fn test_tenant_delta_size_vs_tree_operation() {
648 let insert = TenantInsert {
650 node_path_hash: hash_node_path(&"a".repeat(100)),
651 worker_url: "grpc://worker1:8000".to_string(),
652 epoch: 12345,
653 };
654 let delta = TenantDelta {
655 model_id: "model1".to_string(),
656 version: 1,
657 inserts: vec![insert],
658 evictions: vec![],
659 };
660 let delta_bytes = delta.to_bytes().unwrap();
661
662 let tree_op = TreeOperation::Insert(TreeInsertOp {
664 key: TreeKey::Text("x".repeat(20_000)),
665 tenant: "grpc://worker1:8000".to_string(),
666 });
667 let tree_state = TreeState {
668 model_id: "model1".to_string(),
669 operations: vec![tree_op],
670 version: 1,
671 };
672 let tree_bytes = tree_state.to_bytes().unwrap();
673
674 assert!(
676 delta_bytes.len() < tree_bytes.len() / 10,
677 "TenantDelta ({} bytes) should be much smaller than TreeState ({} bytes)",
678 delta_bytes.len(),
679 tree_bytes.len()
680 );
681 }
682}