1use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9pub struct AgentHandle<T> {
13 join: JoinHandle<SageResult<T>>,
14 message_tx: mpsc::Sender<Message>,
15}
16
17impl<T> AgentHandle<T> {
18 pub async fn result(self) -> SageResult<T> {
20 self.join.await?
21 }
22
23 pub async fn send<M>(&self, msg: M) -> SageResult<()>
27 where
28 M: serde::Serialize,
29 {
30 let message = Message::new(msg)?;
31 self.message_tx
32 .send(message)
33 .await
34 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct Message {
41 pub payload: serde_json::Value,
43}
44
45impl Message {
46 pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
48 Ok(Self {
49 payload: serde_json::to_value(value)?,
50 })
51 }
52}
53
54pub struct AgentContext<T> {
58 pub llm: LlmClient,
60 result_tx: Option<oneshot::Sender<T>>,
62 message_rx: mpsc::Receiver<Message>,
64}
65
66impl<T> AgentContext<T> {
67 fn new(
69 llm: LlmClient,
70 result_tx: oneshot::Sender<T>,
71 message_rx: mpsc::Receiver<Message>,
72 ) -> Self {
73 Self {
74 llm,
75 result_tx: Some(result_tx),
76 message_rx,
77 }
78 }
79
80 pub fn emit(mut self, value: T) -> SageResult<T>
84 where
85 T: Clone,
86 {
87 if let Some(tx) = self.result_tx.take() {
88 let _ = tx.send(value.clone());
90 }
91 Ok(value)
92 }
93
94 pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
96 where
97 R: serde::de::DeserializeOwned,
98 {
99 self.llm.infer(prompt).await
100 }
101
102 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
104 self.llm.infer_string(prompt).await
105 }
106
107 pub async fn receive<M>(&mut self) -> SageResult<M>
112 where
113 M: serde::de::DeserializeOwned,
114 {
115 let msg = self
116 .message_rx
117 .recv()
118 .await
119 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
120
121 serde_json::from_value(msg.payload)
122 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
123 }
124
125 pub async fn receive_timeout<M>(
129 &mut self,
130 timeout: std::time::Duration,
131 ) -> SageResult<Option<M>>
132 where
133 M: serde::de::DeserializeOwned,
134 {
135 match tokio::time::timeout(timeout, self.message_rx.recv()).await {
136 Ok(Some(msg)) => {
137 let value = serde_json::from_value(msg.payload)
138 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
139 Ok(Some(value))
140 }
141 Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
142 Err(_) => Ok(None), }
144 }
145}
146
147pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
151where
152 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
153 F: Future<Output = SageResult<T>> + Send,
154 T: Send + 'static,
155{
156 let (result_tx, result_rx) = oneshot::channel();
157 let (message_tx, message_rx) = mpsc::channel(32);
158
159 let llm = LlmClient::from_env();
160 let ctx = AgentContext::new(llm, result_tx, message_rx);
161
162 let join = tokio::spawn(async move { agent(ctx).await });
163
164 drop(result_rx);
167
168 AgentHandle { join, message_tx }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use serde::{Deserialize, Serialize};
175
176 #[tokio::test]
177 async fn spawn_simple_agent() {
178 let handle = spawn(|ctx: AgentContext<i64>| async move { ctx.emit(42) });
179
180 let result = handle.result().await.expect("agent should succeed");
181 assert_eq!(result, 42);
182 }
183
184 #[tokio::test]
185 async fn spawn_agent_with_computation() {
186 let handle = spawn(|ctx: AgentContext<i64>| async move {
187 let sum = (1..=10).sum();
188 ctx.emit(sum)
189 });
190
191 let result = handle.result().await.expect("agent should succeed");
192 assert_eq!(result, 55);
193 }
194
195 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
196 struct TaskMessage {
197 id: u32,
198 content: String,
199 }
200
201 #[tokio::test]
202 async fn agent_receives_message() {
203 let handle = spawn(|mut ctx: AgentContext<String>| async move {
204 let msg: TaskMessage = ctx.receive().await?;
205 ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
206 });
207
208 handle
209 .send(TaskMessage {
210 id: 42,
211 content: "Hello".to_string(),
212 })
213 .await
214 .expect("send should succeed");
215
216 let result = handle.result().await.expect("agent should succeed");
217 assert_eq!(result, "Got task 42: Hello");
218 }
219
220 #[tokio::test]
221 async fn agent_receives_multiple_messages() {
222 let handle = spawn(|mut ctx: AgentContext<i32>| async move {
223 let mut sum = 0;
224 for _ in 0..3 {
225 let n: i32 = ctx.receive().await?;
226 sum += n;
227 }
228 ctx.emit(sum)
229 });
230
231 for n in [10, 20, 30] {
232 handle.send(n).await.expect("send should succeed");
233 }
234
235 let result = handle.result().await.expect("agent should succeed");
236 assert_eq!(result, 60);
237 }
238
239 #[tokio::test]
240 async fn agent_receive_timeout() {
241 let handle = spawn(|mut ctx: AgentContext<String>| async move {
242 let result: Option<i32> = ctx
243 .receive_timeout(std::time::Duration::from_millis(10))
244 .await?;
245 match result {
246 Some(n) => ctx.emit(format!("Got {n}")),
247 None => ctx.emit("Timeout".to_string()),
248 }
249 });
250
251 let result = handle.result().await.expect("agent should succeed");
253 assert_eq!(result, "Timeout");
254 }
255}