1use crate::{GraphError, Result};
9use blake3::Hasher;
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17use xxhash_rust::xxh3::xxh3_64;
18
19pub type NodeId = String;
21
22pub type EdgeId = String;
24
25pub type ShardId = u32;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum ShardStrategy {
31 Hash,
33 Range,
35 EdgeCut,
37 Custom,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ShardMetadata {
44 pub shard_id: ShardId,
46 pub node_count: usize,
48 pub edge_count: usize,
50 pub cross_shard_edges: usize,
52 pub primary_node: String,
54 pub replicas: Vec<String>,
56 pub created_at: DateTime<Utc>,
58 pub modified_at: DateTime<Utc>,
60 pub strategy: ShardStrategy,
62}
63
64impl ShardMetadata {
65 pub fn new(shard_id: ShardId, primary_node: String, strategy: ShardStrategy) -> Self {
67 Self {
68 shard_id,
69 node_count: 0,
70 edge_count: 0,
71 cross_shard_edges: 0,
72 primary_node,
73 replicas: Vec::new(),
74 created_at: Utc::now(),
75 modified_at: Utc::now(),
76 strategy,
77 }
78 }
79
80 pub fn edge_cut_ratio(&self) -> f64 {
82 if self.edge_count == 0 {
83 0.0
84 } else {
85 self.cross_shard_edges as f64 / self.edge_count as f64
86 }
87 }
88}
89
90pub struct HashPartitioner {
92 shard_count: u32,
94 virtual_nodes: u32,
96}
97
98impl HashPartitioner {
99 pub fn new(shard_count: u32) -> Self {
101 Self {
102 shard_count,
103 virtual_nodes: 150, }
105 }
106
107 pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
109 let hash = xxh3_64(node_id.as_bytes());
110 (hash % self.shard_count as u64) as ShardId
111 }
112
113 pub fn get_shard_secure(&self, node_id: &NodeId) -> ShardId {
115 let mut hasher = Hasher::new();
116 hasher.update(node_id.as_bytes());
117 let hash = hasher.finalize();
118 let hash_bytes = hash.as_bytes();
119 let hash_u64 = u64::from_le_bytes([
120 hash_bytes[0],
121 hash_bytes[1],
122 hash_bytes[2],
123 hash_bytes[3],
124 hash_bytes[4],
125 hash_bytes[5],
126 hash_bytes[6],
127 hash_bytes[7],
128 ]);
129 (hash_u64 % self.shard_count as u64) as ShardId
130 }
131
132 pub fn get_replica_shards(&self, node_id: &NodeId, replica_count: usize) -> Vec<ShardId> {
134 let mut shards = Vec::with_capacity(replica_count);
135 let primary = self.get_shard(node_id);
136 shards.push(primary);
137
138 for i in 1..replica_count {
140 let salted_id = format!("{}-replica-{}", node_id, i);
141 let shard = self.get_shard(&salted_id);
142 if !shards.contains(&shard) {
143 shards.push(shard);
144 }
145 }
146
147 shards
148 }
149}
150
151pub struct RangePartitioner {
153 shard_count: u32,
155 ranges: Vec<String>,
157}
158
159impl RangePartitioner {
160 pub fn new(shard_count: u32) -> Self {
162 Self {
163 shard_count,
164 ranges: Vec::new(),
165 }
166 }
167
168 pub fn with_boundaries(boundaries: Vec<String>) -> Self {
170 Self {
171 shard_count: boundaries.len() as u32,
172 ranges: boundaries,
173 }
174 }
175
176 pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
178 if self.ranges.is_empty() {
179 let hash = xxh3_64(node_id.as_bytes());
181 return (hash % self.shard_count as u64) as ShardId;
182 }
183
184 for (idx, boundary) in self.ranges.iter().enumerate() {
186 if node_id <= boundary {
187 return idx as ShardId;
188 }
189 }
190
191 (self.shard_count - 1) as ShardId
193 }
194
195 pub fn update_boundaries(&mut self, new_boundaries: Vec<String>) {
197 info!(
198 "Updating range boundaries: old={}, new={}",
199 self.ranges.len(),
200 new_boundaries.len()
201 );
202 self.ranges = new_boundaries;
203 self.shard_count = self.ranges.len() as u32;
204 }
205}
206
207pub struct EdgeCutMinimizer {
209 shard_count: u32,
211 node_assignments: Arc<DashMap<NodeId, ShardId>>,
213 edge_weights: Arc<DashMap<(NodeId, NodeId), f64>>,
215 adjacency: Arc<DashMap<NodeId, HashSet<NodeId>>>,
217}
218
219impl EdgeCutMinimizer {
220 pub fn new(shard_count: u32) -> Self {
222 Self {
223 shard_count,
224 node_assignments: Arc::new(DashMap::new()),
225 edge_weights: Arc::new(DashMap::new()),
226 adjacency: Arc::new(DashMap::new()),
227 }
228 }
229
230 pub fn add_edge(&self, from: NodeId, to: NodeId, weight: f64) {
232 self.edge_weights.insert((from.clone(), to.clone()), weight);
233
234 self.adjacency
236 .entry(from.clone())
237 .or_insert_with(HashSet::new)
238 .insert(to.clone());
239
240 self.adjacency
241 .entry(to)
242 .or_insert_with(HashSet::new)
243 .insert(from);
244 }
245
246 pub fn get_shard(&self, node_id: &NodeId) -> Option<ShardId> {
248 self.node_assignments.get(node_id).map(|r| *r.value())
249 }
250
251 pub fn compute_partitioning(&self) -> Result<HashMap<NodeId, ShardId>> {
253 info!("Computing edge-cut minimized partitioning");
254
255 let nodes: Vec<_> = self.adjacency.iter().map(|e| e.key().clone()).collect();
256
257 if nodes.is_empty() {
258 return Ok(HashMap::new());
259 }
260
261 let coarse_graph = self.coarsen_graph(&nodes);
263
264 let mut assignments = self.initial_partition(&coarse_graph);
266
267 self.refine_partition(&mut assignments);
269
270 for (node, shard) in &assignments {
272 self.node_assignments.insert(node.clone(), *shard);
273 }
274
275 info!(
276 "Partitioning complete: {} nodes across {} shards",
277 assignments.len(),
278 self.shard_count
279 );
280
281 Ok(assignments)
282 }
283
284 fn coarsen_graph(&self, nodes: &[NodeId]) -> HashMap<NodeId, Vec<NodeId>> {
286 let mut coarse: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
287 let mut visited = HashSet::new();
288
289 for node in nodes {
290 if visited.contains(node) {
291 continue;
292 }
293
294 let mut group = vec![node.clone()];
295 visited.insert(node.clone());
296
297 if let Some(neighbors) = self.adjacency.get(node) {
299 let mut best_neighbor: Option<(NodeId, f64)> = None;
300
301 for neighbor in neighbors.iter() {
302 if visited.contains(neighbor) {
303 continue;
304 }
305
306 let weight = self
307 .edge_weights
308 .get(&(node.clone(), neighbor.clone()))
309 .map(|w| *w.value())
310 .unwrap_or(1.0);
311
312 if let Some((_, best_weight)) = best_neighbor {
313 if weight > best_weight {
314 best_neighbor = Some((neighbor.clone(), weight));
315 }
316 } else {
317 best_neighbor = Some((neighbor.clone(), weight));
318 }
319 }
320
321 if let Some((neighbor, _)) = best_neighbor {
322 group.push(neighbor.clone());
323 visited.insert(neighbor);
324 }
325 }
326
327 let representative = node.clone();
328 coarse.insert(representative, group);
329 }
330
331 coarse
332 }
333
334 fn initial_partition(
336 &self,
337 coarse_graph: &HashMap<NodeId, Vec<NodeId>>,
338 ) -> HashMap<NodeId, ShardId> {
339 let mut assignments = HashMap::new();
340 let mut shard_sizes: Vec<usize> = vec![0; self.shard_count as usize];
341
342 for (representative, group) in coarse_graph {
343 let shard = shard_sizes
345 .iter()
346 .enumerate()
347 .min_by_key(|(_, size)| *size)
348 .map(|(idx, _)| idx as ShardId)
349 .unwrap_or(0);
350
351 for node in group {
352 assignments.insert(node.clone(), shard);
353 shard_sizes[shard as usize] += 1;
354 }
355 }
356
357 assignments
358 }
359
360 fn refine_partition(&self, assignments: &mut HashMap<NodeId, ShardId>) {
362 const MAX_ITERATIONS: usize = 10;
363 let mut improved = true;
364 let mut iteration = 0;
365
366 while improved && iteration < MAX_ITERATIONS {
367 improved = false;
368 iteration += 1;
369
370 for (node, current_shard) in assignments.clone().iter() {
371 let current_cost = self.compute_node_cost(node, *current_shard, assignments);
372
373 for target_shard in 0..self.shard_count {
375 if target_shard == *current_shard {
376 continue;
377 }
378
379 let new_cost = self.compute_node_cost(node, target_shard, assignments);
380
381 if new_cost < current_cost {
382 assignments.insert(node.clone(), target_shard);
383 improved = true;
384 break;
385 }
386 }
387 }
388
389 debug!("Refinement iteration {}: improved={}", iteration, improved);
390 }
391 }
392
393 fn compute_node_cost(
395 &self,
396 node: &NodeId,
397 shard: ShardId,
398 assignments: &HashMap<NodeId, ShardId>,
399 ) -> usize {
400 let mut cross_shard_edges = 0;
401
402 if let Some(neighbors) = self.adjacency.get(node) {
403 for neighbor in neighbors.iter() {
404 if let Some(neighbor_shard) = assignments.get(neighbor) {
405 if *neighbor_shard != shard {
406 cross_shard_edges += 1;
407 }
408 }
409 }
410 }
411
412 cross_shard_edges
413 }
414
415 pub fn calculate_edge_cut(&self, assignments: &HashMap<NodeId, ShardId>) -> usize {
417 let mut cut = 0;
418
419 for entry in self.edge_weights.iter() {
420 let ((from, to), _) = entry.pair();
421 let from_shard = assignments.get(from);
422 let to_shard = assignments.get(to);
423
424 if from_shard.is_some() && to_shard.is_some() && from_shard != to_shard {
425 cut += 1;
426 }
427 }
428
429 cut
430 }
431}
432
433pub struct GraphShard {
435 metadata: ShardMetadata,
437 nodes: Arc<DashMap<NodeId, NodeData>>,
439 edges: Arc<DashMap<EdgeId, EdgeData>>,
441 strategy: ShardStrategy,
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct NodeData {
448 pub id: NodeId,
449 pub properties: HashMap<String, serde_json::Value>,
450 pub labels: Vec<String>,
451}
452
453#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct EdgeData {
456 pub id: EdgeId,
457 pub from: NodeId,
458 pub to: NodeId,
459 pub edge_type: String,
460 pub properties: HashMap<String, serde_json::Value>,
461}
462
463impl GraphShard {
464 pub fn new(metadata: ShardMetadata) -> Self {
466 let strategy = metadata.strategy;
467 Self {
468 metadata,
469 nodes: Arc::new(DashMap::new()),
470 edges: Arc::new(DashMap::new()),
471 strategy,
472 }
473 }
474
475 pub fn add_node(&self, node: NodeData) -> Result<()> {
477 self.nodes.insert(node.id.clone(), node);
478 Ok(())
479 }
480
481 pub fn add_edge(&self, edge: EdgeData) -> Result<()> {
483 self.edges.insert(edge.id.clone(), edge);
484 Ok(())
485 }
486
487 pub fn get_node(&self, node_id: &NodeId) -> Option<NodeData> {
489 self.nodes.get(node_id).map(|n| n.value().clone())
490 }
491
492 pub fn get_edge(&self, edge_id: &EdgeId) -> Option<EdgeData> {
494 self.edges.get(edge_id).map(|e| e.value().clone())
495 }
496
497 pub fn metadata(&self) -> &ShardMetadata {
499 &self.metadata
500 }
501
502 pub fn node_count(&self) -> usize {
504 self.nodes.len()
505 }
506
507 pub fn edge_count(&self) -> usize {
509 self.edges.len()
510 }
511
512 pub fn list_nodes(&self) -> Vec<NodeData> {
514 self.nodes.iter().map(|e| e.value().clone()).collect()
515 }
516
517 pub fn list_edges(&self) -> Vec<EdgeData> {
519 self.edges.iter().map(|e| e.value().clone()).collect()
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_hash_partitioner() {
529 let partitioner = HashPartitioner::new(16);
530
531 let node1 = "node-1".to_string();
532 let node2 = "node-2".to_string();
533
534 let shard1 = partitioner.get_shard(&node1);
535 let shard2 = partitioner.get_shard(&node2);
536
537 assert!(shard1 < 16);
538 assert!(shard2 < 16);
539
540 assert_eq!(shard1, partitioner.get_shard(&node1));
542 }
543
544 #[test]
545 fn test_range_partitioner() {
546 let boundaries = vec!["m".to_string(), "z".to_string()];
547 let partitioner = RangePartitioner::with_boundaries(boundaries);
548
549 assert_eq!(partitioner.get_shard(&"apple".to_string()), 0);
550 assert_eq!(partitioner.get_shard(&"orange".to_string()), 1);
551 assert_eq!(partitioner.get_shard(&"zebra".to_string()), 1);
552 }
553
554 #[test]
555 fn test_edge_cut_minimizer() {
556 let minimizer = EdgeCutMinimizer::new(2);
557
558 minimizer.add_edge("A".to_string(), "B".to_string(), 1.0);
560 minimizer.add_edge("B".to_string(), "C".to_string(), 1.0);
561 minimizer.add_edge("C".to_string(), "D".to_string(), 1.0);
562
563 let assignments = minimizer.compute_partitioning().unwrap();
564 let cut = minimizer.calculate_edge_cut(&assignments);
565
566 assert!(cut <= 2);
568 }
569
570 #[test]
571 fn test_shard_metadata() {
572 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
573
574 assert_eq!(metadata.shard_id, 0);
575 assert_eq!(metadata.edge_cut_ratio(), 0.0);
576 }
577
578 #[test]
579 fn test_graph_shard() {
580 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
581 let shard = GraphShard::new(metadata);
582
583 let node = NodeData {
584 id: "test-node".to_string(),
585 properties: HashMap::new(),
586 labels: vec!["TestLabel".to_string()],
587 };
588
589 shard.add_node(node.clone()).unwrap();
590
591 assert_eq!(shard.node_count(), 1);
592 assert!(shard.get_node(&"test-node".to_string()).is_some());
593 }
594}