rust_logic_graph/multi_db/
parallel.rs

1use serde_json::Value;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::Instant;
6use tokio::task::JoinSet;
7use tracing::{debug, info};
8
9use crate::error::{ErrorContext, RustLogicGraphError};
10
11type BoxedFuture = Pin<Box<dyn Future<Output = Result<Value, RustLogicGraphError>> + Send>>;
12
13/// Result from a single database query
14#[derive(Debug, Clone)]
15pub struct QueryResult {
16    pub database: String,
17    pub query: String,
18    pub result: Value,
19    pub duration_ms: u128,
20    pub row_count: usize,
21}
22
23/// Parallel Database Executor
24///
25/// Executes multiple database queries concurrently across different databases,
26/// collecting results and providing detailed execution statistics.
27///
28/// # Example
29/// ```no_run
30/// use rust_logic_graph::multi_db::ParallelDBExecutor;
31///
32/// #[tokio::main]
33/// async fn main() -> anyhow::Result<()> {
34///     let mut executor = ParallelDBExecutor::new();
35///     
36///     // Register query closures for different databases
37///     executor
38///         .add_query("oms_db", "user_query", Box::new(|| Box::pin(async {
39///             // Execute query against OMS database
40///             Ok(serde_json::json!({"user_id": 123, "name": "John"}))
41///         })))
42///         .add_query("inventory_db", "stock_query", Box::new(|| Box::pin(async {
43///             // Execute query against Inventory database
44///             Ok(serde_json::json!({"product_id": "PROD-001", "qty": 50}))
45///         })));
46///     
47///     let results = executor.execute_all().await?;
48///     println!("Executed {} queries in parallel", results.len());
49///     Ok(())
50/// }
51/// ```
52pub struct ParallelDBExecutor {
53    queries: Vec<(String, String, BoxedFuture)>,
54    max_concurrent: usize,
55}
56
57impl ParallelDBExecutor {
58    /// Create a new parallel database executor
59    pub fn new() -> Self {
60        Self {
61            queries: Vec::new(),
62            max_concurrent: 10, // Default: 10 concurrent queries
63        }
64    }
65
66    /// Set maximum number of concurrent queries
67    pub fn with_max_concurrent(mut self, max: usize) -> Self {
68        self.max_concurrent = max;
69        self
70    }
71
72    /// Add a query to execute
73    ///
74    /// # Arguments
75    /// * `database` - Database identifier (e.g., "oms_db", "inventory_db")
76    /// * `query_id` - Unique query identifier for tracking
77    /// * `query_fn` - Async closure that executes the query
78    pub fn add_query<F, Fut>(
79        &mut self,
80        database: impl Into<String>,
81        query_id: impl Into<String>,
82        query_fn: F,
83    ) -> &mut Self
84    where
85        F: FnOnce() -> Fut + Send + 'static,
86        Fut: Future<Output = Result<Value, RustLogicGraphError>> + Send + 'static,
87    {
88        let db = database.into();
89        let qid = query_id.into();
90
91        // Box the future
92        let boxed: BoxedFuture = Box::pin(query_fn());
93        self.queries.push((db, qid, boxed));
94        self
95    }
96
97    /// Execute all registered queries in parallel
98    ///
99    /// Returns a HashMap with query_id as key and QueryResult as value.
100    ///
101    /// # Errors
102    /// If any query fails, returns the first encountered error.
103    pub async fn execute_all(
104        &mut self,
105    ) -> Result<HashMap<String, QueryResult>, RustLogicGraphError> {
106        let total_start = Instant::now();
107        let query_count = self.queries.len();
108
109        info!(
110            "🚀 Parallel DB Executor: Starting {} queries across databases",
111            query_count
112        );
113
114        if query_count == 0 {
115            return Ok(HashMap::new());
116        }
117
118        let mut join_set = JoinSet::new();
119        let mut results = HashMap::new();
120
121        // Take ownership of queries
122        let queries = std::mem::take(&mut self.queries);
123
124        // Spawn all queries as concurrent tasks
125        for (database, query_id, query_future) in queries {
126            let db_clone = database.clone();
127            let qid_clone = query_id.clone();
128
129            join_set.spawn(async move {
130                let start = Instant::now();
131                debug!(
132                    "⏱️  Executing query '{}' on database '{}'",
133                    qid_clone, db_clone
134                );
135
136                match query_future.await {
137                    Ok(result) => {
138                        let duration_ms = start.elapsed().as_millis();
139                        let row_count = if result.is_array() {
140                            result.as_array().map(|arr| arr.len()).unwrap_or(0)
141                        } else if result.is_object() {
142                            1
143                        } else {
144                            0
145                        };
146
147                        debug!(
148                            "✅ Query '{}' completed in {}ms ({} rows)",
149                            qid_clone, duration_ms, row_count
150                        );
151
152                        Ok((
153                            query_id,
154                            QueryResult {
155                                database: db_clone,
156                                query: qid_clone,
157                                result,
158                                duration_ms,
159                                row_count,
160                            },
161                        ))
162                    }
163                    Err(e) => Err(e.with_context(
164                        ErrorContext::new()
165                            .with_service(&db_clone)
166                            .add_metadata("query_id", &qid_clone),
167                    )),
168                }
169            });
170        }
171
172        // Collect all results
173        while let Some(task_result) = join_set.join_next().await {
174            match task_result {
175                Ok(Ok((query_id, query_result))) => {
176                    results.insert(query_id, query_result);
177                }
178                Ok(Err(e)) => {
179                    // Cancel remaining tasks and return error
180                    join_set.abort_all();
181                    return Err(e);
182                }
183                Err(join_err) => {
184                    join_set.abort_all();
185                    return Err(RustLogicGraphError::node_execution_error(
186                        "parallel_executor",
187                        format!("Task join error: {}", join_err),
188                    ));
189                }
190            }
191        }
192
193        let total_duration_ms = total_start.elapsed().as_millis();
194        let total_rows: usize = results.values().map(|r| r.row_count).sum();
195
196        info!(
197            "✅ Parallel DB Executor: Completed {} queries in {}ms ({} total rows)",
198            query_count, total_duration_ms, total_rows
199        );
200
201        // Log per-database statistics
202        let mut db_stats: HashMap<String, (usize, u128)> = HashMap::new();
203        for result in results.values() {
204            let entry = db_stats.entry(result.database.clone()).or_insert((0, 0));
205            entry.0 += 1; // query count
206            entry.1 += result.duration_ms; // total duration
207        }
208
209        for (db, (count, duration)) in db_stats {
210            info!("  📊 {}: {} queries, {}ms total", db, count, duration);
211        }
212
213        Ok(results)
214    }
215}
216
217impl Default for ParallelDBExecutor {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use serde_json::json;
227
228    #[tokio::test]
229    async fn test_parallel_execution() {
230        let mut executor = ParallelDBExecutor::new();
231
232        executor
233            .add_query("db1", "query1", || async {
234                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
235                Ok(json!({"result": 1}))
236            })
237            .add_query("db2", "query2", || async {
238                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
239                Ok(json!({"result": 2}))
240            });
241
242        let results = executor.execute_all().await.unwrap();
243        assert_eq!(results.len(), 2);
244        assert!(results.contains_key("query1"));
245        assert!(results.contains_key("query2"));
246    }
247
248    #[tokio::test]
249    async fn test_error_propagation() {
250        let mut executor = ParallelDBExecutor::new();
251
252        executor
253            .add_query("db1", "query1", || async { Ok(json!({"result": 1})) })
254            .add_query("db2", "failing_query", || async {
255                Err(RustLogicGraphError::database_connection_error("Test error"))
256            });
257
258        let result = executor.execute_all().await;
259        assert!(result.is_err());
260    }
261}