1use hashbrown::{HashMap, HashSet};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::errors::{A2aError, A2aResult};
11use super::rpc::{ListTasksParams, ListTasksResult, TaskPushNotificationConfig};
12use super::types::{Artifact, Message, Task, TaskState, TaskStatus};
13
14#[derive(Debug, Clone)]
16pub struct TaskManager {
17 state: Arc<RwLock<TaskManagerState>>,
19 max_tasks: usize,
21}
22
23#[derive(Debug, Default)]
24struct TaskManagerState {
25 tasks: HashMap<String, Task>,
26 contexts: HashMap<String, Vec<String>>,
27 webhook_configs: HashMap<String, TaskPushNotificationConfig>,
28}
29
30impl Default for TaskManager {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl TaskManager {
37 pub fn new() -> Self {
39 Self {
40 state: Arc::new(RwLock::new(TaskManagerState::default())),
41 max_tasks: 1000,
42 }
43 }
44
45 pub fn with_capacity(max_tasks: usize) -> Self {
47 Self {
48 state: Arc::new(RwLock::new(TaskManagerState {
49 tasks: HashMap::with_capacity(max_tasks.min(100)),
50 contexts: HashMap::new(),
51 webhook_configs: HashMap::new(),
52 })),
53 max_tasks,
54 }
55 }
56
57 pub async fn create_task(&self, context_id: Option<String>) -> Task {
59 let mut task = Task::new();
60 if let Some(ref ctx_id) = context_id {
61 task = task.with_context_id(ctx_id);
62 }
63
64 let task_id = task.id.clone();
65 let mut state = self.state.write().await;
66
67 if state.tasks.len() >= self.max_tasks {
68 self.evict_oldest_tasks(&mut state);
69 }
70
71 state.tasks.insert(task_id.clone(), task.clone());
72 if let Some(ctx_id) = context_id {
73 state.contexts.entry(ctx_id).or_default().push(task_id);
74 }
75
76 task
77 }
78
79 fn evict_oldest_tasks(&self, state: &mut TaskManagerState) {
81 let mut completed_tasks: Vec<_> = state
82 .tasks
83 .iter()
84 .filter(|(_, task)| task.is_terminal())
85 .map(|(id, task)| (id.clone(), task.status.timestamp))
86 .collect();
87
88 completed_tasks.sort_by(|a, b| a.1.cmp(&b.1));
89
90 let evict_count = (self.max_tasks / 10).max(1);
91 let evicted_ids: HashSet<_> = completed_tasks
92 .into_iter()
93 .take(evict_count)
94 .map(|(id, _)| id)
95 .collect();
96
97 if evicted_ids.is_empty() {
98 return;
99 }
100
101 for id in &evicted_ids {
102 state.tasks.remove(id);
103 state.webhook_configs.remove(id);
104 }
105
106 state.contexts.retain(|_, task_ids| {
107 task_ids.retain(|task_id| !evicted_ids.contains(task_id));
108 !task_ids.is_empty()
109 });
110 }
111
112 pub async fn get_task(&self, task_id: &str) -> Option<Task> {
114 let state = self.state.read().await;
115 state.tasks.get(task_id).cloned()
116 }
117
118 pub async fn get_task_or_error(&self, task_id: &str) -> A2aResult<Task> {
120 self.get_task(task_id)
121 .await
122 .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))
123 }
124
125 pub async fn update_status(
127 &self,
128 task_id: &str,
129 state: TaskState,
130 message: Option<Message>,
131 ) -> A2aResult<Task> {
132 let mut manager_state = self.state.write().await;
133 let task = manager_state
134 .tasks
135 .get_mut(task_id)
136 .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
137
138 task.status = match message {
139 Some(msg) => TaskStatus::with_message(state, msg),
140 None => TaskStatus::new(state),
141 };
142
143 Ok(task.clone())
144 }
145
146 pub async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> A2aResult<Task> {
148 let mut state = self.state.write().await;
149 let task = state
150 .tasks
151 .get_mut(task_id)
152 .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
153
154 task.artifacts.push(artifact);
155 Ok(task.clone())
156 }
157
158 pub async fn add_message(&self, task_id: &str, message: Message) -> A2aResult<Task> {
160 let mut state = self.state.write().await;
161 let task = state
162 .tasks
163 .get_mut(task_id)
164 .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
165
166 task.history.push(message);
167 Ok(task.clone())
168 }
169
170 pub async fn cancel_task(&self, task_id: &str) -> A2aResult<Task> {
172 let mut state = self.state.write().await;
173 let task = state
174 .tasks
175 .get_mut(task_id)
176 .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
177
178 if !task.is_cancelable() {
179 return Err(A2aError::TaskNotCancelable(format!(
180 "Task {} is in state {:?} and cannot be canceled",
181 task_id, task.status.state
182 )));
183 }
184
185 task.status = TaskStatus::new(TaskState::Canceled);
186 Ok(task.clone())
187 }
188
189 fn matches_list_filters(
190 task: &Task,
191 status: Option<&TaskState>,
192 updated_after: Option<&chrono::DateTime<chrono::Utc>>,
193 ) -> bool {
194 if let Some(status) = status
195 && &task.status.state != status
196 {
197 return false;
198 }
199
200 if let Some(updated_after) = updated_after
201 && task.status.timestamp < *updated_after
202 {
203 return false;
204 }
205
206 true
207 }
208
209 fn clone_task_for_listing(
210 task: &Task,
211 include_artifacts: bool,
212 history_length: Option<usize>,
213 ) -> Task {
214 let mut task = task.clone();
215
216 if !include_artifacts {
217 task.artifacts.clear();
218 }
219
220 if let Some(history_length) = history_length
221 && task.history.len() > history_length
222 {
223 let trim_count = task.history.len() - history_length;
224 task.history.drain(..trim_count);
225 }
226
227 task
228 }
229
230 pub async fn list_tasks(&self, params: ListTasksParams) -> ListTasksResult {
232 let updated_after = params
233 .last_updated_after
234 .as_deref()
235 .and_then(|after| chrono::DateTime::parse_from_rfc3339(after).ok())
236 .map(|after| after.to_utc());
237
238 let mut matching_tasks: Vec<(String, chrono::DateTime<chrono::Utc>)> = {
239 let state = self.state.read().await;
240 if let Some(context_id) = params.context_id.as_deref() {
241 state
242 .contexts
243 .get(context_id)
244 .into_iter()
245 .flat_map(|task_ids| task_ids.iter())
246 .filter_map(|task_id| {
247 let task = state.tasks.get(task_id)?;
248 Self::matches_list_filters(
249 task,
250 params.status.as_ref(),
251 updated_after.as_ref(),
252 )
253 .then(|| (task_id.clone(), task.status.timestamp))
254 })
255 .collect()
256 } else {
257 state
258 .tasks
259 .iter()
260 .filter(|(_, task)| {
261 Self::matches_list_filters(
262 task,
263 params.status.as_ref(),
264 updated_after.as_ref(),
265 )
266 })
267 .map(|(task_id, task)| (task_id.clone(), task.status.timestamp))
268 .collect()
269 }
270 };
271
272 matching_tasks.sort_by(|a, b| b.1.cmp(&a.1));
273
274 let total_size = matching_tasks.len() as u32;
275 let page_size = params.page_size.unwrap_or(50).min(100);
276 let start_idx = params
277 .page_token
278 .as_ref()
279 .and_then(|token| token.parse::<usize>().ok())
280 .unwrap_or(0);
281
282 let end_idx = (start_idx + page_size as usize).min(matching_tasks.len());
283 let next_page_token = if end_idx < matching_tasks.len() {
284 Some(end_idx.to_string())
285 } else {
286 None
287 };
288
289 let include_artifacts = params.include_artifacts == Some(true);
290 let history_length = params.history_length.map(|len| len as usize);
291 let page_task_ids: Vec<_> = matching_tasks
292 .into_iter()
293 .skip(start_idx)
294 .take(page_size as usize)
295 .collect();
296 let result = if page_task_ids.is_empty() {
297 Vec::new()
298 } else {
299 let state = self.state.read().await;
300 page_task_ids
301 .into_iter()
302 .filter_map(|(task_id, _)| {
303 state.tasks.get(&task_id).map(|task| {
304 Self::clone_task_for_listing(task, include_artifacts, history_length)
305 })
306 })
307 .collect()
308 };
309
310 ListTasksResult {
311 tasks: result,
312 total_size: Some(total_size),
313 page_size: Some(page_size),
314 next_page_token,
315 }
316 }
317
318 pub async fn get_tasks_by_context(&self, context_id: &str) -> Vec<Task> {
320 let state = self.state.read().await;
321 state
322 .contexts
323 .get(context_id)
324 .map(|task_ids| {
325 task_ids
326 .iter()
327 .filter_map(|id| state.tasks.get(id).cloned())
328 .collect()
329 })
330 .unwrap_or_default()
331 }
332
333 pub async fn task_count(&self) -> usize {
335 self.state.read().await.tasks.len()
336 }
337
338 pub async fn clear(&self) {
340 let mut state = self.state.write().await;
341 state.tasks.clear();
342 state.contexts.clear();
343 state.webhook_configs.clear();
344 }
345
346 pub async fn set_webhook_config(&self, config: TaskPushNotificationConfig) -> A2aResult<()> {
348 if !config.url.starts_with("https://") && !config.url.starts_with("http://localhost") {
349 return Err(A2aError::UnsupportedOperation(
350 "Webhook URL must use HTTPS or be localhost".to_string(),
351 ));
352 }
353
354 let mut state = self.state.write().await;
355 if !state.tasks.contains_key(&config.task_id) {
356 return Err(A2aError::TaskNotFound(config.task_id));
357 }
358
359 state.webhook_configs.insert(config.task_id.clone(), config);
360 Ok(())
361 }
362
363 pub async fn get_webhook_config(&self, task_id: &str) -> Option<TaskPushNotificationConfig> {
365 let state = self.state.read().await;
366 state.webhook_configs.get(task_id).cloned()
367 }
368
369 pub async fn remove_webhook_config(&self, task_id: &str) {
371 let mut state = self.state.write().await;
372 state.webhook_configs.remove(task_id);
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::a2a::types::MessageRole;
380
381 #[tokio::test]
382 async fn test_create_task() {
383 let manager = TaskManager::new();
384 let task = manager.create_task(None).await;
385
386 assert!(!task.id.is_empty());
387 assert_eq!(task.state(), TaskState::Submitted);
388 assert_eq!(manager.task_count().await, 1);
389 }
390
391 #[tokio::test]
392 async fn test_create_task_with_context() {
393 let manager = TaskManager::new();
394 let task = manager.create_task(Some("ctx-1".to_string())).await;
395
396 assert_eq!(task.context_id, Some("ctx-1".to_string()));
397
398 let tasks = manager.get_tasks_by_context("ctx-1").await;
399 assert_eq!(tasks.len(), 1);
400 assert_eq!(tasks[0].id, task.id);
401 }
402
403 #[tokio::test]
404 async fn test_get_task() {
405 let manager = TaskManager::new();
406 let task = manager.create_task(None).await;
407
408 let retrieved = manager.get_task(&task.id).await;
409 assert!(retrieved.is_some());
410 assert_eq!(retrieved.unwrap().id, task.id);
411
412 let missing = manager.get_task("nonexistent").await;
413 assert!(missing.is_none());
414 }
415
416 #[tokio::test]
417 async fn test_update_status() {
418 let manager = TaskManager::new();
419 let task = manager.create_task(None).await;
420
421 let updated = manager
422 .update_status(&task.id, TaskState::Working, None)
423 .await
424 .expect("update");
425 assert_eq!(updated.state(), TaskState::Working);
426
427 let msg = Message::agent_text("Task completed successfully");
428 let completed = manager
429 .update_status(&task.id, TaskState::Completed, Some(msg))
430 .await
431 .expect("complete");
432 assert_eq!(completed.state(), TaskState::Completed);
433 assert!(completed.status.message.is_some());
434 }
435
436 #[tokio::test]
437 async fn test_add_artifact() {
438 let manager = TaskManager::new();
439 let task = manager.create_task(None).await;
440
441 let artifact = Artifact::text("art-1", "Generated content");
442 let updated = manager
443 .add_artifact(&task.id, artifact)
444 .await
445 .expect("add artifact");
446 assert_eq!(updated.artifacts.len(), 1);
447 assert_eq!(updated.artifacts[0].id, "art-1");
448 }
449
450 #[tokio::test]
451 async fn test_cancel_task() {
452 let manager = TaskManager::new();
453 let task = manager.create_task(None).await;
454
455 let canceled = manager.cancel_task(&task.id).await.expect("cancel");
456 assert_eq!(canceled.state(), TaskState::Canceled);
457 }
458
459 #[tokio::test]
460 async fn test_cancel_completed_task_fails() {
461 let manager = TaskManager::new();
462 let task = manager.create_task(None).await;
463
464 manager
465 .update_status(&task.id, TaskState::Completed, None)
466 .await
467 .expect("complete");
468
469 let result = manager.cancel_task(&task.id).await;
470 result.unwrap_err();
471 }
472
473 #[tokio::test]
474 async fn test_eviction_cleans_context_and_webhook_indexes() {
475 let manager = TaskManager::with_capacity(1);
476 let task = manager.create_task(Some("ctx-1".to_string())).await;
477
478 manager
479 .update_status(&task.id, TaskState::Completed, None)
480 .await
481 .expect("complete");
482 manager
483 .set_webhook_config(TaskPushNotificationConfig {
484 task_id: task.id.clone(),
485 url: "https://example.com/webhook".to_string(),
486 authentication: None,
487 })
488 .await
489 .expect("set webhook");
490
491 let replacement = manager.create_task(None).await;
492
493 assert_eq!(manager.task_count().await, 1);
494 assert!(manager.get_task(&task.id).await.is_none());
495 assert!(manager.get_webhook_config(&task.id).await.is_none());
496 assert!(manager.get_tasks_by_context("ctx-1").await.is_empty());
497 assert_eq!(
498 manager.get_task(&replacement.id).await.unwrap().id,
499 replacement.id
500 );
501 }
502
503 #[tokio::test]
504 async fn test_list_tasks() {
505 let manager = TaskManager::new();
506
507 let _task1 = manager.create_task(Some("ctx-1".to_string())).await;
508 let _task2 = manager.create_task(Some("ctx-1".to_string())).await;
509 let _task3 = manager.create_task(Some("ctx-2".to_string())).await;
510
511 let all = manager.list_tasks(ListTasksParams::default()).await;
512 assert_eq!(all.tasks.len(), 3);
513
514 let ctx1_tasks = manager
515 .list_tasks(ListTasksParams {
516 context_id: Some("ctx-1".to_string()),
517 ..Default::default()
518 })
519 .await;
520 assert_eq!(ctx1_tasks.tasks.len(), 2);
521 }
522
523 #[tokio::test]
524 async fn test_list_tasks_paginates_and_trims_after_sorting() {
525 let manager = TaskManager::new();
526
527 let older = manager.create_task(Some("ctx-1".to_string())).await;
528 tokio::time::sleep(std::time::Duration::from_millis(2)).await;
529 let newer = manager.create_task(Some("ctx-1".to_string())).await;
530
531 manager
532 .add_artifact(&newer.id, Artifact::text("art-1", "Generated content"))
533 .await
534 .expect("add artifact");
535 manager
536 .add_message(&newer.id, Message::user_text("Hello"))
537 .await
538 .expect("add msg1");
539 manager
540 .add_message(&newer.id, Message::agent_text("Hi there"))
541 .await
542 .expect("add msg2");
543
544 let first_page = manager
545 .list_tasks(ListTasksParams {
546 context_id: Some("ctx-1".to_string()),
547 page_size: Some(1),
548 history_length: Some(1),
549 include_artifacts: Some(false),
550 ..Default::default()
551 })
552 .await;
553
554 assert_eq!(first_page.total_size, Some(2));
555 assert_eq!(first_page.next_page_token.as_deref(), Some("1"));
556 assert_eq!(first_page.tasks.len(), 1);
557 assert_eq!(first_page.tasks[0].id, newer.id);
558 assert!(first_page.tasks[0].artifacts.is_empty());
559 assert_eq!(first_page.tasks[0].history.len(), 1);
560 assert_eq!(first_page.tasks[0].history[0].role, MessageRole::Agent);
561
562 let second_page = manager
563 .list_tasks(ListTasksParams {
564 context_id: Some("ctx-1".to_string()),
565 page_size: Some(1),
566 page_token: Some("1".to_string()),
567 ..Default::default()
568 })
569 .await;
570
571 assert_eq!(second_page.tasks.len(), 1);
572 assert_eq!(second_page.tasks[0].id, older.id);
573 assert!(second_page.next_page_token.is_none());
574 }
575
576 #[tokio::test]
577 async fn test_add_message_to_history() {
578 let manager = TaskManager::new();
579 let task = manager.create_task(None).await;
580
581 let msg1 = Message::user_text("Hello");
582 let msg2 = Message::agent_text("Hi there!");
583
584 manager.add_message(&task.id, msg1).await.expect("add msg1");
585 let updated = manager.add_message(&task.id, msg2).await.expect("add msg2");
586
587 assert_eq!(updated.history.len(), 2);
588 assert_eq!(updated.history[0].role, MessageRole::User);
589 assert_eq!(updated.history[1].role, MessageRole::Agent);
590 }
591}