taskforceai_sdk/
stream.rs1use crate::client::TaskForceAI;
2use crate::error::TaskForceAIError;
3use crate::types::{TaskStatus, TaskSubmissionOptions};
4use futures_util::{Stream, StreamExt};
5use std::pin::Pin;
6
7pub type TaskStatusStream =
8 Pin<Box<dyn Stream<Item = Result<TaskStatus, TaskForceAIError>> + Send>>;
9
10impl TaskForceAI {
11 pub async fn stream_task_status(
12 &self,
13 task_id: &str,
14 ) -> Result<TaskStatusStream, TaskForceAIError> {
15 if task_id.trim().is_empty() {
16 return Err(TaskForceAIError::EmptyTaskId);
17 }
18
19 if self.mock_mode {
20 let status = self.get_task_status(task_id).await?;
21 let stream = futures_util::stream::iter(vec![Ok(status)]);
22 return Ok(Box::pin(stream));
23 }
24
25 let url = format!("{}/stream/{}", self.base_url, task_id);
26 let mut request = self.client.get(&url);
27
28 if !self.api_key.is_empty() {
29 request = request.bearer_auth(&self.api_key);
30 }
31
32 request = request.header("Accept", "text/event-stream");
33
34 let response = request.send().await?;
35 if !response.status().is_success() {
36 let status = response.status();
37 let message = response.text().await.unwrap_or_default();
38 return Err(TaskForceAIError::Api { status, message });
39 }
40
41 let mut bytes_stream = response.bytes_stream();
42 let mut buffer = String::new();
43
44 let s = futures_util::stream::poll_fn(move |cx| {
45 loop {
46 if let Some(line_end) = buffer.find('\n') {
47 let line = buffer.drain(..line_end + 1).collect::<String>();
48 let line = line.trim();
49
50 if let Some(data) = line.strip_prefix("data:") {
51 let data = data.trim();
52 match serde_json::from_str::<TaskStatus>(data) {
53 Ok(status) => return std::task::Poll::Ready(Some(Ok(status))),
54 Err(e) => {
55 return std::task::Poll::Ready(Some(Err(
56 TaskForceAIError::Serialization(e),
57 )))
58 }
59 }
60 }
61 continue;
62 }
63
64 match bytes_stream.poll_next_unpin(cx) {
65 std::task::Poll::Ready(Some(Ok(bytes))) => {
66 buffer.push_str(&String::from_utf8_lossy(&bytes));
67 continue;
68 }
69 std::task::Poll::Ready(Some(Err(e))) => {
70 return std::task::Poll::Ready(Some(Err(TaskForceAIError::Network(e))))
71 }
72 std::task::Poll::Ready(None) => {
73 if buffer.is_empty() {
74 return std::task::Poll::Ready(None);
75 } else {
76 let line = std::mem::take(&mut buffer);
78 let line = line.trim();
79 if let Some(data) = line.strip_prefix("data:") {
80 let data = data.trim();
81 match serde_json::from_str::<TaskStatus>(data) {
82 Ok(status) => return std::task::Poll::Ready(Some(Ok(status))),
83 Err(e) => {
84 return std::task::Poll::Ready(Some(Err(
85 TaskForceAIError::Serialization(e),
86 )))
87 }
88 }
89 }
90 return std::task::Poll::Ready(None);
91 }
92 }
93 std::task::Poll::Pending => return std::task::Poll::Pending,
94 }
95 }
96 });
97
98 Ok(Box::pin(s))
99 }
100
101 pub async fn run_task_stream(
102 &self,
103 prompt: &str,
104 options: Option<TaskSubmissionOptions>,
105 ) -> Result<TaskStatusStream, TaskForceAIError> {
106 let task_id = self.submit_task(prompt, options).await?;
107 self.stream_task_status(&task_id).await
108 }
109}