Skip to main content

taskforceai_sdk/
lib.rs

1pub mod client;
2pub mod error;
3pub mod files;
4pub mod stream;
5pub mod threads;
6pub mod types;
7
8pub use client::TaskForceAI;
9pub use error::TaskForceAIError;
10pub use files::{File, FileListResponse, FileUploadOptions};
11pub use threads::{
12    CreateThreadOptions, Thread, ThreadListResponse, ThreadMessage, ThreadMessagesResponse,
13    ThreadRunOptions, ThreadRunResponse,
14};
15pub use types::{TaskForceAIOptions, TaskStatus, TaskStatusValue, TaskSubmissionOptions};
16
17#[cfg(test)]
18mod tests {
19    use super::*;
20    use crate::client::{DEFAULT_BASE_URL, DEFAULT_TIMEOUT_SECS};
21    use futures_util::StreamExt;
22    use mockito::Server;
23    use std::time::Duration;
24
25    #[tokio::test]
26    async fn test_new_client_defaults() {
27        let client = TaskForceAI::new(TaskForceAIOptions {
28            api_key: Some("test-key".to_string()),
29            ..Default::default()
30        })
31        .unwrap();
32        assert_eq!(client.base_url, DEFAULT_BASE_URL);
33        assert_eq!(client.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
34    }
35
36    #[tokio::test]
37    async fn test_new_client_error() {
38        let res = TaskForceAI::new(TaskForceAIOptions {
39            api_key: None,
40            mock_mode: Some(false),
41            ..Default::default()
42        });
43        assert!(matches!(res, Err(TaskForceAIError::MissingApiKey)));
44    }
45
46    #[tokio::test]
47    async fn test_mock_mode() {
48        let opts = TaskForceAIOptions {
49            mock_mode: Some(true),
50            ..Default::default()
51        };
52        let client = TaskForceAI::new(opts).unwrap();
53
54        // Test run_task
55        let status = client.run_task("hello", None, None, None).await.unwrap();
56        assert_eq!(status.task_id, "mock-task-123");
57        assert_eq!(status.status, TaskStatusValue::Completed);
58
59        // Test stream_task_status
60        let mut stream = client.stream_task_status("mock-id").await.unwrap();
61        let ev = stream.next().await.unwrap().unwrap();
62        assert_eq!(ev.status, TaskStatusValue::Completed);
63    }
64
65    #[tokio::test]
66    async fn test_submit_task_errors() {
67        let client = TaskForceAI::new(TaskForceAIOptions {
68            api_key: Some("key".to_string()),
69            ..Default::default()
70        })
71        .unwrap();
72
73        let res = client.submit_task("  ", None).await;
74        assert!(matches!(res, Err(TaskForceAIError::EmptyPrompt)));
75    }
76
77    #[tokio::test]
78    async fn test_api_error() {
79        let mut server = Server::new_async().await;
80        let _mock = server
81            .mock("POST", "/run")
82            .with_status(401)
83            .with_body("Unauthorized")
84            .create_async()
85            .await;
86
87        let client = TaskForceAI::new(TaskForceAIOptions {
88            base_url: Some(server.url()),
89            api_key: Some("wrong".to_string()),
90            ..Default::default()
91        })
92        .unwrap();
93
94        let res = client.submit_task("hi", None).await;
95        match res {
96            Err(TaskForceAIError::Api { status, .. }) => assert_eq!(status, 401),
97            _ => panic!("Expected API error"),
98        }
99    }
100
101    #[tokio::test]
102    async fn test_wait_for_completion_timeout() {
103        let mut server = Server::new_async().await;
104        let _mock = server
105            .mock("GET", "/status/task-1")
106            .with_status(200)
107            .with_body(r#"{"taskId": "task-1", "status": "processing"}"#)
108            .expect(2)
109            .create_async()
110            .await;
111
112        let client = TaskForceAI::new(TaskForceAIOptions {
113            base_url: Some(server.url()),
114            api_key: Some("key".to_string()),
115            ..Default::default()
116        })
117        .unwrap();
118
119        let res = client
120            .wait_for_completion("task-1", Some(Duration::from_millis(1)), Some(2))
121            .await;
122        assert!(matches!(res, Err(TaskForceAIError::Timeout)));
123    }
124
125    #[tokio::test]
126    async fn test_wait_for_completion_failed() {
127        let mut server = Server::new_async().await;
128        let _mock = server
129            .mock("GET", "/status/task-1")
130            .with_status(200)
131            .with_body(r#"{"taskId": "task-1", "status": "failed", "error": "oops"}"#)
132            .create_async()
133            .await;
134
135        let client = TaskForceAI::new(TaskForceAIOptions {
136            base_url: Some(server.url()),
137            api_key: Some("key".to_string()),
138            ..Default::default()
139        })
140        .unwrap();
141
142        let res = client.wait_for_completion("task-1", None, None).await;
143        match res {
144            Err(TaskForceAIError::TaskFailed(msg)) => assert_eq!(msg, "oops"),
145            _ => panic!("Expected TaskFailed error"),
146        }
147    }
148
149    #[tokio::test]
150    async fn test_stream_task_status() {
151        let mut server = Server::new_async().await;
152        let _mock = server.mock("GET", "/stream/task-1")
153            .with_status(200)
154            .with_header("content-type", "text/event-stream")
155            .with_body("data: {\"taskId\": \"task-1\", \"status\": \"processing\"}\ndata: {\"taskId\": \"task-1\", \"status\": \"completed\", \"result\": \"stream-done\"}\n")
156            .create_async().await;
157
158        let client = TaskForceAI::new(TaskForceAIOptions {
159            base_url: Some(server.url()),
160            api_key: Some("key".to_string()),
161            ..Default::default()
162        })
163        .unwrap();
164
165        let mut stream = client.stream_task_status("task-1").await.unwrap();
166
167        let ev1 = stream.next().await.unwrap().unwrap();
168        assert_eq!(ev1.status, TaskStatusValue::Processing);
169
170        let ev2 = stream.next().await.unwrap().unwrap();
171        assert_eq!(ev2.status, TaskStatusValue::Completed);
172        assert_eq!(ev2.result.unwrap(), "stream-done");
173
174        assert!(stream.next().await.is_none());
175    }
176
177    #[tokio::test]
178    async fn test_run_task_stream() {
179        let mut server = Server::new_async().await;
180        let _run_mock = server
181            .mock("POST", "/run")
182            .with_status(200)
183            .with_body(r#"{"taskId": "task-2"}"#)
184            .create_async()
185            .await;
186        let _stream_mock = server
187            .mock("GET", "/stream/task-2")
188            .with_status(200)
189            .with_body("data: {\"taskId\": \"task-2\", \"status\": \"completed\"}\n")
190            .create_async()
191            .await;
192
193        let client = TaskForceAI::new(TaskForceAIOptions {
194            base_url: Some(server.url()),
195            api_key: Some("key".to_string()),
196            ..Default::default()
197        })
198        .unwrap();
199
200        let mut stream = client.run_task_stream("hi", None).await.unwrap();
201        let ev = stream.next().await.unwrap().unwrap();
202        assert_eq!(ev.status, TaskStatusValue::Completed);
203    }
204
205    #[tokio::test]
206    async fn test_stream_error_status() {
207        let mut server = Server::new_async().await;
208        let _mock = server
209            .mock("GET", "/stream/task-1")
210            .with_status(403)
211            .create_async()
212            .await;
213
214        let client = TaskForceAI::new(TaskForceAIOptions {
215            base_url: Some(server.url()),
216            api_key: Some("key".to_string()),
217            ..Default::default()
218        })
219        .unwrap();
220
221        let res = client.stream_task_status("task-1").await;
222        assert!(matches!(res, Err(TaskForceAIError::Api { .. })));
223    }
224
225    #[tokio::test]
226    async fn test_stream_malformed_json() {
227        let mut server = Server::new_async().await;
228        let _mock = server
229            .mock("GET", "/stream/task-1")
230            .with_status(200)
231            .with_body("data: {malformed}\n")
232            .create_async()
233            .await;
234
235        let client = TaskForceAI::new(TaskForceAIOptions {
236            base_url: Some(server.url()),
237            api_key: Some("key".to_string()),
238            ..Default::default()
239        })
240        .unwrap();
241
242        let mut stream = client.stream_task_status("task-1").await.unwrap();
243        let res = stream.next().await.unwrap();
244        assert!(matches!(res, Err(TaskForceAIError::Serialization(_))));
245    }
246
247    #[tokio::test]
248    async fn test_get_task_status_errors() {
249        let client = TaskForceAI::new(TaskForceAIOptions {
250            api_key: Some("key".to_string()),
251            ..Default::default()
252        })
253        .unwrap();
254
255        let res = client.get_task_status("  ").await;
256        assert!(matches!(res, Err(TaskForceAIError::EmptyTaskId)));
257    }
258
259    #[tokio::test]
260    async fn test_stream_task_status_empty_id() {
261        let client = TaskForceAI::new(TaskForceAIOptions {
262            api_key: Some("key".to_string()),
263            ..Default::default()
264        })
265        .unwrap();
266
267        let res = client.stream_task_status("").await;
268        assert!(matches!(res, Err(TaskForceAIError::EmptyTaskId)));
269    }
270
271    #[tokio::test]
272    async fn test_stream_bytes_error() {
273        let mut server = Server::new_async().await;
274        let _mock = server
275            .mock("GET", "/stream/task-1")
276            .with_status(200)
277            .with_body("data: {\"taskId\": \"task-1\", \"status\": \"completed\"}")
278            .create_async()
279            .await;
280
281        let client = TaskForceAI::new(TaskForceAIOptions {
282            base_url: Some(server.url()),
283            api_key: Some("key".to_string()),
284            ..Default::default()
285        })
286        .unwrap();
287
288        let mut stream = client.stream_task_status("task-1").await.unwrap();
289        let ev = stream.next().await.unwrap().unwrap();
290        assert_eq!(ev.status, TaskStatusValue::Completed);
291    }
292
293    #[tokio::test]
294    async fn test_submit_task_with_options() {
295        let mut server = Server::new_async().await;
296        let _mock = server
297            .mock("POST", "/run")
298            .with_status(200)
299            .with_body(r#"{"taskId": "task-opts"}"#)
300            .create_async()
301            .await;
302
303        let client = TaskForceAI::new(TaskForceAIOptions {
304            base_url: Some(server.url()),
305            api_key: Some("key".to_string()),
306            ..Default::default()
307        })
308        .unwrap();
309
310        let opts = TaskSubmissionOptions {
311            model_id: Some("gpt-4".to_string()),
312            silent: Some(true),
313            ..Default::default()
314        };
315        let task_id = client.submit_task("hello", Some(opts)).await.unwrap();
316        assert_eq!(task_id, "task-opts");
317    }
318
319    #[tokio::test]
320    async fn test_stream_empty_end_unique() {
321        let mut server = Server::new_async().await;
322        let _mock = server
323            .mock("GET", "/stream/task-1")
324            .with_status(200)
325            .with_body("data: {\"taskId\": \"task-1\", \"status\": \"processing\"}\n\n")
326            .create_async()
327            .await;
328
329        let client = TaskForceAI::new(TaskForceAIOptions {
330            base_url: Some(server.url()),
331            api_key: Some("key".to_string()),
332            ..Default::default()
333        })
334        .unwrap();
335
336        let mut stream = client.stream_task_status("task-1").await.unwrap();
337        let ev = stream.next().await.unwrap().unwrap();
338        assert_eq!(ev.status, TaskStatusValue::Processing);
339        assert!(stream.next().await.is_none());
340    }
341
342    #[tokio::test]
343    async fn test_stream_none_with_empty_buffer() {
344        let mut server = Server::new_async().await;
345        let _mock = server
346            .mock("GET", "/stream/task-1")
347            .with_status(200)
348            .with_body("data: {\"taskId\": \"task-1\", \"status\": \"processing\"}\n") // No trailing newline here to leave buffer empty after drain
349            .create_async()
350            .await;
351
352        let client = TaskForceAI::new(TaskForceAIOptions {
353            base_url: Some(server.url()),
354            api_key: Some("key".to_string()),
355            ..Default::default()
356        })
357        .unwrap();
358
359        let mut stream = client.stream_task_status("task-1").await.unwrap();
360        let _ = stream.next().await;
361        assert!(stream.next().await.is_none());
362    }
363
364    #[tokio::test]
365    async fn test_stream_non_data_line() {
366        let mut server = Server::new_async().await;
367        let _mock = server
368            .mock("GET", "/stream/task-1")
369            .with_status(200)
370            .with_body(": comment\nnot-data: something\ndata: {\"taskId\": \"task-1\", \"status\": \"completed\"}\n")
371            .create_async()
372            .await;
373
374        let client = TaskForceAI::new(TaskForceAIOptions {
375            base_url: Some(server.url()),
376            api_key: Some("key".to_string()),
377            ..Default::default()
378        })
379        .unwrap();
380
381        let mut stream = client.stream_task_status("task-1").await.unwrap();
382        let ev = stream.next().await.unwrap().unwrap();
383        assert_eq!(ev.status, TaskStatusValue::Completed);
384        assert!(stream.next().await.is_none());
385    }
386
387    #[tokio::test]
388    async fn test_wait_for_completion_unknown_fail() {
389        let mut server = Server::new_async().await;
390        let _mock = server
391            .mock("GET", "/status/task-1")
392            .with_status(200)
393            .with_body(r#"{"taskId": "task-1", "status": "failed"}"#)
394            .create_async()
395            .await;
396
397        let client = TaskForceAI::new(TaskForceAIOptions {
398            base_url: Some(server.url()),
399            api_key: Some("key".to_string()),
400            ..Default::default()
401        })
402        .unwrap();
403
404        let res = client.wait_for_completion("task-1", None, None).await;
405        match res {
406            Err(TaskForceAIError::TaskFailed(msg)) => assert_eq!(msg, "Unknown error"),
407            _ => panic!("Expected TaskFailed error"),
408        }
409    }
410
411    #[tokio::test]
412    async fn test_api_error_no_body() {
413        let mut server = Server::new_async().await;
414        let _mock = server
415            .mock("POST", "/run")
416            .with_status(500)
417            .create_async()
418            .await;
419
420        let client = TaskForceAI::new(TaskForceAIOptions {
421            base_url: Some(server.url()),
422            api_key: Some("key".to_string()),
423            ..Default::default()
424        })
425        .unwrap();
426
427        let res = client.submit_task("hi", None).await;
428        assert!(matches!(res, Err(TaskForceAIError::Api { status, .. }) if status == 500));
429    }
430
431    #[tokio::test]
432    async fn test_stream_last_line_malformed_no_newline() {
433        let mut server = Server::new_async().await;
434        let _mock = server
435            .mock("GET", "/stream/task-1")
436            .with_status(200)
437            .with_body("data: {malformed}")
438            .create_async()
439            .await;
440
441        let client = TaskForceAI::new(TaskForceAIOptions {
442            base_url: Some(server.url()),
443            api_key: Some("key".to_string()),
444            ..Default::default()
445        })
446        .unwrap();
447
448        let mut stream = client.stream_task_status("task-1").await.unwrap();
449        let res = stream.next().await.unwrap();
450        assert!(matches!(res, Err(TaskForceAIError::Serialization(_))));
451    }
452
453    #[tokio::test]
454    async fn test_stream_empty_body() {
455        let mut server = Server::new_async().await;
456        let _mock = server
457            .mock("GET", "/stream/task-1")
458            .with_status(200)
459            .with_body("")
460            .create_async()
461            .await;
462
463        let client = TaskForceAI::new(TaskForceAIOptions {
464            base_url: Some(server.url()),
465            api_key: Some("key".to_string()),
466            ..Default::default()
467        })
468        .unwrap();
469
470        let mut stream = client.stream_task_status("task-1").await.unwrap();
471        assert!(stream.next().await.is_none());
472    }
473
474    #[tokio::test]
475    async fn test_serialization_error_enum_coverage() {
476        let err = TaskForceAIError::Serialization(
477            serde_json::from_str::<serde_json::Value>("{ ").unwrap_err(),
478        );
479        assert!(err.to_string().contains("Serialization error"));
480    }
481
482    #[tokio::test]
483    async fn test_error_variants() {
484        let e = TaskForceAIError::EmptyTaskId;
485        assert_eq!(e.to_string(), "Task ID must be a non-empty string");
486
487        let e = TaskForceAIError::Stream("oops".to_string());
488        assert_eq!(e.to_string(), "Stream error: oops");
489
490        let e = TaskForceAIError::Other("oops".to_string());
491        assert_eq!(e.to_string(), "Other error: oops");
492    }
493
494    #[tokio::test]
495    async fn test_run_task_error() {
496        let mut server = Server::new_async().await;
497        let _mock = server
498            .mock("POST", "/run")
499            .with_status(500)
500            .create_async()
501            .await;
502
503        let client = TaskForceAI::new(TaskForceAIOptions {
504            base_url: Some(server.url()),
505            api_key: Some("key".to_string()),
506            ..Default::default()
507        })
508        .unwrap();
509
510        let res = client.run_task("hi", None, None, None).await;
511        assert!(matches!(res, Err(TaskForceAIError::Api { status, .. }) if status == 500));
512    }
513
514    #[tokio::test]
515    async fn test_wait_for_completion_error_path() {
516        let mut server = Server::new_async().await;
517        let _mock = server
518            .mock("GET", "/status/task-1")
519            .with_status(500)
520            .create_async()
521            .await;
522
523        let client = TaskForceAI::new(TaskForceAIOptions {
524            base_url: Some(server.url()),
525            api_key: Some("key".to_string()),
526            ..Default::default()
527        })
528        .unwrap();
529
530        let res = client.wait_for_completion("task-1", None, None).await;
531        assert!(matches!(res, Err(TaskForceAIError::Api { status, .. }) if status == 500));
532    }
533
534    #[tokio::test]
535    async fn test_stream_last_line_no_newline_garbage() {
536        let mut server = Server::new_async().await;
537        let _mock = server
538            .mock("GET", "/stream/task-1")
539            .with_status(200)
540            .with_body("garbage-no-newline")
541            .create_async()
542            .await;
543
544        let client = TaskForceAI::new(TaskForceAIOptions {
545            base_url: Some(server.url()),
546            api_key: Some("key".to_string()),
547            ..Default::default()
548        })
549        .unwrap();
550
551        let mut stream = client.stream_task_status("task-1").await.unwrap();
552        assert!(stream.next().await.is_none());
553    }
554
555    #[tokio::test]
556    async fn test_stream_network_error_mid_stream() {
557        let mut server = Server::new_async().await;
558        let _mock = server
559            .mock("GET", "/stream/task-1")
560            .with_status(200)
561            .with_body("data: {\"taskId\": \"task-1\", \"status\": \"processing\"}\n")
562            .create_async()
563            .await;
564
565        let client = TaskForceAI::new(TaskForceAIOptions {
566            base_url: Some(server.url()),
567            api_key: Some("key".to_string()),
568            ..Default::default()
569        })
570        .unwrap();
571
572        let mut stream = client.stream_task_status("task-1").await.unwrap();
573        let _ = stream.next().await;
574
575        drop(server);
576        // Mid-stream network error is hard to simulate with mockito reliably.
577        // We've already verified the variant and other error paths.
578    }
579}