rust_logic_graph/multi_db/
parallel.rs1use std::collections::HashMap;
2use std::time::Instant;
3use std::pin::Pin;
4use std::future::Future;
5use serde_json::Value;
6use tokio::task::JoinSet;
7use tracing::{info, debug};
8
9use crate::error::{RustLogicGraphError, ErrorContext};
10
11type BoxedFuture = Pin<Box<dyn Future<Output = Result<Value, RustLogicGraphError>> + Send>>;
12
13#[derive(Debug, Clone)]
15pub struct QueryResult {
16 pub database: String,
17 pub query: String,
18 pub result: Value,
19 pub duration_ms: u128,
20 pub row_count: usize,
21}
22
23pub struct ParallelDBExecutor {
53 queries: Vec<(String, String, BoxedFuture)>,
54 max_concurrent: usize,
55}
56
57impl ParallelDBExecutor {
58 pub fn new() -> Self {
60 Self {
61 queries: Vec::new(),
62 max_concurrent: 10, }
64 }
65
66 pub fn with_max_concurrent(mut self, max: usize) -> Self {
68 self.max_concurrent = max;
69 self
70 }
71
72 pub fn add_query<F, Fut>(
79 &mut self,
80 database: impl Into<String>,
81 query_id: impl Into<String>,
82 query_fn: F,
83 ) -> &mut Self
84 where
85 F: FnOnce() -> Fut + Send + 'static,
86 Fut: Future<Output = Result<Value, RustLogicGraphError>> + Send + 'static,
87 {
88 let db = database.into();
89 let qid = query_id.into();
90
91 let boxed: BoxedFuture = Box::pin(query_fn());
93 self.queries.push((db, qid, boxed));
94 self
95 }
96
97 pub async fn execute_all(&mut self) -> Result<HashMap<String, QueryResult>, RustLogicGraphError> {
104 let total_start = Instant::now();
105 let query_count = self.queries.len();
106
107 info!("🚀 Parallel DB Executor: Starting {} queries across databases", query_count);
108
109 if query_count == 0 {
110 return Ok(HashMap::new());
111 }
112
113 let mut join_set = JoinSet::new();
114 let mut results = HashMap::new();
115
116 let queries = std::mem::take(&mut self.queries);
118
119 for (database, query_id, query_future) in queries {
121 let db_clone = database.clone();
122 let qid_clone = query_id.clone();
123
124 join_set.spawn(async move {
125 let start = Instant::now();
126 debug!("⏱️ Executing query '{}' on database '{}'", qid_clone, db_clone);
127
128 match query_future.await {
129 Ok(result) => {
130 let duration_ms = start.elapsed().as_millis();
131 let row_count = if result.is_array() {
132 result.as_array().map(|arr| arr.len()).unwrap_or(0)
133 } else if result.is_object() {
134 1
135 } else {
136 0
137 };
138
139 debug!("✅ Query '{}' completed in {}ms ({} rows)", qid_clone, duration_ms, row_count);
140
141 Ok((query_id, QueryResult {
142 database: db_clone,
143 query: qid_clone,
144 result,
145 duration_ms,
146 row_count,
147 }))
148 }
149 Err(e) => {
150 Err(e.with_context(
151 ErrorContext::new()
152 .with_service(&db_clone)
153 .add_metadata("query_id", &qid_clone)
154 ))
155 }
156 }
157 });
158 }
159
160 while let Some(task_result) = join_set.join_next().await {
162 match task_result {
163 Ok(Ok((query_id, query_result))) => {
164 results.insert(query_id, query_result);
165 }
166 Ok(Err(e)) => {
167 join_set.abort_all();
169 return Err(e);
170 }
171 Err(join_err) => {
172 join_set.abort_all();
173 return Err(RustLogicGraphError::node_execution_error(
174 "parallel_executor",
175 format!("Task join error: {}", join_err)
176 ));
177 }
178 }
179 }
180
181 let total_duration_ms = total_start.elapsed().as_millis();
182 let total_rows: usize = results.values().map(|r| r.row_count).sum();
183
184 info!("✅ Parallel DB Executor: Completed {} queries in {}ms ({} total rows)",
185 query_count, total_duration_ms, total_rows);
186
187 let mut db_stats: HashMap<String, (usize, u128)> = HashMap::new();
189 for result in results.values() {
190 let entry = db_stats.entry(result.database.clone()).or_insert((0, 0));
191 entry.0 += 1; entry.1 += result.duration_ms; }
194
195 for (db, (count, duration)) in db_stats {
196 info!(" 📊 {}: {} queries, {}ms total", db, count, duration);
197 }
198
199 Ok(results)
200 }
201}
202
203impl Default for ParallelDBExecutor {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use serde_json::json;
213
214 #[tokio::test]
215 async fn test_parallel_execution() {
216 let mut executor = ParallelDBExecutor::new();
217
218 executor
219 .add_query("db1", "query1", || async {
220 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
221 Ok(json!({"result": 1}))
222 })
223 .add_query("db2", "query2", || async {
224 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
225 Ok(json!({"result": 2}))
226 });
227
228 let results = executor.execute_all().await.unwrap();
229 assert_eq!(results.len(), 2);
230 assert!(results.contains_key("query1"));
231 assert!(results.contains_key("query2"));
232 }
233
234 #[tokio::test]
235 async fn test_error_propagation() {
236 let mut executor = ParallelDBExecutor::new();
237
238 executor
239 .add_query("db1", "query1", || async {
240 Ok(json!({"result": 1}))
241 })
242 .add_query("db2", "failing_query", || async {
243 Err(RustLogicGraphError::database_connection_error("Test error"))
244 });
245
246 let result = executor.execute_all().await;
247 assert!(result.is_err());
248 }
249}