1use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskObject, TaskStatus};
32
33const DEFAULT_TTL_MS: u64 = 300_000;
35
36const DEFAULT_POLL_INTERVAL_MS: u64 = 2_000;
38
39#[derive(Debug)]
41pub struct Task {
42 pub id: String,
44 pub tool_name: String,
46 pub arguments: serde_json::Value,
48 pub status: TaskStatus,
50 pub created_at: Instant,
52 pub created_at_str: String,
54 pub last_updated_at_str: String,
56 pub ttl: u64,
58 pub poll_interval: u64,
60 pub status_message: Option<String>,
62 pub result: Option<CallToolResult>,
64 pub error: Option<String>,
66 pub cancellation_token: CancellationToken,
68 pub completed_at: Option<Instant>,
70 pub completion_notify: Arc<tokio::sync::Notify>,
72}
73
74impl Task {
75 fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
77 let cancelled = Arc::new(AtomicBool::new(false));
78 let now_str = chrono_now_iso8601();
79 Self {
80 id,
81 tool_name,
82 arguments,
83 status: TaskStatus::Working,
84 created_at: Instant::now(),
85 created_at_str: now_str.clone(),
86 last_updated_at_str: now_str,
87 ttl: ttl.unwrap_or(DEFAULT_TTL_MS),
88 poll_interval: DEFAULT_POLL_INTERVAL_MS,
89 status_message: Some("Task started".to_string()),
90 result: None,
91 error: None,
92 cancellation_token: CancellationToken { cancelled },
93 completed_at: None,
94 completion_notify: Arc::new(tokio::sync::Notify::new()),
95 }
96 }
97
98 pub fn to_task_object(&self) -> TaskObject {
100 TaskObject {
101 task_id: self.id.clone(),
102 status: self.status,
103 status_message: self.status_message.clone(),
104 created_at: self.created_at_str.clone(),
105 last_updated_at: self.last_updated_at_str.clone(),
106 ttl: Some(self.ttl),
107 poll_interval: Some(self.poll_interval),
108 meta: None,
109 }
110 }
111
112 pub fn is_expired(&self) -> bool {
114 if let Some(completed_at) = self.completed_at {
115 completed_at.elapsed() > Duration::from_millis(self.ttl)
116 } else {
117 false
118 }
119 }
120
121 pub fn is_cancelled(&self) -> bool {
123 self.cancellation_token.is_cancelled()
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct CancellationToken {
130 cancelled: Arc<AtomicBool>,
131}
132
133impl CancellationToken {
134 pub fn is_cancelled(&self) -> bool {
136 self.cancelled.load(Ordering::Relaxed)
137 }
138
139 pub fn cancel(&self) {
141 self.cancelled.store(true, Ordering::Relaxed);
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct TaskStore {
148 tasks: Arc<RwLock<HashMap<String, Task>>>,
149 next_id: Arc<AtomicU64>,
150}
151
152impl Default for TaskStore {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl TaskStore {
159 pub fn new() -> Self {
161 Self {
162 tasks: Arc::new(RwLock::new(HashMap::new())),
163 next_id: Arc::new(AtomicU64::new(1)),
164 }
165 }
166
167 fn generate_id(&self) -> String {
169 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
170 format!("task-{}", id)
171 }
172
173 pub fn create_task(
177 &self,
178 tool_name: &str,
179 arguments: serde_json::Value,
180 ttl: Option<u64>,
181 ) -> (String, CancellationToken) {
182 let id = self.generate_id();
183 let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
184 let token = task.cancellation_token.clone();
185
186 if let Ok(mut tasks) = self.tasks.write() {
187 tasks.insert(id.clone(), task);
188 }
189
190 (id, token)
191 }
192
193 pub fn get_task(&self, task_id: &str) -> Option<TaskObject> {
195 if let Ok(tasks) = self.tasks.read() {
196 tasks.get(task_id).map(|t| t.to_task_object())
197 } else {
198 None
199 }
200 }
201
202 pub fn get_task_result(
204 &self,
205 task_id: &str,
206 ) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
207 if let Ok(tasks) = self.tasks.read() {
208 tasks
209 .get(task_id)
210 .map(|t| (t.to_task_object(), t.result.clone(), t.error.clone()))
211 } else {
212 None
213 }
214 }
215
216 pub async fn wait_for_completion(
221 &self,
222 task_id: &str,
223 ) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
224 let notify = {
226 let tasks = self.tasks.read().ok()?;
227 let task = tasks.get(task_id)?;
228 if task.status.is_terminal() {
229 return Some((
230 task.to_task_object(),
231 task.result.clone(),
232 task.error.clone(),
233 ));
234 }
235 task.completion_notify.clone()
236 };
237
238 notify.notified().await;
240
241 self.get_task_result(task_id)
243 }
244
245 pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskObject> {
247 if let Ok(tasks) = self.tasks.read() {
248 tasks
249 .values()
250 .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
251 .map(|t| t.to_task_object())
252 .collect()
253 } else {
254 vec![]
255 }
256 }
257
258 pub fn require_input(&self, task_id: &str, message: &str) -> bool {
260 let Ok(mut tasks) = self.tasks.write() else {
261 return false;
262 };
263 let Some(task) = tasks.get_mut(task_id) else {
264 return false;
265 };
266 if task.status.is_terminal() {
267 return false;
268 }
269 task.status = TaskStatus::InputRequired;
270 task.status_message = Some(message.to_string());
271 task.last_updated_at_str = chrono_now_iso8601();
272 true
273 }
274
275 pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
277 let Ok(mut tasks) = self.tasks.write() else {
278 return false;
279 };
280 let Some(task) = tasks.get_mut(task_id) else {
281 return false;
282 };
283 if task.status.is_terminal() {
284 return false;
285 }
286 task.status = TaskStatus::Completed;
287 task.status_message = Some("Task completed".to_string());
288 task.result = Some(result);
289 task.completed_at = Some(Instant::now());
290 task.last_updated_at_str = chrono_now_iso8601();
291 task.completion_notify.notify_waiters();
292 true
293 }
294
295 pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
297 let Ok(mut tasks) = self.tasks.write() else {
298 return false;
299 };
300 let Some(task) = tasks.get_mut(task_id) else {
301 return false;
302 };
303 if task.status.is_terminal() {
304 return false;
305 }
306 task.status = TaskStatus::Failed;
307 task.status_message = Some(format!("Task failed: {}", error));
308 task.error = Some(error.to_string());
309 task.completed_at = Some(Instant::now());
310 task.last_updated_at_str = chrono_now_iso8601();
311 task.completion_notify.notify_waiters();
312 true
313 }
314
315 pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskObject> {
317 let mut tasks = self.tasks.write().ok()?;
318 let task = tasks.get_mut(task_id)?;
319
320 task.cancellation_token.cancel();
322
323 if !task.status.is_terminal() {
325 task.status = TaskStatus::Cancelled;
326 task.status_message = Some(
327 reason
328 .map(|r| format!("Cancelled: {}", r))
329 .unwrap_or_else(|| "Task cancelled".to_string()),
330 );
331 task.completed_at = Some(Instant::now());
332 task.last_updated_at_str = chrono_now_iso8601();
333 task.completion_notify.notify_waiters();
334 }
335 Some(task.to_task_object())
336 }
337
338 pub fn cleanup_expired(&self) -> usize {
340 if let Ok(mut tasks) = self.tasks.write() {
341 let before = tasks.len();
342 tasks.retain(|_, t| !t.is_expired());
343 before - tasks.len()
344 } else {
345 0
346 }
347 }
348
349 #[cfg(test)]
351 pub fn len(&self) -> usize {
352 if let Ok(tasks) = self.tasks.read() {
353 tasks.len()
354 } else {
355 0
356 }
357 }
358
359 #[cfg(test)]
361 pub fn is_empty(&self) -> bool {
362 self.len() == 0
363 }
364}
365
366fn chrono_now_iso8601() -> String {
368 use std::time::SystemTime;
369
370 let now = SystemTime::now();
371 let duration = now
372 .duration_since(SystemTime::UNIX_EPOCH)
373 .unwrap_or_default();
374
375 let secs = duration.as_secs();
376 let millis = duration.subsec_millis();
377
378 let days = secs / 86400;
381 let remaining = secs % 86400;
382 let hours = remaining / 3600;
383 let remaining = remaining % 3600;
384 let minutes = remaining / 60;
385 let seconds = remaining % 60;
386
387 let mut year = 1970i32;
390 let mut remaining_days = days as i32;
391
392 loop {
393 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
394 if remaining_days < days_in_year {
395 break;
396 }
397 remaining_days -= days_in_year;
398 year += 1;
399 }
400
401 let days_in_months: [i32; 12] = if is_leap_year(year) {
402 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
403 } else {
404 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
405 };
406
407 let mut month = 1;
408 for days_in_month in days_in_months.iter() {
409 if remaining_days < *days_in_month {
410 break;
411 }
412 remaining_days -= days_in_month;
413 month += 1;
414 }
415
416 let day = remaining_days + 1;
417
418 format!(
419 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
420 year, month, day, hours, minutes, seconds, millis
421 )
422}
423
424fn is_leap_year(year: i32) -> bool {
425 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_create_task() {
434 let store = TaskStore::new();
435 let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
436
437 assert!(id.starts_with("task-"));
438 assert!(!token.is_cancelled());
439
440 let info = store.get_task(&id).expect("task should exist");
441 assert_eq!(info.task_id, id);
442 assert_eq!(info.status, TaskStatus::Working);
443 }
444
445 #[test]
446 fn test_task_lifecycle() {
447 let store = TaskStore::new();
448 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
449
450 assert!(store.complete_task(&id, CallToolResult::text("Done")));
452
453 let info = store.get_task(&id).unwrap();
454 assert_eq!(info.status, TaskStatus::Completed);
455 }
456
457 #[test]
458 fn test_task_cancellation() {
459 let store = TaskStore::new();
460 let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
461
462 assert!(!token.is_cancelled());
463
464 let task_obj = store.cancel_task(&id, Some("User requested"));
465 assert!(task_obj.is_some());
466 assert_eq!(task_obj.unwrap().status, TaskStatus::Cancelled);
467 assert!(token.is_cancelled());
468
469 let info = store.get_task(&id).unwrap();
470 assert_eq!(info.status, TaskStatus::Cancelled);
471 }
472
473 #[test]
474 fn test_task_failure() {
475 let store = TaskStore::new();
476 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
477
478 assert!(store.fail_task(&id, "Something went wrong"));
479
480 let info = store.get_task(&id).unwrap();
481 assert_eq!(info.status, TaskStatus::Failed);
482 assert!(info.status_message.as_ref().unwrap().contains("failed"));
483 }
484
485 #[test]
486 fn test_list_tasks() {
487 let store = TaskStore::new();
488 store.create_task("tool1", serde_json::json!({}), None);
489 store.create_task("tool2", serde_json::json!({}), None);
490 let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
491
492 store.complete_task(&id3, CallToolResult::text("Done"));
494
495 let all = store.list_tasks(None);
497 assert_eq!(all.len(), 3);
498
499 let working = store.list_tasks(Some(TaskStatus::Working));
501 assert_eq!(working.len(), 2);
502
503 let completed = store.list_tasks(Some(TaskStatus::Completed));
505 assert_eq!(completed.len(), 1);
506 }
507
508 #[test]
509 fn test_terminal_state_immutable() {
510 let store = TaskStore::new();
511 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
512
513 store.complete_task(&id, CallToolResult::text("Done"));
515
516 assert!(!store.fail_task(&id, "Error"));
518
519 let info = store.get_task(&id).unwrap();
521 assert_eq!(info.status, TaskStatus::Completed);
522 }
523
524 #[test]
525 fn test_task_ids_unique() {
526 let store = TaskStore::new();
527 let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
528 let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
529 let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
530
531 assert_ne!(id1, id2);
532 assert_ne!(id2, id3);
533 assert_ne!(id1, id3);
534 }
535
536 #[test]
537 fn test_get_task_result() {
538 let store = TaskStore::new();
539 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
540
541 let result = CallToolResult::text("The result");
543 store.complete_task(&id, result);
544
545 let (task_obj, result, error) = store.get_task_result(&id).unwrap();
546 assert_eq!(task_obj.status, TaskStatus::Completed);
547 assert!(result.is_some());
548 assert!(error.is_none());
549 }
550
551 #[test]
552 fn test_iso8601_timestamp() {
553 let ts = chrono_now_iso8601();
554 assert!(ts.ends_with('Z'));
556 assert!(ts.contains('T'));
557 assert_eq!(ts.len(), 24); }
559
560 #[test]
561 fn test_task_status_display() {
562 assert_eq!(TaskStatus::Working.to_string(), "working");
563 assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
564 assert_eq!(TaskStatus::Completed.to_string(), "completed");
565 assert_eq!(TaskStatus::Failed.to_string(), "failed");
566 assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
567 }
568
569 #[test]
570 fn test_task_status_is_terminal() {
571 assert!(!TaskStatus::Working.is_terminal());
572 assert!(!TaskStatus::InputRequired.is_terminal());
573 assert!(TaskStatus::Completed.is_terminal());
574 assert!(TaskStatus::Failed.is_terminal());
575 assert!(TaskStatus::Cancelled.is_terminal());
576 }
577}