1use crate::skills::cli_bridge::{CliToolBridge, CliToolResult};
7use anyhow::{Result, anyhow};
8use async_stream::stream;
9use futures::{Stream, StreamExt};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::pin::Pin;
13use std::process::Stdio;
14use std::time::{Duration, Instant};
15use tokio::io::{AsyncReadExt, BufReader};
16use tokio::process::Command as TokioCommand;
17use tokio::time::{interval, timeout};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct StreamingConfig {
22 pub enable_streaming: bool,
24
25 pub buffer_size: usize,
27
28 pub update_interval_ms: u64,
30
31 pub max_execution_time_secs: u64,
33
34 pub enable_partial_json: bool,
36
37 pub enable_progress_reporting: bool,
39
40 pub include_stderr: bool,
42
43 pub line_based_streaming: bool,
45}
46
47impl Default for StreamingConfig {
48 fn default() -> Self {
49 Self {
50 enable_streaming: true,
51 buffer_size: 8192,
52 update_interval_ms: 100,
53 max_execution_time_secs: 300, enable_partial_json: true,
55 enable_progress_reporting: true,
56 include_stderr: true,
57 line_based_streaming: true,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub enum StreamEvent {
65 Progress {
67 percentage: f32,
69 message: String,
71 elapsed_ms: u64,
73 estimated_remaining_ms: Option<u64>,
75 },
76 Output {
78 data: String,
80 output_type: OutputType,
82 is_partial: bool,
84 },
85 JsonObject {
87 value: Value,
89 raw: String,
91 },
92 Completed {
94 exit_code: i32,
96 total_time_ms: u64,
98 result: Option<CliToolResult>,
100 },
101 Error {
103 message: String,
105 fatal: bool,
107 },
108 Started {
110 command: String,
112 args: Vec<String>,
114 start_time: chrono::DateTime<chrono::Utc>,
116 },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub enum OutputType {
122 Stdout,
123 Stderr,
124}
125
126pub struct StreamingSkillExecutor {
128 config: StreamingConfig,
129}
130
131impl StreamingSkillExecutor {
132 pub fn new() -> Self {
134 Self::with_config(StreamingConfig::default())
135 }
136
137 pub fn with_config(config: StreamingConfig) -> Self {
139 Self { config }
140 }
141}
142
143impl Default for StreamingSkillExecutor {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl StreamingSkillExecutor {
150 pub fn execute_cli_tool_streaming(
152 &self,
153 bridge: &CliToolBridge,
154 args: Value,
155 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
156 let config = self.config.clone();
157 let bridge = bridge.clone();
158 let args = args.clone();
159
160 Box::pin(stream! {
161 let start_time = Instant::now();
162 let start_datetime = chrono::Utc::now();
163
164 let mut cmd = TokioCommand::new(&bridge.config.executable_path);
166
167 if let Some(working_dir) = &bridge.config.working_dir {
169 cmd.current_dir(working_dir);
170 }
171
172 if let Some(env) = &bridge.config.environment {
174 for (key, value) in env {
175 cmd.env(key, value);
176 }
177 }
178
179 cmd.stdin(Stdio::null())
181 .stdout(Stdio::piped())
182 .stderr(Stdio::piped())
183 .kill_on_drop(true);
184
185 if let Err(e) = Self::configure_arguments(&mut cmd, &args) {
187 yield Err(anyhow!("Failed to configure arguments: {}", e));
188 return;
189 }
190
191 let _command_str = format!("{:?}", cmd);
193 let args: Vec<String> = cmd.as_std().get_args()
194 .map(|arg| arg.to_string_lossy().to_string())
195 .collect();
196
197 yield Ok(StreamEvent::Started {
199 command: bridge.config.executable_path.display().to_string(),
200 args,
201 start_time: start_datetime,
202 });
203
204 let mut child = match cmd.spawn() {
206 Ok(child) => child,
207 Err(e) => {
208 yield Err(anyhow!("Failed to spawn process: {}", e));
209 return;
210 }
211 };
212
213 let stdout = match child.stdout.take() {
214 Some(stdout) => stdout,
215 None => {
216 yield Err(anyhow!("Failed to capture stdout"));
217 return;
218 }
219 };
220
221 let stderr = if config.include_stderr {
222 child.stderr.take()
223 } else {
224 None
225 };
226
227 let progress_tracker = ProgressTracker::new(config.update_interval_ms);
229 let mut progress_interval = interval(Duration::from_millis(config.update_interval_ms));
230
231 let stdout_stream = Self::stream_output(
233 stdout,
234 OutputType::Stdout,
235 config.clone(),
236 progress_tracker.clone(),
237 );
238
239 let stderr_stream = stderr.map(|s| {
241 Self::stream_output(
242 s,
243 OutputType::Stderr,
244 config.clone(),
245 progress_tracker.clone(),
246 )
247 });
248
249 let mut stdout_stream = Box::pin(stdout_stream);
251 let mut stderr_stream = stderr_stream.map(Box::pin);
252
253 let mut output_buffer = String::new();
255 let mut json_buffer = String::new();
256 let _is_parsing_json = false;
257
258 loop {
259 tokio::select! {
260 _ = progress_interval.tick() => {
262 let progress = progress_tracker.get_progress();
263 yield Ok(StreamEvent::Progress {
264 percentage: progress.percentage,
265 message: progress.message.clone(),
266 elapsed_ms: start_time.elapsed().as_millis() as u64,
267 estimated_remaining_ms: progress.estimated_remaining_ms,
268 });
269 }
270
271 Some(event) = stdout_stream.next() => {
273 match event {
274 Ok(StreamEvent::Output { data, output_type, is_partial }) => {
275 output_buffer.push_str(&data);
276
277 if let Some(json_events) = Self::extract_json_objects(&mut json_buffer, &data).filter(|_| config.enable_partial_json) {
279 for json_event in json_events {
280 yield Ok(json_event);
281 }
282 }
283
284 yield Ok(StreamEvent::Output {
285 data,
286 output_type,
287 is_partial,
288 });
289 }
290 Ok(event) => yield Ok(event),
291 Err(e) => {
292 yield Err(anyhow!("Stdout stream error: {}", e));
293 break;
294 }
295 }
296 }
297
298 Some(event_result) = async {
300 match &mut stderr_stream {
301 Some(stream) => stream.next().await,
302 None => None,
303 }
304 } => {
305 match event_result {
306 Ok(StreamEvent::Output { data, output_type, is_partial }) => {
307 yield Ok(StreamEvent::Output {
308 data,
309 output_type,
310 is_partial,
311 });
312 }
313 Ok(event) => yield Ok(event),
314 Err(e) => {
315 yield Err(anyhow!("Stderr stream error: {}", e));
316 }
317 }
318 }
319
320 else => {
322 break;
323 }
324 }
325
326 if start_time.elapsed().as_secs() > config.max_execution_time_secs {
328 let _ = child.kill().await;
329 yield Err(anyhow!("Execution timed out after {} seconds", config.max_execution_time_secs));
330 return;
331 }
332 }
333
334 let exit_status = match timeout(
336 Duration::from_secs(config.max_execution_time_secs),
337 child.wait()
338 ).await {
339 Ok(Ok(status)) => status,
340 Ok(Err(e)) => {
341 yield Err(anyhow!("Failed to wait for process: {}", e));
342 return;
343 }
344 Err(_) => {
345 yield Err(anyhow!("Process wait timed out"));
346 return;
347 }
348 };
349
350 let exit_code = exit_status.code().unwrap_or(-1);
351 let total_time_ms = start_time.elapsed().as_millis() as u64;
352
353 let result = CliToolResult {
355 exit_code,
356 stdout: output_buffer.clone(),
357 stderr: String::new(), json_output: None,
359 execution_time_ms: total_time_ms,
360 };
361
362 yield Ok(StreamEvent::Completed {
363 exit_code,
364 total_time_ms,
365 result: Some(result),
366 });
367 })
368 }
369
370 fn stream_output<R>(
372 reader: R,
373 output_type: OutputType,
374 config: StreamingConfig,
375 mut progress_tracker: ProgressTracker,
376 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>
377 where
378 R: AsyncReadExt + Send + Unpin + 'static,
379 {
380 Box::pin(stream! {
381 let mut reader = BufReader::new(reader);
382 let mut buffer = vec![0u8; config.buffer_size];
383 let mut line_buffer = String::new();
384
385 loop {
386 match reader.read(&mut buffer).await {
387 Ok(0) => {
388 if !line_buffer.is_empty() && config.line_based_streaming {
390 progress_tracker.update_with_output(&line_buffer);
391 yield Ok(StreamEvent::Output {
392 data: line_buffer.clone(),
393 output_type: output_type.clone(),
394 is_partial: false,
395 });
396 }
397 break;
398 }
399 Ok(n) => {
400 let data = String::from_utf8_lossy(&buffer[..n]);
401
402 if config.line_based_streaming {
403 line_buffer.push_str(&data);
404
405 while let Some(line_end) = line_buffer.find('\n') {
407 let line = line_buffer[..line_end + 1].to_string();
408 line_buffer.drain(..line_end + 1);
409
410 progress_tracker.update_with_output(&line);
411 yield Ok(StreamEvent::Output {
412 data: line,
413 output_type: output_type.clone(),
414 is_partial: false,
415 });
416 }
417 } else {
418 progress_tracker.update_with_output(&data);
420 yield Ok(StreamEvent::Output {
421 data: data.to_string(),
422 output_type: output_type.clone(),
423 is_partial: true,
424 });
425 }
426 }
427 Err(e) => {
428 yield Err(anyhow!("Read error: {}", e));
429 break;
430 }
431 }
432 }
433 })
434 }
435
436 fn extract_json_objects(json_buffer: &mut String, new_data: &str) -> Option<Vec<StreamEvent>> {
438 json_buffer.push_str(new_data);
439
440 let mut events = vec![];
441
442 while let Some(brace_start) = json_buffer.find('{') {
444 let mut brace_count = 0;
445 let mut end_pos = None;
446
447 for (i, ch) in json_buffer[brace_start..].chars().enumerate() {
448 match ch {
449 '{' => brace_count += 1,
450 '}' => {
451 brace_count -= 1;
452 if brace_count == 0 {
453 end_pos = Some(brace_start + i + 1);
454 break;
455 }
456 }
457 _ => {}
458 }
459 }
460
461 if let Some(end) = end_pos {
462 let json_str = &json_buffer[brace_start..end];
463
464 if let Ok(value) = serde_json::from_str::<Value>(json_str) {
465 events.push(StreamEvent::JsonObject {
466 value,
467 raw: json_str.to_string(),
468 });
469
470 json_buffer.drain(..end);
472 } else {
473 json_buffer.drain(..brace_start + 1);
475 }
476 } else {
477 break;
479 }
480 }
481
482 if events.is_empty() {
483 None
484 } else {
485 Some(events)
486 }
487 }
488
489 fn configure_arguments(cmd: &mut TokioCommand, args: &Value) -> Result<()> {
491 if args.is_null() {
492 return Ok(());
493 }
494
495 match args {
496 Value::String(s) => {
497 cmd.arg(s);
498 }
499 Value::Array(arr) => {
500 for arg in arr {
501 if let Some(s) = arg.as_str() {
502 cmd.arg(s);
503 }
504 }
505 }
506 Value::Object(map) => {
507 for (key, value) in map {
508 if let Some(s) = value.as_str() {
509 cmd.arg(format!("--{}", key));
510 cmd.arg(s);
511 } else if value.as_bool().is_some_and(|flag| flag) {
512 cmd.arg(format!("--{}", key));
513 }
514 }
515 }
516 _ => {
517 let json_str = serde_json::to_string(args)?;
518 cmd.arg(json_str);
519 }
520 }
521
522 Ok(())
523 }
524}
525
526#[derive(Debug, Clone)]
528pub struct ProgressTracker {
529 start_time: Instant,
530 #[expect(dead_code)]
531 update_interval_ms: u64,
532 total_output_bytes: usize,
533 last_output_time: Instant,
534 estimated_total_bytes: Option<usize>,
535}
536
537impl ProgressTracker {
538 pub fn new(update_interval_ms: u64) -> Self {
540 Self {
541 start_time: Instant::now(),
542 update_interval_ms,
543 total_output_bytes: 0,
544 last_output_time: Instant::now(),
545 estimated_total_bytes: None,
546 }
547 }
548
549 pub fn update_with_output(&mut self, output: &str) {
551 self.total_output_bytes += output.len();
552 self.last_output_time = Instant::now();
553
554 if self.estimated_total_bytes.is_none() && self.start_time.elapsed().as_secs() > 5 {
556 let elapsed_secs = self.start_time.elapsed().as_secs().max(1);
557 let bytes_per_second = self.total_output_bytes / elapsed_secs as usize;
558
559 self.estimated_total_bytes = Some(bytes_per_second * 180); }
562 }
563
564 pub fn get_progress(&self) -> ProgressInfo {
566 let elapsed_ms = self.start_time.elapsed().as_millis() as u64;
567 let percentage = if let Some(estimated) = self.estimated_total_bytes {
568 if estimated > 0 {
569 ((self.total_output_bytes as f32 / estimated as f32) * 100.0).min(95.0)
570 } else {
571 0.0
572 }
573 } else {
574 let estimated_total_ms = 300_000; ((elapsed_ms as f32 / estimated_total_ms as f32) * 100.0).min(95.0)
577 };
578
579 let estimated_remaining_ms = if let Some(estimated) = self.estimated_total_bytes {
580 if self.total_output_bytes > 0 {
581 let bytes_remaining = estimated.saturating_sub(self.total_output_bytes);
582 let bytes_per_ms = self.total_output_bytes as f32 / elapsed_ms as f32;
583 Some((bytes_remaining as f32 / bytes_per_ms) as u64)
584 } else {
585 None
586 }
587 } else {
588 None
589 };
590
591 let message = if percentage < 10.0 {
592 "Starting execution...".to_string()
593 } else if percentage < 50.0 {
594 "Processing...".to_string()
595 } else if percentage < 90.0 {
596 "Almost complete...".to_string()
597 } else {
598 "Finalizing...".to_string()
599 };
600
601 ProgressInfo {
602 percentage,
603 message,
604 elapsed_ms,
605 estimated_remaining_ms,
606 }
607 }
608}
609
610#[derive(Debug, Clone)]
612pub struct ProgressInfo {
613 pub percentage: f32,
615
616 pub message: String,
618
619 pub elapsed_ms: u64,
621
622 pub estimated_remaining_ms: Option<u64>,
624}
625
626pub trait StreamingExecution {
628 fn execute_streaming(
630 &self,
631 args: Value,
632 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>;
633}
634
635impl StreamingExecution for CliToolBridge {
636 fn execute_streaming(
637 &self,
638 args: Value,
639 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
640 let executor = StreamingSkillExecutor::new();
641 executor.execute_cli_tool_streaming(self, args)
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_streaming_config_default() {
651 let config = StreamingConfig::default();
652 assert!(config.enable_streaming);
653 assert_eq!(config.buffer_size, 8192);
654 assert_eq!(config.max_execution_time_secs, 300);
655 }
656
657 #[test]
658 fn test_progress_tracker() {
659 let mut tracker = ProgressTracker::new(100);
660
661 let progress = tracker.get_progress();
663 assert!(progress.percentage >= 0.0 && progress.percentage <= 100.0);
664
665 tracker.update_with_output("test output");
667 let progress = tracker.get_progress();
668 assert!(!progress.message.is_empty() || progress.percentage > 0.0);
670 }
671
672 #[tokio::test]
673 async fn test_json_extraction() {
674 let mut buffer = String::new();
675 let data = r#"{"key": "value"} some text {"another": "object"}"#;
676
677 let events = StreamingSkillExecutor::extract_json_objects(&mut buffer, data);
678 assert!(events.is_some());
679
680 let events = events.unwrap();
681 assert_eq!(events.len(), 2);
682
683 match &events[0] {
684 StreamEvent::JsonObject { value, .. } => {
685 assert_eq!(value["key"], "value");
686 }
687 _ => panic!("Expected JsonObject event"),
688 }
689 }
690}