1use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22use async_trait::async_trait;
23use serde_json::Value;
24use thulp_core::{ToolCall, ToolResult, Transport};
25
26use crate::{
27 calculate_delay, is_error_retryable, ExecutionConfig, ExecutionContext, ExecutionHooks,
28 NoOpHooks, RetryConfig, RetryableError, Skill, SkillError, SkillExecutor, SkillResult,
29 SkillStep, StepResult, TimeoutAction,
30};
31
32pub struct DefaultSkillExecutor<T, H = NoOpHooks> {
62 transport: Arc<T>,
63 hooks: Arc<H>,
64}
65
66impl<T: Transport> DefaultSkillExecutor<T, NoOpHooks> {
67 pub fn new(transport: T) -> Self {
69 Self {
70 transport: Arc::new(transport),
71 hooks: Arc::new(NoOpHooks),
72 }
73 }
74}
75
76impl<T: Transport, H: ExecutionHooks> DefaultSkillExecutor<T, H> {
77 pub fn with_hooks(transport: T, hooks: H) -> Self {
79 Self {
80 transport: Arc::new(transport),
81 hooks: Arc::new(hooks),
82 }
83 }
84
85 pub fn from_arcs(transport: Arc<T>, hooks: Arc<H>) -> Self {
90 Self { transport, hooks }
91 }
92
93 pub fn transport(&self) -> &T {
95 &self.transport
96 }
97
98 pub fn hooks(&self) -> &H {
100 &self.hooks
101 }
102
103 fn prepare_arguments(
109 &self,
110 args: &Value,
111 context: &ExecutionContext,
112 ) -> Result<Value, SkillError> {
113 self.substitute_value(args, &context.variables())
114 }
115
116 fn substitute_value(
118 &self,
119 value: &Value,
120 variables: &std::collections::HashMap<String, Value>,
121 ) -> Result<Value, SkillError> {
122 match value {
123 Value::String(s) => {
124 let trimmed = s.trim();
126 if trimmed.starts_with("{{") && trimmed.ends_with("}}") {
127 let inner = &trimmed[2..trimmed.len() - 2];
128 if !inner.contains("{{") && !inner.contains("}}") {
130 let var_name = inner.trim();
131 if let Some(var_value) = variables.get(var_name) {
132 return Ok(var_value.clone());
133 }
134 }
135 }
136
137 let mut result = s.clone();
139 for (key, var_value) in variables {
140 let placeholder = format!("{{{{{}}}}}", key);
141 if result.contains(&placeholder) {
142 let replacement = match var_value {
144 Value::String(s) => s.clone(),
145 Value::Null => "null".to_string(),
146 Value::Bool(b) => b.to_string(),
147 Value::Number(n) => n.to_string(),
148 _ => serde_json::to_string(var_value).map_err(|e| {
149 SkillError::InvalidConfig(format!(
150 "Failed to serialize value: {}",
151 e
152 ))
153 })?,
154 };
155 result = result.replace(&placeholder, &replacement);
156 }
157 }
158 Ok(Value::String(result))
159 }
160 Value::Array(arr) => {
161 let substituted: Result<Vec<Value>, SkillError> = arr
162 .iter()
163 .map(|v| self.substitute_value(v, variables))
164 .collect();
165 Ok(Value::Array(substituted?))
166 }
167 Value::Object(obj) => {
168 let mut new_obj = serde_json::Map::new();
169 for (k, v) in obj {
170 new_obj.insert(k.clone(), self.substitute_value(v, variables)?);
171 }
172 Ok(Value::Object(new_obj))
173 }
174 _ => Ok(value.clone()),
176 }
177 }
178
179 async fn execute_step_with_retry_timeout(
181 &self,
182 tool_call: &ToolCall,
183 step: &SkillStep,
184 timeout: Duration,
185 retry_config: &RetryConfig,
186 context: &ExecutionContext,
187 ) -> Result<(ToolResult, usize), SkillError> {
188 let mut attempts = 0;
189
190 loop {
191 attempts += 1;
192
193 let result = tokio::time::timeout(timeout, self.transport.call(tool_call)).await;
195
196 match result {
197 Ok(Ok(tool_result)) => {
198 return Ok((tool_result, attempts - 1)); }
201 Ok(Err(e)) => {
202 let error_msg = e.to_string();
204
205 if attempts > retry_config.max_retries
206 || !is_error_retryable(&error_msg, retry_config)
207 {
208 return Err(SkillError::RetryExhausted {
209 step: step.name.clone(),
210 attempts,
211 message: error_msg,
212 });
213 }
214
215 self.hooks.on_retry(step, attempts, &error_msg, context);
217
218 let delay = calculate_delay(retry_config, attempts);
219 tracing::warn!(
220 step = %step.name,
221 attempt = attempts,
222 max_retries = retry_config.max_retries,
223 delay_ms = delay.as_millis() as u64,
224 error = %e,
225 "Retrying step after error"
226 );
227 tokio::time::sleep(delay).await;
228 }
229 Err(_elapsed) => {
230 self.hooks
232 .on_timeout(step, timeout.as_millis() as u64, context);
233
234 if attempts > retry_config.max_retries
236 || !retry_config
237 .retryable_errors
238 .contains(&RetryableError::Timeout)
239 {
240 return Err(SkillError::StepTimeout {
241 step: step.name.clone(),
242 duration: timeout,
243 });
244 }
245
246 self.hooks.on_retry(step, attempts, "timeout", context);
248
249 let delay = calculate_delay(retry_config, attempts);
250 tracing::warn!(
251 step = %step.name,
252 attempt = attempts,
253 max_retries = retry_config.max_retries,
254 delay_ms = delay.as_millis() as u64,
255 "Retrying step after timeout"
256 );
257 tokio::time::sleep(delay).await;
258 }
259 }
260 }
261 }
262}
263
264#[async_trait]
265impl<T: Transport, H: ExecutionHooks> SkillExecutor for DefaultSkillExecutor<T, H> {
266 async fn execute(
267 &self,
268 skill: &Skill,
269 context: &mut ExecutionContext,
270 ) -> Result<SkillResult, SkillError> {
271 self.hooks.before_skill(skill, context);
273
274 let config = context.config().clone();
275 let skill_timeout = config.timeout.skill_timeout;
276
277 let result = tokio::time::timeout(skill_timeout, async {
279 self.execute_steps(skill, context, &config).await
280 })
281 .await;
282
283 let skill_result = match result {
284 Ok(inner_result) => inner_result,
285 Err(_elapsed) => {
286 match config.timeout.timeout_action {
288 TimeoutAction::Fail => {
289 let error = SkillError::SkillTimeout {
290 duration: skill_timeout,
291 };
292 self.hooks.on_error(&error, context);
293 Err(error)
294 }
295 TimeoutAction::Skip | TimeoutAction::Partial => {
296 Ok(SkillResult {
298 success: false,
299 step_results: vec![],
300 output: None,
301 error: Some(format!("Skill timed out after {:?}", skill_timeout)),
302 })
303 }
304 }
305 }
306 };
307
308 match &skill_result {
310 Ok(result) => {
311 self.hooks.after_skill(skill, result, context);
312 }
313 Err(e) => {
314 self.hooks.on_error(e, context);
315 let failure_result = SkillResult {
317 success: false,
318 step_results: vec![],
319 output: None,
320 error: Some(e.to_string()),
321 };
322 self.hooks.after_skill(skill, &failure_result, context);
323 }
324 }
325
326 skill_result
327 }
328
329 async fn execute_step(
330 &self,
331 step: &SkillStep,
332 context: &mut ExecutionContext,
333 ) -> Result<StepResult, SkillError> {
334 let config = context.config().clone();
335
336 let step_timeout = step
338 .timeout_secs
339 .map(Duration::from_secs)
340 .unwrap_or(config.timeout.step_timeout);
341
342 let max_retries = step.max_retries.unwrap_or(config.retry.max_retries);
344 let step_retry_config = RetryConfig {
345 max_retries,
346 ..config.retry.clone()
347 };
348
349 let prepared_args = self.prepare_arguments(&step.arguments, context)?;
351
352 let tool_call = ToolCall {
353 tool: step.tool.clone(),
354 arguments: prepared_args,
355 };
356
357 self.hooks.before_step(step, 0, context);
359
360 let start = Instant::now();
361
362 let result = self
364 .execute_step_with_retry_timeout(
365 &tool_call,
366 step,
367 step_timeout,
368 &step_retry_config,
369 context,
370 )
371 .await;
372
373 let duration_ms = start.elapsed().as_millis() as u64;
374
375 let step_result = match result {
376 Ok((tool_result, retry_attempts)) => {
377 if let Some(data) = &tool_result.data {
379 context.set_output(step.name.clone(), data.clone());
380 }
381
382 let is_success = tool_result.is_success();
383 StepResult {
384 step_name: step.name.clone(),
385 success: is_success,
386 output: tool_result.data,
387 error: if is_success { None } else { tool_result.error },
388 duration_ms,
389 retry_attempts,
390 }
391 }
392 Err(e) => {
393 self.hooks.on_error(&e, context);
394
395 StepResult {
396 step_name: step.name.clone(),
397 success: false,
398 output: None,
399 error: Some(e.to_string()),
400 duration_ms,
401 retry_attempts: 0,
402 }
403 }
404 };
405
406 self.hooks.after_step(step, 0, &step_result, context);
408
409 if step_result.success {
410 Ok(step_result)
411 } else {
412 Err(SkillError::Execution(
414 step_result.error.clone().unwrap_or_default(),
415 ))
416 }
417 }
418}
419
420impl<T: Transport, H: ExecutionHooks> DefaultSkillExecutor<T, H> {
421 async fn execute_steps(
423 &self,
424 skill: &Skill,
425 context: &mut ExecutionContext,
426 config: &ExecutionConfig,
427 ) -> Result<SkillResult, SkillError> {
428 let mut step_results: Vec<(String, ToolResult)> = Vec::new();
429
430 for (index, step) in skill.steps.iter().enumerate() {
431 let step_timeout = step
433 .timeout_secs
434 .map(Duration::from_secs)
435 .unwrap_or(config.timeout.step_timeout);
436
437 let max_retries = step.max_retries.unwrap_or(config.retry.max_retries);
439 let step_retry_config = RetryConfig {
440 max_retries,
441 ..config.retry.clone()
442 };
443
444 let prepared_args = self.prepare_arguments(&step.arguments, context)?;
446
447 let tool_call = ToolCall {
448 tool: step.tool.clone(),
449 arguments: prepared_args,
450 };
451
452 self.hooks.before_step(step, index, context);
454
455 let start = Instant::now();
456
457 let step_result = self
459 .execute_step_with_retry_timeout(
460 &tool_call,
461 step,
462 step_timeout,
463 &step_retry_config,
464 context,
465 )
466 .await;
467
468 let duration_ms = start.elapsed().as_millis() as u64;
469
470 match step_result {
471 Ok((tool_result, retry_attempts)) => {
472 let sr = StepResult {
474 step_name: step.name.clone(),
475 success: true,
476 output: tool_result.data.clone(),
477 error: None,
478 duration_ms,
479 retry_attempts,
480 };
481 self.hooks.after_step(step, index, &sr, context);
482
483 step_results.push((step.name.clone(), tool_result.clone()));
484
485 context.set_output(
487 step.name.clone(),
488 tool_result.data.clone().unwrap_or(Value::Null),
489 );
490
491 if step_results.len() == skill.steps.len() {
493 return Ok(SkillResult {
494 success: true,
495 step_results,
496 output: tool_result.data,
497 error: None,
498 });
499 }
500 }
501 Err(e) => {
502 let sr = StepResult::failure(&step.name, e.to_string(), duration_ms);
504 self.hooks.after_step(step, index, &sr, context);
505 self.hooks.on_error(&e, context);
506
507 if step.continue_on_error {
508 step_results.push((step.name.clone(), ToolResult::failure(e.to_string())));
510 } else {
511 match &config.timeout.timeout_action {
513 TimeoutAction::Skip => {
514 step_results
515 .push((step.name.clone(), ToolResult::failure(e.to_string())));
516 }
518 TimeoutAction::Partial => {
519 return Ok(SkillResult {
520 success: false,
521 step_results,
522 output: None,
523 error: Some(e.to_string()),
524 });
525 }
526 TimeoutAction::Fail => {
527 return Err(e);
528 }
529 }
530 }
531 }
532 }
533 }
534
535 Ok(SkillResult {
536 success: true,
537 step_results,
538 output: None,
539 error: None,
540 })
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use std::collections::HashMap;
548 use std::sync::atomic::{AtomicUsize, Ordering};
549
550 struct MockTransport {
552 responses: HashMap<String, ToolResult>,
553 }
554
555 impl MockTransport {
556 fn new() -> Self {
557 Self {
558 responses: HashMap::new(),
559 }
560 }
561
562 fn with_response(mut self, tool_name: &str, result: ToolResult) -> Self {
563 self.responses.insert(tool_name.to_string(), result);
564 self
565 }
566 }
567
568 #[async_trait]
569 impl Transport for MockTransport {
570 async fn connect(&mut self) -> thulp_core::Result<()> {
571 Ok(())
572 }
573
574 async fn disconnect(&mut self) -> thulp_core::Result<()> {
575 Ok(())
576 }
577
578 fn is_connected(&self) -> bool {
579 true
580 }
581
582 async fn list_tools(&self) -> thulp_core::Result<Vec<thulp_core::ToolDefinition>> {
583 Ok(vec![])
584 }
585
586 async fn call(&self, call: &ToolCall) -> thulp_core::Result<ToolResult> {
587 if let Some(result) = self.responses.get(&call.tool) {
588 Ok(result.clone())
589 } else {
590 Err(thulp_core::Error::ToolNotFound(call.tool.clone()))
591 }
592 }
593 }
594
595 #[tokio::test]
596 async fn test_default_executor_basic() {
597 let transport = MockTransport::new().with_response(
598 "tool1",
599 ToolResult::success(serde_json::json!({"result": 1})),
600 );
601
602 let executor = DefaultSkillExecutor::new(transport);
603
604 let skill = Skill::new("test", "Test skill").with_step(SkillStep {
605 name: "step1".to_string(),
606 tool: "tool1".to_string(),
607 arguments: serde_json::json!({}),
608 continue_on_error: false,
609 timeout_secs: None,
610 max_retries: None,
611 });
612
613 let mut context = ExecutionContext::new();
614 let result = executor.execute(&skill, &mut context).await.unwrap();
615
616 assert!(result.success);
617 assert_eq!(result.step_results.len(), 1);
618 }
619
620 #[tokio::test]
621 async fn test_default_executor_with_hooks() {
622 struct CountingHooks {
623 before_skill_count: Arc<AtomicUsize>,
624 after_skill_count: Arc<AtomicUsize>,
625 before_step_count: Arc<AtomicUsize>,
626 after_step_count: Arc<AtomicUsize>,
627 }
628
629 impl ExecutionHooks for CountingHooks {
630 fn before_skill(&self, _skill: &Skill, _context: &ExecutionContext) {
631 self.before_skill_count.fetch_add(1, Ordering::SeqCst);
632 }
633
634 fn after_skill(
635 &self,
636 _skill: &Skill,
637 _result: &SkillResult,
638 _context: &ExecutionContext,
639 ) {
640 self.after_skill_count.fetch_add(1, Ordering::SeqCst);
641 }
642
643 fn before_step(
644 &self,
645 _step: &SkillStep,
646 _step_index: usize,
647 _context: &ExecutionContext,
648 ) {
649 self.before_step_count.fetch_add(1, Ordering::SeqCst);
650 }
651
652 fn after_step(
653 &self,
654 _step: &SkillStep,
655 _step_index: usize,
656 _result: &StepResult,
657 _context: &ExecutionContext,
658 ) {
659 self.after_step_count.fetch_add(1, Ordering::SeqCst);
660 }
661 }
662
663 let before_skill = Arc::new(AtomicUsize::new(0));
664 let after_skill = Arc::new(AtomicUsize::new(0));
665 let before_step = Arc::new(AtomicUsize::new(0));
666 let after_step = Arc::new(AtomicUsize::new(0));
667
668 let hooks = CountingHooks {
669 before_skill_count: before_skill.clone(),
670 after_skill_count: after_skill.clone(),
671 before_step_count: before_step.clone(),
672 after_step_count: after_step.clone(),
673 };
674
675 let transport = MockTransport::new()
676 .with_response("tool1", ToolResult::success(serde_json::json!({})))
677 .with_response("tool2", ToolResult::success(serde_json::json!({})));
678
679 let executor = DefaultSkillExecutor::with_hooks(transport, hooks);
680
681 let skill = Skill::new("test", "Test skill")
682 .with_step(SkillStep {
683 name: "step1".to_string(),
684 tool: "tool1".to_string(),
685 arguments: serde_json::json!({}),
686 continue_on_error: false,
687 timeout_secs: None,
688 max_retries: None,
689 })
690 .with_step(SkillStep {
691 name: "step2".to_string(),
692 tool: "tool2".to_string(),
693 arguments: serde_json::json!({}),
694 continue_on_error: false,
695 timeout_secs: None,
696 max_retries: None,
697 });
698
699 let mut context = ExecutionContext::new();
700 let result = executor.execute(&skill, &mut context).await.unwrap();
701
702 assert!(result.success);
703 assert_eq!(before_skill.load(Ordering::SeqCst), 1);
704 assert_eq!(after_skill.load(Ordering::SeqCst), 1);
705 assert_eq!(before_step.load(Ordering::SeqCst), 2);
706 assert_eq!(after_step.load(Ordering::SeqCst), 2);
707 }
708
709 #[tokio::test]
710 async fn test_default_executor_context_propagation() {
711 let transport = MockTransport::new()
712 .with_response(
713 "step1_tool",
714 ToolResult::success(serde_json::json!({"value": 42})),
715 )
716 .with_response(
717 "step2_tool",
718 ToolResult::success(serde_json::json!({"doubled": 84})),
719 );
720
721 let executor = DefaultSkillExecutor::new(transport);
722
723 let skill = Skill::new("test", "Test skill")
724 .with_step(SkillStep {
725 name: "step1".to_string(),
726 tool: "step1_tool".to_string(),
727 arguments: serde_json::json!({}),
728 continue_on_error: false,
729 timeout_secs: None,
730 max_retries: None,
731 })
732 .with_step(SkillStep {
733 name: "step2".to_string(),
734 tool: "step2_tool".to_string(),
735 arguments: serde_json::json!({"input": "{{step1}}"}),
736 continue_on_error: false,
737 timeout_secs: None,
738 max_retries: None,
739 });
740
741 let mut context = ExecutionContext::new();
742 let result = executor.execute(&skill, &mut context).await.unwrap();
743
744 assert!(result.success);
745
746 assert!(context.get_output("step1").is_some());
748 assert!(context.get_output("step2").is_some());
749 }
750
751 #[tokio::test]
752 async fn test_default_executor_continue_on_error() {
753 let transport = MockTransport::new()
754 .with_response(
756 "step2_tool",
757 ToolResult::success(serde_json::json!({"ok": true})),
758 );
759
760 let executor = DefaultSkillExecutor::new(transport);
761
762 let skill = Skill::new("test", "Test skill")
763 .with_step(SkillStep {
764 name: "step1".to_string(),
765 tool: "step1_tool".to_string(),
766 arguments: serde_json::json!({}),
767 continue_on_error: true, timeout_secs: None,
769 max_retries: Some(0),
770 })
771 .with_step(SkillStep {
772 name: "step2".to_string(),
773 tool: "step2_tool".to_string(),
774 arguments: serde_json::json!({}),
775 continue_on_error: false,
776 timeout_secs: None,
777 max_retries: None,
778 });
779
780 let config = ExecutionConfig::new().with_retry(crate::RetryConfig::no_retries());
781 let mut context = ExecutionContext::new().with_config(config);
782
783 let result = executor.execute(&skill, &mut context).await.unwrap();
784
785 assert!(result.success);
786 assert_eq!(result.step_results.len(), 2);
787
788 let (_, step1_result) = &result.step_results[0];
790 assert!(!step1_result.is_success());
791
792 let (_, step2_result) = &result.step_results[1];
794 assert!(step2_result.is_success());
795 }
796
797 #[tokio::test]
798 async fn test_default_executor_from_arcs() {
799 let transport = Arc::new(
800 MockTransport::new().with_response("tool", ToolResult::success(serde_json::json!({}))),
801 );
802 let hooks = Arc::new(NoOpHooks);
803
804 let executor = DefaultSkillExecutor::from_arcs(transport.clone(), hooks.clone());
805
806 let skill = Skill::new("test", "Test").with_step(SkillStep {
807 name: "s".to_string(),
808 tool: "tool".to_string(),
809 arguments: serde_json::json!({}),
810 continue_on_error: false,
811 timeout_secs: None,
812 max_retries: None,
813 });
814
815 let mut context = ExecutionContext::new();
816 let result = executor.execute(&skill, &mut context).await.unwrap();
817
818 assert!(result.success);
819 }
820}