Skip to main content

sage_runtime/
agent.rs

1//! Agent spawning and lifecycle management.
2
3use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9/// Handle to a spawned agent.
10///
11/// This is returned by `spawn()` and can be awaited to get the agent's result.
12pub struct AgentHandle<T> {
13    join: JoinHandle<SageResult<T>>,
14    #[allow(dead_code)]
15    message_tx: mpsc::Sender<Message>,
16}
17
18impl<T> AgentHandle<T> {
19    /// Wait for the agent to complete and return its result.
20    pub async fn result(self) -> SageResult<T> {
21        self.join.await?
22    }
23
24    /// Send a message to the agent.
25    #[allow(dead_code)]
26    pub async fn send(&self, msg: Message) -> SageResult<()> {
27        self.message_tx
28            .send(msg)
29            .await
30            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
31    }
32}
33
34/// A message that can be sent to an agent.
35#[derive(Debug, Clone)]
36pub struct Message {
37    /// The message payload as a JSON value.
38    pub payload: serde_json::Value,
39}
40
41impl Message {
42    /// Create a new message from a serializable value.
43    pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
44        Ok(Self {
45            payload: serde_json::to_value(value)?,
46        })
47    }
48}
49
50/// Context provided to agent handlers.
51///
52/// This gives agents access to LLM inference and the ability to emit results.
53pub struct AgentContext<T> {
54    /// LLM client for inference calls.
55    pub llm: LlmClient,
56    /// Channel to send the result to the awaiter.
57    result_tx: Option<oneshot::Sender<T>>,
58    /// Channel to receive messages.
59    #[allow(dead_code)]
60    message_rx: mpsc::Receiver<Message>,
61}
62
63impl<T> AgentContext<T> {
64    /// Create a new agent context.
65    fn new(
66        llm: LlmClient,
67        result_tx: oneshot::Sender<T>,
68        message_rx: mpsc::Receiver<Message>,
69    ) -> Self {
70        Self {
71            llm,
72            result_tx: Some(result_tx),
73            message_rx,
74        }
75    }
76
77    /// Emit a value to the awaiter.
78    ///
79    /// This should be called once at the end of the agent's execution.
80    pub fn emit(mut self, value: T) -> SageResult<T>
81    where
82        T: Clone,
83    {
84        if let Some(tx) = self.result_tx.take() {
85            // Ignore send errors - the receiver may have been dropped
86            let _ = tx.send(value.clone());
87        }
88        Ok(value)
89    }
90
91    /// Call the LLM with a prompt and parse the response.
92    pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
93    where
94        R: serde::de::DeserializeOwned,
95    {
96        self.llm.infer(prompt).await
97    }
98
99    /// Call the LLM with a prompt and return the raw string response.
100    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
101        self.llm.infer_string(prompt).await
102    }
103}
104
105/// Spawn an agent and return a handle to it.
106///
107/// The agent will run asynchronously in a separate task.
108pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
109where
110    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
111    F: Future<Output = SageResult<T>> + Send,
112    T: Send + 'static,
113{
114    let (result_tx, result_rx) = oneshot::channel();
115    let (message_tx, message_rx) = mpsc::channel(32);
116
117    let llm = LlmClient::from_env();
118    let ctx = AgentContext::new(llm, result_tx, message_rx);
119
120    let join = tokio::spawn(async move { agent(ctx).await });
121
122    // We need to handle the result_rx somewhere, but for now we just let
123    // the result come from the JoinHandle
124    drop(result_rx);
125
126    AgentHandle { join, message_tx }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[tokio::test]
134    async fn spawn_simple_agent() {
135        let handle = spawn(|ctx: AgentContext<i64>| async move { ctx.emit(42) });
136
137        let result = handle.result().await.expect("agent should succeed");
138        assert_eq!(result, 42);
139    }
140
141    #[tokio::test]
142    async fn spawn_agent_with_computation() {
143        let handle = spawn(|ctx: AgentContext<i64>| async move {
144            let sum = (1..=10).sum();
145            ctx.emit(sum)
146        });
147
148        let result = handle.result().await.expect("agent should succeed");
149        assert_eq!(result, 55);
150    }
151}