1use std::sync::{Arc, RwLock};
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6
7use crate::chat::events::{
8 ChatEvent, EventSender, ToolExecutionResult, ToolRequest as ToolRequestEvent, ToolRequestType,
9};
10use crate::module::{ContextComponent, ContextComponentId};
11use crate::module::{Module, SessionStateComponent};
12use crate::module::{PromptComponent, PromptComponentId};
13use crate::settings::config::Settings;
14use crate::tools::r#trait::{
15 ContinuationPreference, ToolCallHandle, ToolCategory, ToolExecutor, ToolOutput, ToolRequest,
16};
17use crate::tools::ToolName;
18
19pub struct TaskListModule {
21 inner: Arc<TaskListModuleInner>,
22}
23
24pub(crate) struct TaskListModuleInner {
25 pub(crate) task_list: RwLock<TaskList>,
26 pub(crate) event_sender: EventSender,
27}
28
29impl TaskListModule {
30 pub fn new(event_sender: EventSender) -> Self {
31 let inner = Arc::new(TaskListModuleInner {
32 task_list: RwLock::new(TaskList::default()),
33 event_sender,
34 });
35 inner.emit_update();
36 Self { inner }
37 }
38
39 pub fn manage_tool(&self) -> Arc<dyn ToolExecutor> {
40 Arc::new(ManageTaskListTool {
41 inner: self.inner.clone(),
42 })
43 }
44
45 pub fn context_component(&self) -> Arc<dyn ContextComponent + Send + Sync> {
46 Arc::new(TaskListContextComponent {
47 inner: self.inner.clone(),
48 })
49 }
50
51 pub fn get(&self) -> TaskList {
52 self.inner.get()
53 }
54
55 pub fn replace(&self, title: String, tasks: Vec<TaskWithStatus>) {
56 self.inner.replace(title, tasks);
57 }
58}
59
60impl Module for TaskListModule {
61 fn prompt_components(&self) -> Vec<Arc<dyn PromptComponent>> {
62 vec![Arc::new(TaskListPromptComponent)]
63 }
64
65 fn context_components(&self) -> Vec<Arc<dyn ContextComponent>> {
66 vec![self.context_component()]
67 }
68
69 fn tools(&self) -> Vec<Arc<dyn ToolExecutor>> {
70 vec![self.manage_tool()]
71 }
72
73 fn session_state(&self) -> Option<Arc<dyn SessionStateComponent>> {
74 Some(Arc::new(TaskListSessionState {
75 inner: self.inner.clone(),
76 }))
77 }
78}
79
80struct TaskListSessionState {
81 inner: Arc<TaskListModuleInner>,
82}
83
84impl SessionStateComponent for TaskListSessionState {
85 fn key(&self) -> &str {
86 "task_list"
87 }
88
89 fn save(&self) -> Value {
90 serde_json::to_value(self.inner.get()).expect("TaskList serialization cannot fail")
91 }
92
93 fn load(&self, state: Value) -> Result<()> {
94 let task_list: TaskList = serde_json::from_value(state)?;
95 let tasks = task_list
96 .tasks
97 .iter()
98 .map(|t| TaskWithStatus {
99 description: t.description.clone(),
100 status: t.status,
101 })
102 .collect();
103 self.inner.replace(task_list.title, tasks);
104 Ok(())
105 }
106}
107
108impl TaskListModuleInner {
109 pub(crate) fn replace(&self, title: String, tasks: Vec<TaskWithStatus>) {
110 let new_list = TaskList::from_tasks_with_status(title, tasks);
111 *self.task_list.write().unwrap() = new_list;
112 self.emit_update();
113 }
114
115 pub(crate) fn get(&self) -> TaskList {
116 self.task_list.read().unwrap().clone()
117 }
118
119 fn emit_update(&self) {
120 self.event_sender.send(ChatEvent::TaskUpdate(self.get()));
121 }
122}
123
124pub const TASK_LIST_CONTEXT_ID: ContextComponentId = ContextComponentId("tasks");
125
126struct TaskListContextComponent {
127 inner: Arc<TaskListModuleInner>,
128}
129
130#[async_trait::async_trait(?Send)]
131impl ContextComponent for TaskListContextComponent {
132 fn id(&self) -> ContextComponentId {
133 TASK_LIST_CONTEXT_ID
134 }
135
136 async fn build_context_section(&self) -> Option<String> {
137 let task_list = self.inner.task_list.read().unwrap();
138
139 if task_list.tasks.is_empty() {
140 return None;
141 }
142
143 let mut output = format!("Task List: {}\n", task_list.title);
144
145 for task in &task_list.tasks {
146 let status_marker = match task.status {
147 TaskStatus::Pending => "[Pending]",
148 TaskStatus::InProgress => "[InProgress]",
149 TaskStatus::Completed => "[Completed]",
150 TaskStatus::Failed => "[Failed]",
151 };
152 output.push_str(&format!(
153 " - {} Task {}: {}\n",
154 status_marker, task.id, task.description
155 ));
156 }
157
158 Some(output)
159 }
160}
161
162pub const TASK_LIST_PROMPT_ID: PromptComponentId = PromptComponentId("tasks");
163
164const TASK_LIST_MANAGEMENT: &str = r#"## Task List Management
165• The 'context' will always include a task list. The task list is designed to help you break down large tasks in to smaller chunks of work and to provide feedback to the user about what you are working on.
166• When possible, design each step so that it can be validated (compile and pass tests). Some tasks may require multiple steps before validation is feasible.
167• The task list can be updated with a special tool called "manage_task_list". Ensure the task list is always up to date.
168• The "manage_task_list" is neither an "Execution" nor a "Meta" tool and may be combined with either type of response. "manage_task_list" may never be the only tool request; "manage_task_list" must always be combined with at least 1 other tool call.
169
170## When to Update the Task List
171• Set the task list once a plan has been presented to the user and approved. A new task list created with "manage_task_list" must be combined with "Exection" tools beginning work on the first task.
172• Update the task list when a task has been completed. If there are additional tasks, "manage_task_list" must be combined with "Execution" tools beginning work on the next task. When completing the last task, "manage_task_list" must be combined with "complete_task".
173• Before marking a task complete ensure changes: 1/ comply with style mandates 2/ compile and build (when possible) 3/ tests pass (when possible)
174• "complete_task" should only be used when completing the final task in the task list.
175"#;
176
177pub struct TaskListPromptComponent;
179
180impl PromptComponent for TaskListPromptComponent {
181 fn id(&self) -> PromptComponentId {
182 TASK_LIST_PROMPT_ID
183 }
184
185 fn build_prompt_section(&self, _settings: &Settings) -> Option<String> {
186 Some(TASK_LIST_MANAGEMENT.to_string())
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct TaskWithStatus {
192 pub description: String,
193 pub status: TaskStatus,
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "snake_case")]
198pub enum TaskStatus {
199 Pending,
200 InProgress,
201 Completed,
202 Failed,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Task {
207 pub id: usize,
208 pub description: String,
209 pub status: TaskStatus,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct TaskList {
214 pub title: String,
215 pub tasks: Vec<Task>,
216}
217
218impl TaskList {
219 pub fn from_tasks_with_status(title: String, tasks_with_status: Vec<TaskWithStatus>) -> Self {
220 let tasks = tasks_with_status
221 .into_iter()
222 .enumerate()
223 .map(|(id, task)| Task {
224 id,
225 description: task.description,
226 status: task.status,
227 })
228 .collect();
229
230 Self { title, tasks }
231 }
232}
233
234impl Default for TaskList {
235 fn default() -> Self {
236 Self {
237 title: "Understand user requirements".to_string(),
238 tasks: vec![
239 Task {
240 id: 0,
241 description: "Await user request".to_string(),
242 status: TaskStatus::InProgress,
243 },
244 Task {
245 id: 1,
246 description:
247 "Understand/Explore the code base and propose a comprehensive plan"
248 .to_string(),
249 status: TaskStatus::Pending,
250 },
251 ],
252 }
253 }
254}
255
256#[derive(Debug, Serialize, Deserialize)]
261struct TaskInput {
262 description: String,
263 status: TaskStatus,
264}
265
266#[derive(Debug, Serialize, Deserialize)]
267struct ManageTaskListInput {
268 title: String,
269 tasks: Vec<TaskInput>,
270}
271
272pub struct ManageTaskListTool {
273 pub(crate) inner: Arc<TaskListModuleInner>,
274}
275
276impl ManageTaskListTool {
277 pub fn tool_name() -> ToolName {
278 ToolName::new("manage_task_list")
279 }
280}
281
282struct ManageTaskListHandle {
283 title: String,
284 tasks: Vec<TaskWithStatus>,
285 tool_use_id: String,
286 inner: Arc<TaskListModuleInner>,
287}
288
289#[async_trait::async_trait(?Send)]
290impl ToolCallHandle for ManageTaskListHandle {
291 fn tool_request(&self) -> ToolRequestEvent {
292 ToolRequestEvent {
293 tool_call_id: self.tool_use_id.clone(),
294 tool_name: "manage_task_list".to_string(),
295 tool_type: ToolRequestType::Other {
296 args: json!({ "title": self.title, "task_count": self.tasks.len() }),
297 },
298 }
299 }
300
301 async fn execute(self: Box<Self>) -> ToolOutput {
302 self.inner.replace(self.title.clone(), self.tasks);
303 ToolOutput::Result {
304 content: format!("Task list updated: {}", self.title),
305 is_error: false,
306 continuation: ContinuationPreference::Continue,
307 ui_result: ToolExecutionResult::Other {
308 result: json!({ "title": self.title }),
309 },
310 }
311 }
312}
313
314#[async_trait::async_trait(?Send)]
315impl ToolExecutor for ManageTaskListTool {
316 fn name(&self) -> String {
317 "manage_task_list".to_string()
318 }
319
320 fn description(&self) -> String {
321 "Create or update the task list. This tool must be combined with at least 1 other tool call.".to_string()
322 }
323
324 fn input_schema(&self) -> Value {
325 json!({
326 "type": "object",
327 "properties": {
328 "title": {
329 "type": "string",
330 "description": "Title for the task list (≤50 characters) describing the current work"
331 },
332 "tasks": {
333 "type": "array",
334 "items": {
335 "type": "object",
336 "properties": {
337 "description": {
338 "type": "string",
339 "description": "Task description"
340 },
341 "status": {
342 "type": "string",
343 "enum": ["pending", "in_progress", "completed", "failed"],
344 "description": "Current task status"
345 }
346 },
347 "required": ["description", "status"],
348 "additionalProperties": false
349 },
350 "description": "Complete list of tasks with current status"
351 }
352 },
353 "required": ["title", "tasks"]
354 })
355 }
356
357 fn category(&self) -> ToolCategory {
358 ToolCategory::TaskList
359 }
360
361 async fn process(&self, request: &ToolRequest) -> Result<Box<dyn ToolCallHandle>> {
362 let input: ManageTaskListInput = serde_json::from_value(request.arguments.clone())?;
363
364 if input.tasks.is_empty() {
365 return Err(anyhow::anyhow!("Task list cannot be empty"));
366 }
367
368 let tasks: Vec<TaskWithStatus> = input
369 .tasks
370 .into_iter()
371 .map(|t| TaskWithStatus {
372 description: t.description,
373 status: t.status,
374 })
375 .collect();
376
377 Ok(Box::new(ManageTaskListHandle {
378 title: input.title,
379 tasks,
380 tool_use_id: request.tool_use_id.clone(),
381 inner: self.inner.clone(),
382 }))
383 }
384}