1use crate::client::TaskForceAI;
2use crate::error::TaskForceAIError;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Thread {
10 pub id: i64,
11 pub title: String,
12 #[serde(with = "chrono::serde::ts_seconds")]
13 pub created_at: DateTime<Utc>,
14 #[serde(with = "chrono::serde::ts_seconds")]
15 pub updated_at: DateTime<Utc>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ThreadMessage {
21 pub id: i64,
22 pub thread_id: i64,
23 pub role: String, pub content: String,
25 #[serde(with = "chrono::serde::ts_seconds")]
26 pub created_at: DateTime<Utc>,
27}
28
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct CreateThreadOptions {
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub title: Option<String>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub messages: Option<Vec<ThreadMessage>>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub metadata: Option<HashMap<String, serde_json::Value>>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ThreadListResponse {
43 pub threads: Vec<Thread>,
44 pub total: i64,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ThreadMessagesResponse {
50 pub messages: Vec<ThreadMessage>,
51 pub total: i64,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ThreadRunOptions {
57 pub prompt: String,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub model_id: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub options: Option<HashMap<String, serde_json::Value>>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ThreadRunResponse {
67 pub task_id: String,
68 pub thread_id: i64,
69 pub message_id: i64,
70}
71
72impl TaskForceAI {
73 pub async fn create_thread(
75 &self,
76 options: Option<CreateThreadOptions>,
77 ) -> Result<Thread, TaskForceAIError> {
78 let body = options
79 .map(|o| serde_json::to_value(o))
80 .transpose()?
81 .unwrap_or_else(|| serde_json::json!({}));
82
83 self.request(reqwest::Method::POST, "/threads", Some(body))
84 .await
85 }
86
87 pub async fn list_threads(
89 &self,
90 limit: i32,
91 offset: i32,
92 ) -> Result<ThreadListResponse, TaskForceAIError> {
93 let path = format!("/threads?limit={}&offset={}", limit, offset);
94 self.request(reqwest::Method::GET, &path, None).await
95 }
96
97 pub async fn get_thread(&self, thread_id: i64) -> Result<Thread, TaskForceAIError> {
99 let path = format!("/threads/{}", thread_id);
100 self.request(reqwest::Method::GET, &path, None).await
101 }
102
103 pub async fn delete_thread(&self, thread_id: i64) -> Result<(), TaskForceAIError> {
105 let path = format!("/threads/{}", thread_id);
106 let _: serde_json::Value = self.request(reqwest::Method::DELETE, &path, None).await?;
107 Ok(())
108 }
109
110 pub async fn get_thread_messages(
112 &self,
113 thread_id: i64,
114 limit: i32,
115 offset: i32,
116 ) -> Result<ThreadMessagesResponse, TaskForceAIError> {
117 let path = format!(
118 "/threads/{}/messages?limit={}&offset={}",
119 thread_id, limit, offset
120 );
121 self.request(reqwest::Method::GET, &path, None).await
122 }
123
124 pub async fn run_in_thread(
126 &self,
127 thread_id: i64,
128 options: ThreadRunOptions,
129 ) -> Result<ThreadRunResponse, TaskForceAIError> {
130 if options.prompt.trim().is_empty() {
131 return Err(TaskForceAIError::EmptyPrompt);
132 }
133
134 let path = format!("/threads/{}/runs", thread_id);
135 let body = serde_json::to_value(options)?;
136
137 self.request(reqwest::Method::POST, &path, Some(body)).await
138 }
139}