taskforceai_sdk/
client.rs

1use crate::error::TaskForceAIError;
2use crate::types::{
3    SubmitTaskResponse, TaskForceAIOptions, TaskStatus, TaskStatusValue, TaskSubmissionOptions,
4};
5use std::time::Duration;
6use tokio::time::sleep;
7
8pub const DEFAULT_BASE_URL: &str = "https://taskforceai.chat/api/developer";
9pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
10pub const DEFAULT_POLL_INTERVAL_MS: u64 = 1000;
11pub const DEFAULT_MAX_POLL_ATTEMPTS: u32 = 60;
12
13pub struct TaskForceAI {
14    pub(crate) api_key: String,
15    pub(crate) base_url: String,
16    #[allow(dead_code)]
17    pub(crate) timeout: Duration,
18    pub(crate) mock_mode: bool,
19    pub(crate) client: reqwest::Client,
20}
21
22impl TaskForceAI {
23    pub fn new(options: TaskForceAIOptions) -> Result<Self, TaskForceAIError> {
24        let mock_mode = options.mock_mode.unwrap_or(false);
25        let api_key = options.api_key.unwrap_or_default();
26
27        if !mock_mode && api_key.is_empty() {
28            return Err(TaskForceAIError::MissingApiKey);
29        }
30
31        let base_url = options
32            .base_url
33            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
34            .trim_end_matches('/')
35            .to_string();
36
37        let timeout = Duration::from_secs(options.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
38
39        let client = reqwest::Client::builder().timeout(timeout).build()?;
40
41        Ok(Self {
42            api_key,
43            base_url,
44            timeout,
45            mock_mode,
46            client,
47        })
48    }
49
50    pub(crate) async fn request<T>(
51        &self,
52        method: reqwest::Method,
53        path: &str,
54        body: Option<serde_json::Value>,
55    ) -> Result<T, TaskForceAIError>
56    where
57        T: serde::de::DeserializeOwned,
58    {
59        if self.mock_mode {
60            return self.mock_response(path, &method);
61        }
62
63        let url = format!("{}{}", self.base_url, path);
64        let mut request = self.client.request(method, &url);
65
66        if !self.api_key.is_empty() {
67            request = request.bearer_auth(&self.api_key);
68        }
69
70        request = request.header("X-SDK-Language", "rust");
71
72        if let Some(b) = body {
73            request = request.json(&b);
74        }
75
76        let response = request.send().await?;
77        let status = response.status();
78
79        if !status.is_success() {
80            let message = response
81                .text()
82                .await
83                .unwrap_or_else(|_| "Failed to read error message from response body".to_string());
84            return Err(TaskForceAIError::Api { status, message });
85        }
86
87        Ok(response.json().await?)
88    }
89
90    fn mock_response<T>(&self, path: &str, method: &reqwest::Method) -> Result<T, TaskForceAIError>
91    where
92        T: serde::de::DeserializeOwned,
93    {
94        let val = if method == reqwest::Method::POST && path == "/run" {
95            serde_json::json!({ "taskId": "mock-task-123" })
96        } else if path.starts_with("/status/") {
97            serde_json::json!({
98                "taskId": "mock-task-123",
99                "status": "completed",
100                "result": "This is a mock response. Configure your API key to get real results."
101            })
102        } else {
103            serde_json::json!({ "status": "ok" })
104        };
105
106        Ok(serde_json::from_value(val)?)
107    }
108
109    pub async fn submit_task(
110        &self,
111        prompt: &str,
112        options: Option<TaskSubmissionOptions>,
113    ) -> Result<String, TaskForceAIError> {
114        if prompt.trim().is_empty() {
115            return Err(TaskForceAIError::EmptyPrompt);
116        }
117
118        let mut body = serde_json::json!({ "prompt": prompt });
119        if let Some(opts) = options {
120            if let Some(obj) = body.as_object_mut() {
121                obj.insert("options".to_string(), serde_json::to_value(opts)?);
122            }
123        }
124
125        let response: SubmitTaskResponse = self
126            .request(reqwest::Method::POST, "/run", Some(body))
127            .await?;
128        Ok(response.task_id)
129    }
130
131    pub async fn get_task_status(&self, task_id: &str) -> Result<TaskStatus, TaskForceAIError> {
132        if task_id.trim().is_empty() {
133            return Err(TaskForceAIError::EmptyTaskId);
134        }
135        self.request(reqwest::Method::GET, &format!("/status/{}", task_id), None)
136            .await
137    }
138
139    pub async fn wait_for_completion(
140        &self,
141        task_id: &str,
142        poll_interval: Option<Duration>,
143        max_attempts: Option<u32>,
144    ) -> Result<TaskStatus, TaskForceAIError> {
145        let interval = poll_interval.unwrap_or(Duration::from_millis(DEFAULT_POLL_INTERVAL_MS));
146        let max = max_attempts.unwrap_or(DEFAULT_MAX_POLL_ATTEMPTS);
147
148        for _ in 0..max {
149            let status = self.get_task_status(task_id).await?;
150            match status.status {
151                TaskStatusValue::Completed => return Ok(status),
152                TaskStatusValue::Failed => {
153                    return Err(TaskForceAIError::TaskFailed(
154                        status.error.unwrap_or_else(|| "Unknown error".to_string()),
155                    ))
156                }
157                TaskStatusValue::Processing => (),
158            }
159            sleep(interval).await;
160        }
161
162        Err(TaskForceAIError::Timeout)
163    }
164
165    pub async fn run_task(
166        &self,
167        prompt: &str,
168        options: Option<TaskSubmissionOptions>,
169        poll_interval: Option<Duration>,
170        max_attempts: Option<u32>,
171    ) -> Result<TaskStatus, TaskForceAIError> {
172        let task_id = self.submit_task(prompt, options).await?;
173        self.wait_for_completion(&task_id, poll_interval, max_attempts)
174            .await
175    }
176}