1use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use crate::session::{ProtocolViolation, SenderHandle, SessionId, SharedSessionRegistry};
6use std::future::Future;
7use tokio::sync::{mpsc, oneshot};
8use tokio::task::JoinHandle;
9
10pub struct AgentHandle<T> {
14 join: JoinHandle<SageResult<T>>,
15 message_tx: mpsc::Sender<Message>,
16}
17
18impl<T> AgentHandle<T> {
19 pub async fn result(self) -> SageResult<T> {
21 self.join.await?
22 }
23
24 pub async fn send<M>(&self, msg: M) -> SageResult<()>
28 where
29 M: serde::Serialize,
30 {
31 let message = Message::new(msg)?;
32 self.message_tx
33 .send(message)
34 .await
35 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
36 }
37
38 pub async fn send_message(&self, message: Message) -> SageResult<()> {
43 self.message_tx
44 .send(message)
45 .await
46 .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct Message {
53 pub payload: serde_json::Value,
55 pub session_id: Option<SessionId>,
57 pub sender: Option<SenderHandle>,
59 pub type_name: Option<String>,
61}
62
63impl Message {
64 pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
66 Ok(Self {
67 payload: serde_json::to_value(value)?,
68 session_id: None,
69 sender: None,
70 type_name: None,
71 })
72 }
73
74 pub fn with_session<T: serde::Serialize>(
76 value: T,
77 session_id: SessionId,
78 sender: SenderHandle,
79 type_name: impl Into<String>,
80 ) -> SageResult<Self> {
81 Ok(Self {
82 payload: serde_json::to_value(value)?,
83 session_id: Some(session_id),
84 sender: Some(sender),
85 type_name: Some(type_name.into()),
86 })
87 }
88
89 #[must_use]
91 pub fn with_type_name(mut self, type_name: impl Into<String>) -> Self {
92 self.type_name = Some(type_name.into());
93 self
94 }
95}
96
97pub struct AgentContext<T> {
101 pub llm: LlmClient,
103 result_tx: Option<oneshot::Sender<T>>,
105 message_rx: mpsc::Receiver<Message>,
107 emitted: bool,
109 current_message: Option<Message>,
111 session_registry: SharedSessionRegistry,
113 agent_role: Option<String>,
115}
116
117impl<T> AgentContext<T> {
118 fn new(
120 llm: LlmClient,
121 result_tx: oneshot::Sender<T>,
122 message_rx: mpsc::Receiver<Message>,
123 session_registry: SharedSessionRegistry,
124 ) -> Self {
125 Self {
126 llm,
127 result_tx: Some(result_tx),
128 message_rx,
129 emitted: false,
130 current_message: None,
131 session_registry,
132 agent_role: None,
133 }
134 }
135
136 pub fn set_role(&mut self, role: impl Into<String>) {
138 self.agent_role = Some(role.into());
139 }
140
141 #[must_use]
143 pub fn session_registry(&self) -> &SharedSessionRegistry {
144 &self.session_registry
145 }
146
147 pub fn emit(&mut self, value: T) -> SageResult<T>
152 where
153 T: Clone,
154 {
155 if self.emitted {
156 return Ok(value);
158 }
159 self.emitted = true;
160 if let Some(tx) = self.result_tx.take() {
161 let _ = tx.send(value.clone());
163 }
164 Ok(value)
165 }
166
167 pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
169 where
170 R: serde::de::DeserializeOwned,
171 {
172 self.llm.infer(prompt).await
173 }
174
175 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
177 self.llm.infer_string(prompt).await
178 }
179
180 pub async fn receive<M>(&mut self) -> SageResult<M>
185 where
186 M: serde::de::DeserializeOwned,
187 {
188 let msg = self
189 .message_rx
190 .recv()
191 .await
192 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
193
194 self.current_message = Some(msg.clone());
196
197 serde_json::from_value(msg.payload)
198 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
199 }
200
201 pub async fn receive_timeout<M>(
205 &mut self,
206 timeout: std::time::Duration,
207 ) -> SageResult<Option<M>>
208 where
209 M: serde::de::DeserializeOwned,
210 {
211 match tokio::time::timeout(timeout, self.message_rx.recv()).await {
212 Ok(Some(msg)) => {
213 self.current_message = Some(msg.clone());
215
216 let value = serde_json::from_value(msg.payload)
217 .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
218 Ok(Some(value))
219 }
220 Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
221 Err(_) => Ok(None), }
223 }
224
225 pub async fn receive_raw(&mut self) -> SageResult<Message> {
230 let msg = self
231 .message_rx
232 .recv()
233 .await
234 .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
235
236 self.current_message = Some(msg.clone());
238
239 Ok(msg)
240 }
241
242 pub fn set_current_message(&mut self, msg: Message) {
246 self.current_message = Some(msg);
247 }
248
249 pub fn clear_current_message(&mut self) {
251 self.current_message = None;
252 }
253
254 pub async fn reply<M: serde::Serialize>(&mut self, msg: M) -> SageResult<()> {
264 let current = self
265 .current_message
266 .as_ref()
267 .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
268
269 let sender = current
270 .sender
271 .as_ref()
272 .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
273
274 sender.send(msg).await
275 }
276
277 pub async fn reply_with_protocol<M: serde::Serialize>(
294 &mut self,
295 msg: M,
296 msg_type: &str,
297 role: &str,
298 ) -> SageResult<()> {
299 let current = self
300 .current_message
301 .as_ref()
302 .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
303
304 if let Some(session_id) = current.session_id {
306 let mut registry = self.session_registry.write().await;
307 if let Some(session) = registry.get_mut(&session_id) {
308 if !session.state.can_send(msg_type, role) {
310 return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
311 protocol: session.protocol.clone(),
312 expected: "valid reply".to_string(),
313 received: msg_type.to_string(),
314 state: session.state.state_name().to_string(),
315 }));
316 }
317 session.state.transition(msg_type)?;
319 }
320 }
321
322 let sender = current
323 .sender
324 .as_ref()
325 .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
326
327 sender.send(msg).await
328 }
329
330 pub async fn validate_protocol_receive(
343 &mut self,
344 msg_type: &str,
345 role: &str,
346 ) -> SageResult<()> {
347 let current = match &self.current_message {
348 Some(msg) => msg,
349 None => return Ok(()), };
351
352 if let Some(session_id) = current.session_id {
354 let mut registry = self.session_registry.write().await;
355 if let Some(session) = registry.get_mut(&session_id) {
356 if !session.state.can_receive(msg_type, role) {
358 return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
359 protocol: session.protocol.clone(),
360 expected: "valid message for current state".to_string(),
361 received: msg_type.to_string(),
362 state: session.state.state_name().to_string(),
363 }));
364 }
365 session.state.transition(msg_type)?;
367
368 if session.state.is_terminal() {
370 drop(registry);
371 self.session_registry.write().await.remove(&session_id);
372 }
373 }
374 }
375
376 Ok(())
377 }
378
379 pub async fn start_session(
392 &self,
393 protocol: String,
394 role: String,
395 state: Box<dyn crate::session::ProtocolStateMachine>,
396 partner: SenderHandle,
397 ) -> SessionId {
398 let mut registry = self.session_registry.write().await;
399 let session_id = registry.next_id();
400 registry.start_session(session_id, protocol, role, state, partner);
401 session_id
402 }
403
404 #[must_use]
406 pub fn current_message(&self) -> Option<&Message> {
407 self.current_message.as_ref()
408 }
409}
410
411pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
415where
416 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
417 F: Future<Output = SageResult<T>> + Send,
418 T: Send + 'static,
419{
420 spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
421}
422
423pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
427where
428 A: FnOnce(AgentContext<T>) -> F + Send + 'static,
429 F: Future<Output = SageResult<T>> + Send,
430 T: Send + 'static,
431{
432 let (result_tx, result_rx) = oneshot::channel();
433 let (message_tx, message_rx) = mpsc::channel(32);
434
435 let llm = LlmClient::new(llm_config);
436 let session_registry = crate::session::shared_registry();
437 let ctx = AgentContext::new(llm, result_tx, message_rx, session_registry);
438
439 let join = tokio::spawn(async move { agent(ctx).await });
440
441 drop(result_rx);
444
445 AgentHandle { join, message_tx }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use serde::{Deserialize, Serialize};
452
453 #[tokio::test]
454 async fn spawn_simple_agent() {
455 let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
456
457 let result = handle.result().await.expect("agent should succeed");
458 assert_eq!(result, 42);
459 }
460
461 #[tokio::test]
462 async fn spawn_agent_with_computation() {
463 let handle = spawn(|mut ctx: AgentContext<i64>| async move {
464 let sum = (1..=10).sum();
465 ctx.emit(sum)
466 });
467
468 let result = handle.result().await.expect("agent should succeed");
469 assert_eq!(result, 55);
470 }
471
472 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
473 struct TaskMessage {
474 id: u32,
475 content: String,
476 }
477
478 #[tokio::test]
479 async fn agent_receives_message() {
480 let handle = spawn(|mut ctx: AgentContext<String>| async move {
481 let msg: TaskMessage = ctx.receive().await?;
482 ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
483 });
484
485 handle
486 .send(TaskMessage {
487 id: 42,
488 content: "Hello".to_string(),
489 })
490 .await
491 .expect("send should succeed");
492
493 let result = handle.result().await.expect("agent should succeed");
494 assert_eq!(result, "Got task 42: Hello");
495 }
496
497 #[tokio::test]
498 async fn agent_receives_multiple_messages() {
499 let handle = spawn(|mut ctx: AgentContext<i32>| async move {
500 let mut sum = 0;
501 for _ in 0..3 {
502 let n: i32 = ctx.receive().await?;
503 sum += n;
504 }
505 ctx.emit(sum)
506 });
507
508 for n in [10, 20, 30] {
509 handle.send(n).await.expect("send should succeed");
510 }
511
512 let result = handle.result().await.expect("agent should succeed");
513 assert_eq!(result, 60);
514 }
515
516 #[tokio::test]
517 async fn agent_receive_timeout() {
518 let handle = spawn(|mut ctx: AgentContext<String>| async move {
519 let result: Option<i32> = ctx
520 .receive_timeout(std::time::Duration::from_millis(10))
521 .await?;
522 match result {
523 Some(n) => ctx.emit(format!("Got {n}")),
524 None => ctx.emit("Timeout".to_string()),
525 }
526 });
527
528 let result = handle.result().await.expect("agent should succeed");
530 assert_eq!(result, "Timeout");
531 }
532}