Skip to main content

sage_runtime/
agent.rs

1//! Agent spawning and lifecycle management.
2
3use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9/// Handle to a spawned agent.
10///
11/// This is returned by `spawn()` and can be awaited to get the agent's result.
12pub struct AgentHandle<T> {
13    join: JoinHandle<SageResult<T>>,
14    message_tx: mpsc::Sender<Message>,
15}
16
17impl<T> AgentHandle<T> {
18    /// Wait for the agent to complete and return its result.
19    pub async fn result(self) -> SageResult<T> {
20        self.join.await?
21    }
22
23    /// Send a message to the agent.
24    ///
25    /// The message will be serialized to JSON and placed in the agent's mailbox.
26    pub async fn send<M>(&self, msg: M) -> SageResult<()>
27    where
28        M: serde::Serialize,
29    {
30        let message = Message::new(msg)?;
31        self.message_tx
32            .send(message)
33            .await
34            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
35    }
36}
37
38/// A message that can be sent to an agent.
39#[derive(Debug, Clone)]
40pub struct Message {
41    /// The message payload as a JSON value.
42    pub payload: serde_json::Value,
43}
44
45impl Message {
46    /// Create a new message from a serializable value.
47    pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
48        Ok(Self {
49            payload: serde_json::to_value(value)?,
50        })
51    }
52}
53
54/// Context provided to agent handlers.
55///
56/// This gives agents access to LLM inference and the ability to emit results.
57pub struct AgentContext<T> {
58    /// LLM client for inference calls.
59    pub llm: LlmClient,
60    /// Channel to send the result to the awaiter.
61    result_tx: Option<oneshot::Sender<T>>,
62    /// Channel to receive messages from other agents.
63    message_rx: mpsc::Receiver<Message>,
64}
65
66impl<T> AgentContext<T> {
67    /// Create a new agent context.
68    fn new(
69        llm: LlmClient,
70        result_tx: oneshot::Sender<T>,
71        message_rx: mpsc::Receiver<Message>,
72    ) -> Self {
73        Self {
74            llm,
75            result_tx: Some(result_tx),
76            message_rx,
77        }
78    }
79
80    /// Emit a value to the awaiter.
81    ///
82    /// This should be called once at the end of the agent's execution.
83    pub fn emit(mut self, value: T) -> SageResult<T>
84    where
85        T: Clone,
86    {
87        if let Some(tx) = self.result_tx.take() {
88            // Ignore send errors - the receiver may have been dropped
89            let _ = tx.send(value.clone());
90        }
91        Ok(value)
92    }
93
94    /// Call the LLM with a prompt and parse the response.
95    pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
96    where
97        R: serde::de::DeserializeOwned,
98    {
99        self.llm.infer(prompt).await
100    }
101
102    /// Call the LLM with a prompt and return the raw string response.
103    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
104        self.llm.infer_string(prompt).await
105    }
106
107    /// Receive a message from the agent's mailbox.
108    ///
109    /// This blocks until a message is available. The message is deserialized
110    /// into the specified type.
111    pub async fn receive<M>(&mut self) -> SageResult<M>
112    where
113        M: serde::de::DeserializeOwned,
114    {
115        let msg = self
116            .message_rx
117            .recv()
118            .await
119            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
120
121        serde_json::from_value(msg.payload)
122            .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
123    }
124
125    /// Receive a message with a timeout.
126    ///
127    /// Returns `None` if the timeout expires before a message arrives.
128    pub async fn receive_timeout<M>(
129        &mut self,
130        timeout: std::time::Duration,
131    ) -> SageResult<Option<M>>
132    where
133        M: serde::de::DeserializeOwned,
134    {
135        match tokio::time::timeout(timeout, self.message_rx.recv()).await {
136            Ok(Some(msg)) => {
137                let value = serde_json::from_value(msg.payload)
138                    .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
139                Ok(Some(value))
140            }
141            Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
142            Err(_) => Ok(None), // Timeout
143        }
144    }
145}
146
147/// Spawn an agent and return a handle to it.
148///
149/// The agent will run asynchronously in a separate task.
150pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
151where
152    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
153    F: Future<Output = SageResult<T>> + Send,
154    T: Send + 'static,
155{
156    let (result_tx, result_rx) = oneshot::channel();
157    let (message_tx, message_rx) = mpsc::channel(32);
158
159    let llm = LlmClient::from_env();
160    let ctx = AgentContext::new(llm, result_tx, message_rx);
161
162    let join = tokio::spawn(async move { agent(ctx).await });
163
164    // We need to handle the result_rx somewhere, but for now we just let
165    // the result come from the JoinHandle
166    drop(result_rx);
167
168    AgentHandle { join, message_tx }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use serde::{Deserialize, Serialize};
175
176    #[tokio::test]
177    async fn spawn_simple_agent() {
178        let handle = spawn(|ctx: AgentContext<i64>| async move { ctx.emit(42) });
179
180        let result = handle.result().await.expect("agent should succeed");
181        assert_eq!(result, 42);
182    }
183
184    #[tokio::test]
185    async fn spawn_agent_with_computation() {
186        let handle = spawn(|ctx: AgentContext<i64>| async move {
187            let sum = (1..=10).sum();
188            ctx.emit(sum)
189        });
190
191        let result = handle.result().await.expect("agent should succeed");
192        assert_eq!(result, 55);
193    }
194
195    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
196    struct TaskMessage {
197        id: u32,
198        content: String,
199    }
200
201    #[tokio::test]
202    async fn agent_receives_message() {
203        let handle = spawn(|mut ctx: AgentContext<String>| async move {
204            let msg: TaskMessage = ctx.receive().await?;
205            ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
206        });
207
208        handle
209            .send(TaskMessage {
210                id: 42,
211                content: "Hello".to_string(),
212            })
213            .await
214            .expect("send should succeed");
215
216        let result = handle.result().await.expect("agent should succeed");
217        assert_eq!(result, "Got task 42: Hello");
218    }
219
220    #[tokio::test]
221    async fn agent_receives_multiple_messages() {
222        let handle = spawn(|mut ctx: AgentContext<i32>| async move {
223            let mut sum = 0;
224            for _ in 0..3 {
225                let n: i32 = ctx.receive().await?;
226                sum += n;
227            }
228            ctx.emit(sum)
229        });
230
231        for n in [10, 20, 30] {
232            handle.send(n).await.expect("send should succeed");
233        }
234
235        let result = handle.result().await.expect("agent should succeed");
236        assert_eq!(result, 60);
237    }
238
239    #[tokio::test]
240    async fn agent_receive_timeout() {
241        let handle = spawn(|mut ctx: AgentContext<String>| async move {
242            let result: Option<i32> = ctx
243                .receive_timeout(std::time::Duration::from_millis(10))
244                .await?;
245            match result {
246                Some(n) => ctx.emit(format!("Got {n}")),
247                None => ctx.emit("Timeout".to_string()),
248            }
249        });
250
251        // Don't send anything, let it timeout
252        let result = handle.result().await.expect("agent should succeed");
253        assert_eq!(result, "Timeout");
254    }
255}