1use crate::distributed::shard::{EdgeData, GraphShard, NodeData, NodeId, ShardId};
11use crate::{GraphError, Result};
12use chrono::{DateTime, Utc};
13use dashmap::DashMap;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use tracing::{debug, info, warn};
19use uuid::Uuid;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct QueryPlan {
24 pub query_id: String,
26 pub query: String,
28 pub target_shards: Vec<ShardId>,
30 pub steps: Vec<QueryStep>,
32 pub estimated_cost: f64,
34 pub is_distributed: bool,
36 pub created_at: DateTime<Utc>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub enum QueryStep {
43 NodeScan {
45 shard_id: ShardId,
46 label: Option<String>,
47 filter: Option<String>,
48 },
49 EdgeScan {
51 shard_id: ShardId,
52 edge_type: Option<String>,
53 },
54 Join {
56 left_shard: ShardId,
57 right_shard: ShardId,
58 join_key: String,
59 },
60 Aggregate {
62 operation: AggregateOp,
63 group_by: Option<String>,
64 },
65 Filter { predicate: String },
67 Sort { key: String, ascending: bool },
69 Limit { count: usize },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum AggregateOp {
76 Count,
77 Sum(String),
78 Avg(String),
79 Min(String),
80 Max(String),
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct QueryResult {
86 pub query_id: String,
88 pub nodes: Vec<NodeData>,
90 pub edges: Vec<EdgeData>,
92 pub aggregates: HashMap<String, serde_json::Value>,
94 pub stats: QueryStats,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct QueryStats {
101 pub execution_time_ms: u64,
103 pub shards_queried: usize,
105 pub nodes_scanned: usize,
107 pub edges_scanned: usize,
109 pub cached: bool,
111}
112
113pub struct ShardCoordinator {
115 shards: Arc<DashMap<ShardId, Arc<GraphShard>>>,
117 query_cache: Arc<DashMap<String, QueryResult>>,
119 transactions: Arc<DashMap<String, Transaction>>,
121}
122
123impl ShardCoordinator {
124 pub fn new() -> Self {
126 Self {
127 shards: Arc::new(DashMap::new()),
128 query_cache: Arc::new(DashMap::new()),
129 transactions: Arc::new(DashMap::new()),
130 }
131 }
132
133 pub fn register_shard(&self, shard_id: ShardId, shard: Arc<GraphShard>) {
135 info!("Registering shard {} with coordinator", shard_id);
136 self.shards.insert(shard_id, shard);
137 }
138
139 pub fn unregister_shard(&self, shard_id: ShardId) -> Result<()> {
141 info!("Unregistering shard {}", shard_id);
142 self.shards
143 .remove(&shard_id)
144 .ok_or_else(|| GraphError::ShardError(format!("Shard {} not found", shard_id)))?;
145 Ok(())
146 }
147
148 pub fn get_shard(&self, shard_id: ShardId) -> Option<Arc<GraphShard>> {
150 self.shards.get(&shard_id).map(|s| Arc::clone(s.value()))
151 }
152
153 pub fn list_shards(&self) -> Vec<ShardId> {
155 self.shards.iter().map(|e| *e.key()).collect()
156 }
157
158 pub fn plan_query(&self, query: &str) -> Result<QueryPlan> {
160 let query_id = Uuid::new_v4().to_string();
161
162 let target_shards: Vec<ShardId> = self.list_shards();
165
166 let steps = self.parse_query_steps(query)?;
167
168 let estimated_cost = self.estimate_cost(&steps, &target_shards);
169
170 Ok(QueryPlan {
171 query_id,
172 query: query.to_string(),
173 target_shards,
174 steps,
175 estimated_cost,
176 is_distributed: true,
177 created_at: Utc::now(),
178 })
179 }
180
181 fn parse_query_steps(&self, query: &str) -> Result<Vec<QueryStep>> {
183 let mut steps = Vec::new();
186
187 if query.to_lowercase().contains("match") {
189 for shard_id in self.list_shards() {
191 steps.push(QueryStep::NodeScan {
192 shard_id,
193 label: None,
194 filter: None,
195 });
196 }
197 }
198
199 if query.to_lowercase().contains("count") {
201 steps.push(QueryStep::Aggregate {
202 operation: AggregateOp::Count,
203 group_by: None,
204 });
205 }
206
207 if let Some(limit_pos) = query.to_lowercase().find("limit") {
209 if let Some(count_str) = query[limit_pos..].split_whitespace().nth(1) {
210 if let Ok(count) = count_str.parse::<usize>() {
211 steps.push(QueryStep::Limit { count });
212 }
213 }
214 }
215
216 Ok(steps)
217 }
218
219 fn estimate_cost(&self, steps: &[QueryStep], target_shards: &[ShardId]) -> f64 {
221 let mut cost = 0.0;
222
223 for step in steps {
224 match step {
225 QueryStep::NodeScan { .. } => cost += 10.0,
226 QueryStep::EdgeScan { .. } => cost += 15.0,
227 QueryStep::Join { .. } => cost += 50.0,
228 QueryStep::Aggregate { .. } => cost += 20.0,
229 QueryStep::Filter { .. } => cost += 5.0,
230 QueryStep::Sort { .. } => cost += 30.0,
231 QueryStep::Limit { .. } => cost += 1.0,
232 }
233 }
234
235 cost * target_shards.len() as f64
237 }
238
239 pub async fn execute_query(&self, plan: QueryPlan) -> Result<QueryResult> {
241 let start = std::time::Instant::now();
242
243 info!(
244 "Executing query {} across {} shards",
245 plan.query_id,
246 plan.target_shards.len()
247 );
248
249 if let Some(cached) = self.query_cache.get(&plan.query) {
251 debug!("Query cache hit for: {}", plan.query);
252 return Ok(cached.value().clone());
253 }
254
255 let mut nodes = Vec::new();
256 let mut edges = Vec::new();
257 let mut aggregates = HashMap::new();
258 let mut nodes_scanned = 0;
259 let mut edges_scanned = 0;
260
261 for step in &plan.steps {
263 match step {
264 QueryStep::NodeScan {
265 shard_id,
266 label,
267 filter,
268 } => {
269 if let Some(shard) = self.get_shard(*shard_id) {
270 let shard_nodes = shard.list_nodes();
271 nodes_scanned += shard_nodes.len();
272
273 let filtered: Vec<_> = if let Some(label_filter) = label {
275 shard_nodes
276 .into_iter()
277 .filter(|n| n.labels.contains(label_filter))
278 .collect()
279 } else {
280 shard_nodes
281 };
282
283 nodes.extend(filtered);
284 }
285 }
286 QueryStep::EdgeScan {
287 shard_id,
288 edge_type,
289 } => {
290 if let Some(shard) = self.get_shard(*shard_id) {
291 let shard_edges = shard.list_edges();
292 edges_scanned += shard_edges.len();
293
294 let filtered: Vec<_> = if let Some(type_filter) = edge_type {
296 shard_edges
297 .into_iter()
298 .filter(|e| &e.edge_type == type_filter)
299 .collect()
300 } else {
301 shard_edges
302 };
303
304 edges.extend(filtered);
305 }
306 }
307 QueryStep::Aggregate {
308 operation,
309 group_by,
310 } => {
311 match operation {
312 AggregateOp::Count => {
313 aggregates.insert(
314 "count".to_string(),
315 serde_json::Value::Number(nodes.len().into()),
316 );
317 }
318 _ => {
319 }
321 }
322 }
323 QueryStep::Limit { count } => {
324 nodes.truncate(*count);
325 }
326 _ => {
327 }
329 }
330 }
331
332 let execution_time_ms = start.elapsed().as_millis() as u64;
333
334 let result = QueryResult {
335 query_id: plan.query_id.clone(),
336 nodes,
337 edges,
338 aggregates,
339 stats: QueryStats {
340 execution_time_ms,
341 shards_queried: plan.target_shards.len(),
342 nodes_scanned,
343 edges_scanned,
344 cached: false,
345 },
346 };
347
348 self.query_cache.insert(plan.query.clone(), result.clone());
350
351 info!(
352 "Query {} completed in {}ms",
353 plan.query_id, execution_time_ms
354 );
355
356 Ok(result)
357 }
358
359 pub fn begin_transaction(&self) -> String {
361 let tx_id = Uuid::new_v4().to_string();
362 let transaction = Transaction::new(tx_id.clone());
363 self.transactions.insert(tx_id.clone(), transaction);
364 info!("Started transaction: {}", tx_id);
365 tx_id
366 }
367
368 pub async fn commit_transaction(&self, tx_id: &str) -> Result<()> {
370 if let Some((_, tx)) = self.transactions.remove(tx_id) {
371 info!("Committing transaction: {}", tx_id);
373 Ok(())
374 } else {
375 Err(GraphError::CoordinatorError(format!(
376 "Transaction not found: {}",
377 tx_id
378 )))
379 }
380 }
381
382 pub async fn rollback_transaction(&self, tx_id: &str) -> Result<()> {
384 if let Some((_, tx)) = self.transactions.remove(tx_id) {
385 warn!("Rolling back transaction: {}", tx_id);
386 Ok(())
387 } else {
388 Err(GraphError::CoordinatorError(format!(
389 "Transaction not found: {}",
390 tx_id
391 )))
392 }
393 }
394
395 pub fn clear_cache(&self) {
397 self.query_cache.clear();
398 info!("Query cache cleared");
399 }
400}
401
402#[derive(Debug, Clone)]
404struct Transaction {
405 id: String,
407 shards: HashSet<ShardId>,
409 state: TransactionState,
411 created_at: DateTime<Utc>,
413}
414
415impl Transaction {
416 fn new(id: String) -> Self {
417 Self {
418 id,
419 shards: HashSet::new(),
420 state: TransactionState::Active,
421 created_at: Utc::now(),
422 }
423 }
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428enum TransactionState {
429 Active,
430 Preparing,
431 Committed,
432 Aborted,
433}
434
435pub struct Coordinator {
437 shard_coordinator: Arc<ShardCoordinator>,
439 config: CoordinatorConfig,
441}
442
443impl Coordinator {
444 pub fn new(config: CoordinatorConfig) -> Self {
446 Self {
447 shard_coordinator: Arc::new(ShardCoordinator::new()),
448 config,
449 }
450 }
451
452 pub fn shard_coordinator(&self) -> Arc<ShardCoordinator> {
454 Arc::clone(&self.shard_coordinator)
455 }
456
457 pub async fn execute(&self, query: &str) -> Result<QueryResult> {
459 let plan = self.shard_coordinator.plan_query(query)?;
460 self.shard_coordinator.execute_query(plan).await
461 }
462
463 pub fn config(&self) -> &CoordinatorConfig {
465 &self.config
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub struct CoordinatorConfig {
472 pub enable_cache: bool,
474 pub cache_ttl_seconds: u64,
476 pub max_query_time_seconds: u64,
478 pub enable_optimization: bool,
480}
481
482impl Default for CoordinatorConfig {
483 fn default() -> Self {
484 Self {
485 enable_cache: true,
486 cache_ttl_seconds: 300,
487 max_query_time_seconds: 60,
488 enable_optimization: true,
489 }
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::distributed::shard::ShardMetadata;
497 use crate::distributed::shard::ShardStrategy;
498
499 #[tokio::test]
500 async fn test_shard_coordinator() {
501 let coordinator = ShardCoordinator::new();
502
503 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
504 let shard = Arc::new(GraphShard::new(metadata));
505
506 coordinator.register_shard(0, shard);
507
508 assert_eq!(coordinator.list_shards().len(), 1);
509 assert!(coordinator.get_shard(0).is_some());
510 }
511
512 #[tokio::test]
513 async fn test_query_planning() {
514 let coordinator = ShardCoordinator::new();
515
516 let metadata = ShardMetadata::new(0, "node-1".to_string(), ShardStrategy::Hash);
517 let shard = Arc::new(GraphShard::new(metadata));
518 coordinator.register_shard(0, shard);
519
520 let plan = coordinator.plan_query("MATCH (n:Person) RETURN n").unwrap();
521
522 assert!(!plan.query_id.is_empty());
523 assert!(!plan.steps.is_empty());
524 }
525
526 #[tokio::test]
527 async fn test_transaction() {
528 let coordinator = ShardCoordinator::new();
529
530 let tx_id = coordinator.begin_transaction();
531 assert!(!tx_id.is_empty());
532
533 coordinator.commit_transaction(&tx_id).await.unwrap();
534 }
535}