taskforceai_sdk/
client.rs1use crate::error::TaskForceAIError;
2use crate::types::{
3 SubmitTaskResponse, TaskForceAIOptions, TaskStatus, TaskStatusValue, TaskSubmissionOptions,
4};
5use std::time::Duration;
6use tokio::time::sleep;
7
8pub const DEFAULT_BASE_URL: &str = "https://taskforceai.chat/api/developer";
9pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
10pub const DEFAULT_POLL_INTERVAL_MS: u64 = 1000;
11pub const DEFAULT_MAX_POLL_ATTEMPTS: u32 = 60;
12
13pub struct TaskForceAI {
14 pub(crate) api_key: String,
15 pub(crate) base_url: String,
16 #[allow(dead_code)]
17 pub(crate) timeout: Duration,
18 pub(crate) mock_mode: bool,
19 pub(crate) client: reqwest::Client,
20}
21
22impl TaskForceAI {
23 pub fn new(options: TaskForceAIOptions) -> Result<Self, TaskForceAIError> {
24 let mock_mode = options.mock_mode.unwrap_or(false);
25 let api_key = options.api_key.unwrap_or_default();
26
27 if !mock_mode && api_key.is_empty() {
28 return Err(TaskForceAIError::MissingApiKey);
29 }
30
31 let base_url = options
32 .base_url
33 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
34 .trim_end_matches('/')
35 .to_string();
36
37 let timeout = Duration::from_secs(options.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
38
39 let client = reqwest::Client::builder().timeout(timeout).build()?;
40
41 Ok(Self {
42 api_key,
43 base_url,
44 timeout,
45 mock_mode,
46 client,
47 })
48 }
49
50 pub(crate) async fn request<T>(
51 &self,
52 method: reqwest::Method,
53 path: &str,
54 body: Option<serde_json::Value>,
55 ) -> Result<T, TaskForceAIError>
56 where
57 T: serde::de::DeserializeOwned,
58 {
59 if self.mock_mode {
60 return self.mock_response(path, &method);
61 }
62
63 let url = format!("{}{}", self.base_url, path);
64 let mut request = self.client.request(method, &url);
65
66 if !self.api_key.is_empty() {
67 request = request.bearer_auth(&self.api_key);
68 }
69
70 request = request.header("X-SDK-Language", "rust");
71
72 if let Some(b) = body {
73 request = request.json(&b);
74 }
75
76 let response = request.send().await?;
77 let status = response.status();
78
79 if !status.is_success() {
80 let message = response
81 .text()
82 .await
83 .unwrap_or_else(|_| "Failed to read error message from response body".to_string());
84 return Err(TaskForceAIError::Api { status, message });
85 }
86
87 Ok(response.json().await?)
88 }
89
90 fn mock_response<T>(&self, path: &str, method: &reqwest::Method) -> Result<T, TaskForceAIError>
91 where
92 T: serde::de::DeserializeOwned,
93 {
94 let val = if method == reqwest::Method::POST && path == "/run" {
95 serde_json::json!({ "taskId": "mock-task-123" })
96 } else if path.starts_with("/status/") {
97 serde_json::json!({
98 "taskId": "mock-task-123",
99 "status": "completed",
100 "result": "This is a mock response. Configure your API key to get real results."
101 })
102 } else {
103 serde_json::json!({ "status": "ok" })
104 };
105
106 Ok(serde_json::from_value(val)?)
107 }
108
109 pub async fn submit_task(
110 &self,
111 prompt: &str,
112 options: Option<TaskSubmissionOptions>,
113 ) -> Result<String, TaskForceAIError> {
114 if prompt.trim().is_empty() {
115 return Err(TaskForceAIError::EmptyPrompt);
116 }
117
118 let mut body = serde_json::json!({ "prompt": prompt });
119 if let Some(opts) = options {
120 if let Some(obj) = body.as_object_mut() {
121 obj.insert("options".to_string(), serde_json::to_value(opts)?);
122 }
123 }
124
125 let response: SubmitTaskResponse = self
126 .request(reqwest::Method::POST, "/run", Some(body))
127 .await?;
128 Ok(response.task_id)
129 }
130
131 pub async fn get_task_status(&self, task_id: &str) -> Result<TaskStatus, TaskForceAIError> {
132 if task_id.trim().is_empty() {
133 return Err(TaskForceAIError::EmptyTaskId);
134 }
135 self.request(reqwest::Method::GET, &format!("/status/{}", task_id), None)
136 .await
137 }
138
139 pub async fn wait_for_completion(
140 &self,
141 task_id: &str,
142 poll_interval: Option<Duration>,
143 max_attempts: Option<u32>,
144 ) -> Result<TaskStatus, TaskForceAIError> {
145 let interval = poll_interval.unwrap_or(Duration::from_millis(DEFAULT_POLL_INTERVAL_MS));
146 let max = max_attempts.unwrap_or(DEFAULT_MAX_POLL_ATTEMPTS);
147
148 for _ in 0..max {
149 let status = self.get_task_status(task_id).await?;
150 match status.status {
151 TaskStatusValue::Completed => return Ok(status),
152 TaskStatusValue::Failed => {
153 return Err(TaskForceAIError::TaskFailed(
154 status.error.unwrap_or_else(|| "Unknown error".to_string()),
155 ))
156 }
157 TaskStatusValue::Processing => (),
158 }
159 sleep(interval).await;
160 }
161
162 Err(TaskForceAIError::Timeout)
163 }
164
165 pub async fn run_task(
166 &self,
167 prompt: &str,
168 options: Option<TaskSubmissionOptions>,
169 poll_interval: Option<Duration>,
170 max_attempts: Option<u32>,
171 ) -> Result<TaskStatus, TaskForceAIError> {
172 let task_id = self.submit_task(prompt, options).await?;
173 self.wait_for_completion(&task_id, poll_interval, max_attempts)
174 .await
175 }
176}