Skip to main content

sage_runtime/
agent.rs

1//! Agent spawning and lifecycle management.
2
3use crate::error::{SageError, SageResult};
4use crate::llm::LlmClient;
5use crate::session::{ProtocolViolation, SenderHandle, SessionId, SharedSessionRegistry};
6use std::future::Future;
7use tokio::sync::{mpsc, oneshot};
8use tokio::task::JoinHandle;
9
10/// Handle to a spawned agent.
11///
12/// This is returned by `spawn()` and can be awaited to get the agent's result.
13pub struct AgentHandle<T> {
14    join: JoinHandle<SageResult<T>>,
15    message_tx: mpsc::Sender<Message>,
16}
17
18impl<T> AgentHandle<T> {
19    /// Wait for the agent to complete and return its result.
20    pub async fn result(self) -> SageResult<T> {
21        self.join.await?
22    }
23
24    /// Send a message to the agent.
25    ///
26    /// The message will be serialized to JSON and placed in the agent's mailbox.
27    pub async fn send<M>(&self, msg: M) -> SageResult<()>
28    where
29        M: serde::Serialize,
30    {
31        let message = Message::new(msg)?;
32        self.message_tx
33            .send(message)
34            .await
35            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
36    }
37
38    /// Send a pre-built message to the agent.
39    ///
40    /// This is used by generated code when the message needs additional metadata
41    /// (like type_name for protocol tracking).
42    pub async fn send_message(&self, message: Message) -> SageResult<()> {
43        self.message_tx
44            .send(message)
45            .await
46            .map_err(|e| SageError::Agent(format!("Failed to send message: {e}")))
47    }
48}
49
50/// A message that can be sent to an agent.
51#[derive(Debug, Clone)]
52pub struct Message {
53    /// The message payload as a JSON value.
54    pub payload: serde_json::Value,
55    /// Phase 3: Session ID for protocol tracking.
56    pub session_id: Option<SessionId>,
57    /// Phase 3: Handle for replying to this message.
58    pub sender: Option<SenderHandle>,
59    /// Phase 3: Type name for protocol validation.
60    pub type_name: Option<String>,
61}
62
63impl Message {
64    /// Create a new message from a serializable value.
65    pub fn new<T: serde::Serialize>(value: T) -> SageResult<Self> {
66        Ok(Self {
67            payload: serde_json::to_value(value)?,
68            session_id: None,
69            sender: None,
70            type_name: None,
71        })
72    }
73
74    /// Create a new message with session context.
75    pub fn with_session<T: serde::Serialize>(
76        value: T,
77        session_id: SessionId,
78        sender: SenderHandle,
79        type_name: impl Into<String>,
80    ) -> SageResult<Self> {
81        Ok(Self {
82            payload: serde_json::to_value(value)?,
83            session_id: Some(session_id),
84            sender: Some(sender),
85            type_name: Some(type_name.into()),
86        })
87    }
88
89    /// Set the type name for this message.
90    #[must_use]
91    pub fn with_type_name(mut self, type_name: impl Into<String>) -> Self {
92        self.type_name = Some(type_name.into());
93        self
94    }
95}
96
97/// Context provided to agent handlers.
98///
99/// This gives agents access to LLM inference and the ability to emit results.
100pub struct AgentContext<T> {
101    /// LLM client for inference calls.
102    pub llm: LlmClient,
103    /// Channel to send the result to the awaiter.
104    result_tx: Option<oneshot::Sender<T>>,
105    /// Channel to receive messages from other agents.
106    message_rx: mpsc::Receiver<Message>,
107    /// Whether emit has been called (prevents double-emit).
108    emitted: bool,
109    /// Phase 3: The current message being handled (for reply()).
110    current_message: Option<Message>,
111    /// Phase 3: Session registry for protocol tracking.
112    session_registry: SharedSessionRegistry,
113    /// Phase 3: The role this agent plays in protocols.
114    agent_role: Option<String>,
115}
116
117impl<T> AgentContext<T> {
118    /// Create a new agent context.
119    fn new(
120        llm: LlmClient,
121        result_tx: oneshot::Sender<T>,
122        message_rx: mpsc::Receiver<Message>,
123        session_registry: SharedSessionRegistry,
124    ) -> Self {
125        Self {
126            llm,
127            result_tx: Some(result_tx),
128            message_rx,
129            emitted: false,
130            current_message: None,
131            session_registry,
132            agent_role: None,
133        }
134    }
135
136    /// Set the role this agent plays in protocols.
137    pub fn set_role(&mut self, role: impl Into<String>) {
138        self.agent_role = Some(role.into());
139    }
140
141    /// Get the session registry.
142    #[must_use]
143    pub fn session_registry(&self) -> &SharedSessionRegistry {
144        &self.session_registry
145    }
146
147    /// Emit a value to the awaiter.
148    ///
149    /// This should be called once at the end of the agent's execution.
150    /// Calling emit multiple times is a no-op after the first call.
151    pub fn emit(&mut self, value: T) -> SageResult<T>
152    where
153        T: Clone,
154    {
155        if self.emitted {
156            // Already emitted, just return the value
157            return Ok(value);
158        }
159        self.emitted = true;
160        if let Some(tx) = self.result_tx.take() {
161            // Ignore send errors - the receiver may have been dropped
162            let _ = tx.send(value.clone());
163        }
164        Ok(value)
165    }
166
167    /// Call the LLM with a prompt and parse the response.
168    pub async fn infer<R>(&self, prompt: &str) -> SageResult<R>
169    where
170        R: serde::de::DeserializeOwned,
171    {
172        self.llm.infer(prompt).await
173    }
174
175    /// Call the LLM with a prompt and return the raw string response.
176    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
177        self.llm.infer_string(prompt).await
178    }
179
180    /// Receive a message from the agent's mailbox.
181    ///
182    /// This blocks until a message is available. The message is deserialized
183    /// into the specified type.
184    pub async fn receive<M>(&mut self) -> SageResult<M>
185    where
186        M: serde::de::DeserializeOwned,
187    {
188        let msg = self
189            .message_rx
190            .recv()
191            .await
192            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
193
194        // Phase 3: Store current message for reply()
195        self.current_message = Some(msg.clone());
196
197        serde_json::from_value(msg.payload)
198            .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))
199    }
200
201    /// Receive a message with a timeout.
202    ///
203    /// Returns `None` if the timeout expires before a message arrives.
204    pub async fn receive_timeout<M>(
205        &mut self,
206        timeout: std::time::Duration,
207    ) -> SageResult<Option<M>>
208    where
209        M: serde::de::DeserializeOwned,
210    {
211        match tokio::time::timeout(timeout, self.message_rx.recv()).await {
212            Ok(Some(msg)) => {
213                // Phase 3: Store current message for reply()
214                self.current_message = Some(msg.clone());
215
216                let value = serde_json::from_value(msg.payload)
217                    .map_err(|e| SageError::Agent(format!("Failed to deserialize message: {e}")))?;
218                Ok(Some(value))
219            }
220            Ok(None) => Err(SageError::Agent("Message channel closed".to_string())),
221            Err(_) => Ok(None), // Timeout
222        }
223    }
224
225    /// Receive the raw message from the agent's mailbox.
226    ///
227    /// This blocks until a message is available. Returns the full Message
228    /// including session context.
229    pub async fn receive_raw(&mut self) -> SageResult<Message> {
230        let msg = self
231            .message_rx
232            .recv()
233            .await
234            .ok_or_else(|| SageError::Agent("Message channel closed".to_string()))?;
235
236        // Store current message for reply()
237        self.current_message = Some(msg.clone());
238
239        Ok(msg)
240    }
241
242    /// Set the current message context (for use in message handlers).
243    ///
244    /// This is called by generated code when entering a message handler.
245    pub fn set_current_message(&mut self, msg: Message) {
246        self.current_message = Some(msg);
247    }
248
249    /// Clear the current message context (for use after message handlers).
250    pub fn clear_current_message(&mut self) {
251        self.current_message = None;
252    }
253
254    /// Phase 3: Reply to the current message.
255    ///
256    /// This sends a response back to the sender of the current message.
257    /// Can only be called inside a message handler.
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if called outside a message handler or if
262    /// the current message has no sender handle.
263    pub async fn reply<M: serde::Serialize>(&mut self, msg: M) -> SageResult<()> {
264        let current = self
265            .current_message
266            .as_ref()
267            .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
268
269        let sender = current
270            .sender
271            .as_ref()
272            .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
273
274        sender.send(msg).await
275    }
276
277    /// Phase 3: Reply to the current message with protocol state validation.
278    ///
279    /// This validates that the reply is allowed by the protocol state machine,
280    /// transitions the state, and then sends the reply.
281    ///
282    /// # Arguments
283    /// * `msg` - The message to send back
284    /// * `msg_type` - The type name of the message for protocol validation
285    /// * `role` - The role this agent plays in the protocol
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if:
290    /// - Called outside a message handler
291    /// - The current message has no sender handle
292    /// - The protocol state doesn't allow this reply
293    pub async fn reply_with_protocol<M: serde::Serialize>(
294        &mut self,
295        msg: M,
296        msg_type: &str,
297        role: &str,
298    ) -> SageResult<()> {
299        let current = self
300            .current_message
301            .as_ref()
302            .ok_or_else(|| SageError::from(ProtocolViolation::ReplyOutsideHandler))?;
303
304        // If message has a session, validate protocol state
305        if let Some(session_id) = current.session_id {
306            let mut registry = self.session_registry.write().await;
307            if let Some(session) = registry.get_mut(&session_id) {
308                // Validate that we can send this message type from our role
309                if !session.state.can_send(msg_type, role) {
310                    return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
311                        protocol: session.protocol.clone(),
312                        expected: "valid reply".to_string(),
313                        received: msg_type.to_string(),
314                        state: session.state.state_name().to_string(),
315                    }));
316                }
317                // Transition the state machine
318                session.state.transition(msg_type)?;
319            }
320        }
321
322        let sender = current
323            .sender
324            .as_ref()
325            .ok_or_else(|| SageError::Agent("Message has no sender handle".to_string()))?;
326
327        sender.send(msg).await
328    }
329
330    /// Phase 3: Validate incoming message against protocol state.
331    ///
332    /// Call this after receiving a message to validate it against the
333    /// protocol state machine and transition to the next state.
334    ///
335    /// # Arguments
336    /// * `msg_type` - The type name of the received message
337    /// * `role` - The role this agent plays in the protocol
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if the message violates the protocol.
342    pub async fn validate_protocol_receive(
343        &mut self,
344        msg_type: &str,
345        role: &str,
346    ) -> SageResult<()> {
347        let current = match &self.current_message {
348            Some(msg) => msg,
349            None => return Ok(()), // No current message, nothing to validate
350        };
351
352        // If message has a session, validate protocol state
353        if let Some(session_id) = current.session_id {
354            let mut registry = self.session_registry.write().await;
355            if let Some(session) = registry.get_mut(&session_id) {
356                // Validate that we can receive this message type in our role
357                if !session.state.can_receive(msg_type, role) {
358                    return Err(SageError::from(ProtocolViolation::UnexpectedMessage {
359                        protocol: session.protocol.clone(),
360                        expected: "valid message for current state".to_string(),
361                        received: msg_type.to_string(),
362                        state: session.state.state_name().to_string(),
363                    }));
364                }
365                // Transition the state machine
366                session.state.transition(msg_type)?;
367
368                // If protocol is complete, remove the session
369                if session.state.is_terminal() {
370                    drop(registry);
371                    self.session_registry.write().await.remove(&session_id);
372                }
373            }
374        }
375
376        Ok(())
377    }
378
379    /// Phase 3: Start a new protocol session.
380    ///
381    /// Call this when initiating a protocol exchange with another agent.
382    ///
383    /// # Arguments
384    /// * `protocol` - The protocol name
385    /// * `role` - The role this agent plays
386    /// * `state` - The initial state machine for this protocol
387    /// * `partner` - Handle to send messages to the partner
388    ///
389    /// # Returns
390    /// The session ID for tracking this protocol session.
391    pub async fn start_session(
392        &self,
393        protocol: String,
394        role: String,
395        state: Box<dyn crate::session::ProtocolStateMachine>,
396        partner: SenderHandle,
397    ) -> SessionId {
398        let mut registry = self.session_registry.write().await;
399        let session_id = registry.next_id();
400        registry.start_session(session_id, protocol, role, state, partner);
401        session_id
402    }
403
404    /// Get the current message being handled (if any).
405    #[must_use]
406    pub fn current_message(&self) -> Option<&Message> {
407        self.current_message.as_ref()
408    }
409}
410
411/// Spawn an agent and return a handle to it.
412///
413/// The agent will run asynchronously in a separate task.
414pub fn spawn<A, T, F>(agent: A) -> AgentHandle<T>
415where
416    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
417    F: Future<Output = SageResult<T>> + Send,
418    T: Send + 'static,
419{
420    spawn_with_llm_config(agent, crate::llm::LlmConfig::from_env())
421}
422
423/// Spawn an agent with a custom LLM configuration.
424///
425/// This is used by effect handlers to configure per-agent LLM settings.
426pub fn spawn_with_llm_config<A, T, F>(agent: A, llm_config: crate::llm::LlmConfig) -> AgentHandle<T>
427where
428    A: FnOnce(AgentContext<T>) -> F + Send + 'static,
429    F: Future<Output = SageResult<T>> + Send,
430    T: Send + 'static,
431{
432    let (result_tx, result_rx) = oneshot::channel();
433    let (message_tx, message_rx) = mpsc::channel(32);
434
435    let llm = LlmClient::new(llm_config);
436    let session_registry = crate::session::shared_registry();
437    let ctx = AgentContext::new(llm, result_tx, message_rx, session_registry);
438
439    let join = tokio::spawn(async move { agent(ctx).await });
440
441    // We need to handle the result_rx somewhere, but for now we just let
442    // the result come from the JoinHandle
443    drop(result_rx);
444
445    AgentHandle { join, message_tx }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use serde::{Deserialize, Serialize};
452
453    #[tokio::test]
454    async fn spawn_simple_agent() {
455        let handle = spawn(|mut ctx: AgentContext<i64>| async move { ctx.emit(42) });
456
457        let result = handle.result().await.expect("agent should succeed");
458        assert_eq!(result, 42);
459    }
460
461    #[tokio::test]
462    async fn spawn_agent_with_computation() {
463        let handle = spawn(|mut ctx: AgentContext<i64>| async move {
464            let sum = (1..=10).sum();
465            ctx.emit(sum)
466        });
467
468        let result = handle.result().await.expect("agent should succeed");
469        assert_eq!(result, 55);
470    }
471
472    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
473    struct TaskMessage {
474        id: u32,
475        content: String,
476    }
477
478    #[tokio::test]
479    async fn agent_receives_message() {
480        let handle = spawn(|mut ctx: AgentContext<String>| async move {
481            let msg: TaskMessage = ctx.receive().await?;
482            ctx.emit(format!("Got task {}: {}", msg.id, msg.content))
483        });
484
485        handle
486            .send(TaskMessage {
487                id: 42,
488                content: "Hello".to_string(),
489            })
490            .await
491            .expect("send should succeed");
492
493        let result = handle.result().await.expect("agent should succeed");
494        assert_eq!(result, "Got task 42: Hello");
495    }
496
497    #[tokio::test]
498    async fn agent_receives_multiple_messages() {
499        let handle = spawn(|mut ctx: AgentContext<i32>| async move {
500            let mut sum = 0;
501            for _ in 0..3 {
502                let n: i32 = ctx.receive().await?;
503                sum += n;
504            }
505            ctx.emit(sum)
506        });
507
508        for n in [10, 20, 30] {
509            handle.send(n).await.expect("send should succeed");
510        }
511
512        let result = handle.result().await.expect("agent should succeed");
513        assert_eq!(result, 60);
514    }
515
516    #[tokio::test]
517    async fn agent_receive_timeout() {
518        let handle = spawn(|mut ctx: AgentContext<String>| async move {
519            let result: Option<i32> = ctx
520                .receive_timeout(std::time::Duration::from_millis(10))
521                .await?;
522            match result {
523                Some(n) => ctx.emit(format!("Got {n}")),
524                None => ctx.emit("Timeout".to_string()),
525            }
526        });
527
528        // Don't send anything, let it timeout
529        let result = handle.result().await.expect("agent should succeed");
530        assert_eq!(result, "Timeout");
531    }
532}