zeph_scheduler/
handlers.rs1use std::future::Future;
5use std::pin::Pin;
6
7use tokio::sync::mpsc;
8
9use crate::error::SchedulerError;
10use crate::sanitize::sanitize_task_prompt;
11use crate::task::TaskHandler;
12
13pub struct CustomTaskHandler {
46 tx: mpsc::Sender<String>,
47}
48
49impl CustomTaskHandler {
50 #[must_use]
52 pub fn new(tx: mpsc::Sender<String>) -> Self {
53 Self { tx }
54 }
55}
56
57impl TaskHandler for CustomTaskHandler {
58 fn execute(
59 &self,
60 config: &serde_json::Value,
61 ) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + Send + '_>> {
62 let raw = config
63 .get("task")
64 .and_then(|v| v.as_str())
65 .unwrap_or("Execute the following scheduled task now: check status");
66 let prompt = sanitize_task_prompt(raw);
67 let tx = self.tx.clone();
68 Box::pin(async move {
69 if tx.try_send(prompt).is_err() {
70 tracing::warn!("custom task handler: agent channel full or closed");
71 }
72 Ok(())
73 })
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[tokio::test]
82 async fn custom_handler_sends_task_prompt() {
83 let (tx, mut rx) = mpsc::channel(1);
84 let handler = CustomTaskHandler::new(tx);
85 let config = serde_json::json!({"task": "do something important"});
86 handler.execute(&config).await.unwrap();
87 let msg = rx.recv().await.unwrap();
88 assert_eq!(msg, "do something important");
89 }
90
91 #[tokio::test]
92 async fn custom_handler_uses_default_when_no_task_field() {
93 let (tx, mut rx) = mpsc::channel(1);
94 let handler = CustomTaskHandler::new(tx);
95 handler.execute(&serde_json::Value::Null).await.unwrap();
96 let msg = rx.recv().await.unwrap();
97 assert!(msg.contains("Execute the following scheduled task now:"));
98 }
99
100 #[tokio::test]
101 async fn custom_handler_ok_when_channel_full() {
102 let (tx, _rx) = mpsc::channel(1);
103 let _ = tx.try_send("fill".to_owned());
105 let handler = CustomTaskHandler::new(tx);
106 let config = serde_json::json!({"task": "overflow"});
107 let result = handler.execute(&config).await;
108 assert!(result.is_ok());
109 }
110
111 #[tokio::test]
112 async fn custom_handler_ok_when_channel_closed() {
113 let (tx, rx) = mpsc::channel(1);
114 drop(rx);
115 let handler = CustomTaskHandler::new(tx);
116 let config = serde_json::json!({"task": "closed"});
117 let result = handler.execute(&config).await;
118 assert!(result.is_ok());
119 }
120
121 #[tokio::test]
122 async fn custom_handler_strips_control_chars() {
123 let (tx, mut rx) = mpsc::channel(1);
124 let handler = CustomTaskHandler::new(tx);
125 let config = serde_json::json!({"task": "hello\x01\x00world"});
126 handler.execute(&config).await.unwrap();
127 let msg = rx.recv().await.unwrap();
128 assert_eq!(msg, "helloworld");
129 }
130
131 #[tokio::test]
132 async fn custom_handler_truncates_long_prompt() {
133 let (tx, mut rx) = mpsc::channel(1);
134 let handler = CustomTaskHandler::new(tx);
135 let long_task = "a".repeat(1000);
136 let config = serde_json::json!({"task": long_task});
137 handler.execute(&config).await.unwrap();
138 let msg = rx.recv().await.unwrap();
139 assert_eq!(msg.chars().count(), 512);
140 }
141}