1use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use std::future::Future;
6use tokio::sync::{mpsc, oneshot};
7use tokio::task::JoinHandle;
8
9pub struct AgentHandle<T> {
13 join: JoinHandle<SageResult<T>>,
14 #[allow(dead_code)]
15 message_tx: mpsc::Sender<Message>,
16}
17
18impl<T> AgentHandle<T> {
19 pub async fn result(self) -> SageResult<T> {
21 self.join.await?
22 }
23
24 #[allow(dead_code)]
26 pub async fn send(&self, msg: Message) -> SageResult<()> {
27 self.message_tx
28 .send(msg)
29 .await
30 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct Message {
37 pub payload: serde_json::Value,
39}
40
41impl Message {
42 pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
44 Ok(Self {
45 payload: serde_json::to_value(value)?,
46 })
47 }
48}
49
50pub struct AgentContext<T> {
54 pub llm: LlmClient,
56 result_tx: Option<oneshot::Sender<T>>,
58 #[allow(dead_code)]
60 message_rx: mpsc::Receiver<Message>,
61}
62
63impl<T> AgentContext<T> {
64 fn new(
66 llm: LlmClient,
67 result_tx: oneshot::Sender<T>,
68 message_rx: mpsc::Receiver<Message>,
69 ) -> Self {
70 Self {
71 llm,
72 result_tx: Some(result_tx),
73 message_rx,
74 }
75 }
76
77 pub fn emit(mut self, value: T) -> SageResult<T>
81 where
82 T: Clone,
83 {
84 if let Some(tx) = self.result_tx.take() {
85 let _ = tx.send(value.clone());
87 }
88 Ok(value)
89 }
90
91 pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
93 where
94 R: serde::de::DeserializeOwned,
95 {
96 self.llm.infer(prompt).await
97 }
98
99 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
101 self.llm.infer_string(prompt).await
102 }
103}
104
105pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
109where
110 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
111 F: Future<Output = SageResult<T>> + Send,
112 T: Send + 'static,
113{
114 let (result_tx, result_rx) = oneshot::channel();
115 let (message_tx, message_rx) = mpsc::channel(32);
116
117 let llm = LlmClient::from_env();
118 let ctx = AgentContext::new(llm, result_tx, message_rx);
119
120 let join = tokio::spawn(async move { agent(ctx).await });
121
122 drop(result_rx);
125
126 AgentHandle { join, message_tx }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[tokio::test]
134 async fn spawn_simple_agent() {
135 let handle = spawn(|ctx: AgentContext<i64>| async move { ctx.emit(42) });
136
137 let result = handle.result().await.expect("agent should succeed");
138 assert_eq!(result, 42);
139 }
140
141 #[tokio::test]
142 async fn spawn_agent_with_computation() {
143 let handle = spawn(|ctx: AgentContext<i64>| async move {
144 let sum = (1..=10).sum();
145 ctx.emit(sum)
146 });
147
148 let result = handle.result().await.expect("agent should succeed");
149 assert_eq!(result, 55);
150 }
151}