tower_a2a/protocol/
task.rs

1//! A2A task types and lifecycle management
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5
6use super::{error::TaskError, message::Message, Artifact};
7
8/// A task in the A2A protocol
9///
10/// Tasks represent asynchronous operations performed by agents.
11/// They have a lifecycle from submitted to completion, with various intermediate states.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub struct Task {
14    /// Unique identifier for the task
15    pub id: String,
16
17    /// Current status of the task
18    pub status: TaskStatus,
19
20    /// Input message that created this task
21    pub input: Message,
22
23    /// Artifacts produced by this task (replaces legacy 'output' field)
24    #[serde(default, skip_serializing_if = "Vec::is_empty")]
25    pub artifacts: Vec<Artifact>,
26
27    /// Message exchange history for this task
28    #[serde(default, skip_serializing_if = "Vec::is_empty")]
29    pub history: Vec<Message>,
30
31    /// Error information (present if task failed)
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub error: Option<TaskError>,
34
35    /// When the task was created
36    #[serde(rename = "createdAt")]
37    pub created_at: DateTime<Utc>,
38
39    /// When the task was last updated
40    #[serde(rename = "updatedAt", skip_serializing_if = "Option::is_none")]
41    pub updated_at: Option<DateTime<Utc>>,
42
43    /// Optional context ID for grouping related tasks/messages
44    #[serde(rename = "contextId", skip_serializing_if = "Option::is_none")]
45    pub context_id: Option<String>,
46}
47
48impl Task {
49    /// Create a new task
50    pub fn new(id: impl Into<String>, input: Message) -> Self {
51        Self {
52            id: id.into(),
53            status: TaskStatus::Submitted,
54            input,
55            artifacts: Vec::new(),
56            history: Vec::new(),
57            error: None,
58            created_at: Utc::now(),
59            updated_at: None,
60            context_id: None,
61        }
62    }
63
64    /// Check if the task is in a terminal state
65    pub fn is_terminal(&self) -> bool {
66        matches!(
67            self.status,
68            TaskStatus::Completed
69                | TaskStatus::Failed
70                | TaskStatus::Cancelled
71                | TaskStatus::Rejected
72        )
73    }
74
75    /// Check if the task is still processing
76    pub fn is_processing(&self) -> bool {
77        matches!(self.status, TaskStatus::Submitted | TaskStatus::Working)
78    }
79
80    /// Check if the task requires input
81    pub fn requires_input(&self) -> bool {
82        matches!(
83            self.status,
84            TaskStatus::InputRequired | TaskStatus::AuthRequired
85        )
86    }
87
88    /// Update the task status
89    pub fn with_status(mut self, status: TaskStatus) -> Self {
90        self.status = status;
91        self.updated_at = Some(Utc::now());
92        self
93    }
94
95    /// Add an artifact to the task
96    pub fn with_artifact(mut self, artifact: Artifact) -> Self {
97        self.artifacts.push(artifact);
98        self.updated_at = Some(Utc::now());
99        self
100    }
101
102    /// Add a message to the history
103    pub fn with_history_message(mut self, message: Message) -> Self {
104        self.history.push(message);
105        self.updated_at = Some(Utc::now());
106        self
107    }
108
109    /// Set the task error
110    pub fn with_error(mut self, error: TaskError) -> Self {
111        self.error = Some(error);
112        self.updated_at = Some(Utc::now());
113        self
114    }
115
116    /// Set the context ID
117    pub fn with_context_id(mut self, context_id: impl Into<String>) -> Self {
118        self.context_id = Some(context_id.into());
119        self
120    }
121}
122
123/// Task status in the A2A protocol lifecycle
124///
125/// Task lifecycle: submitted → working → completed/failed/cancelled/rejected
126/// Non-terminal states: input-required, auth-required (awaiting client input)
127#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
128#[serde(rename_all = "kebab-case")]
129pub enum TaskStatus {
130    /// Task has been received and is queued for processing
131    Submitted,
132
133    /// Task is currently being processed
134    Working,
135
136    /// Task requires additional input from the client
137    InputRequired,
138
139    /// Task requires authentication or authorization
140    AuthRequired,
141
142    /// Task completed successfully
143    Completed,
144
145    /// Task failed with an error
146    Failed,
147
148    /// Task was cancelled by the client
149    Cancelled,
150
151    /// Task was rejected by the agent (e.g., invalid request)
152    Rejected,
153}
154
155impl TaskStatus {
156    /// Check if this is a terminal status
157    pub fn is_terminal(&self) -> bool {
158        matches!(
159            self,
160            TaskStatus::Completed
161                | TaskStatus::Failed
162                | TaskStatus::Cancelled
163                | TaskStatus::Rejected
164        )
165    }
166
167    /// Check if this status requires client action
168    pub fn requires_action(&self) -> bool {
169        matches!(self, TaskStatus::InputRequired | TaskStatus::AuthRequired)
170    }
171}
172
173/// Request to send a message to an agent
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct SendMessageRequest {
176    /// The message to send
177    pub message: Message,
178
179    /// Whether to stream the response
180    #[serde(default)]
181    pub stream: bool,
182
183    /// Optional context ID for multi-turn conversations
184    #[serde(rename = "contextId", skip_serializing_if = "Option::is_none")]
185    pub context_id: Option<String>,
186
187    /// Optional task ID to continue from
188    #[serde(rename = "taskId", skip_serializing_if = "Option::is_none")]
189    pub task_id: Option<String>,
190}
191
192/// Response from listing tasks
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct TaskListResponse {
195    /// List of tasks
196    pub tasks: Vec<Task>,
197
198    /// Total number of tasks matching the query
199    pub total: usize,
200
201    /// Optional continuation token for pagination
202    #[serde(rename = "nextToken", skip_serializing_if = "Option::is_none")]
203    pub next_token: Option<String>,
204}
205
206#[cfg(test)]
207mod tests {
208    use crate::protocol::message::Message;
209
210    use super::*;
211
212    #[test]
213    fn test_task_creation() {
214        let msg = Message::user("Test");
215        let task = Task::new("task-123", msg);
216
217        assert_eq!(task.id, "task-123");
218        assert_eq!(task.status, TaskStatus::Submitted);
219        assert!(!task.is_terminal());
220        assert!(task.is_processing());
221    }
222
223    #[test]
224    fn test_task_lifecycle() {
225        let msg = Message::user("Test");
226        let task = Task::new("task-123", msg);
227
228        let task = task.with_status(TaskStatus::Working);
229        assert_eq!(task.status, TaskStatus::Working);
230        assert!(task.is_processing());
231
232        let task = task.with_status(TaskStatus::Completed);
233        assert!(task.is_terminal());
234        assert!(!task.is_processing());
235    }
236
237    #[test]
238    fn test_task_status() {
239        assert!(TaskStatus::Completed.is_terminal());
240        assert!(TaskStatus::Failed.is_terminal());
241        assert!(!TaskStatus::Working.is_terminal());
242
243        assert!(TaskStatus::InputRequired.requires_action());
244        assert!(TaskStatus::AuthRequired.requires_action());
245        assert!(!TaskStatus::Working.requires_action());
246    }
247
248    #[test]
249    fn test_task_serialization() {
250        let msg = Message::user("Test");
251        let task = Task::new("task-123", msg);
252
253        let json = serde_json::to_string(&task).unwrap();
254        assert!(json.contains("\"id\":\"task-123\""));
255        assert!(json.contains("\"status\":\"submitted\""));
256
257        let deserialized: Task = serde_json::from_str(&json).unwrap();
258        assert_eq!(task.id, deserialized.id);
259        assert_eq!(task.status, deserialized.status);
260    }
261}