ruvector_graph/distributed/
coordinator.rs

1//! Query coordinator for distributed graph execution
2//!
3//! Coordinates distributed query execution across multiple shards:
4//! - Query planning and optimization
5//! - Query routing to relevant shards
6//! - Result aggregation and merging
7//! - Transaction coordination across shards
8//! - Query caching and optimization
9
10use crate::distributed::shard::{EdgeData, GraphShard, NodeData, NodeId, ShardId};
11use crate::{GraphError, Result};
12use chrono::{DateTime, Utc};
13use dashmap::DashMap;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21/// Query execution plan
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct QueryPlan {
24    /// Unique query ID
25    pub query_id: String,
26    /// Original query (Cypher-like syntax)
27    pub query: String,
28    /// Shards involved in this query
29    pub target_shards: Vec<ShardId>,
30    /// Execution steps
31    pub steps: Vec<QueryStep>,
32    /// Estimated cost
33    pub estimated_cost: f64,
34    /// Whether this is a distributed query
35    pub is_distributed: bool,
36    /// Creation timestamp
37    pub created_at: DateTime<Utc>,
38}
39
40/// Individual step in query execution
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum QueryStep {
43    /// Scan nodes with optional filter
44    NodeScan {
45        shard_id: ShardId,
46        label: Option<String>,
47        filter: Option<String>,
48    },
49    /// Scan edges
50    EdgeScan {
51        shard_id: ShardId,
52        edge_type: Option<String>,
53    },
54    /// Join results from multiple shards
55    Join {
56        left_shard: ShardId,
57        right_shard: ShardId,
58        join_key: String,
59    },
60    /// Aggregate results
61    Aggregate {
62        operation: AggregateOp,
63        group_by: Option<String>,
64    },
65    /// Filter results
66    Filter { predicate: String },
67    /// Sort results
68    Sort { key: String, ascending: bool },
69    /// Limit results
70    Limit { count: usize },
71}
72
73/// Aggregate operations
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum AggregateOp {
76    Count,
77    Sum(String),
78    Avg(String),
79    Min(String),
80    Max(String),
81}
82
83/// Query result
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct QueryResult {
86    /// Query ID
87    pub query_id: String,
88    /// Result nodes
89    pub nodes: Vec<NodeData>,
90    /// Result edges
91    pub edges: Vec<EdgeData>,
92    /// Aggregate results
93    pub aggregates: HashMap<String, serde_json::Value>,
94    /// Execution statistics
95    pub stats: QueryStats,
96}
97
98/// Query execution statistics
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct QueryStats {
101    /// Execution time in milliseconds
102    pub execution_time_ms: u64,
103    /// Number of shards queried
104    pub shards_queried: usize,
105    /// Total nodes scanned
106    pub nodes_scanned: usize,
107    /// Total edges scanned
108    pub edges_scanned: usize,
109    /// Whether query was cached
110    pub cached: bool,
111}
112
113/// Shard coordinator for managing distributed queries
114pub struct ShardCoordinator {
115    /// Map of shard_id to GraphShard
116    shards: Arc<DashMap<ShardId, Arc<GraphShard>>>,
117    /// Query cache
118    query_cache: Arc<DashMap<String, QueryResult>>,
119    /// Active transactions
120    transactions: Arc<DashMap<String, Transaction>>,
121}
122
123impl ShardCoordinator {
124    /// Create a new shard coordinator
125    pub fn new() -> Self {
126        Self {
127            shards: Arc::new(DashMap::new()),
128            query_cache: Arc::new(DashMap::new()),
129            transactions: Arc::new(DashMap::new()),
130        }
131    }
132
133    /// Register a shard with the coordinator
134    pub fn register_shard(&self, shard_id: ShardId, shard: Arc<GraphShard>) {
135        info!("Registering shard {} with coordinator", shard_id);
136        self.shards.insert(shard_id, shard);
137    }
138
139    /// Unregister a shard
140    pub fn unregister_shard(&self, shard_id: ShardId) -> Result<()> {
141        info!("Unregistering shard {}", shard_id);
142        self.shards
143            .remove(&shard_id)
144            .ok_or_else(|| GraphError::ShardError(format!("Shard {} not found", shard_id)))?;
145        Ok(())
146    }
147
148    /// Get a shard by ID
149    pub fn get_shard(&self, shard_id: ShardId) -> Option<Arc<GraphShard>> {
150        self.shards.get(&shard_id).map(|s| Arc::clone(s.value()))
151    }
152
153    /// List all registered shards
154    pub fn list_shards(&self) -> Vec<ShardId> {
155        self.shards.iter().map(|e| *e.key()).collect()
156    }
157
158    /// Create a query plan from a Cypher-like query
159    pub fn plan_query(&self, query: &str) -> Result<QueryPlan> {
160        let query_id = Uuid::new_v4().to_string();
161
162        // Parse query and determine target shards
163        // For now, simple heuristic: query all shards for distributed queries
164        let target_shards: Vec<ShardId> = self.list_shards();
165
166        let steps = self.parse_query_steps(query)?;
167
168        let estimated_cost = self.estimate_cost(&steps, &target_shards);
169
170        Ok(QueryPlan {
171            query_id,
172            query: query.to_string(),
173            target_shards,
174            steps,
175            estimated_cost,
176            is_distributed: true,
177            created_at: Utc::now(),
178        })
179    }
180
181    /// Parse query into execution steps
182    fn parse_query_steps(&self, query: &str) -> Result<Vec<QueryStep>> {
183        // Simplified query parsing
184        // In production, use a proper Cypher parser
185        let mut steps = Vec::new();
186
187        // Example: "MATCH (n:Person) RETURN n"
188        if query.to_lowercase().contains("match") {
189            // Add node scan for each shard
190            for shard_id in self.list_shards() {
191                steps.push(QueryStep::NodeScan {
192                    shard_id,
193                    label: None,
194                    filter: None,
195                });
196            }
197        }
198
199        // Add aggregation if needed
200        if query.to_lowercase().contains("count") {
201            steps.push(QueryStep::Aggregate {
202                operation: AggregateOp::Count,
203                group_by: None,
204            });
205        }
206
207        // Add limit if specified
208        if let Some(limit_pos) = query.to_lowercase().find("limit") {
209            if let Some(count_str) = query[limit_pos..].split_whitespace().nth(1) {
210                if let Ok(count) = count_str.parse::<usize>() {
211                    steps.push(QueryStep::Limit { count });
212                }
213            }
214        }
215
216        Ok(steps)
217    }
218
219    /// Estimate query execution cost
220    fn estimate_cost(&self, steps: &[QueryStep], target_shards: &[ShardId]) -> f64 {
221        let mut cost = 0.0;
222
223        for step in steps {
224            match step {
225                QueryStep::NodeScan { .. } => cost += 10.0,
226                QueryStep::EdgeScan { .. } => cost += 15.0,
227                QueryStep::Join { .. } => cost += 50.0,
228                QueryStep::Aggregate { .. } => cost += 20.0,
229                QueryStep::Filter { .. } => cost += 5.0,
230                QueryStep::Sort { .. } => cost += 30.0,
231                QueryStep::Limit { .. } => cost += 1.0,
232            }
233        }
234
235        // Multiply by number of shards for distributed queries
236        cost * target_shards.len() as f64
237    }
238
239    /// Execute a query plan
240    pub async fn execute_query(&self, plan: QueryPlan) -> Result<QueryResult> {
241        let start = std::time::Instant::now();
242
243        info!(
244            "Executing query {} across {} shards",
245            plan.query_id,
246            plan.target_shards.len()
247        );
248
249        // Check cache first
250        if let Some(cached) = self.query_cache.get(&plan.query) {
251            debug!("Query cache hit for: {}", plan.query);
252            return Ok(cached.value().clone());
253        }
254
255        let mut nodes = Vec::new();
256        let mut edges = Vec::new();
257        let mut aggregates = HashMap::new();
258        let mut nodes_scanned = 0;
259        let mut edges_scanned = 0;
260
261        // Execute steps
262        for step in &plan.steps {
263            match step {
264                QueryStep::NodeScan {
265                    shard_id,
266                    label,
267                    filter,
268                } => {
269                    if let Some(shard) = self.get_shard(*shard_id) {
270                        let shard_nodes = shard.list_nodes();
271                        nodes_scanned += shard_nodes.len();
272
273                        // Apply label filter
274                        let filtered: Vec<_> = if let Some(label_filter) = label {
275                            shard_nodes
276                                .into_iter()
277                                .filter(|n| n.labels.contains(label_filter))
278                                .collect()
279                        } else {
280                            shard_nodes
281                        };
282
283                        nodes.extend(filtered);
284                    }
285                }
286                QueryStep::EdgeScan {
287                    shard_id,
288                    edge_type,
289                } => {
290                    if let Some(shard) = self.get_shard(*shard_id) {
291                        let shard_edges = shard.list_edges();
292                        edges_scanned += shard_edges.len();
293
294                        // Apply edge type filter
295                        let filtered: Vec<_> = if let Some(type_filter) = edge_type {
296                            shard_edges
297                                .into_iter()
298                                .filter(|e| &e.edge_type == type_filter)
299                                .collect()
300                        } else {
301                            shard_edges
302                        };
303
304                        edges.extend(filtered);
305                    }
306                }
307                QueryStep::Aggregate {
308                    operation,
309                    group_by,
310                } => {
311                    match operation {
312                        AggregateOp::Count => {
313                            aggregates.insert(
314                                "count".to_string(),
315                                serde_json::Value::Number(nodes.len().into()),
316                            );
317                        }
318                        _ => {
319                            // Implement other aggregations
320                        }
321                    }
322                }
323                QueryStep::Limit { count } => {
324                    nodes.truncate(*count);
325                }
326                _ => {
327                    // Implement other steps
328                }
329            }
330        }
331
332        let execution_time_ms = start.elapsed().as_millis() as u64;
333
334        let result = QueryResult {
335            query_id: plan.query_id.clone(),
336            nodes,
337            edges,
338            aggregates,
339            stats: QueryStats {
340                execution_time_ms,
341                shards_queried: plan.target_shards.len(),
342                nodes_scanned,
343                edges_scanned,
344                cached: false,
345            },
346        };
347
348        // Cache the result
349        self.query_cache.insert(plan.query.clone(), result.clone());
350
351        info!(
352            "Query {} completed in {}ms",
353            plan.query_id, execution_time_ms
354        );
355
356        Ok(result)
357    }
358
359    /// Begin a distributed transaction
360    pub fn begin_transaction(&self) -> String {
361        let tx_id = Uuid::new_v4().to_string();
362        let transaction = Transaction::new(tx_id.clone());
363        self.transactions.insert(tx_id.clone(), transaction);
364        info!("Started transaction: {}", tx_id);
365        tx_id
366    }
367
368    /// Commit a transaction
369    pub async fn commit_transaction(&self, tx_id: &str) -> Result<()> {
370        if let Some((_, tx)) = self.transactions.remove(tx_id) {
371            // In production, implement 2PC (Two-Phase Commit)
372            info!("Committing transaction: {}", tx_id);
373            Ok(())
374        } else {
375            Err(GraphError::CoordinatorError(format!(
376                "Transaction not found: {}",
377                tx_id
378            )))
379        }
380    }
381
382    /// Rollback a transaction
383    pub async fn rollback_transaction(&self, tx_id: &str) -> Result<()> {
384        if let Some((_, tx)) = self.transactions.remove(tx_id) {
385            warn!("Rolling back transaction: {}", tx_id);
386            Ok(())
387        } else {
388            Err(GraphError::CoordinatorError(format!(
389                "Transaction not found: {}",
390                tx_id
391            )))
392        }
393    }
394
395    /// Clear query cache
396    pub fn clear_cache(&self) {
397        self.query_cache.clear();
398        info!("Query cache cleared");
399    }
400}
401
402/// Distributed transaction
403#[derive(Debug, Clone)]
404struct Transaction {
405    /// Transaction ID
406    id: String,
407    /// Participating shards
408    shards: HashSet<ShardId>,
409    /// Transaction state
410    state: TransactionState,
411    /// Created timestamp
412    created_at: DateTime<Utc>,
413}
414
415impl Transaction {
416    fn new(id: String) -> Self {
417        Self {
418            id,
419            shards: HashSet::new(),
420            state: TransactionState::Active,
421            created_at: Utc::now(),
422        }
423    }
424}
425
426/// Transaction state
427#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428enum TransactionState {
429    Active,
430    Preparing,
431    Committed,
432    Aborted,
433}
434
435/// Main coordinator for the entire distributed graph system
436pub struct Coordinator {
437    /// Shard coordinator
438    shard_coordinator: Arc<ShardCoordinator>,
439    /// Coordinator configuration
440    config: CoordinatorConfig,
441}
442
443impl Coordinator {
444    /// Create a new coordinator
445    pub fn new(config: CoordinatorConfig) -> Self {
446        Self {
447            shard_coordinator: Arc::new(ShardCoordinator::new()),
448            config,
449        }
450    }
451
452    /// Get the shard coordinator
453    pub fn shard_coordinator(&self) -> Arc<ShardCoordinator> {
454        Arc::clone(&self.shard_coordinator)
455    }
456
457    /// Execute a query
458    pub async fn execute(&self, query: &str) -> Result<QueryResult> {
459        let plan = self.shard_coordinator.plan_query(query)?;
460        self.shard_coordinator.execute_query(plan).await
461    }
462
463    /// Get configuration
464    pub fn config(&self) -> &CoordinatorConfig {
465        &self.config
466    }
467}
468
469/// Coordinator configuration
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct CoordinatorConfig {
472    /// Enable query caching
473    pub enable_cache: bool,
474    /// Cache TTL in seconds
475    pub cache_ttl_seconds: u64,
476    /// Maximum query execution time
477    pub max_query_time_seconds: u64,
478    /// Enable query optimization
479    pub enable_optimization: bool,
480}
481
482impl Default for CoordinatorConfig {
483    fn default() -> Self {
484        Self {
485            enable_cache: true,
486            cache_ttl_seconds: 300,
487            max_query_time_seconds: 60,
488            enable_optimization: true,
489        }
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::distributed::shard::ShardMetadata;
497    use crate::distributed::shard::ShardStrategy;
498
499    #[tokio::test]
500    async fn test_shard_coordinator() {
501        let coordinator = ShardCoordinator::new();
502
503        let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
504        let shard = Arc::new(GraphShard::new(metadata));
505
506        coordinator.register_shard(0, shard);
507
508        assert_eq!(coordinator.list_shards().len(), 1);
509        assert!(coordinator.get_shard(0).is_some());
510    }
511
512    #[tokio::test]
513    async fn test_query_planning() {
514        let coordinator = ShardCoordinator::new();
515
516        let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
517        let shard = Arc::new(GraphShard::new(metadata));
518        coordinator.register_shard(0, shard);
519
520        let plan = coordinator.plan_query("MATCH (n:Person) RETURN n").unwrap();
521
522        assert!(!plan.query_id.is_empty());
523        assert!(!plan.steps.is_empty());
524    }
525
526    #[tokio::test]
527    async fn test_transaction() {
528        let coordinator = ShardCoordinator::new();
529
530        let tx_id = coordinator.begin_transaction();
531        assert!(!tx_id.is_empty());
532
533        coordinator.commit_transaction(&tx_id).await.unwrap();
534    }
535}