1use crate::{GraphError, Result};
9use blake3::Hasher;
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17use xxhash_rust::xxh3::xxh3_64;
18
19pub type NodeId = String;
21
22pub type EdgeId = String;
24
25pub type ShardId = u32;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum ShardStrategy {
31 Hash,
33 Range,
35 EdgeCut,
37 Custom,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ShardMetadata {
44 pub shard_id: ShardId,
46 pub node_count: usize,
48 pub edge_count: usize,
50 pub cross_shard_edges: usize,
52 pub primary_node: String,
54 pub replicas: Vec<String>,
56 pub created_at: DateTime<Utc>,
58 pub modified_at: DateTime<Utc>,
60 pub strategy: ShardStrategy,
62}
63
64impl ShardMetadata {
65 pub fn new(shard_id: ShardId, primary_node: String, strategy: ShardStrategy) -> Self {
67 Self {
68 shard_id,
69 node_count: 0,
70 edge_count: 0,
71 cross_shard_edges: 0,
72 primary_node,
73 replicas: Vec::new(),
74 created_at: Utc::now(),
75 modified_at: Utc::now(),
76 strategy,
77 }
78 }
79
80 pub fn edge_cut_ratio(&self) -> f64 {
82 if self.edge_count == 0 {
83 0.0
84 } else {
85 self.cross_shard_edges as f64 / self.edge_count as f64
86 }
87 }
88}
89
90pub struct HashPartitioner {
92 shard_count: u32,
94 virtual_nodes: u32,
96}
97
98impl HashPartitioner {
99 pub fn new(shard_count: u32) -> Self {
101 assert!(shard_count > 0, "shard_count must be greater than zero");
102 Self {
103 shard_count,
104 virtual_nodes: 150, }
106 }
107
108 pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
110 let hash = xxh3_64(node_id.as_bytes());
111 (hash % self.shard_count as u64) as ShardId
112 }
113
114 pub fn get_shard_secure(&self, node_id: &NodeId) -> ShardId {
116 let mut hasher = Hasher::new();
117 hasher.update(node_id.as_bytes());
118 let hash = hasher.finalize();
119 let hash_bytes = hash.as_bytes();
120 let hash_u64 = u64::from_le_bytes([
121 hash_bytes[0],
122 hash_bytes[1],
123 hash_bytes[2],
124 hash_bytes[3],
125 hash_bytes[4],
126 hash_bytes[5],
127 hash_bytes[6],
128 hash_bytes[7],
129 ]);
130 (hash_u64 % self.shard_count as u64) as ShardId
131 }
132
133 pub fn get_replica_shards(&self, node_id: &NodeId, replica_count: usize) -> Vec<ShardId> {
135 let mut shards = Vec::with_capacity(replica_count);
136 let primary = self.get_shard(node_id);
137 shards.push(primary);
138
139 for i in 1..replica_count {
141 let salted_id = format!("{}-replica-{}", node_id, i);
142 let shard = self.get_shard(&salted_id);
143 if !shards.contains(&shard) {
144 shards.push(shard);
145 }
146 }
147
148 shards
149 }
150}
151
152pub struct RangePartitioner {
154 shard_count: u32,
156 ranges: Vec<String>,
158}
159
160impl RangePartitioner {
161 pub fn new(shard_count: u32) -> Self {
163 Self {
164 shard_count,
165 ranges: Vec::new(),
166 }
167 }
168
169 pub fn with_boundaries(boundaries: Vec<String>) -> Self {
171 Self {
172 shard_count: boundaries.len() as u32,
173 ranges: boundaries,
174 }
175 }
176
177 pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
179 if self.ranges.is_empty() {
180 let hash = xxh3_64(node_id.as_bytes());
182 return (hash % self.shard_count as u64) as ShardId;
183 }
184
185 for (idx, boundary) in self.ranges.iter().enumerate() {
187 if node_id <= boundary {
188 return idx as ShardId;
189 }
190 }
191
192 (self.shard_count - 1) as ShardId
194 }
195
196 pub fn update_boundaries(&mut self, new_boundaries: Vec<String>) {
198 info!(
199 "Updating range boundaries: old={}, new={}",
200 self.ranges.len(),
201 new_boundaries.len()
202 );
203 self.ranges = new_boundaries;
204 self.shard_count = self.ranges.len() as u32;
205 }
206}
207
208pub struct EdgeCutMinimizer {
210 shard_count: u32,
212 node_assignments: Arc<DashMap<NodeId, ShardId>>,
214 edge_weights: Arc<DashMap<(NodeId, NodeId), f64>>,
216 adjacency: Arc<DashMap<NodeId, HashSet<NodeId>>>,
218}
219
220impl EdgeCutMinimizer {
221 pub fn new(shard_count: u32) -> Self {
223 Self {
224 shard_count,
225 node_assignments: Arc::new(DashMap::new()),
226 edge_weights: Arc::new(DashMap::new()),
227 adjacency: Arc::new(DashMap::new()),
228 }
229 }
230
231 pub fn add_edge(&self, from: NodeId, to: NodeId, weight: f64) {
233 self.edge_weights.insert((from.clone(), to.clone()), weight);
234
235 self.adjacency
237 .entry(from.clone())
238 .or_insert_with(HashSet::new)
239 .insert(to.clone());
240
241 self.adjacency
242 .entry(to)
243 .or_insert_with(HashSet::new)
244 .insert(from);
245 }
246
247 pub fn get_shard(&self, node_id: &NodeId) -> Option<ShardId> {
249 self.node_assignments.get(node_id).map(|r| *r.value())
250 }
251
252 pub fn compute_partitioning(&self) -> Result<HashMap<NodeId, ShardId>> {
254 info!("Computing edge-cut minimized partitioning");
255
256 let nodes: Vec<_> = self.adjacency.iter().map(|e| e.key().clone()).collect();
257
258 if nodes.is_empty() {
259 return Ok(HashMap::new());
260 }
261
262 let coarse_graph = self.coarsen_graph(&nodes);
264
265 let mut assignments = self.initial_partition(&coarse_graph);
267
268 self.refine_partition(&mut assignments);
270
271 for (node, shard) in &assignments {
273 self.node_assignments.insert(node.clone(), *shard);
274 }
275
276 info!(
277 "Partitioning complete: {} nodes across {} shards",
278 assignments.len(),
279 self.shard_count
280 );
281
282 Ok(assignments)
283 }
284
285 fn coarsen_graph(&self, nodes: &[NodeId]) -> HashMap<NodeId, Vec<NodeId>> {
287 let mut coarse: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
288 let mut visited = HashSet::new();
289
290 for node in nodes {
291 if visited.contains(node) {
292 continue;
293 }
294
295 let mut group = vec![node.clone()];
296 visited.insert(node.clone());
297
298 if let Some(neighbors) = self.adjacency.get(node) {
300 let mut best_neighbor: Option<(NodeId, f64)> = None;
301
302 for neighbor in neighbors.iter() {
303 if visited.contains(neighbor) {
304 continue;
305 }
306
307 let weight = self
308 .edge_weights
309 .get(&(node.clone(), neighbor.clone()))
310 .map(|w| *w.value())
311 .unwrap_or(1.0);
312
313 if let Some((_, best_weight)) = best_neighbor {
314 if weight > best_weight {
315 best_neighbor = Some((neighbor.clone(), weight));
316 }
317 } else {
318 best_neighbor = Some((neighbor.clone(), weight));
319 }
320 }
321
322 if let Some((neighbor, _)) = best_neighbor {
323 group.push(neighbor.clone());
324 visited.insert(neighbor);
325 }
326 }
327
328 let representative = node.clone();
329 coarse.insert(representative, group);
330 }
331
332 coarse
333 }
334
335 fn initial_partition(
337 &self,
338 coarse_graph: &HashMap<NodeId, Vec<NodeId>>,
339 ) -> HashMap<NodeId, ShardId> {
340 let mut assignments = HashMap::new();
341 let mut shard_sizes: Vec<usize> = vec![0; self.shard_count as usize];
342
343 for (representative, group) in coarse_graph {
344 let shard = shard_sizes
346 .iter()
347 .enumerate()
348 .min_by_key(|(_, size)| *size)
349 .map(|(idx, _)| idx as ShardId)
350 .unwrap_or(0);
351
352 for node in group {
353 assignments.insert(node.clone(), shard);
354 shard_sizes[shard as usize] += 1;
355 }
356 }
357
358 assignments
359 }
360
361 fn refine_partition(&self, assignments: &mut HashMap<NodeId, ShardId>) {
363 const MAX_ITERATIONS: usize = 10;
364 let mut improved = true;
365 let mut iteration = 0;
366
367 while improved && iteration < MAX_ITERATIONS {
368 improved = false;
369 iteration += 1;
370
371 for (node, current_shard) in assignments.clone().iter() {
372 let current_cost = self.compute_node_cost(node, *current_shard, assignments);
373
374 for target_shard in 0..self.shard_count {
376 if target_shard == *current_shard {
377 continue;
378 }
379
380 let new_cost = self.compute_node_cost(node, target_shard, assignments);
381
382 if new_cost < current_cost {
383 assignments.insert(node.clone(), target_shard);
384 improved = true;
385 break;
386 }
387 }
388 }
389
390 debug!("Refinement iteration {}: improved={}", iteration, improved);
391 }
392 }
393
394 fn compute_node_cost(
396 &self,
397 node: &NodeId,
398 shard: ShardId,
399 assignments: &HashMap<NodeId, ShardId>,
400 ) -> usize {
401 let mut cross_shard_edges = 0;
402
403 if let Some(neighbors) = self.adjacency.get(node) {
404 for neighbor in neighbors.iter() {
405 if let Some(neighbor_shard) = assignments.get(neighbor) {
406 if *neighbor_shard != shard {
407 cross_shard_edges += 1;
408 }
409 }
410 }
411 }
412
413 cross_shard_edges
414 }
415
416 pub fn calculate_edge_cut(&self, assignments: &HashMap<NodeId, ShardId>) -> usize {
418 let mut cut = 0;
419
420 for entry in self.edge_weights.iter() {
421 let ((from, to), _) = entry.pair();
422 let from_shard = assignments.get(from);
423 let to_shard = assignments.get(to);
424
425 if from_shard.is_some() && to_shard.is_some() && from_shard != to_shard {
426 cut += 1;
427 }
428 }
429
430 cut
431 }
432}
433
434pub struct GraphShard {
436 metadata: ShardMetadata,
438 nodes: Arc<DashMap<NodeId, NodeData>>,
440 edges: Arc<DashMap<EdgeId, EdgeData>>,
442 strategy: ShardStrategy,
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct NodeData {
449 pub id: NodeId,
450 pub properties: HashMap<String, serde_json::Value>,
451 pub labels: Vec<String>,
452}
453
454#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct EdgeData {
457 pub id: EdgeId,
458 pub from: NodeId,
459 pub to: NodeId,
460 pub edge_type: String,
461 pub properties: HashMap<String, serde_json::Value>,
462}
463
464impl GraphShard {
465 pub fn new(metadata: ShardMetadata) -> Self {
467 let strategy = metadata.strategy;
468 Self {
469 metadata,
470 nodes: Arc::new(DashMap::new()),
471 edges: Arc::new(DashMap::new()),
472 strategy,
473 }
474 }
475
476 pub fn add_node(&self, node: NodeData) -> Result<()> {
478 self.nodes.insert(node.id.clone(), node);
479 Ok(())
480 }
481
482 pub fn add_edge(&self, edge: EdgeData) -> Result<()> {
484 self.edges.insert(edge.id.clone(), edge);
485 Ok(())
486 }
487
488 pub fn get_node(&self, node_id: &NodeId) -> Option<NodeData> {
490 self.nodes.get(node_id).map(|n| n.value().clone())
491 }
492
493 pub fn get_edge(&self, edge_id: &EdgeId) -> Option<EdgeData> {
495 self.edges.get(edge_id).map(|e| e.value().clone())
496 }
497
498 pub fn metadata(&self) -> &ShardMetadata {
500 &self.metadata
501 }
502
503 pub fn node_count(&self) -> usize {
505 self.nodes.len()
506 }
507
508 pub fn edge_count(&self) -> usize {
510 self.edges.len()
511 }
512
513 pub fn list_nodes(&self) -> Vec<NodeData> {
515 self.nodes.iter().map(|e| e.value().clone()).collect()
516 }
517
518 pub fn list_edges(&self) -> Vec<EdgeData> {
520 self.edges.iter().map(|e| e.value().clone()).collect()
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_hash_partitioner() {
530 let partitioner = HashPartitioner::new(16);
531
532 let node1 = "node-1".to_string();
533 let node2 = "node-2".to_string();
534
535 let shard1 = partitioner.get_shard(&node1);
536 let shard2 = partitioner.get_shard(&node2);
537
538 assert!(shard1 < 16);
539 assert!(shard2 < 16);
540
541 assert_eq!(shard1, partitioner.get_shard(&node1));
543 }
544
545 #[test]
546 fn test_range_partitioner() {
547 let boundaries = vec!["m".to_string(), "z".to_string()];
548 let partitioner = RangePartitioner::with_boundaries(boundaries);
549
550 assert_eq!(partitioner.get_shard(&"apple".to_string()), 0);
551 assert_eq!(partitioner.get_shard(&"orange".to_string()), 1);
552 assert_eq!(partitioner.get_shard(&"zebra".to_string()), 1);
553 }
554
555 #[test]
556 fn test_edge_cut_minimizer() {
557 let minimizer = EdgeCutMinimizer::new(2);
558
559 minimizer.add_edge("A".to_string(), "B".to_string(), 1.0);
561 minimizer.add_edge("B".to_string(), "C".to_string(), 1.0);
562 minimizer.add_edge("C".to_string(), "D".to_string(), 1.0);
563
564 let assignments = minimizer.compute_partitioning().unwrap();
565 let cut = minimizer.calculate_edge_cut(&assignments);
566
567 assert!(cut <= 2);
569 }
570
571 #[test]
572 fn test_shard_metadata() {
573 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
574
575 assert_eq!(metadata.shard_id, 0);
576 assert_eq!(metadata.edge_cut_ratio(), 0.0);
577 }
578
579 #[test]
580 fn test_graph_shard() {
581 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
582 let shard = GraphShard::new(metadata);
583
584 let node = NodeData {
585 id: "test-node".to_string(),
586 properties: HashMap::new(),
587 labels: vec!["TestLabel".to_string()],
588 };
589
590 shard.add_node(node.clone()).unwrap();
591
592 assert_eq!(shard.node_count(), 1);
593 assert!(shard.get_node(&"test-node".to_string()).is_some());
594 }
595}