rust_logic_graph/multi_db/
parallel.rs1use serde_json::Value;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::Instant;
6use tokio::task::JoinSet;
7use tracing::{debug, info};
8
9use crate::error::{ErrorContext, RustLogicGraphError};
10
11type BoxedFuture = Pin<Box<dyn Future<Output = Result<Value, RustLogicGraphError>> + Send>>;
12
13#[derive(Debug, Clone)]
15pub struct QueryResult {
16 pub database: String,
17 pub query: String,
18 pub result: Value,
19 pub duration_ms: u128,
20 pub row_count: usize,
21}
22
23pub struct ParallelDBExecutor {
53 queries: Vec<(String, String, BoxedFuture)>,
54 max_concurrent: usize,
55}
56
57impl ParallelDBExecutor {
58 pub fn new() -> Self {
60 Self {
61 queries: Vec::new(),
62 max_concurrent: 10, }
64 }
65
66 pub fn with_max_concurrent(mut self, max: usize) -> Self {
68 self.max_concurrent = max;
69 self
70 }
71
72 pub fn add_query<F, Fut>(
79 &mut self,
80 database: impl Into<String>,
81 query_id: impl Into<String>,
82 query_fn: F,
83 ) -> &mut Self
84 where
85 F: FnOnce() -> Fut + Send + 'static,
86 Fut: Future<Output = Result<Value, RustLogicGraphError>> + Send + 'static,
87 {
88 let db = database.into();
89 let qid = query_id.into();
90
91 let boxed: BoxedFuture = Box::pin(query_fn());
93 self.queries.push((db, qid, boxed));
94 self
95 }
96
97 pub async fn execute_all(
104 &mut self,
105 ) -> Result<HashMap<String, QueryResult>, RustLogicGraphError> {
106 let total_start = Instant::now();
107 let query_count = self.queries.len();
108
109 info!(
110 "🚀 Parallel DB Executor: Starting {} queries across databases",
111 query_count
112 );
113
114 if query_count == 0 {
115 return Ok(HashMap::new());
116 }
117
118 let mut join_set = JoinSet::new();
119 let mut results = HashMap::new();
120
121 let queries = std::mem::take(&mut self.queries);
123
124 for (database, query_id, query_future) in queries {
126 let db_clone = database.clone();
127 let qid_clone = query_id.clone();
128
129 join_set.spawn(async move {
130 let start = Instant::now();
131 debug!(
132 "⏱️ Executing query '{}' on database '{}'",
133 qid_clone, db_clone
134 );
135
136 match query_future.await {
137 Ok(result) => {
138 let duration_ms = start.elapsed().as_millis();
139 let row_count = if result.is_array() {
140 result.as_array().map(|arr| arr.len()).unwrap_or(0)
141 } else if result.is_object() {
142 1
143 } else {
144 0
145 };
146
147 debug!(
148 "✅ Query '{}' completed in {}ms ({} rows)",
149 qid_clone, duration_ms, row_count
150 );
151
152 Ok((
153 query_id,
154 QueryResult {
155 database: db_clone,
156 query: qid_clone,
157 result,
158 duration_ms,
159 row_count,
160 },
161 ))
162 }
163 Err(e) => Err(e.with_context(
164 ErrorContext::new()
165 .with_service(&db_clone)
166 .add_metadata("query_id", &qid_clone),
167 )),
168 }
169 });
170 }
171
172 while let Some(task_result) = join_set.join_next().await {
174 match task_result {
175 Ok(Ok((query_id, query_result))) => {
176 results.insert(query_id, query_result);
177 }
178 Ok(Err(e)) => {
179 join_set.abort_all();
181 return Err(e);
182 }
183 Err(join_err) => {
184 join_set.abort_all();
185 return Err(RustLogicGraphError::node_execution_error(
186 "parallel_executor",
187 format!("Task join error: {}", join_err),
188 ));
189 }
190 }
191 }
192
193 let total_duration_ms = total_start.elapsed().as_millis();
194 let total_rows: usize = results.values().map(|r| r.row_count).sum();
195
196 info!(
197 "✅ Parallel DB Executor: Completed {} queries in {}ms ({} total rows)",
198 query_count, total_duration_ms, total_rows
199 );
200
201 let mut db_stats: HashMap<String, (usize, u128)> = HashMap::new();
203 for result in results.values() {
204 let entry = db_stats.entry(result.database.clone()).or_insert((0, 0));
205 entry.0 += 1; entry.1 += result.duration_ms; }
208
209 for (db, (count, duration)) in db_stats {
210 info!(" 📊 {}: {} queries, {}ms total", db, count, duration);
211 }
212
213 Ok(results)
214 }
215}
216
217impl Default for ParallelDBExecutor {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use serde_json::json;
227
228 #[tokio::test]
229 async fn test_parallel_execution() {
230 let mut executor = ParallelDBExecutor::new();
231
232 executor
233 .add_query("db1", "query1", || async {
234 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
235 Ok(json!({"result": 1}))
236 })
237 .add_query("db2", "query2", || async {
238 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
239 Ok(json!({"result": 2}))
240 });
241
242 let results = executor.execute_all().await.unwrap();
243 assert_eq!(results.len(), 2);
244 assert!(results.contains_key("query1"));
245 assert!(results.contains_key("query2"));
246 }
247
248 #[tokio::test]
249 async fn test_error_propagation() {
250 let mut executor = ParallelDBExecutor::new();
251
252 executor
253 .add_query("db1", "query1", || async { Ok(json!({"result": 1})) })
254 .add_query("db2", "failing_query", || async {
255 Err(RustLogicGraphError::database_connection_error("Test error"))
256 });
257
258 let result = executor.execute_all().await;
259 assert!(result.is_err());
260 }
261}