Skip to main content

synaptic_middleware/
todo_list.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use synaptic_core::{Message, SynapticError};
6use tokio::sync::Mutex;
7
8use crate::{AgentMiddleware, ModelRequest};
9
10/// A single task in the agent's todo list.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TodoItem {
13    pub id: usize,
14    pub task: String,
15    pub done: bool,
16}
17
18/// Adds task-planning capability to an agent by injecting a todo list
19/// into the system prompt.
20///
21/// The middleware maintains a shared todo list. Before each model call,
22/// it appends the current todo state to the system prompt, giving the
23/// model awareness of remaining tasks.
24pub struct TodoListMiddleware {
25    items: Arc<Mutex<Vec<TodoItem>>>,
26    next_id: Arc<Mutex<usize>>,
27}
28
29impl TodoListMiddleware {
30    pub fn new() -> Self {
31        Self {
32            items: Arc::new(Mutex::new(Vec::new())),
33            next_id: Arc::new(Mutex::new(1)),
34        }
35    }
36
37    /// Add a task to the todo list.
38    pub async fn add(&self, task: impl Into<String>) -> usize {
39        let mut id = self.next_id.lock().await;
40        let item_id = *id;
41        *id += 1;
42        drop(id);
43
44        let item = TodoItem {
45            id: item_id,
46            task: task.into(),
47            done: false,
48        };
49        self.items.lock().await.push(item);
50        item_id
51    }
52
53    /// Mark a task as done.
54    pub async fn complete(&self, id: usize) -> bool {
55        let mut items = self.items.lock().await;
56        if let Some(item) = items.iter_mut().find(|i| i.id == id) {
57            item.done = true;
58            true
59        } else {
60            false
61        }
62    }
63
64    /// Get all items.
65    pub async fn items(&self) -> Vec<TodoItem> {
66        self.items.lock().await.clone()
67    }
68
69    fn format_list(items: &[TodoItem]) -> String {
70        if items.is_empty() {
71            return "No tasks in the todo list.".to_string();
72        }
73        let mut s = String::from("Current TODO list:\n");
74        for item in items {
75            let mark = if item.done { "x" } else { " " };
76            s.push_str(&format!("  [{}] #{}: {}\n", mark, item.id, item.task));
77        }
78        s
79    }
80}
81
82impl Default for TodoListMiddleware {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88#[async_trait]
89impl AgentMiddleware for TodoListMiddleware {
90    async fn before_model(&self, request: &mut ModelRequest) -> Result<(), SynapticError> {
91        let items = self.items.lock().await;
92        if items.is_empty() {
93            return Ok(());
94        }
95        let list_text = Self::format_list(&items);
96        drop(items);
97
98        // Inject at the beginning of messages as a system message
99        request.messages.insert(0, Message::system(list_text));
100        Ok(())
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[tokio::test]
109    async fn add_and_complete() {
110        let mw = TodoListMiddleware::new();
111        let id1 = mw.add("Write tests").await;
112        let id2 = mw.add("Fix bug").await;
113        assert_eq!(id1, 1);
114        assert_eq!(id2, 2);
115
116        assert!(mw.complete(1).await);
117        let items = mw.items().await;
118        assert!(items[0].done);
119        assert!(!items[1].done);
120    }
121
122    #[tokio::test]
123    async fn format_empty() {
124        let text = TodoListMiddleware::format_list(&[]);
125        assert_eq!(text, "No tasks in the todo list.");
126    }
127}