Skip to main content

unified_agent_sdk/
session.rs

1//! Session handles, metadata, and event-stream construction.
2//!
3//! [`AgentSession`] is a lightweight value returned by [`crate::executor::AgentExecutor`].
4//! It stores stable metadata and exposes a raw-log to unified-event pipeline.
5//! The event pipeline automatically emits `SessionStarted` and `SessionCompleted`
6//! around the normalized provider events.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use futures::{Stream, StreamExt, stream};
11use std::collections::VecDeque;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use crate::{
17    error::Result,
18    event::{AgentEvent, EventStream, HookManager, converter},
19    log::LogNormalizer,
20    types::{ExecutorType, ExitStatus},
21};
22
23/// Raw byte stream emitted by an executor process or test fixture.
24///
25/// Each provider normalizer decides how to split and buffer these chunks.
26pub type RawLogStream = Pin<Box<dyn Stream<Item = Vec<u8>> + Send>>;
27
28/// Serializable metadata snapshot for persistence, resume bookkeeping, or diagnostics.
29#[derive(Debug, Clone)]
30pub struct SessionMetadata {
31    /// Executor session identifier.
32    pub session_id: String,
33    /// Executor backend type.
34    pub executor_type: ExecutorType,
35    /// Metadata creation timestamp.
36    pub created_at: DateTime<Utc>,
37    /// Last known source message id, if available.
38    pub last_message_id: Option<String>,
39    /// Session working directory.
40    pub working_dir: PathBuf,
41    /// Optional context window capacity override passed at session creation.
42    pub context_window_override_tokens: Option<u32>,
43}
44
45/// Resume descriptor used by higher-level orchestrators or persistence layers.
46#[derive(Debug, Clone)]
47pub struct SessionResume {
48    /// Existing session identifier.
49    pub session_id: String,
50    /// Optional message id used for rewind/reset semantics.
51    pub reset_to_message: Option<String>,
52}
53
54/// Active agent session handle.
55///
56/// This value is intentionally lightweight and provider-agnostic. It stores stable
57/// metadata and a lifecycle controller abstraction, but does not expose provider-
58/// specific client internals.
59pub struct AgentSession {
60    /// Executor session identifier.
61    pub session_id: String,
62    /// Executor backend type.
63    pub executor_type: ExecutorType,
64    /// Working directory used by this session.
65    pub working_dir: PathBuf,
66    /// Session creation timestamp captured when the session is established.
67    pub created_at: DateTime<Utc>,
68    /// Last known source message id, if available.
69    pub last_message_id: Option<String>,
70    /// Optional context window capacity override for context usage normalization.
71    pub context_window_override_tokens: Option<u32>,
72    lifecycle_controller: SessionControllerRef,
73}
74
75impl AgentSession {
76    /// Creates a detached session handle with default lifecycle behavior.
77    ///
78    /// Detached sessions treat `wait` as immediately successful and `cancel` as
79    /// a no-op.
80    pub fn new(
81        session_id: impl Into<String>,
82        executor_type: ExecutorType,
83        working_dir: impl Into<PathBuf>,
84        context_window_override_tokens: Option<u32>,
85    ) -> Self {
86        Self::from_parts(
87            SessionMetadata {
88                session_id: session_id.into(),
89                executor_type,
90                created_at: Utc::now(),
91                last_message_id: None,
92                working_dir: working_dir.into(),
93                context_window_override_tokens,
94            },
95            Arc::new(DetachedSessionLifecycleController),
96        )
97    }
98
99    /// Restores a detached session from persisted metadata.
100    pub fn from_metadata(metadata: SessionMetadata) -> Self {
101        Self::from_parts(metadata, Arc::new(DetachedSessionLifecycleController))
102    }
103
104    pub(crate) fn from_metadata_with_exit_status(
105        metadata: SessionMetadata,
106        exit_status: ExitStatus,
107    ) -> Self {
108        Self::from_parts(
109            metadata,
110            Arc::new(CompletedSessionLifecycleController { exit_status }),
111        )
112    }
113
114    fn from_parts(metadata: SessionMetadata, lifecycle_controller: SessionControllerRef) -> Self {
115        Self {
116            session_id: metadata.session_id,
117            executor_type: metadata.executor_type,
118            working_dir: metadata.working_dir,
119            created_at: metadata.created_at,
120            last_message_id: metadata.last_message_id,
121            context_window_override_tokens: metadata.context_window_override_tokens,
122            lifecycle_controller,
123        }
124    }
125
126    /// Build an event stream pipeline:
127    /// raw logs -> normalized logs -> unified events.
128    ///
129    /// Hooks are triggered for each emitted event when provided.
130    ///
131    /// # Examples
132    ///
133    /// ```rust
134    /// use futures::stream;
135    /// use std::path::PathBuf;
136    /// use unified_agent_sdk::{AgentSession, CodexLogNormalizer, ExecutorType, session::RawLogStream};
137    ///
138    /// let session = AgentSession::new("s1", ExecutorType::Codex, PathBuf::from("."), None);
139    ///
140    /// let raw_logs: RawLogStream = Box::pin(stream::iter(vec![
141    ///     br#"{"type":"item.completed","item":{"type":"agent_message","id":"m1","text":"hello"}}"#
142    ///         .to_vec(),
143    ///     b"\n".to_vec(),
144    /// ]));
145    ///
146    /// let _events = session.event_stream(raw_logs, Box::new(CodexLogNormalizer::new()), None);
147    /// ```
148    pub fn event_stream(
149        &self,
150        raw_logs: RawLogStream,
151        normalizer: Box<dyn LogNormalizer + Send>,
152        hooks: Option<Arc<HookManager>>,
153    ) -> EventStream {
154        let state = EventPipelineState {
155            session_id: self.session_id.clone(),
156            raw_logs,
157            normalizer,
158            hooks,
159            pending_events: VecDeque::new(),
160            emitted_started: false,
161            finished: false,
162            saw_error: false,
163            context_window_override_tokens: self.context_window_override_tokens,
164        };
165
166        let stream = stream::unfold(state, |mut state| async move {
167            loop {
168                if let Some(event) = state.pending_events.pop_front() {
169                    if let Some(hook_manager) = &state.hooks {
170                        hook_manager.trigger(&event).await;
171                    }
172                    return Some((event, state));
173                }
174
175                if !state.emitted_started {
176                    state.emitted_started = true;
177                    state.push_event(AgentEvent::SessionStarted {
178                        session_id: state.session_id.clone(),
179                    });
180                    continue;
181                }
182
183                if state.finished {
184                    return None;
185                }
186
187                match state.raw_logs.next().await {
188                    Some(chunk) => {
189                        let logs = state.normalizer.normalize(&chunk);
190                        state.push_logs(logs);
191                    }
192                    None => {
193                        let logs = state.normalizer.flush();
194                        state.push_logs(logs);
195                        state.push_event(AgentEvent::SessionCompleted {
196                            exit_status: ExitStatus {
197                                code: None,
198                                success: !state.saw_error,
199                            },
200                        });
201                        state.finished = true;
202                    }
203                }
204            }
205        });
206
207        EventStream::new(Box::pin(stream))
208    }
209
210    /// Returns an immutable metadata snapshot for persistence or logging.
211    pub fn metadata(&self) -> SessionMetadata {
212        SessionMetadata {
213            session_id: self.session_id.clone(),
214            executor_type: self.executor_type,
215            created_at: self.created_at,
216            last_message_id: self.last_message_id.clone(),
217            working_dir: self.working_dir.clone(),
218            context_window_override_tokens: self.context_window_override_tokens,
219        }
220    }
221
222    /// Waits for session completion and returns a summarized exit status.
223    ///
224    /// Detached sessions created from metadata return immediately because there is
225    /// no live provider process attached to them.
226    pub async fn wait(&mut self) -> Result<ExitStatus> {
227        self.lifecycle_controller.wait().await
228    }
229
230    /// Requests cancellation of the active session.
231    ///
232    /// Detached sessions treat cancellation as a no-op.
233    pub async fn cancel(&mut self) -> Result<()> {
234        self.lifecycle_controller.cancel().await
235    }
236}
237
238#[async_trait]
239pub(crate) trait SessionLifecycleController: Send + Sync {
240    async fn wait(&self) -> Result<ExitStatus>;
241    async fn cancel(&self) -> Result<()>;
242}
243
244struct DetachedSessionLifecycleController;
245
246#[async_trait]
247impl SessionLifecycleController for DetachedSessionLifecycleController {
248    async fn wait(&self) -> Result<ExitStatus> {
249        Ok(ExitStatus {
250            code: None,
251            success: true,
252        })
253    }
254
255    async fn cancel(&self) -> Result<()> {
256        Ok(())
257    }
258}
259
260struct CompletedSessionLifecycleController {
261    exit_status: ExitStatus,
262}
263
264#[async_trait]
265impl SessionLifecycleController for CompletedSessionLifecycleController {
266    async fn wait(&self) -> Result<ExitStatus> {
267        Ok(self.exit_status)
268    }
269
270    async fn cancel(&self) -> Result<()> {
271        Ok(())
272    }
273}
274
275type SessionControllerRef = Arc<dyn SessionLifecycleController>;
276
277struct EventPipelineState {
278    session_id: String,
279    raw_logs: RawLogStream,
280    normalizer: Box<dyn LogNormalizer + Send>,
281    hooks: Option<Arc<HookManager>>,
282    pending_events: VecDeque<AgentEvent>,
283    emitted_started: bool,
284    finished: bool,
285    saw_error: bool,
286    context_window_override_tokens: Option<u32>,
287}
288
289impl EventPipelineState {
290    fn push_logs(&mut self, logs: Vec<crate::log::NormalizedLog>) {
291        for log in logs {
292            for event in converter::from_normalized_log_with_context_override(
293                log,
294                self.context_window_override_tokens,
295            ) {
296                self.push_event(event);
297            }
298        }
299    }
300
301    fn push_event(&mut self, event: AgentEvent) {
302        if matches!(event, AgentEvent::ErrorOccurred { .. }) {
303            self.saw_error = true;
304        }
305        self.pending_events.push_back(event);
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use async_trait::async_trait;
313    use futures::{StreamExt, stream};
314    use serde_json::json;
315    use std::sync::atomic::{AtomicBool, Ordering};
316    use tokio::sync::Mutex;
317
318    use crate::{
319        event::EventType,
320        log::{ActionType, NormalizedLog},
321        types::{ContextUsageSource, Role, ToolStatus},
322    };
323
324    struct TestNormalizer;
325
326    impl LogNormalizer for TestNormalizer {
327        fn normalize(&mut self, chunk: &[u8]) -> Vec<NormalizedLog> {
328            match chunk {
329                b"message" => vec![NormalizedLog::Message {
330                    role: Role::Assistant,
331                    content: "hello".to_string(),
332                }],
333                b"tool-start" => vec![NormalizedLog::ToolCall {
334                    name: "bash".to_string(),
335                    args: json!({"cmd":"ls"}),
336                    status: ToolStatus::Started,
337                    action: ActionType::CommandRun {
338                        command: "ls".to_string(),
339                    },
340                }],
341                b"tool-done" => vec![NormalizedLog::ToolCall {
342                    name: "bash".to_string(),
343                    args: json!({"cmd":"ls"}),
344                    status: ToolStatus::Completed,
345                    action: ActionType::CommandRun {
346                        command: "ls".to_string(),
347                    },
348                }],
349                b"error" => vec![NormalizedLog::Error {
350                    error_type: "execution_failed".to_string(),
351                    message: "boom".to_string(),
352                }],
353                _ => Vec::new(),
354            }
355        }
356
357        fn flush(&mut self) -> Vec<NormalizedLog> {
358            vec![NormalizedLog::TokenUsage {
359                total: 10,
360                limit: 100,
361            }]
362        }
363    }
364
365    #[tokio::test]
366    async fn session_event_stream_builds_pipeline_and_triggers_hooks() {
367        let session = AgentSession::new("session-1", ExecutorType::Codex, PathBuf::from("."), None);
368
369        let received_messages = Arc::new(Mutex::new(Vec::<String>::new()));
370        let hooks = Arc::new(HookManager::new());
371        hooks.register(
372            EventType::MessageReceived,
373            Arc::new({
374                let received_messages = Arc::clone(&received_messages);
375                move |event| {
376                    let received_messages = Arc::clone(&received_messages);
377                    let content = match event {
378                        AgentEvent::MessageReceived { content, .. } => Some(content.clone()),
379                        _ => None,
380                    };
381                    Box::pin(async move {
382                        if let Some(content) = content {
383                            received_messages.lock().await.push(content);
384                        }
385                    })
386                }
387            }),
388        );
389
390        let raw_logs: RawLogStream = Box::pin(stream::iter(vec![
391            b"message".to_vec(),
392            b"tool-start".to_vec(),
393            b"tool-done".to_vec(),
394        ]));
395
396        let events = session
397            .event_stream(raw_logs, Box::new(TestNormalizer), Some(hooks))
398            .collect::<Vec<_>>()
399            .await;
400
401        assert!(matches!(
402            events.first(),
403            Some(AgentEvent::SessionStarted { session_id }) if session_id == "session-1"
404        ));
405        assert!(events
406            .iter()
407            .any(|event| matches!(event, AgentEvent::MessageReceived { content, .. } if content == "hello")));
408        assert!(events.iter().any(
409            |event| matches!(event, AgentEvent::ToolCallStarted { tool, .. } if tool == "bash")
410        ));
411        assert!(events.iter().any(
412            |event| matches!(event, AgentEvent::ToolCallCompleted { tool, .. } if tool == "bash")
413        ));
414        assert!(events.iter().any(|event| matches!(
415            event,
416            AgentEvent::ContextUsageUpdated { usage }
417                if usage.used_tokens == 10
418                    && usage.window_tokens == Some(100)
419                    && usage.remaining_tokens == Some(90)
420                    && usage.source == ContextUsageSource::ProviderReported
421        )));
422        assert!(matches!(
423            events.last(),
424            Some(AgentEvent::SessionCompleted { exit_status }) if exit_status.success
425        ));
426
427        let captured = received_messages.lock().await.clone();
428        assert_eq!(captured, vec!["hello".to_string()]);
429    }
430
431    #[tokio::test]
432    async fn session_event_stream_marks_completion_as_failed_when_errors_seen() {
433        let session = AgentSession::new(
434            "session-2",
435            ExecutorType::ClaudeCode,
436            PathBuf::from("."),
437            None,
438        );
439
440        let raw_logs: RawLogStream = Box::pin(stream::iter(vec![b"error".to_vec()]));
441        let events = session
442            .event_stream(raw_logs, Box::new(TestNormalizer), None)
443            .collect::<Vec<_>>()
444            .await;
445
446        assert!(events.iter().any(
447            |event| matches!(event, AgentEvent::ErrorOccurred { error } if error.contains("boom"))
448        ));
449        assert!(matches!(
450            events.last(),
451            Some(AgentEvent::SessionCompleted { exit_status }) if !exit_status.success
452        ));
453    }
454
455    struct UnknownLimitNormalizer;
456
457    impl LogNormalizer for UnknownLimitNormalizer {
458        fn normalize(&mut self, _chunk: &[u8]) -> Vec<NormalizedLog> {
459            Vec::new()
460        }
461
462        fn flush(&mut self) -> Vec<NormalizedLog> {
463            vec![NormalizedLog::TokenUsage {
464                total: 15,
465                limit: 0,
466            }]
467        }
468    }
469
470    #[tokio::test]
471    async fn session_event_stream_applies_context_window_override() {
472        let session = AgentSession::new(
473            "session-3",
474            ExecutorType::Codex,
475            PathBuf::from("."),
476            Some(60),
477        );
478
479        let raw_logs: RawLogStream = Box::pin(stream::iter(Vec::<Vec<u8>>::new()));
480        let events = session
481            .event_stream(raw_logs, Box::new(UnknownLimitNormalizer), None)
482            .collect::<Vec<_>>()
483            .await;
484
485        assert!(events.iter().any(|event| matches!(
486            event,
487            AgentEvent::ContextUsageUpdated { usage }
488                if usage.used_tokens == 15
489                    && usage.window_tokens == Some(60)
490                    && usage.remaining_tokens == Some(45)
491                    && usage.source == ContextUsageSource::ConfigOverride
492        )));
493    }
494
495    #[tokio::test]
496    async fn wait_defaults_to_completed_success_when_unmanaged() {
497        let mut session = AgentSession::new(
498            "session-unmanaged",
499            ExecutorType::Codex,
500            PathBuf::from("."),
501            None,
502        );
503
504        let exit_status = session.wait().await.expect("wait should succeed");
505        assert_eq!(
506            exit_status,
507            ExitStatus {
508                code: None,
509                success: true
510            }
511        );
512    }
513
514    #[tokio::test]
515    async fn wait_uses_session_lifecycle_controller() {
516        let mut session = AgentSession::from_metadata_with_exit_status(
517            SessionMetadata {
518                session_id: "session-managed".to_string(),
519                executor_type: ExecutorType::ClaudeCode,
520                created_at: Utc::now(),
521                last_message_id: None,
522                working_dir: PathBuf::from("."),
523                context_window_override_tokens: None,
524            },
525            ExitStatus {
526                code: Some(17),
527                success: false,
528            },
529        );
530
531        let first = session.wait().await.expect("wait should use controller");
532        assert_eq!(
533            first,
534            ExitStatus {
535                code: Some(17),
536                success: false
537            }
538        );
539
540        let second = session
541            .wait()
542            .await
543            .expect("second wait should remain stable");
544        assert_eq!(
545            second,
546            ExitStatus {
547                code: Some(17),
548                success: false
549            }
550        );
551    }
552
553    struct CancelProbeController {
554        cancelled: Arc<AtomicBool>,
555    }
556
557    #[async_trait]
558    impl SessionLifecycleController for CancelProbeController {
559        async fn wait(&self) -> Result<ExitStatus> {
560            Ok(ExitStatus {
561                code: None,
562                success: true,
563            })
564        }
565
566        async fn cancel(&self) -> Result<()> {
567            self.cancelled.store(true, Ordering::Relaxed);
568            Ok(())
569        }
570    }
571
572    #[tokio::test]
573    async fn cancel_delegates_to_registered_lifecycle_controller() {
574        let session_id = "session-cancel".to_string();
575        let cancelled = Arc::new(AtomicBool::new(false));
576
577        let mut session = AgentSession::from_parts(
578            SessionMetadata {
579                session_id,
580                executor_type: ExecutorType::ClaudeCode,
581                created_at: Utc::now(),
582                last_message_id: None,
583                working_dir: PathBuf::from("."),
584                context_window_override_tokens: None,
585            },
586            Arc::new(CancelProbeController {
587                cancelled: Arc::clone(&cancelled),
588            }),
589        );
590
591        session.cancel().await.expect("cancel should succeed");
592        assert!(cancelled.load(Ordering::Relaxed));
593    }
594}