rust_logic_graph/multi_db/
parallel.rs

1use std::collections::HashMap;
2use std::time::Instant;
3use std::pin::Pin;
4use std::future::Future;
5use serde_json::Value;
6use tokio::task::JoinSet;
7use tracing::{info, debug};
8
9use crate::error::{RustLogicGraphError, ErrorContext};
10
11type BoxedFuture = Pin<Box<dyn Future<Output = Result<Value, RustLogicGraphError>> + Send>>;
12
13/// Result from a single database query
14#[derive(Debug, Clone)]
15pub struct QueryResult {
16    pub database: String,
17    pub query: String,
18    pub result: Value,
19    pub duration_ms: u128,
20    pub row_count: usize,
21}
22
23/// Parallel Database Executor
24/// 
25/// Executes multiple database queries concurrently across different databases,
26/// collecting results and providing detailed execution statistics.
27/// 
28/// # Example
29/// ```no_run
30/// use rust_logic_graph::multi_db::ParallelDBExecutor;
31/// 
32/// #[tokio::main]
33/// async fn main() -> anyhow::Result<()> {
34///     let mut executor = ParallelDBExecutor::new();
35///     
36///     // Register query closures for different databases
37///     executor
38///         .add_query("oms_db", "user_query", Box::new(|| Box::pin(async {
39///             // Execute query against OMS database
40///             Ok(serde_json::json!({"user_id": 123, "name": "John"}))
41///         })))
42///         .add_query("inventory_db", "stock_query", Box::new(|| Box::pin(async {
43///             // Execute query against Inventory database
44///             Ok(serde_json::json!({"product_id": "PROD-001", "qty": 50}))
45///         })));
46///     
47///     let results = executor.execute_all().await?;
48///     println!("Executed {} queries in parallel", results.len());
49///     Ok(())
50/// }
51/// ```
52pub struct ParallelDBExecutor {
53    queries: Vec<(String, String, BoxedFuture)>,
54    max_concurrent: usize,
55}
56
57impl ParallelDBExecutor {
58    /// Create a new parallel database executor
59    pub fn new() -> Self {
60        Self {
61            queries: Vec::new(),
62            max_concurrent: 10, // Default: 10 concurrent queries
63        }
64    }
65    
66    /// Set maximum number of concurrent queries
67    pub fn with_max_concurrent(mut self, max: usize) -> Self {
68        self.max_concurrent = max;
69        self
70    }
71    
72    /// Add a query to execute
73    /// 
74    /// # Arguments
75    /// * `database` - Database identifier (e.g., "oms_db", "inventory_db")
76    /// * `query_id` - Unique query identifier for tracking
77    /// * `query_fn` - Async closure that executes the query
78    pub fn add_query<F, Fut>(
79        &mut self,
80        database: impl Into<String>,
81        query_id: impl Into<String>,
82        query_fn: F,
83    ) -> &mut Self
84    where
85        F: FnOnce() -> Fut + Send + 'static,
86        Fut: Future<Output = Result<Value, RustLogicGraphError>> + Send + 'static,
87    {
88        let db = database.into();
89        let qid = query_id.into();
90        
91        // Box the future
92        let boxed: BoxedFuture = Box::pin(query_fn());
93        self.queries.push((db, qid, boxed));
94        self
95    }
96    
97    /// Execute all registered queries in parallel
98    /// 
99    /// Returns a HashMap with query_id as key and QueryResult as value.
100    /// 
101    /// # Errors
102    /// If any query fails, returns the first encountered error.
103    pub async fn execute_all(&mut self) -> Result<HashMap<String, QueryResult>, RustLogicGraphError> {
104        let total_start = Instant::now();
105        let query_count = self.queries.len();
106        
107        info!("🚀 Parallel DB Executor: Starting {} queries across databases", query_count);
108        
109        if query_count == 0 {
110            return Ok(HashMap::new());
111        }
112        
113        let mut join_set = JoinSet::new();
114        let mut results = HashMap::new();
115        
116        // Take ownership of queries
117        let queries = std::mem::take(&mut self.queries);
118        
119        // Spawn all queries as concurrent tasks
120        for (database, query_id, query_future) in queries {
121            let db_clone = database.clone();
122            let qid_clone = query_id.clone();
123            
124            join_set.spawn(async move {
125                let start = Instant::now();
126                debug!("⏱️  Executing query '{}' on database '{}'", qid_clone, db_clone);
127                
128                match query_future.await {
129                    Ok(result) => {
130                        let duration_ms = start.elapsed().as_millis();
131                        let row_count = if result.is_array() {
132                            result.as_array().map(|arr| arr.len()).unwrap_or(0)
133                        } else if result.is_object() {
134                            1
135                        } else {
136                            0
137                        };
138                        
139                        debug!("✅ Query '{}' completed in {}ms ({} rows)", qid_clone, duration_ms, row_count);
140                        
141                        Ok((query_id, QueryResult {
142                            database: db_clone,
143                            query: qid_clone,
144                            result,
145                            duration_ms,
146                            row_count,
147                        }))
148                    }
149                    Err(e) => {
150                        Err(e.with_context(
151                            ErrorContext::new()
152                                .with_service(&db_clone)
153                                .add_metadata("query_id", &qid_clone)
154                        ))
155                    }
156                }
157            });
158        }
159        
160        // Collect all results
161        while let Some(task_result) = join_set.join_next().await {
162            match task_result {
163                Ok(Ok((query_id, query_result))) => {
164                    results.insert(query_id, query_result);
165                }
166                Ok(Err(e)) => {
167                    // Cancel remaining tasks and return error
168                    join_set.abort_all();
169                    return Err(e);
170                }
171                Err(join_err) => {
172                    join_set.abort_all();
173                    return Err(RustLogicGraphError::node_execution_error(
174                        "parallel_executor",
175                        format!("Task join error: {}", join_err)
176                    ));
177                }
178            }
179        }
180        
181        let total_duration_ms = total_start.elapsed().as_millis();
182        let total_rows: usize = results.values().map(|r| r.row_count).sum();
183        
184        info!("✅ Parallel DB Executor: Completed {} queries in {}ms ({} total rows)", 
185            query_count, total_duration_ms, total_rows);
186        
187        // Log per-database statistics
188        let mut db_stats: HashMap<String, (usize, u128)> = HashMap::new();
189        for result in results.values() {
190            let entry = db_stats.entry(result.database.clone()).or_insert((0, 0));
191            entry.0 += 1; // query count
192            entry.1 += result.duration_ms; // total duration
193        }
194        
195        for (db, (count, duration)) in db_stats {
196            info!("  📊 {}: {} queries, {}ms total", db, count, duration);
197        }
198        
199        Ok(results)
200    }
201}
202
203impl Default for ParallelDBExecutor {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use serde_json::json;
213    
214    #[tokio::test]
215    async fn test_parallel_execution() {
216        let mut executor = ParallelDBExecutor::new();
217        
218        executor
219            .add_query("db1", "query1", || async {
220                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
221                Ok(json!({"result": 1}))
222            })
223            .add_query("db2", "query2", || async {
224                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
225                Ok(json!({"result": 2}))
226            });
227        
228        let results = executor.execute_all().await.unwrap();
229        assert_eq!(results.len(), 2);
230        assert!(results.contains_key("query1"));
231        assert!(results.contains_key("query2"));
232    }
233    
234    #[tokio::test]
235    async fn test_error_propagation() {
236        let mut executor = ParallelDBExecutor::new();
237        
238        executor
239            .add_query("db1", "query1", || async {
240                Ok(json!({"result": 1}))
241            })
242            .add_query("db2", "failing_query", || async {
243                Err(RustLogicGraphError::database_connection_error("Test error"))
244            });
245        
246        let result = executor.execute_all().await;
247        assert!(result.is_err());
248    }
249}