Skip to main content

vtcode_core/a2a/
task_manager.rs

1//! A2A Task Manager
2//!
3//! Manages task lifecycle, storage, and queries for the A2A protocol.
4//! Provides an in-memory store with support for concurrent access.
5
6use hashbrown::{HashMap, HashSet};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::errors::{A2aError, A2aResult};
11use super::rpc::{ListTasksParams, ListTasksResult, TaskPushNotificationConfig};
12use super::types::{Artifact, Message, Task, TaskState, TaskStatus};
13
14/// A2A Task Manager - handles task creation, updates, and queries
15#[derive(Debug, Clone)]
16pub struct TaskManager {
17    /// All mutable task manager state lives behind one lock so related indexes stay in sync.
18    state: Arc<RwLock<TaskManagerState>>,
19    /// Maximum tasks to retain (for memory management)
20    max_tasks: usize,
21}
22
23#[derive(Debug, Default)]
24struct TaskManagerState {
25    tasks: HashMap<String, Task>,
26    contexts: HashMap<String, Vec<String>>,
27    webhook_configs: HashMap<String, TaskPushNotificationConfig>,
28}
29
30impl Default for TaskManager {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl TaskManager {
37    /// Create a new task manager
38    pub fn new() -> Self {
39        Self {
40            state: Arc::new(RwLock::new(TaskManagerState::default())),
41            max_tasks: 1000,
42        }
43    }
44
45    /// Create a new task manager with custom capacity
46    pub fn with_capacity(max_tasks: usize) -> Self {
47        Self {
48            state: Arc::new(RwLock::new(TaskManagerState {
49                tasks: HashMap::with_capacity(max_tasks.min(100)),
50                contexts: HashMap::new(),
51                webhook_configs: HashMap::new(),
52            })),
53            max_tasks,
54        }
55    }
56
57    /// Create a new task
58    pub async fn create_task(&self, context_id: Option<String>) -> Task {
59        let mut task = Task::new();
60        if let Some(ref ctx_id) = context_id {
61            task = task.with_context_id(ctx_id);
62        }
63
64        let task_id = task.id.clone();
65        let mut state = self.state.write().await;
66
67        if state.tasks.len() >= self.max_tasks {
68            self.evict_oldest_tasks(&mut state);
69        }
70
71        state.tasks.insert(task_id.clone(), task.clone());
72        if let Some(ctx_id) = context_id {
73            state.contexts.entry(ctx_id).or_default().push(task_id);
74        }
75
76        task
77    }
78
79    /// Evict oldest completed tasks when at capacity
80    fn evict_oldest_tasks(&self, state: &mut TaskManagerState) {
81        let mut completed_tasks: Vec<_> = state
82            .tasks
83            .iter()
84            .filter(|(_, task)| task.is_terminal())
85            .map(|(id, task)| (id.clone(), task.status.timestamp))
86            .collect();
87
88        completed_tasks.sort_by(|a, b| a.1.cmp(&b.1));
89
90        let evict_count = (self.max_tasks / 10).max(1);
91        let evicted_ids: HashSet<_> = completed_tasks
92            .into_iter()
93            .take(evict_count)
94            .map(|(id, _)| id)
95            .collect();
96
97        if evicted_ids.is_empty() {
98            return;
99        }
100
101        for id in &evicted_ids {
102            state.tasks.remove(id);
103            state.webhook_configs.remove(id);
104        }
105
106        state.contexts.retain(|_, task_ids| {
107            task_ids.retain(|task_id| !evicted_ids.contains(task_id));
108            !task_ids.is_empty()
109        });
110    }
111
112    /// Get a task by ID
113    pub async fn get_task(&self, task_id: &str) -> Option<Task> {
114        let state = self.state.read().await;
115        state.tasks.get(task_id).cloned()
116    }
117
118    /// Get a task by ID, returning an error if not found
119    pub async fn get_task_or_error(&self, task_id: &str) -> A2aResult<Task> {
120        self.get_task(task_id)
121            .await
122            .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))
123    }
124
125    /// Update task status
126    pub async fn update_status(
127        &self,
128        task_id: &str,
129        state: TaskState,
130        message: Option<Message>,
131    ) -> A2aResult<Task> {
132        let mut manager_state = self.state.write().await;
133        let task = manager_state
134            .tasks
135            .get_mut(task_id)
136            .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
137
138        task.status = match message {
139            Some(msg) => TaskStatus::with_message(state, msg),
140            None => TaskStatus::new(state),
141        };
142
143        Ok(task.clone())
144    }
145
146    /// Add an artifact to a task
147    pub async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> A2aResult<Task> {
148        let mut state = self.state.write().await;
149        let task = state
150            .tasks
151            .get_mut(task_id)
152            .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
153
154        task.artifacts.push(artifact);
155        Ok(task.clone())
156    }
157
158    /// Add a message to task history
159    pub async fn add_message(&self, task_id: &str, message: Message) -> A2aResult<Task> {
160        let mut state = self.state.write().await;
161        let task = state
162            .tasks
163            .get_mut(task_id)
164            .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
165
166        task.history.push(message);
167        Ok(task.clone())
168    }
169
170    /// Cancel a task
171    pub async fn cancel_task(&self, task_id: &str) -> A2aResult<Task> {
172        let mut state = self.state.write().await;
173        let task = state
174            .tasks
175            .get_mut(task_id)
176            .ok_or_else(|| A2aError::TaskNotFound(task_id.to_string()))?;
177
178        if !task.is_cancelable() {
179            return Err(A2aError::TaskNotCancelable(format!(
180                "Task {} is in state {:?} and cannot be canceled",
181                task_id, task.status.state
182            )));
183        }
184
185        task.status = TaskStatus::new(TaskState::Canceled);
186        Ok(task.clone())
187    }
188
189    fn matches_list_filters(
190        task: &Task,
191        status: Option<&TaskState>,
192        updated_after: Option<&chrono::DateTime<chrono::Utc>>,
193    ) -> bool {
194        if let Some(status) = status
195            && &task.status.state != status
196        {
197            return false;
198        }
199
200        if let Some(updated_after) = updated_after
201            && task.status.timestamp < *updated_after
202        {
203            return false;
204        }
205
206        true
207    }
208
209    fn clone_task_for_listing(
210        task: &Task,
211        include_artifacts: bool,
212        history_length: Option<usize>,
213    ) -> Task {
214        let mut task = task.clone();
215
216        if !include_artifacts {
217            task.artifacts.clear();
218        }
219
220        if let Some(history_length) = history_length
221            && task.history.len() > history_length
222        {
223            let trim_count = task.history.len() - history_length;
224            task.history.drain(..trim_count);
225        }
226
227        task
228    }
229
230    /// List tasks with optional filtering
231    pub async fn list_tasks(&self, params: ListTasksParams) -> ListTasksResult {
232        let updated_after = params
233            .last_updated_after
234            .as_deref()
235            .and_then(|after| chrono::DateTime::parse_from_rfc3339(after).ok())
236            .map(|after| after.to_utc());
237
238        let mut matching_tasks: Vec<(String, chrono::DateTime<chrono::Utc>)> = {
239            let state = self.state.read().await;
240            if let Some(context_id) = params.context_id.as_deref() {
241                state
242                    .contexts
243                    .get(context_id)
244                    .into_iter()
245                    .flat_map(|task_ids| task_ids.iter())
246                    .filter_map(|task_id| {
247                        let task = state.tasks.get(task_id)?;
248                        Self::matches_list_filters(
249                            task,
250                            params.status.as_ref(),
251                            updated_after.as_ref(),
252                        )
253                        .then(|| (task_id.clone(), task.status.timestamp))
254                    })
255                    .collect()
256            } else {
257                state
258                    .tasks
259                    .iter()
260                    .filter(|(_, task)| {
261                        Self::matches_list_filters(
262                            task,
263                            params.status.as_ref(),
264                            updated_after.as_ref(),
265                        )
266                    })
267                    .map(|(task_id, task)| (task_id.clone(), task.status.timestamp))
268                    .collect()
269            }
270        };
271
272        matching_tasks.sort_by(|a, b| b.1.cmp(&a.1));
273
274        let total_size = matching_tasks.len() as u32;
275        let page_size = params.page_size.unwrap_or(50).min(100);
276        let start_idx = params
277            .page_token
278            .as_ref()
279            .and_then(|token| token.parse::<usize>().ok())
280            .unwrap_or(0);
281
282        let end_idx = (start_idx + page_size as usize).min(matching_tasks.len());
283        let next_page_token = if end_idx < matching_tasks.len() {
284            Some(end_idx.to_string())
285        } else {
286            None
287        };
288
289        let include_artifacts = params.include_artifacts == Some(true);
290        let history_length = params.history_length.map(|len| len as usize);
291        let page_task_ids: Vec<_> = matching_tasks
292            .into_iter()
293            .skip(start_idx)
294            .take(page_size as usize)
295            .collect();
296        let result = if page_task_ids.is_empty() {
297            Vec::new()
298        } else {
299            let state = self.state.read().await;
300            page_task_ids
301                .into_iter()
302                .filter_map(|(task_id, _)| {
303                    state.tasks.get(&task_id).map(|task| {
304                        Self::clone_task_for_listing(task, include_artifacts, history_length)
305                    })
306                })
307                .collect()
308        };
309
310        ListTasksResult {
311            tasks: result,
312            total_size: Some(total_size),
313            page_size: Some(page_size),
314            next_page_token,
315        }
316    }
317
318    /// Get tasks by context ID
319    pub async fn get_tasks_by_context(&self, context_id: &str) -> Vec<Task> {
320        let state = self.state.read().await;
321        state
322            .contexts
323            .get(context_id)
324            .map(|task_ids| {
325                task_ids
326                    .iter()
327                    .filter_map(|id| state.tasks.get(id).cloned())
328                    .collect()
329            })
330            .unwrap_or_default()
331    }
332
333    /// Get the number of tasks
334    pub async fn task_count(&self) -> usize {
335        self.state.read().await.tasks.len()
336    }
337
338    /// Clear all tasks (for testing)
339    pub async fn clear(&self) {
340        let mut state = self.state.write().await;
341        state.tasks.clear();
342        state.contexts.clear();
343        state.webhook_configs.clear();
344    }
345
346    /// Set webhook configuration for a task
347    pub async fn set_webhook_config(&self, config: TaskPushNotificationConfig) -> A2aResult<()> {
348        if !config.url.starts_with("https://") && !config.url.starts_with("http://localhost") {
349            return Err(A2aError::UnsupportedOperation(
350                "Webhook URL must use HTTPS or be localhost".to_string(),
351            ));
352        }
353
354        let mut state = self.state.write().await;
355        if !state.tasks.contains_key(&config.task_id) {
356            return Err(A2aError::TaskNotFound(config.task_id));
357        }
358
359        state.webhook_configs.insert(config.task_id.clone(), config);
360        Ok(())
361    }
362
363    /// Get webhook configuration for a task
364    pub async fn get_webhook_config(&self, task_id: &str) -> Option<TaskPushNotificationConfig> {
365        let state = self.state.read().await;
366        state.webhook_configs.get(task_id).cloned()
367    }
368
369    /// Remove webhook configuration for a task
370    pub async fn remove_webhook_config(&self, task_id: &str) {
371        let mut state = self.state.write().await;
372        state.webhook_configs.remove(task_id);
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::a2a::types::MessageRole;
380
381    #[tokio::test]
382    async fn test_create_task() {
383        let manager = TaskManager::new();
384        let task = manager.create_task(None).await;
385
386        assert!(!task.id.is_empty());
387        assert_eq!(task.state(), TaskState::Submitted);
388        assert_eq!(manager.task_count().await, 1);
389    }
390
391    #[tokio::test]
392    async fn test_create_task_with_context() {
393        let manager = TaskManager::new();
394        let task = manager.create_task(Some("ctx-1".to_string())).await;
395
396        assert_eq!(task.context_id, Some("ctx-1".to_string()));
397
398        let tasks = manager.get_tasks_by_context("ctx-1").await;
399        assert_eq!(tasks.len(), 1);
400        assert_eq!(tasks[0].id, task.id);
401    }
402
403    #[tokio::test]
404    async fn test_get_task() {
405        let manager = TaskManager::new();
406        let task = manager.create_task(None).await;
407
408        let retrieved = manager.get_task(&task.id).await;
409        assert!(retrieved.is_some());
410        assert_eq!(retrieved.unwrap().id, task.id);
411
412        let missing = manager.get_task("nonexistent").await;
413        assert!(missing.is_none());
414    }
415
416    #[tokio::test]
417    async fn test_update_status() {
418        let manager = TaskManager::new();
419        let task = manager.create_task(None).await;
420
421        let updated = manager
422            .update_status(&task.id, TaskState::Working, None)
423            .await
424            .expect("update");
425        assert_eq!(updated.state(), TaskState::Working);
426
427        let msg = Message::agent_text("Task completed successfully");
428        let completed = manager
429            .update_status(&task.id, TaskState::Completed, Some(msg))
430            .await
431            .expect("complete");
432        assert_eq!(completed.state(), TaskState::Completed);
433        assert!(completed.status.message.is_some());
434    }
435
436    #[tokio::test]
437    async fn test_add_artifact() {
438        let manager = TaskManager::new();
439        let task = manager.create_task(None).await;
440
441        let artifact = Artifact::text("art-1", "Generated content");
442        let updated = manager
443            .add_artifact(&task.id, artifact)
444            .await
445            .expect("add artifact");
446        assert_eq!(updated.artifacts.len(), 1);
447        assert_eq!(updated.artifacts[0].id, "art-1");
448    }
449
450    #[tokio::test]
451    async fn test_cancel_task() {
452        let manager = TaskManager::new();
453        let task = manager.create_task(None).await;
454
455        let canceled = manager.cancel_task(&task.id).await.expect("cancel");
456        assert_eq!(canceled.state(), TaskState::Canceled);
457    }
458
459    #[tokio::test]
460    async fn test_cancel_completed_task_fails() {
461        let manager = TaskManager::new();
462        let task = manager.create_task(None).await;
463
464        manager
465            .update_status(&task.id, TaskState::Completed, None)
466            .await
467            .expect("complete");
468
469        let result = manager.cancel_task(&task.id).await;
470        result.unwrap_err();
471    }
472
473    #[tokio::test]
474    async fn test_eviction_cleans_context_and_webhook_indexes() {
475        let manager = TaskManager::with_capacity(1);
476        let task = manager.create_task(Some("ctx-1".to_string())).await;
477
478        manager
479            .update_status(&task.id, TaskState::Completed, None)
480            .await
481            .expect("complete");
482        manager
483            .set_webhook_config(TaskPushNotificationConfig {
484                task_id: task.id.clone(),
485                url: "https://example.com/webhook".to_string(),
486                authentication: None,
487            })
488            .await
489            .expect("set webhook");
490
491        let replacement = manager.create_task(None).await;
492
493        assert_eq!(manager.task_count().await, 1);
494        assert!(manager.get_task(&task.id).await.is_none());
495        assert!(manager.get_webhook_config(&task.id).await.is_none());
496        assert!(manager.get_tasks_by_context("ctx-1").await.is_empty());
497        assert_eq!(
498            manager.get_task(&replacement.id).await.unwrap().id,
499            replacement.id
500        );
501    }
502
503    #[tokio::test]
504    async fn test_list_tasks() {
505        let manager = TaskManager::new();
506
507        let _task1 = manager.create_task(Some("ctx-1".to_string())).await;
508        let _task2 = manager.create_task(Some("ctx-1".to_string())).await;
509        let _task3 = manager.create_task(Some("ctx-2".to_string())).await;
510
511        let all = manager.list_tasks(ListTasksParams::default()).await;
512        assert_eq!(all.tasks.len(), 3);
513
514        let ctx1_tasks = manager
515            .list_tasks(ListTasksParams {
516                context_id: Some("ctx-1".to_string()),
517                ..Default::default()
518            })
519            .await;
520        assert_eq!(ctx1_tasks.tasks.len(), 2);
521    }
522
523    #[tokio::test]
524    async fn test_list_tasks_paginates_and_trims_after_sorting() {
525        let manager = TaskManager::new();
526
527        let older = manager.create_task(Some("ctx-1".to_string())).await;
528        tokio::time::sleep(std::time::Duration::from_millis(2)).await;
529        let newer = manager.create_task(Some("ctx-1".to_string())).await;
530
531        manager
532            .add_artifact(&newer.id, Artifact::text("art-1", "Generated content"))
533            .await
534            .expect("add artifact");
535        manager
536            .add_message(&newer.id, Message::user_text("Hello"))
537            .await
538            .expect("add msg1");
539        manager
540            .add_message(&newer.id, Message::agent_text("Hi there"))
541            .await
542            .expect("add msg2");
543
544        let first_page = manager
545            .list_tasks(ListTasksParams {
546                context_id: Some("ctx-1".to_string()),
547                page_size: Some(1),
548                history_length: Some(1),
549                include_artifacts: Some(false),
550                ..Default::default()
551            })
552            .await;
553
554        assert_eq!(first_page.total_size, Some(2));
555        assert_eq!(first_page.next_page_token.as_deref(), Some("1"));
556        assert_eq!(first_page.tasks.len(), 1);
557        assert_eq!(first_page.tasks[0].id, newer.id);
558        assert!(first_page.tasks[0].artifacts.is_empty());
559        assert_eq!(first_page.tasks[0].history.len(), 1);
560        assert_eq!(first_page.tasks[0].history[0].role, MessageRole::Agent);
561
562        let second_page = manager
563            .list_tasks(ListTasksParams {
564                context_id: Some("ctx-1".to_string()),
565                page_size: Some(1),
566                page_token: Some("1".to_string()),
567                ..Default::default()
568            })
569            .await;
570
571        assert_eq!(second_page.tasks.len(), 1);
572        assert_eq!(second_page.tasks[0].id, older.id);
573        assert!(second_page.next_page_token.is_none());
574    }
575
576    #[tokio::test]
577    async fn test_add_message_to_history() {
578        let manager = TaskManager::new();
579        let task = manager.create_task(None).await;
580
581        let msg1 = Message::user_text("Hello");
582        let msg2 = Message::agent_text("Hi there!");
583
584        manager.add_message(&task.id, msg1).await.expect("add msg1");
585        let updated = manager.add_message(&task.id, msg2).await.expect("add msg2");
586
587        assert_eq!(updated.history.len(), 2);
588        assert_eq!(updated.history[0].role, MessageRole::User);
589        assert_eq!(updated.history[1].role, MessageRole::Agent);
590    }
591}