taskforceai_sdk/
stream.rs

1use crate::client::TaskForceAI;
2use crate::error::TaskForceAIError;
3use crate::types::{TaskStatus, TaskSubmissionOptions};
4use futures_util::{Stream, StreamExt};
5use std::pin::Pin;
6
7pub type TaskStatusStream =
8    Pin<Box<dyn Stream<Item = Result<TaskStatus, TaskForceAIError>> + Send>>;
9
10impl TaskForceAI {
11    pub async fn stream_task_status(
12        &self,
13        task_id: &str,
14    ) -> Result<TaskStatusStream, TaskForceAIError> {
15        if task_id.trim().is_empty() {
16            return Err(TaskForceAIError::EmptyTaskId);
17        }
18
19        if self.mock_mode {
20            let status = self.get_task_status(task_id).await?;
21            let stream = futures_util::stream::iter(vec![Ok(status)]);
22            return Ok(Box::pin(stream));
23        }
24
25        let url = format!("{}/stream/{}", self.base_url, task_id);
26        let mut request = self.client.get(&url);
27
28        if !self.api_key.is_empty() {
29            request = request.bearer_auth(&self.api_key);
30        }
31
32        request = request.header("Accept", "text/event-stream");
33
34        let response = request.send().await?;
35        if !response.status().is_success() {
36            let status = response.status();
37            let message = response.text().await.unwrap_or_default();
38            return Err(TaskForceAIError::Api { status, message });
39        }
40
41        let mut bytes_stream = response.bytes_stream();
42        let mut buffer = String::new();
43
44        let s = futures_util::stream::poll_fn(move |cx| {
45            loop {
46                if let Some(line_end) = buffer.find('\n') {
47                    let line = buffer.drain(..line_end + 1).collect::<String>();
48                    let line = line.trim();
49
50                    if let Some(data) = line.strip_prefix("data:") {
51                        let data = data.trim();
52                        match serde_json::from_str::<TaskStatus>(data) {
53                            Ok(status) => return std::task::Poll::Ready(Some(Ok(status))),
54                            Err(e) => {
55                                return std::task::Poll::Ready(Some(Err(
56                                    TaskForceAIError::Serialization(e),
57                                )))
58                            }
59                        }
60                    }
61                    continue;
62                }
63
64                match bytes_stream.poll_next_unpin(cx) {
65                    std::task::Poll::Ready(Some(Ok(bytes))) => {
66                        buffer.push_str(&String::from_utf8_lossy(&bytes));
67                        continue;
68                    }
69                    std::task::Poll::Ready(Some(Err(e))) => {
70                        return std::task::Poll::Ready(Some(Err(TaskForceAIError::Network(e))))
71                    }
72                    std::task::Poll::Ready(None) => {
73                        if buffer.is_empty() {
74                            return std::task::Poll::Ready(None);
75                        } else {
76                            // Handle potential last line without newline
77                            let line = std::mem::take(&mut buffer);
78                            let line = line.trim();
79                            if let Some(data) = line.strip_prefix("data:") {
80                                let data = data.trim();
81                                match serde_json::from_str::<TaskStatus>(data) {
82                                    Ok(status) => return std::task::Poll::Ready(Some(Ok(status))),
83                                    Err(e) => {
84                                        return std::task::Poll::Ready(Some(Err(
85                                            TaskForceAIError::Serialization(e),
86                                        )))
87                                    }
88                                }
89                            }
90                            return std::task::Poll::Ready(None);
91                        }
92                    }
93                    std::task::Poll::Pending => return std::task::Poll::Pending,
94                }
95            }
96        });
97
98        Ok(Box::pin(s))
99    }
100
101    pub async fn run_task_stream(
102        &self,
103        prompt: &str,
104        options: Option<TaskSubmissionOptions>,
105    ) -> Result<TaskStatusStream, TaskForceAIError> {
106        let task_id = self.submit_task(prompt, options).await?;
107        self.stream_task_status(&task_id).await
108    }
109}