1use crate::config::ExecutionConfig;
6use crate::errors::{ExecutionError, Result};
7use crate::events::EventHandler;
8use crate::executor::Executor;
9use crate::types::{
10 ExecutionRequest, ExecutionResult, ExecutionState, ExecutionStatus, ExecutionSummary,
11};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{RwLock, Semaphore};
15use tokio_util::sync::CancellationToken;
16use uuid::Uuid;
17
18#[derive(Clone)]
26pub struct ExecutionEngine {
27 config: ExecutionConfig,
28 executions: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ExecutionState>>>>>,
29 event_handler: Option<Arc<dyn EventHandler>>,
30 semaphore: Arc<Semaphore>,
31 executor: Arc<Executor>,
32}
33
34impl ExecutionEngine {
35 pub fn new(config: ExecutionConfig) -> Result<Self> {
37 config
39 .validate()
40 .map_err(ExecutionError::InvalidConfig)?;
41
42 let executor = Executor::new(config.clone());
43 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_executions));
44
45 Ok(Self {
46 config,
47 executions: Arc::new(RwLock::new(HashMap::new())),
48 event_handler: None,
49 semaphore,
50 executor: Arc::new(executor),
51 })
52 }
53
54 pub fn with_event_handler(mut self, handler: Arc<dyn EventHandler>) -> Self {
56 self.event_handler = Some(handler.clone());
57
58 let executor = Executor::new(self.config.clone()).with_event_handler(handler);
60 self.executor = Arc::new(executor);
61
62 self
63 }
64
65 pub async fn execute(&self, request: ExecutionRequest) -> Result<Uuid> {
70 let execution_id = request.id;
71
72 let cancel_token = CancellationToken::new();
74 let state = Arc::new(RwLock::new(ExecutionState::new(request.clone())));
75
76 {
78 let mut executions = self.executions.write().await;
79 executions.insert(execution_id, state.clone());
80 }
81
82 let semaphore = self.semaphore.clone();
84 let current_permits = semaphore.available_permits();
85
86 if current_permits == 0 {
87 return Err(ExecutionError::ConcurrencyLimitReached(
89 self.config.max_concurrent_executions,
90 ));
91 }
92
93 let permit = semaphore
95 .clone()
96 .acquire_owned()
97 .await
98 .map_err(|_| ExecutionError::Internal("Semaphore closed".to_string()))?;
99
100 let executor = self.executor.clone();
102
103 tokio::spawn(async move {
104 let result = executor.execute(request, state.clone(), cancel_token).await;
106
107 if let Ok(ref exec_result) = result {
109 let _ = executor.write_logs(execution_id, exec_result).await;
110 }
111
112 drop(permit);
114
115 result
119 });
120
121 Ok(execution_id)
122 }
123
124 pub async fn get_status(&self, execution_id: Uuid) -> Result<ExecutionStatus> {
126 let executions = self.executions.read().await;
127 let state = executions
128 .get(&execution_id)
129 .ok_or(ExecutionError::NotFound(execution_id))?;
130
131 let state_lock = state.read().await;
132 Ok(state_lock.status)
133 }
134
135 pub async fn get_result(&self, execution_id: Uuid) -> Result<ExecutionResult> {
137 let executions = self.executions.read().await;
138 let state = executions
139 .get(&execution_id)
140 .ok_or(ExecutionError::NotFound(execution_id))?;
141
142 let state_lock = state.read().await;
143
144 if !state_lock.status.is_terminal() {
145 return Err(ExecutionError::Internal(format!(
146 "Execution {} is still running (status: {:?})",
147 execution_id, state_lock.status
148 )));
149 }
150
151 Ok(state_lock.to_result())
152 }
153
154 pub async fn wait_for_completion(&self, execution_id: Uuid) -> Result<ExecutionResult> {
156 loop {
158 let status = self.get_status(execution_id).await?;
159
160 if status.is_terminal() {
161 return self.get_result(execution_id).await;
162 }
163
164 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
166 }
167 }
168
169 pub async fn cancel(&self, execution_id: Uuid) -> Result<()> {
171 let executions = self.executions.read().await;
172 let state = executions
173 .get(&execution_id)
174 .ok_or(ExecutionError::NotFound(execution_id))?;
175
176 let state_lock = state.read().await;
177
178 if state_lock.status.is_terminal() {
180 return Err(ExecutionError::Internal(format!(
181 "Cannot cancel execution {} - already in terminal state: {:?}",
182 execution_id, state_lock.status
183 )));
184 }
185
186 Err(ExecutionError::Internal(
190 "Direct cancellation not yet implemented - use timeout instead".to_string(),
191 ))
192 }
193
194 pub async fn list_executions(&self) -> Vec<ExecutionSummary> {
196 let executions = self.executions.read().await;
197 let mut summaries = Vec::new();
198
199 for (id, state) in executions.iter() {
200 let state_lock = state.read().await;
201 let duration = state_lock.completed_at.map(|completed| (completed - state_lock.started_at)
202 .to_std()
203 .unwrap_or(std::time::Duration::from_secs(0)));
204
205 summaries.push(ExecutionSummary {
206 id: *id,
207 status: state_lock.status,
208 started_at: state_lock.started_at,
209 duration,
210 });
211 }
212
213 summaries.sort_by(|a, b| b.started_at.cmp(&a.started_at));
215
216 summaries
217 }
218
219 pub async fn running_count(&self) -> usize {
221 let executions = self.executions.read().await;
222 let mut count = 0;
223
224 for (_, state) in executions.iter() {
225 let state_lock = state.read().await;
226 if state_lock.status == ExecutionStatus::Running
227 || state_lock.status == ExecutionStatus::Pending
228 {
229 count += 1;
230 }
231 }
232
233 count
234 }
235
236 pub async fn total_count(&self) -> usize {
238 let executions = self.executions.read().await;
239 executions.len()
240 }
241
242 pub async fn read_logs(&self, execution_id: Uuid) -> Result<String> {
244 self.executor.read_logs(execution_id).await
245 }
246
247 pub fn config(&self) -> &ExecutionConfig {
249 &self.config
250 }
251
252 pub fn available_permits(&self) -> usize {
254 self.semaphore.available_permits()
255 }
256
257 pub async fn cleanup_old_executions(&self) -> usize {
265 crate::cleanup::cleanup_old_executions(
266 &self.executions,
267 self.config.execution_retention_secs,
268 self.config.max_in_memory_executions,
269 )
270 .await
271 }
272
273 pub async fn remove_execution(&self, execution_id: Uuid) -> Result<()> {
277 let removed = crate::cleanup::remove_execution(&self.executions, execution_id).await;
278
279 if removed {
280 Ok(())
281 } else {
282 Err(ExecutionError::NotFound(execution_id))
283 }
284 }
285
286 pub fn start_cleanup_task(self: Arc<Self>) {
293 if !self.config.enable_auto_cleanup {
294 return;
295 }
296
297 tokio::spawn(async move {
298 let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); loop {
301 interval.tick().await;
302
303 let removed = self.cleanup_old_executions().await;
304
305 if removed > 0 {
306 tracing::info!("Cleanup task removed {} old executions", removed);
307 }
308 }
309 });
310 }
311}
312
313#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::types::Command;
321 use std::collections::HashMap;
322
323 fn create_test_request() -> ExecutionRequest {
324 ExecutionRequest {
325 id: Uuid::new_v4(),
326 command: Command::Shell {
327 command: "echo 'test'".to_string(),
328 shell: "bash".to_string(),
329 },
330 env: HashMap::new(),
331 working_dir: None,
332 timeout_ms: Some(5000),
333 metadata: Default::default(),
334 }
335 }
336
337 #[tokio::test]
338 async fn test_engine_creation() {
339 let config = ExecutionConfig::default();
340 let engine = ExecutionEngine::new(config);
341 assert!(engine.is_ok());
342 }
343
344 #[tokio::test]
345 async fn test_engine_invalid_config() {
346 let mut config = ExecutionConfig::default();
347 config.max_concurrent_executions = 0; let engine = ExecutionEngine::new(config);
350 assert!(engine.is_err());
351 }
352
353 #[tokio::test]
354 async fn test_engine_execute_simple() {
355 let config = ExecutionConfig::default();
356 let engine = ExecutionEngine::new(config).unwrap();
357
358 let request = create_test_request();
359 let execution_id = engine.execute(request).await.unwrap();
360
361 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
363
364 let status = engine.get_status(execution_id).await.unwrap();
365 assert_eq!(status, ExecutionStatus::Completed);
366 }
367
368 #[tokio::test]
369 async fn test_engine_wait_for_completion() {
370 let config = ExecutionConfig::default();
371 let engine = ExecutionEngine::new(config).unwrap();
372
373 let request = create_test_request();
374 let execution_id = engine.execute(request).await.unwrap();
375
376 let result = engine.wait_for_completion(execution_id).await.unwrap();
377 assert_eq!(result.status, ExecutionStatus::Completed);
378 assert_eq!(result.exit_code, 0);
379 }
380
381 #[tokio::test]
382 async fn test_engine_get_result_before_complete() {
383 let config = ExecutionConfig::default();
384 let engine = ExecutionEngine::new(config).unwrap();
385
386 let request = ExecutionRequest {
387 id: Uuid::new_v4(),
388 command: Command::Shell {
389 command: "sleep 1".to_string(),
390 shell: "bash".to_string(),
391 },
392 env: HashMap::new(),
393 working_dir: None,
394 timeout_ms: Some(5000),
395 metadata: Default::default(),
396 };
397
398 let execution_id = engine.execute(request).await.unwrap();
399
400 let result = engine.get_result(execution_id).await;
402 assert!(result.is_err());
403 }
404
405 #[tokio::test]
406 async fn test_engine_list_executions() {
407 let config = ExecutionConfig::default();
408 let engine = ExecutionEngine::new(config).unwrap();
409
410 let request1 = create_test_request();
412 let request2 = create_test_request();
413
414 let _id1 = engine.execute(request1).await.unwrap();
415 let _id2 = engine.execute(request2).await.unwrap();
416
417 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
419
420 let list = engine.list_executions().await;
421 assert_eq!(list.len(), 2);
422 }
423
424 #[tokio::test]
425 async fn test_engine_running_count() {
426 let config = ExecutionConfig::default();
427 let engine = ExecutionEngine::new(config).unwrap();
428
429 assert_eq!(engine.running_count().await, 0);
430
431 let request = ExecutionRequest {
433 id: Uuid::new_v4(),
434 command: Command::Shell {
435 command: "sleep 2".to_string(),
436 shell: "bash".to_string(),
437 },
438 env: HashMap::new(),
439 working_dir: None,
440 timeout_ms: Some(10000),
441 metadata: Default::default(),
442 };
443
444 let _id = engine.execute(request).await.unwrap();
445
446 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
448 let count = engine.running_count().await;
449 assert!(count > 0);
450 }
451
452 #[tokio::test]
453 async fn test_engine_concurrency_limit() {
454 let config = ExecutionConfig {
455 max_concurrent_executions: 2,
456 ..Default::default()
457 };
458 let engine = ExecutionEngine::new(config).unwrap();
459
460 let request1 = ExecutionRequest {
462 id: Uuid::new_v4(),
463 command: Command::Shell {
464 command: "sleep 2".to_string(),
465 shell: "bash".to_string(),
466 },
467 env: HashMap::new(),
468 working_dir: None,
469 timeout_ms: Some(10000),
470 metadata: Default::default(),
471 };
472
473 let request2 = request1.clone();
474 let mut request2 = request2;
475 request2.id = Uuid::new_v4();
476
477 let _id1 = engine.execute(request1).await.unwrap();
478 let _id2 = engine.execute(request2).await.unwrap();
479
480 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
482
483 let request3 = ExecutionRequest {
485 id: Uuid::new_v4(),
486 command: Command::Shell {
487 command: "echo 'test'".to_string(),
488 shell: "bash".to_string(),
489 },
490 env: HashMap::new(),
491 working_dir: None,
492 timeout_ms: Some(5000),
493 metadata: Default::default(),
494 };
495
496 let result = engine.execute(request3).await;
497 assert!(result.is_err());
498 assert!(matches!(
499 result.unwrap_err(),
500 ExecutionError::ConcurrencyLimitReached(_)
501 ));
502 }
503
504 #[tokio::test]
505 async fn test_engine_available_permits() {
506 let config = ExecutionConfig {
507 max_concurrent_executions: 5,
508 ..Default::default()
509 };
510 let engine = ExecutionEngine::new(config).unwrap();
511
512 assert_eq!(engine.available_permits(), 5);
513
514 let request = create_test_request();
516 let _id = engine.execute(request).await.unwrap();
517
518 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
520
521 let permits = engine.available_permits();
523 assert!(permits <= 5);
524 }
525
526 #[tokio::test]
527 async fn test_engine_not_found() {
528 let config = ExecutionConfig::default();
529 let engine = ExecutionEngine::new(config).unwrap();
530
531 let fake_id = Uuid::new_v4();
532 let result = engine.get_status(fake_id).await;
533
534 assert!(result.is_err());
535 assert!(matches!(result.unwrap_err(), ExecutionError::NotFound(_)));
536 }
537}