Skip to main content

sage_runtime/
agent.rs

1//! Agent spawning and lifecycle management.
2
3use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9/// Handle to a spawned agent.
10///
11/// This is returned by `spawn()` and can be awaited to get the agent's result.
12pub struct AgentHandle<T> {
13    join: JoinHandle<SageResult<T>>,
14    message_tx: mpsc::Sender<Message>,
15}
16
17impl<T> AgentHandle<T> {
18    /// Wait for the agent to complete and return its result.
19    pub async fn result(self) -> SageResult<T> {
20        self.join.await?
21    }
22
23    /// Send a message to the agent.
24    ///
25    /// The message will be serialized to JSON and placed in the agent's mailbox.
26    pub async fn send<M>(&self, msg: M) -> SageResult<()>
27    where
28        M: serde::Serialize,
29    {
30        let message = Message::new(msg)?;
31        self.message_tx
32            .send(message)
33            .await
34            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
35    }
36}
37
38/// A message that can be sent to an agent.
39#[derive(Debug, Clone)]
40pub struct Message {
41    /// The message payload as a JSON value.
42    pub payload: serde_json::Value,
43}
44
45impl Message {
46    /// Create a new message from a serializable value.
47    pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
48        Ok(Self {
49            payload: serde_json::to_value(value)?,
50        })
51    }
52}
53
54/// Context provided to agent handlers.
55///
56/// This gives agents access to LLM inference and the ability to emit results.
57pub struct AgentContext<T> {
58    /// LLM client for inference calls.
59    pub llm: LlmClient,
60    /// Channel to send the result to the awaiter.
61    result_tx: Option<oneshot::Sender<T>>,
62    /// Channel to receive messages from other agents.
63    message_rx: mpsc::Receiver<Message>,
64    /// Whether emit has been called (prevents double-emit).
65    emitted: bool,
66}
67
68impl<T> AgentContext<T> {
69    /// Create a new agent context.
70    fn new(
71        llm: LlmClient,
72        result_tx: oneshot::Sender<T>,
73        message_rx: mpsc::Receiver<Message>,
74    ) -> Self {
75        Self {
76            llm,
77            result_tx: Some(result_tx),
78            message_rx,
79            emitted: false,
80        }
81    }
82
83    /// Emit a value to the awaiter.
84    ///
85    /// This should be called once at the end of the agent's execution.
86    /// Calling emit multiple times is a no-op after the first call.
87    pub fn emit(&mut self, value: T) -> SageResult<T>
88    where
89        T: Clone,
90    {
91        if self.emitted {
92            // Already emitted, just return the value
93            return Ok(value);
94        }
95        self.emitted = true;
96        if let Some(tx) = self.result_tx.take() {
97            // Ignore send errors - the receiver may have been dropped
98            let _ = tx.send(value.clone());
99        }
100        Ok(value)
101    }
102
103    /// Call the LLM with a prompt and parse the response.
104    pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
105    where
106        R: serde::de::DeserializeOwned,
107    {
108        self.llm.infer(prompt).await
109    }
110
111    /// Call the LLM with a prompt and return the raw string response.
112    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
113        self.llm.infer_string(prompt).await
114    }
115
116    /// Receive a message from the agent's mailbox.
117    ///
118    /// This blocks until a message is available. The message is deserialized
119    /// into the specified type.
120    pub async fn receive<M>(&mut self) -> SageResult<M>
121    where
122        M: serde::de::DeserializeOwned,
123    {
124        let msg = self
125            .message_rx
126            .recv()
127            .await
128            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
129
130        serde_json::from_value(msg.payload)
131            .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
132    }
133
134    /// Receive a message with a timeout.
135    ///
136    /// Returns `None` if the timeout expires before a message arrives.
137    pub async fn receive_timeout<M>(
138        &mut self,
139        timeout: std::time::Duration,
140    ) -> SageResult<Option<M>>
141    where
142        M: serde::de::DeserializeOwned,
143    {
144        match tokio::time::timeout(timeout, self.message_rx.recv()).await {
145            Ok(Some(msg)) => {
146                let value = serde_json::from_value(msg.payload)
147                    .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
148                Ok(Some(value))
149            }
150            Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
151            Err(_) => Ok(None), // Timeout
152        }
153    }
154}
155
156/// Spawn an agent and return a handle to it.
157///
158/// The agent will run asynchronously in a separate task.
159pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
160where
161    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
162    F: Future<Output = SageResult<T>> + Send,
163    T: Send + 'static,
164{
165    let (result_tx, result_rx) = oneshot::channel();
166    let (message_tx, message_rx) = mpsc::channel(32);
167
168    let llm = LlmClient::from_env();
169    let ctx = AgentContext::new(llm, result_tx, message_rx);
170
171    let join = tokio::spawn(async move { agent(ctx).await });
172
173    // We need to handle the result_rx somewhere, but for now we just let
174    // the result come from the JoinHandle
175    drop(result_rx);
176
177    AgentHandle { join, message_tx }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use serde::{Deserialize, Serialize};
184
185    #[tokio::test]
186    async fn spawn_simple_agent() {
187        let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
188
189        let result = handle.result().await.expect("agent should succeed");
190        assert_eq!(result, 42);
191    }
192
193    #[tokio::test]
194    async fn spawn_agent_with_computation() {
195        let handle = spawn(|mut ctx: AgentContext<i64>| async move {
196            let sum = (1..=10).sum();
197            ctx.emit(sum)
198        });
199
200        let result = handle.result().await.expect("agent should succeed");
201        assert_eq!(result, 55);
202    }
203
204    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
205    struct TaskMessage {
206        id: u32,
207        content: String,
208    }
209
210    #[tokio::test]
211    async fn agent_receives_message() {
212        let handle = spawn(|mut ctx: AgentContext<String>| async move {
213            let msg: TaskMessage = ctx.receive().await?;
214            ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
215        });
216
217        handle
218            .send(TaskMessage {
219                id: 42,
220                content: "Hello".to_string(),
221            })
222            .await
223            .expect("send should succeed");
224
225        let result = handle.result().await.expect("agent should succeed");
226        assert_eq!(result, "Got task 42: Hello");
227    }
228
229    #[tokio::test]
230    async fn agent_receives_multiple_messages() {
231        let handle = spawn(|mut ctx: AgentContext<i32>| async move {
232            let mut sum = 0;
233            for _ in 0..3 {
234                let n: i32 = ctx.receive().await?;
235                sum += n;
236            }
237            ctx.emit(sum)
238        });
239
240        for n in [10, 20, 30] {
241            handle.send(n).await.expect("send should succeed");
242        }
243
244        let result = handle.result().await.expect("agent should succeed");
245        assert_eq!(result, 60);
246    }
247
248    #[tokio::test]
249    async fn agent_receive_timeout() {
250        let handle = spawn(|mut ctx: AgentContext<String>| async move {
251            let result: Option<i32> = ctx
252                .receive_timeout(std::time::Duration::from_millis(10))
253                .await?;
254            match result {
255                Some(n) => ctx.emit(format!("Got {n}")),
256                None => ctx.emit("Timeout".to_string()),
257            }
258        });
259
260        // Don't send anything, let it timeout
261        let result = handle.result().await.expect("agent should succeed");
262        assert_eq!(result, "Timeout");
263    }
264}