Skip to main content

stakpak_server/
session_actor.rs

1use crate::{
2    context::{ContextFile, EnvironmentContext, ProjectContext, SessionContextBuilder},
3    message_bridge,
4    sandbox::{SandboxConfig, SandboxMode, SandboxedMcpServer},
5    state::AppState,
6    types::{RunConfig, SessionHandle},
7};
8use async_trait::async_trait;
9use rmcp::model::{
10    CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
11    CancelledNotificationParam, ServerResult,
12};
13use serde_json::json;
14use stakai::{ContentPart, Message, MessageContent, Role};
15use stakpak_agent_core::{
16    AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, BudgetAwareContextReducer,
17    CheckpointEnvelopeV1, CompactionConfig, PassthroughCompactionEngine, ProposedToolCall,
18    RetryConfig, ToolExecutionResult, ToolExecutor, run_agent,
19};
20use stakpak_api::CreateCheckpointRequest;
21use stakpak_mcp_client::McpClient;
22use stakpak_shared::utils::sanitize_text_output;
23use std::{path::Path, sync::Arc};
24use tokio::sync::{Mutex, mpsc};
25use tokio_util::sync::CancellationToken;
26use uuid::Uuid;
27
28const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
29pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
30
31pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
32    AgentRunContext { run_id, session_id }
33}
34
35pub fn build_checkpoint_envelope(
36    run_id: Uuid,
37    messages: Vec<stakai::Message>,
38    metadata: serde_json::Value,
39) -> CheckpointEnvelopeV1 {
40    CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
41}
42
43pub fn spawn_session_actor(
44    state: AppState,
45    session_id: Uuid,
46    run_id: Uuid,
47    run_config: RunConfig,
48    user_message: Message,
49    caller_context: Vec<ContextFile>,
50    sandbox_config: Option<SandboxConfig>,
51) -> Result<SessionHandle, String> {
52    let (command_tx, command_rx) = mpsc::channel(128);
53    let cancel = CancellationToken::new();
54
55    let handle = SessionHandle::new(command_tx, cancel.clone());
56
57    let state_for_task = state.clone();
58    tokio::spawn(async move {
59        let actor_result = run_session_actor(
60            state_for_task.clone(),
61            session_id,
62            run_id,
63            run_config,
64            user_message,
65            caller_context,
66            command_rx,
67            cancel,
68            sandbox_config,
69        )
70        .await;
71
72        let finish_result = actor_result.map(|_| ());
73        let _ = state_for_task
74            .run_manager
75            .mark_run_finished(session_id, run_id, finish_result)
76            .await;
77    });
78
79    Ok(handle)
80}
81
82#[allow(clippy::too_many_arguments)]
83async fn run_session_actor(
84    state: AppState,
85    session_id: Uuid,
86    run_id: Uuid,
87    run_config: RunConfig,
88    mut user_message: Message,
89    caller_context: Vec<ContextFile>,
90    command_rx: mpsc::Receiver<AgentCommand>,
91    cancel: CancellationToken,
92    sandbox_config: Option<SandboxConfig>,
93) -> Result<(), String> {
94    let active_checkpoint = state
95        .session_store
96        .get_active_checkpoint(session_id)
97        .await
98        .ok();
99    let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
100
101    let (initial_messages, mut initial_metadata) =
102        match state.checkpoint_store.load_latest(session_id).await {
103            Ok(Some(envelope)) => (envelope.messages, envelope.metadata),
104            Ok(None) => {
105                let messages = active_checkpoint
106                    .as_ref()
107                    .map(|checkpoint| {
108                        message_bridge::chat_to_stakai(checkpoint.state.messages.clone())
109                    })
110                    .unwrap_or_default();
111                let metadata = active_checkpoint
112                    .as_ref()
113                    .and_then(|checkpoint| checkpoint.state.metadata.clone())
114                    .unwrap_or_else(|| json!({}));
115                (messages, metadata)
116            }
117            Err(error) => {
118                return Err(format!("Failed to load checkpoint envelope: {error}"));
119            }
120        };
121
122    // If sandbox is requested, determine how to provide it based on the configured mode:
123    // - Persistent: reuse the pre-spawned sandbox from AppState (no per-session overhead)
124    // - Ephemeral: spawn a new sandbox container for this session
125    //
126    // `ephemeral_sandbox` holds the owned sandbox for ephemeral mode so we can
127    // shut it down at the end. Persistent sandboxes are not owned by the session.
128    let mut ephemeral_sandbox: Option<SandboxedMcpServer> = None;
129
130    let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
131        if let Some(ref sandbox_cfg) = sandbox_config {
132            if let Some(ref persistent) = state.persistent_sandbox {
133                // Persistent mode: reuse the pre-spawned sandbox
134                tracing::info!(session_id = %session_id, "Using persistent sandbox for session");
135                (
136                    persistent.tools().await,
137                    Box::new(SandboxedToolExecutor {
138                        mcp_client: persistent.client().await,
139                    }),
140                )
141            } else if sandbox_cfg.mode == SandboxMode::Persistent {
142                // Persistent mode was configured but the sandbox is not available.
143                // This should not happen because the server hard-fails on startup
144                // if the persistent sandbox cannot be spawned. Fail explicitly rather
145                // than silently falling back to ephemeral mode.
146                return Err(format!(
147                    "Sandbox mode is 'persistent' but no persistent sandbox is available for session {session_id}. \
148                     This indicates the server started without a healthy sandbox. Restart the autopilot to fix."
149                ));
150            } else {
151                // Ephemeral mode: spawn a new sandbox for this session
152                tracing::info!(session_id = %session_id, image = %sandbox_cfg.image, "Spawning ephemeral sandbox container for session");
153                let sandbox = SandboxedMcpServer::spawn(sandbox_cfg).await.map_err(|e| {
154                    format!("Failed to start sandbox for session {session_id}: {e}")
155                })?;
156                let tools = sandbox.tools.clone();
157                let client = sandbox.client.clone();
158                ephemeral_sandbox = Some(sandbox);
159                (
160                    tools,
161                    Box::new(SandboxedToolExecutor { mcp_client: client }),
162                )
163            }
164        } else {
165            (
166                state.current_mcp_tools().await,
167                Box::new(ServerToolExecutor {
168                    state: state.clone(),
169                }),
170            )
171        };
172
173    let is_new_session = is_new_session_history(&initial_messages);
174    let session_cwd = resolve_session_cwd(&state, session_id).await;
175    let environment = EnvironmentContext::snapshot(&session_cwd).await;
176
177    // Combine caller context with pre-loaded remote skills context from AppState.
178    // Explicit caller context should force per-turn injection, even on resumed
179    // sessions, while startup remote skills remain baseline context.
180    let has_runtime_caller_context = !caller_context.is_empty();
181    let mut all_caller_context = caller_context;
182    all_caller_context.extend(state.current_skills().await);
183
184    let project =
185        ProjectContext::discover(Path::new(&session_cwd)).with_caller_context(all_caller_context);
186
187    let session_context = SessionContextBuilder::new()
188        .base_system_prompt(
189            run_config
190                .system_prompt
191                .clone()
192                .or_else(|| state.base_system_prompt.clone())
193                .unwrap_or_default(),
194        )
195        .environment(environment)
196        .project(project)
197        .tools(&run_tools)
198        .budget(state.context_budget.clone())
199        .build();
200
201    if (is_new_session || has_runtime_caller_context)
202        && let Some(context_block) = session_context.user_context_block.as_deref()
203    {
204        user_message = prepend_context_to_user_message(user_message, context_block);
205    }
206
207    let mut baseline_messages = initial_messages.clone();
208    baseline_messages.push(user_message.clone());
209
210    let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
211        state.clone(),
212        session_id,
213        run_id,
214        run_config.model.clone(),
215        parent_checkpoint_id,
216        baseline_messages,
217        initial_metadata.clone(),
218    ));
219
220    checkpoint_runtime
221        .persist_snapshot()
222        .await
223        .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
224
225    let periodic_checkpoint_cancel = CancellationToken::new();
226    let periodic_checkpoint_runtime = checkpoint_runtime.clone();
227    let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
228    let periodic_task = tokio::spawn(async move {
229        let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
230        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
231
232        loop {
233            tokio::select! {
234                _ = periodic_checkpoint_cancel_task.cancelled() => break,
235                _ = interval.tick() => {
236                    let _ = periodic_checkpoint_runtime.persist_snapshot().await;
237                }
238            }
239        }
240    });
241
242    let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
243
244    let event_state = state.clone();
245    let event_forwarder = tokio::spawn(async move {
246        while let Some(event) = core_event_rx.recv().await {
247            handle_core_event(&event_state, session_id, run_id, event).await;
248        }
249    });
250
251    // Use the model's maximum output capacity as the output budget for context
252    // window calculations. This is conservative — the actual response may be shorter,
253    // but reserving the full limit avoids mid-response context truncation.
254    let max_output_tokens = run_config.model.limit.output as u32;
255    let agent_config = AgentConfig {
256        model: run_config.model.clone(),
257        system_prompt: session_context.system_prompt,
258        max_turns: run_config.max_turns,
259        max_output_tokens,
260        provider_options: None,
261        tool_approval: run_config.tool_approval_policy.clone(),
262        retry: RetryConfig::default(),
263        compaction: CompactionConfig::default(),
264        tools: run_tools,
265    };
266
267    let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
268        checkpoint_runtime: checkpoint_runtime.clone(),
269    })];
270
271    let compactor = PassthroughCompactionEngine;
272    let context_reducer = BudgetAwareContextReducer::new(5, 0.8);
273    let run_context = build_run_context(session_id, run_id);
274
275    let run_result = run_agent(
276        run_context,
277        run_config.inference.as_ref(),
278        &agent_config,
279        initial_messages,
280        &mut initial_metadata,
281        user_message,
282        tool_executor.as_ref(),
283        &hooks,
284        core_event_tx,
285        command_rx,
286        cancel,
287        &compactor,
288        &context_reducer,
289    )
290    .await;
291
292    periodic_checkpoint_cancel.cancel();
293    let _ = periodic_task.await;
294
295    // Shut down ephemeral sandbox container if one was started.
296    // Persistent sandboxes are NOT shut down here — they live for the process lifetime.
297    if let Some(sandbox) = ephemeral_sandbox {
298        sandbox.shutdown().await;
299    }
300
301    state.clear_pending_tools(session_id, run_id).await;
302
303    match &run_result {
304        Ok(result) => {
305            checkpoint_runtime.update_messages(&result.messages).await;
306            checkpoint_runtime.update_metadata(&result.metadata).await;
307            checkpoint_runtime
308                .persist_snapshot()
309                .await
310                .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
311        }
312        Err(_) => {
313            checkpoint_runtime.update_metadata(&initial_metadata).await;
314            let _ = checkpoint_runtime.persist_snapshot().await;
315        }
316    }
317
318    let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
319
320    run_result
321        .map(|_| ())
322        .map_err(|error| format!("Agent run failed: {error}"))
323}
324
325fn is_new_session_history(messages: &[Message]) -> bool {
326    !messages
327        .iter()
328        .any(|message| matches!(message.role, Role::User | Role::Assistant | Role::Tool))
329}
330
331async fn resolve_session_cwd(state: &AppState, session_id: Uuid) -> String {
332    // 1. Session-specific cwd (set by API caller)
333    if let Ok(session) = state.session_store.get_session(session_id).await
334        && let Some(cwd) = session.cwd
335        && !cwd.trim().is_empty()
336    {
337        return cwd;
338    }
339
340    // 2. Configured project directory (set at server startup, e.g. from `stakpak up`)
341    if let Some(project_dir) = &state.project_dir {
342        return project_dir.clone();
343    }
344
345    // 3. Process working directory
346    std::env::current_dir()
347        .ok()
348        .map(|path| path.to_string_lossy().to_string())
349        .unwrap_or_else(|| ".".to_string())
350}
351
352fn prepend_context_to_user_message(mut message: Message, context_block: &str) -> Message {
353    if context_block.trim().is_empty() {
354        return message;
355    }
356
357    match &mut message.content {
358        MessageContent::Text(text) => {
359            let existing = std::mem::take(text);
360            *text = if existing.trim().is_empty() {
361                context_block.to_string()
362            } else {
363                format!("{context_block}\n\n{existing}")
364            };
365        }
366        MessageContent::Parts(parts) => {
367            let mut prefixed = Vec::with_capacity(parts.len() + 1);
368            prefixed.push(ContentPart::text(context_block));
369            prefixed.append(parts);
370            *parts = prefixed;
371        }
372    }
373
374    message
375}
376
377async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
378    match &event {
379        AgentEvent::ToolCallsProposed { tool_calls, .. } => {
380            state
381                .set_pending_tools(session_id, run_id, tool_calls.clone())
382                .await;
383        }
384        AgentEvent::TurnCompleted { .. }
385        | AgentEvent::RunCompleted { .. }
386        | AgentEvent::RunError { .. } => {
387            state.clear_pending_tools(session_id, run_id).await;
388        }
389        _ => {}
390    }
391
392    state.events.publish(session_id, Some(run_id), event).await;
393}
394
395#[derive(Clone)]
396struct ServerToolExecutor {
397    state: AppState,
398}
399
400#[async_trait]
401impl ToolExecutor for ServerToolExecutor {
402    async fn execute_tool_call(
403        &self,
404        run: &AgentRunContext,
405        tool_call: &ProposedToolCall,
406        cancel: &CancellationToken,
407    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
408        Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
409    }
410}
411
412/// Tool executor that routes calls through a per-session sandboxed MCP client.
413#[derive(Clone)]
414struct SandboxedToolExecutor {
415    mcp_client: Arc<McpClient>,
416}
417
418#[async_trait]
419impl ToolExecutor for SandboxedToolExecutor {
420    async fn execute_tool_call(
421        &self,
422        run: &AgentRunContext,
423        tool_call: &ProposedToolCall,
424        cancel: &CancellationToken,
425    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
426        Ok(execute_mcp_tool_call_with_client(
427            &self.mcp_client,
428            run.session_id,
429            run.run_id,
430            tool_call,
431            cancel,
432        )
433        .await)
434    }
435}
436
437struct CheckpointRuntime {
438    state: AppState,
439    session_id: Uuid,
440    run_id: Uuid,
441    active_model: stakai::Model,
442    inner: Mutex<CheckpointRuntimeInner>,
443}
444
445struct CheckpointRuntimeInner {
446    parent_checkpoint_id: Option<Uuid>,
447    latest_messages: Vec<Message>,
448    latest_metadata: serde_json::Value,
449    last_persisted_signature: Option<String>,
450    dirty: bool,
451}
452
453impl CheckpointRuntime {
454    fn new(
455        state: AppState,
456        session_id: Uuid,
457        run_id: Uuid,
458        active_model: stakai::Model,
459        parent_checkpoint_id: Option<Uuid>,
460        latest_messages: Vec<Message>,
461        latest_metadata: serde_json::Value,
462    ) -> Self {
463        Self {
464            state,
465            session_id,
466            run_id,
467            active_model,
468            inner: Mutex::new(CheckpointRuntimeInner {
469                parent_checkpoint_id,
470                latest_messages,
471                latest_metadata,
472                last_persisted_signature: None,
473                dirty: true,
474            }),
475        }
476    }
477
478    async fn update_messages(&self, messages: &[Message]) {
479        let mut guard = self.inner.lock().await;
480        guard.latest_messages = messages.to_vec();
481        guard.dirty = true;
482    }
483
484    async fn update_metadata(&self, metadata: &serde_json::Value) {
485        let mut guard = self.inner.lock().await;
486        guard.latest_metadata = metadata.clone();
487        guard.dirty = true;
488    }
489
490    async fn persist_snapshot(&self) -> Result<Uuid, String> {
491        let mut guard = self.inner.lock().await;
492        self.persist_if_needed(&mut guard).await
493    }
494
495    async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
496        if !guard.dirty
497            && let Some(checkpoint_id) = guard.parent_checkpoint_id
498        {
499            return Ok(checkpoint_id);
500        }
501
502        let signature = checkpoint_signature(&guard.latest_messages, &guard.latest_metadata)?;
503        let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
504        let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
505
506        if !should_persist {
507            guard.dirty = false;
508            if let Some(checkpoint_id) = guard.parent_checkpoint_id {
509                return Ok(checkpoint_id);
510            }
511        }
512
513        let checkpoint_id = persist_checkpoint(
514            &self.state,
515            self.session_id,
516            self.run_id,
517            &self.active_model,
518            guard.parent_checkpoint_id,
519            &guard.latest_messages,
520            &guard.latest_metadata,
521        )
522        .await?;
523
524        guard.parent_checkpoint_id = Some(checkpoint_id);
525        guard.last_persisted_signature = Some(signature);
526        guard.dirty = false;
527
528        Ok(checkpoint_id)
529    }
530}
531
532struct ServerCheckpointHook {
533    checkpoint_runtime: Arc<CheckpointRuntime>,
534}
535
536#[async_trait]
537impl AgentHook for ServerCheckpointHook {
538    async fn before_inference(
539        &self,
540        _run: &AgentRunContext,
541        messages: &[Message],
542        _model: &stakai::Model,
543    ) -> Result<(), stakpak_agent_core::AgentError> {
544        self.checkpoint_runtime.update_messages(messages).await;
545        Ok(())
546    }
547
548    async fn after_inference(
549        &self,
550        _run: &AgentRunContext,
551        messages: &[Message],
552        _model: &stakai::Model,
553    ) -> Result<(), stakpak_agent_core::AgentError> {
554        self.checkpoint_runtime.update_messages(messages).await;
555        Ok(())
556    }
557
558    async fn after_tool_execution(
559        &self,
560        _run: &AgentRunContext,
561        _tool_call: &ProposedToolCall,
562        messages: &[Message],
563    ) -> Result<(), stakpak_agent_core::AgentError> {
564        self.checkpoint_runtime.update_messages(messages).await;
565        Ok(())
566    }
567
568    async fn on_error(
569        &self,
570        _run: &AgentRunContext,
571        _error: &stakpak_agent_core::AgentError,
572        messages: &[Message],
573    ) -> Result<(), stakpak_agent_core::AgentError> {
574        self.checkpoint_runtime.update_messages(messages).await;
575        let _ = self.checkpoint_runtime.persist_snapshot().await;
576        Ok(())
577    }
578}
579
580async fn execute_mcp_tool_call(
581    state: &AppState,
582    session_id: Uuid,
583    run_id: Uuid,
584    tool_call: &ProposedToolCall,
585    cancel: &CancellationToken,
586) -> ToolExecutionResult {
587    let Some(mcp_client) = state.mcp_client.as_ref() else {
588        return ToolExecutionResult::Completed {
589            result: "MCP client is not initialized".to_string(),
590            is_error: true,
591        };
592    };
593
594    execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
595}
596
597async fn execute_mcp_tool_call_with_client(
598    mcp_client: &McpClient,
599    session_id: Uuid,
600    run_id: Uuid,
601    tool_call: &ProposedToolCall,
602    cancel: &CancellationToken,
603) -> ToolExecutionResult {
604    let metadata = Some(serde_json::Map::from_iter([
605        (
606            "session_id".to_string(),
607            serde_json::Value::String(session_id.to_string()),
608        ),
609        (
610            "run_id".to_string(),
611            serde_json::Value::String(run_id.to_string()),
612        ),
613        (
614            "tool_call_id".to_string(),
615            serde_json::Value::String(tool_call.id.clone()),
616        ),
617    ]));
618
619    let arguments = match &tool_call.arguments {
620        serde_json::Value::Object(map) => Some(map.clone()),
621        serde_json::Value::Null => None,
622        other => Some(serde_json::Map::from_iter([(
623            "input".to_string(),
624            other.clone(),
625        )])),
626    };
627
628    let request_handle = match stakpak_mcp_client::call_tool(
629        mcp_client,
630        CallToolRequestParam {
631            name: tool_call.name.clone().into(),
632            arguments,
633        },
634        metadata,
635    )
636    .await
637    {
638        Ok(handle) => handle,
639        Err(error) => {
640            return ToolExecutionResult::Completed {
641                result: format!("MCP tool call failed: {error}"),
642                is_error: true,
643            };
644        }
645    };
646
647    let peer_for_cancel = request_handle.peer.clone();
648    let request_id = request_handle.id.clone();
649
650    tokio::select! {
651        _ = cancel.cancelled() => {
652            let notification = CancelledNotification {
653                method: CancelledNotificationMethod,
654                params: CancelledNotificationParam {
655                    request_id,
656                    reason: Some("user cancel".to_string()),
657                },
658                extensions: Default::default(),
659            };
660
661            let _ = peer_for_cancel.send_notification(notification.into()).await;
662            ToolExecutionResult::Cancelled
663        }
664        server_result = request_handle.await_response() => {
665            match server_result {
666                Ok(ServerResult::CallToolResult(result)) => {
667                    ToolExecutionResult::Completed {
668                        result: render_call_tool_result(&result),
669                        is_error: result.is_error.unwrap_or(false),
670                    }
671                }
672                Ok(_) => ToolExecutionResult::Completed {
673                    result: "Unexpected MCP response type".to_string(),
674                    is_error: true,
675                },
676                Err(error) => ToolExecutionResult::Completed {
677                    result: format!("MCP tool execution error: {error}"),
678                    is_error: true,
679                },
680            }
681        }
682    }
683}
684
685fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
686    let rendered = result
687        .content
688        .iter()
689        .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
690        .collect::<Vec<_>>()
691        .join("\n");
692
693    if !rendered.is_empty() {
694        return sanitize_text_output(&rendered);
695    }
696
697    if result.content.is_empty() {
698        return "<empty tool result>".to_string();
699    }
700
701    "<non-text tool result omitted for safety>".to_string()
702}
703
704fn checkpoint_signature(
705    messages: &[Message],
706    metadata: &serde_json::Value,
707) -> Result<String, String> {
708    serde_json::to_string(&(messages, metadata))
709        .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
710}
711
712async fn persist_checkpoint(
713    state: &AppState,
714    session_id: Uuid,
715    run_id: Uuid,
716    active_model: &stakai::Model,
717    parent_id: Option<Uuid>,
718    messages: &[Message],
719    metadata: &serde_json::Value,
720) -> Result<Uuid, String> {
721    // TODO(ahmed): Migrate server/session checkpoint storage to `Vec<stakai::Message>` directly
722    // and remove the ChatMessage adapter conversion (`message_bridge::stakai_to_chat`).
723    let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages))
724        .with_metadata(metadata.clone());
725
726    if let Some(parent_id) = parent_id {
727        request = request.with_parent(parent_id);
728    }
729
730    let checkpoint = state
731        .session_store
732        .create_checkpoint(session_id, &request)
733        .await
734        .map_err(|error| error.to_string())?;
735
736    let mut envelope_metadata = if metadata.is_object() {
737        metadata.clone()
738    } else {
739        json!({})
740    };
741
742    if let Some(obj) = envelope_metadata.as_object_mut() {
743        obj.insert(
744            "session_id".to_string(),
745            serde_json::Value::String(session_id.to_string()),
746        );
747        obj.insert(
748            "checkpoint_id".to_string(),
749            serde_json::Value::String(checkpoint.id.to_string()),
750        );
751        obj.insert(
752            ACTIVE_MODEL_METADATA_KEY.to_string(),
753            serde_json::Value::String(format!("{}/{}", active_model.provider, active_model.id)),
754        );
755    }
756
757    let envelope = build_checkpoint_envelope(run_id, messages.to_vec(), envelope_metadata);
758
759    state
760        .checkpoint_store
761        .save_latest(session_id, &envelope)
762        .await
763        .map_err(|error| {
764            format!(
765                "Failed to persist checkpoint envelope for session {}: {}",
766                session_id, error
767            )
768        })?;
769
770    Ok(checkpoint.id)
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use rmcp::model::{CallToolResult, Content};
777    use serde_json::json;
778    use stakai::{ContentPart, Message, MessageContent, Role};
779
780    #[test]
781    fn run_id_is_not_regenerated_when_building_run_context() {
782        let session_id = Uuid::new_v4();
783        let run_id = Uuid::new_v4();
784
785        let run_context = build_run_context(session_id, run_id);
786
787        assert_eq!(run_context.session_id, session_id);
788        assert_eq!(run_context.run_id, run_id);
789    }
790
791    #[test]
792    fn checkpoint_envelope_carries_same_run_id() {
793        let run_id = Uuid::new_v4();
794        let envelope = build_checkpoint_envelope(
795            run_id,
796            vec![Message::new(Role::User, "hello")],
797            json!({"turn": 1}),
798        );
799
800        assert_eq!(envelope.run_id, Some(run_id));
801    }
802
803    #[test]
804    fn render_call_tool_result_sanitizes_text_blocks() {
805        let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
806
807        assert_eq!(render_call_tool_result(&result), "okdone");
808    }
809
810    #[test]
811    fn render_call_tool_result_omits_non_text_blocks() {
812        let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
813
814        assert_eq!(
815            render_call_tool_result(&result),
816            "<non-text tool result omitted for safety>"
817        );
818    }
819
820    #[test]
821    fn checkpoint_signature_changes_when_messages_change() {
822        let messages_a = vec![Message::new(Role::User, "hello")];
823        let messages_b = vec![
824            Message::new(Role::User, "hello"),
825            Message::new(Role::Assistant, "hi"),
826        ];
827
828        let sig_a = checkpoint_signature(&messages_a, &json!({}))
829            .unwrap_or_else(|error| panic!("signature failed: {error}"));
830        let sig_b = checkpoint_signature(&messages_b, &json!({}))
831            .unwrap_or_else(|error| panic!("signature failed: {error}"));
832
833        assert_ne!(sig_a, sig_b);
834    }
835
836    #[test]
837    fn checkpoint_signature_changes_when_metadata_changes() {
838        let messages = vec![Message::new(Role::User, "hello")];
839
840        let sig_a = checkpoint_signature(&messages, &json!({}))
841            .unwrap_or_else(|error| panic!("signature failed: {error}"));
842        let sig_b = checkpoint_signature(&messages, &json!({"trimmed_up_to_message_index": 5}))
843            .unwrap_or_else(|error| panic!("signature failed: {error}"));
844
845        assert_ne!(sig_a, sig_b);
846    }
847
848    #[test]
849    fn is_new_session_empty_history() {
850        assert!(is_new_session_history(&[]));
851    }
852
853    #[test]
854    fn is_new_session_system_only() {
855        let messages = vec![Message::new(Role::System, "you are an agent")];
856        assert!(is_new_session_history(&messages));
857    }
858
859    #[test]
860    fn is_not_new_session_with_user_message() {
861        let messages = vec![Message::new(Role::User, "hello")];
862        assert!(!is_new_session_history(&messages));
863    }
864
865    #[test]
866    fn is_not_new_session_with_system_and_user() {
867        let messages = vec![
868            Message::new(Role::System, "system"),
869            Message::new(Role::User, "hello"),
870        ];
871        assert!(!is_new_session_history(&messages));
872    }
873
874    #[test]
875    fn is_not_new_session_with_assistant() {
876        let messages = vec![Message::new(Role::Assistant, "hi there")];
877        assert!(!is_new_session_history(&messages));
878    }
879
880    #[test]
881    fn prepend_context_to_text_message() {
882        let msg = Message::new(Role::User, "how do I deploy?");
883        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
884
885        let text = result.text().unwrap_or_default();
886        assert!(
887            text.starts_with("<context>env info</context>"),
888            "context should be prepended"
889        );
890        assert!(
891            text.contains("how do I deploy?"),
892            "original text should be preserved"
893        );
894    }
895
896    #[test]
897    fn prepend_context_to_empty_text_message() {
898        let msg = Message::new(Role::User, "  ");
899        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
900
901        let text = result.text().unwrap_or_default();
902        assert_eq!(text, "<context>env info</context>");
903    }
904
905    #[test]
906    fn prepend_context_to_parts_message() {
907        let msg = Message {
908            role: Role::User,
909            content: MessageContent::Parts(vec![ContentPart::text("original text")]),
910            name: None,
911            provider_options: None,
912        };
913        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
914
915        if let MessageContent::Parts(parts) = &result.content {
916            assert_eq!(parts.len(), 2, "should have context part + original part");
917            if let ContentPart::Text { text, .. } = &parts[0] {
918                assert_eq!(text, "<context>env info</context>");
919            } else {
920                panic!("first part should be text");
921            }
922        } else {
923            panic!("expected Parts content");
924        }
925    }
926
927    #[test]
928    fn prepend_empty_context_is_noop() {
929        let msg = Message::new(Role::User, "hello");
930        let result = prepend_context_to_user_message(msg, "   ");
931
932        assert_eq!(result.text().unwrap_or_default(), "hello");
933    }
934}