ruvector_graph/distributed/
shard.rs

1//! Graph sharding strategies for distributed hypergraphs
2//!
3//! Provides multiple partitioning strategies optimized for graph workloads:
4//! - Hash-based node partitioning for uniform distribution
5//! - Range-based partitioning for locality-aware queries
6//! - Edge-cut minimization for reducing cross-shard communication
7
8use crate::{GraphError, Result};
9use blake3::Hasher;
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17use xxhash_rust::xxh3::xxh3_64;
18
19/// Unique identifier for a graph node
20pub type NodeId = String;
21
22/// Unique identifier for a graph edge
23pub type EdgeId = String;
24
25/// Shard identifier
26pub type ShardId = u32;
27
28/// Graph sharding strategy
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum ShardStrategy {
31    /// Hash-based partitioning using consistent hashing
32    Hash,
33    /// Range-based partitioning for ordered node IDs
34    Range,
35    /// Edge-cut minimization for graph partitioning
36    EdgeCut,
37    /// Custom partitioning strategy
38    Custom,
39}
40
41/// Metadata about a graph shard
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ShardMetadata {
44    /// Shard identifier
45    pub shard_id: ShardId,
46    /// Number of nodes in this shard
47    pub node_count: usize,
48    /// Number of edges in this shard
49    pub edge_count: usize,
50    /// Number of edges crossing to other shards
51    pub cross_shard_edges: usize,
52    /// Primary node responsible for this shard
53    pub primary_node: String,
54    /// Replica nodes
55    pub replicas: Vec<String>,
56    /// Creation timestamp
57    pub created_at: DateTime<Utc>,
58    /// Last modification timestamp
59    pub modified_at: DateTime<Utc>,
60    /// Partitioning strategy used
61    pub strategy: ShardStrategy,
62}
63
64impl ShardMetadata {
65    /// Create new shard metadata
66    pub fn new(shard_id: ShardId, primary_node: String, strategy: ShardStrategy) -> Self {
67        Self {
68            shard_id,
69            node_count: 0,
70            edge_count: 0,
71            cross_shard_edges: 0,
72            primary_node,
73            replicas: Vec::new(),
74            created_at: Utc::now(),
75            modified_at: Utc::now(),
76            strategy,
77        }
78    }
79
80    /// Calculate edge cut ratio (cross-shard edges / total edges)
81    pub fn edge_cut_ratio(&self) -> f64 {
82        if self.edge_count == 0 {
83            0.0
84        } else {
85            self.cross_shard_edges as f64 / self.edge_count as f64
86        }
87    }
88}
89
90/// Hash-based node partitioner
91pub struct HashPartitioner {
92    /// Total number of shards
93    shard_count: u32,
94    /// Virtual nodes per physical shard for better distribution
95    virtual_nodes: u32,
96}
97
98impl HashPartitioner {
99    /// Create a new hash partitioner
100    pub fn new(shard_count: u32) -> Self {
101        Self {
102            shard_count,
103            virtual_nodes: 150, // Similar to consistent hashing best practices
104        }
105    }
106
107    /// Get the shard ID for a given node ID using xxHash
108    pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
109        let hash = xxh3_64(node_id.as_bytes());
110        (hash % self.shard_count as u64) as ShardId
111    }
112
113    /// Get the shard ID using BLAKE3 for cryptographic strength (alternative)
114    pub fn get_shard_secure(&self, node_id: &NodeId) -> ShardId {
115        let mut hasher = Hasher::new();
116        hasher.update(node_id.as_bytes());
117        let hash = hasher.finalize();
118        let hash_bytes = hash.as_bytes();
119        let hash_u64 = u64::from_le_bytes([
120            hash_bytes[0],
121            hash_bytes[1],
122            hash_bytes[2],
123            hash_bytes[3],
124            hash_bytes[4],
125            hash_bytes[5],
126            hash_bytes[6],
127            hash_bytes[7],
128        ]);
129        (hash_u64 % self.shard_count as u64) as ShardId
130    }
131
132    /// Get multiple candidate shards for replication
133    pub fn get_replica_shards(&self, node_id: &NodeId, replica_count: usize) -> Vec<ShardId> {
134        let mut shards = Vec::with_capacity(replica_count);
135        let primary = self.get_shard(node_id);
136        shards.push(primary);
137
138        // Generate additional shards using salted hashing
139        for i in 1..replica_count {
140            let salted_id = format!("{}-replica-{}", node_id, i);
141            let shard = self.get_shard(&salted_id);
142            if !shards.contains(&shard) {
143                shards.push(shard);
144            }
145        }
146
147        shards
148    }
149}
150
151/// Range-based node partitioner for ordered node IDs
152pub struct RangePartitioner {
153    /// Total number of shards
154    shard_count: u32,
155    /// Range boundaries (shard_id -> max_value in range)
156    ranges: Vec<String>,
157}
158
159impl RangePartitioner {
160    /// Create a new range partitioner with automatic range distribution
161    pub fn new(shard_count: u32) -> Self {
162        Self {
163            shard_count,
164            ranges: Vec::new(),
165        }
166    }
167
168    /// Create a range partitioner with explicit boundaries
169    pub fn with_boundaries(boundaries: Vec<String>) -> Self {
170        Self {
171            shard_count: boundaries.len() as u32,
172            ranges: boundaries,
173        }
174    }
175
176    /// Get the shard ID for a node based on range boundaries
177    pub fn get_shard(&self, node_id: &NodeId) -> ShardId {
178        if self.ranges.is_empty() {
179            // Fallback to simple modulo if no ranges defined
180            let hash = xxh3_64(node_id.as_bytes());
181            return (hash % self.shard_count as u64) as ShardId;
182        }
183
184        // Binary search through sorted ranges
185        for (idx, boundary) in self.ranges.iter().enumerate() {
186            if node_id <= boundary {
187                return idx as ShardId;
188            }
189        }
190
191        // Last shard for values beyond all boundaries
192        (self.shard_count - 1) as ShardId
193    }
194
195    /// Update range boundaries based on data distribution
196    pub fn update_boundaries(&mut self, new_boundaries: Vec<String>) {
197        info!(
198            "Updating range boundaries: old={}, new={}",
199            self.ranges.len(),
200            new_boundaries.len()
201        );
202        self.ranges = new_boundaries;
203        self.shard_count = self.ranges.len() as u32;
204    }
205}
206
207/// Edge-cut minimization using METIS-like graph partitioning
208pub struct EdgeCutMinimizer {
209    /// Total number of shards
210    shard_count: u32,
211    /// Node to shard assignments
212    node_assignments: Arc<DashMap<NodeId, ShardId>>,
213    /// Edge information for partitioning decisions
214    edge_weights: Arc<DashMap<(NodeId, NodeId), f64>>,
215    /// Adjacency list representation
216    adjacency: Arc<DashMap<NodeId, HashSet<NodeId>>>,
217}
218
219impl EdgeCutMinimizer {
220    /// Create a new edge-cut minimizer
221    pub fn new(shard_count: u32) -> Self {
222        Self {
223            shard_count,
224            node_assignments: Arc::new(DashMap::new()),
225            edge_weights: Arc::new(DashMap::new()),
226            adjacency: Arc::new(DashMap::new()),
227        }
228    }
229
230    /// Add an edge to the graph for partitioning consideration
231    pub fn add_edge(&self, from: NodeId, to: NodeId, weight: f64) {
232        self.edge_weights.insert((from.clone(), to.clone()), weight);
233
234        // Update adjacency list
235        self.adjacency
236            .entry(from.clone())
237            .or_insert_with(HashSet::new)
238            .insert(to.clone());
239
240        self.adjacency
241            .entry(to)
242            .or_insert_with(HashSet::new)
243            .insert(from);
244    }
245
246    /// Get the shard assignment for a node
247    pub fn get_shard(&self, node_id: &NodeId) -> Option<ShardId> {
248        self.node_assignments.get(node_id).map(|r| *r.value())
249    }
250
251    /// Compute initial partitioning using multilevel k-way partitioning
252    pub fn compute_partitioning(&self) -> Result<HashMap<NodeId, ShardId>> {
253        info!("Computing edge-cut minimized partitioning");
254
255        let nodes: Vec<_> = self.adjacency.iter().map(|e| e.key().clone()).collect();
256
257        if nodes.is_empty() {
258            return Ok(HashMap::new());
259        }
260
261        // Phase 1: Coarsening - merge highly connected nodes
262        let coarse_graph = self.coarsen_graph(&nodes);
263
264        // Phase 2: Initial partitioning using greedy approach
265        let mut assignments = self.initial_partition(&coarse_graph);
266
267        // Phase 3: Refinement using Kernighan-Lin algorithm
268        self.refine_partition(&mut assignments);
269
270        // Store assignments
271        for (node, shard) in &assignments {
272            self.node_assignments.insert(node.clone(), *shard);
273        }
274
275        info!(
276            "Partitioning complete: {} nodes across {} shards",
277            assignments.len(),
278            self.shard_count
279        );
280
281        Ok(assignments)
282    }
283
284    /// Coarsen the graph by merging highly connected nodes
285    fn coarsen_graph(&self, nodes: &[NodeId]) -> HashMap<NodeId, Vec<NodeId>> {
286        let mut coarse: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
287        let mut visited = HashSet::new();
288
289        for node in nodes {
290            if visited.contains(node) {
291                continue;
292            }
293
294            let mut group = vec![node.clone()];
295            visited.insert(node.clone());
296
297            // Find best matching neighbor based on edge weight
298            if let Some(neighbors) = self.adjacency.get(node) {
299                let mut best_neighbor: Option<(NodeId, f64)> = None;
300
301                for neighbor in neighbors.iter() {
302                    if visited.contains(neighbor) {
303                        continue;
304                    }
305
306                    let weight = self
307                        .edge_weights
308                        .get(&(node.clone(), neighbor.clone()))
309                        .map(|w| *w.value())
310                        .unwrap_or(1.0);
311
312                    if let Some((_, best_weight)) = best_neighbor {
313                        if weight > best_weight {
314                            best_neighbor = Some((neighbor.clone(), weight));
315                        }
316                    } else {
317                        best_neighbor = Some((neighbor.clone(), weight));
318                    }
319                }
320
321                if let Some((neighbor, _)) = best_neighbor {
322                    group.push(neighbor.clone());
323                    visited.insert(neighbor);
324                }
325            }
326
327            let representative = node.clone();
328            coarse.insert(representative, group);
329        }
330
331        coarse
332    }
333
334    /// Initial partition using greedy approach
335    fn initial_partition(
336        &self,
337        coarse_graph: &HashMap<NodeId, Vec<NodeId>>,
338    ) -> HashMap<NodeId, ShardId> {
339        let mut assignments = HashMap::new();
340        let mut shard_sizes: Vec<usize> = vec![0; self.shard_count as usize];
341
342        for (representative, group) in coarse_graph {
343            // Assign to least-loaded shard
344            let shard = shard_sizes
345                .iter()
346                .enumerate()
347                .min_by_key(|(_, size)| *size)
348                .map(|(idx, _)| idx as ShardId)
349                .unwrap_or(0);
350
351            for node in group {
352                assignments.insert(node.clone(), shard);
353                shard_sizes[shard as usize] += 1;
354            }
355        }
356
357        assignments
358    }
359
360    /// Refine partition using simplified Kernighan-Lin algorithm
361    fn refine_partition(&self, assignments: &mut HashMap<NodeId, ShardId>) {
362        const MAX_ITERATIONS: usize = 10;
363        let mut improved = true;
364        let mut iteration = 0;
365
366        while improved && iteration < MAX_ITERATIONS {
367            improved = false;
368            iteration += 1;
369
370            for (node, current_shard) in assignments.clone().iter() {
371                let current_cost = self.compute_node_cost(node, *current_shard, assignments);
372
373                // Try moving to each other shard
374                for target_shard in 0..self.shard_count {
375                    if target_shard == *current_shard {
376                        continue;
377                    }
378
379                    let new_cost = self.compute_node_cost(node, target_shard, assignments);
380
381                    if new_cost < current_cost {
382                        assignments.insert(node.clone(), target_shard);
383                        improved = true;
384                        break;
385                    }
386                }
387            }
388
389            debug!("Refinement iteration {}: improved={}", iteration, improved);
390        }
391    }
392
393    /// Compute the cost (number of cross-shard edges) for a node in a given shard
394    fn compute_node_cost(
395        &self,
396        node: &NodeId,
397        shard: ShardId,
398        assignments: &HashMap<NodeId, ShardId>,
399    ) -> usize {
400        let mut cross_shard_edges = 0;
401
402        if let Some(neighbors) = self.adjacency.get(node) {
403            for neighbor in neighbors.iter() {
404                if let Some(neighbor_shard) = assignments.get(neighbor) {
405                    if *neighbor_shard != shard {
406                        cross_shard_edges += 1;
407                    }
408                }
409            }
410        }
411
412        cross_shard_edges
413    }
414
415    /// Calculate total edge cut across all shards
416    pub fn calculate_edge_cut(&self, assignments: &HashMap<NodeId, ShardId>) -> usize {
417        let mut cut = 0;
418
419        for entry in self.edge_weights.iter() {
420            let ((from, to), _) = entry.pair();
421            let from_shard = assignments.get(from);
422            let to_shard = assignments.get(to);
423
424            if from_shard.is_some() && to_shard.is_some() && from_shard != to_shard {
425                cut += 1;
426            }
427        }
428
429        cut
430    }
431}
432
433/// Graph shard containing partitioned data
434pub struct GraphShard {
435    /// Shard metadata
436    metadata: ShardMetadata,
437    /// Nodes in this shard
438    nodes: Arc<DashMap<NodeId, NodeData>>,
439    /// Edges in this shard (including cross-shard edges)
440    edges: Arc<DashMap<EdgeId, EdgeData>>,
441    /// Partitioning strategy
442    strategy: ShardStrategy,
443}
444
445/// Node data in the graph
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct NodeData {
448    pub id: NodeId,
449    pub properties: HashMap<String, serde_json::Value>,
450    pub labels: Vec<String>,
451}
452
453/// Edge data in the graph
454#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct EdgeData {
456    pub id: EdgeId,
457    pub from: NodeId,
458    pub to: NodeId,
459    pub edge_type: String,
460    pub properties: HashMap<String, serde_json::Value>,
461}
462
463impl GraphShard {
464    /// Create a new graph shard
465    pub fn new(metadata: ShardMetadata) -> Self {
466        let strategy = metadata.strategy;
467        Self {
468            metadata,
469            nodes: Arc::new(DashMap::new()),
470            edges: Arc::new(DashMap::new()),
471            strategy,
472        }
473    }
474
475    /// Add a node to this shard
476    pub fn add_node(&self, node: NodeData) -> Result<()> {
477        self.nodes.insert(node.id.clone(), node);
478        Ok(())
479    }
480
481    /// Add an edge to this shard
482    pub fn add_edge(&self, edge: EdgeData) -> Result<()> {
483        self.edges.insert(edge.id.clone(), edge);
484        Ok(())
485    }
486
487    /// Get a node by ID
488    pub fn get_node(&self, node_id: &NodeId) -> Option<NodeData> {
489        self.nodes.get(node_id).map(|n| n.value().clone())
490    }
491
492    /// Get an edge by ID
493    pub fn get_edge(&self, edge_id: &EdgeId) -> Option<EdgeData> {
494        self.edges.get(edge_id).map(|e| e.value().clone())
495    }
496
497    /// Get shard metadata
498    pub fn metadata(&self) -> &ShardMetadata {
499        &self.metadata
500    }
501
502    /// Get node count
503    pub fn node_count(&self) -> usize {
504        self.nodes.len()
505    }
506
507    /// Get edge count
508    pub fn edge_count(&self) -> usize {
509        self.edges.len()
510    }
511
512    /// List all nodes in this shard
513    pub fn list_nodes(&self) -> Vec<NodeData> {
514        self.nodes.iter().map(|e| e.value().clone()).collect()
515    }
516
517    /// List all edges in this shard
518    pub fn list_edges(&self) -> Vec<EdgeData> {
519        self.edges.iter().map(|e| e.value().clone()).collect()
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_hash_partitioner() {
529        let partitioner = HashPartitioner::new(16);
530
531        let node1 = "node-1".to_string();
532        let node2 = "node-2".to_string();
533
534        let shard1 = partitioner.get_shard(&node1);
535        let shard2 = partitioner.get_shard(&node2);
536
537        assert!(shard1 < 16);
538        assert!(shard2 < 16);
539
540        // Same node should always map to same shard
541        assert_eq!(shard1, partitioner.get_shard(&node1));
542    }
543
544    #[test]
545    fn test_range_partitioner() {
546        let boundaries = vec!["m".to_string(), "z".to_string()];
547        let partitioner = RangePartitioner::with_boundaries(boundaries);
548
549        assert_eq!(partitioner.get_shard(&"apple".to_string()), 0);
550        assert_eq!(partitioner.get_shard(&"orange".to_string()), 1);
551        assert_eq!(partitioner.get_shard(&"zebra".to_string()), 1);
552    }
553
554    #[test]
555    fn test_edge_cut_minimizer() {
556        let minimizer = EdgeCutMinimizer::new(2);
557
558        // Create a simple graph: A-B-C-D
559        minimizer.add_edge("A".to_string(), "B".to_string(), 1.0);
560        minimizer.add_edge("B".to_string(), "C".to_string(), 1.0);
561        minimizer.add_edge("C".to_string(), "D".to_string(), 1.0);
562
563        let assignments = minimizer.compute_partitioning().unwrap();
564        let cut = minimizer.calculate_edge_cut(&assignments);
565
566        // Optimal partitioning should minimize edge cuts
567        assert!(cut <= 2);
568    }
569
570    #[test]
571    fn test_shard_metadata() {
572        let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
573
574        assert_eq!(metadata.shard_id, 0);
575        assert_eq!(metadata.edge_cut_ratio(), 0.0);
576    }
577
578    #[test]
579    fn test_graph_shard() {
580        let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
581        let shard = GraphShard::new(metadata);
582
583        let node = NodeData {
584            id: "test-node".to_string(),
585            properties: HashMap::new(),
586            labels: vec!["TestLabel".to_string()],
587        };
588
589        shard.add_node(node.clone()).unwrap();
590
591        assert_eq!(shard.node_count(), 1);
592        assert!(shard.get_node(&"test-node".to_string()).is_some());
593    }
594}