Skip to main content

stakpak_server/
session_actor.rs

1use crate::{
2    message_bridge,
3    sandbox::{SandboxConfig, SandboxedMcpServer},
4    state::AppState,
5    types::SessionHandle,
6};
7use async_trait::async_trait;
8use rmcp::model::{
9    CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
10    CancelledNotificationParam, ServerResult,
11};
12use serde_json::json;
13use stakai::Message;
14use stakpak_agent_core::{
15    AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, CheckpointEnvelopeV1,
16    CompactionConfig, ContextConfig, PassthroughCompactionEngine, ProposedToolCall, RetryConfig,
17    ToolExecutionResult, ToolExecutor, run_agent,
18};
19use stakpak_api::CreateCheckpointRequest;
20use stakpak_mcp_client::McpClient;
21use stakpak_shared::utils::sanitize_text_output;
22use std::sync::Arc;
23use tokio::sync::{Mutex, mpsc};
24use tokio_util::sync::CancellationToken;
25use uuid::Uuid;
26
27const MAX_TURNS: usize = 64;
28const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
29pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
30
31pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
32    AgentRunContext { run_id, session_id }
33}
34
35pub fn build_checkpoint_envelope(
36    run_id: Uuid,
37    messages: Vec<stakai::Message>,
38    metadata: serde_json::Value,
39) -> CheckpointEnvelopeV1 {
40    CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
41}
42
43pub fn spawn_session_actor(
44    state: AppState,
45    session_id: Uuid,
46    run_id: Uuid,
47    model: stakai::Model,
48    user_message: Message,
49    sandbox_config: Option<SandboxConfig>,
50) -> Result<SessionHandle, String> {
51    let (command_tx, command_rx) = mpsc::channel(128);
52    let cancel = CancellationToken::new();
53
54    let handle = SessionHandle::new(command_tx, cancel.clone());
55
56    let state_for_task = state.clone();
57    tokio::spawn(async move {
58        let actor_result = run_session_actor(
59            state_for_task.clone(),
60            session_id,
61            run_id,
62            model,
63            user_message,
64            command_rx,
65            cancel,
66            sandbox_config,
67        )
68        .await;
69
70        let finish_result = actor_result.map(|_| ());
71        let _ = state_for_task
72            .run_manager
73            .mark_run_finished(session_id, run_id, finish_result)
74            .await;
75    });
76
77    Ok(handle)
78}
79
80#[allow(clippy::too_many_arguments)]
81async fn run_session_actor(
82    state: AppState,
83    session_id: Uuid,
84    run_id: Uuid,
85    model: stakai::Model,
86    user_message: Message,
87    command_rx: mpsc::Receiver<AgentCommand>,
88    cancel: CancellationToken,
89    sandbox_config: Option<SandboxConfig>,
90) -> Result<(), String> {
91    let active_checkpoint = state
92        .session_store
93        .get_active_checkpoint(session_id)
94        .await
95        .ok();
96    let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
97
98    let initial_messages = match state.checkpoint_store.load_latest(session_id).await {
99        Ok(Some(envelope)) => envelope.messages,
100        Ok(None) => active_checkpoint
101            .map(|checkpoint| message_bridge::chat_to_stakai(checkpoint.state.messages))
102            .unwrap_or_default(),
103        Err(error) => {
104            return Err(format!("Failed to load checkpoint envelope: {error}"));
105        }
106    };
107
108    let mut baseline_messages = initial_messages.clone();
109    baseline_messages.push(user_message.clone());
110
111    let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
112        state.clone(),
113        session_id,
114        run_id,
115        model.clone(),
116        parent_checkpoint_id,
117        baseline_messages,
118    ));
119
120    checkpoint_runtime
121        .persist_snapshot()
122        .await
123        .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
124
125    let periodic_checkpoint_cancel = CancellationToken::new();
126    let periodic_checkpoint_runtime = checkpoint_runtime.clone();
127    let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
128    let periodic_task = tokio::spawn(async move {
129        let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
130        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
131
132        loop {
133            tokio::select! {
134                _ = periodic_checkpoint_cancel_task.cancelled() => break,
135                _ = interval.tick() => {
136                    let _ = periodic_checkpoint_runtime.persist_snapshot().await;
137                }
138            }
139        }
140    });
141
142    let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
143
144    let event_state = state.clone();
145    let event_forwarder = tokio::spawn(async move {
146        while let Some(event) = core_event_rx.recv().await {
147            handle_core_event(&event_state, session_id, run_id, event).await;
148        }
149    });
150
151    // If sandbox is requested, spawn a sandboxed MCP server for this session.
152    // Otherwise, use the shared in-process MCP client.
153    let sandbox = if let Some(sandbox_config) = sandbox_config {
154        tracing::info!(session_id = %session_id, image = %sandbox_config.image, "Spawning sandbox container for session");
155        Some(
156            SandboxedMcpServer::spawn(&sandbox_config)
157                .await
158                .map_err(|e| format!("Failed to start sandbox for session {session_id}: {e}"))?,
159        )
160    } else {
161        None
162    };
163
164    let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
165        if let Some(ref sandbox) = sandbox {
166            (
167                sandbox.tools.clone(),
168                Box::new(SandboxedToolExecutor {
169                    mcp_client: sandbox.client.clone(),
170                }),
171            )
172        } else {
173            (
174                state.current_mcp_tools().await,
175                Box::new(ServerToolExecutor {
176                    state: state.clone(),
177                }),
178            )
179        };
180
181    let agent_config = AgentConfig {
182        model,
183        system_prompt: String::new(),
184        max_turns: MAX_TURNS,
185        max_output_tokens: 0,
186        provider_options: None,
187        tool_approval: state.tool_approval_policy.clone(),
188        context: ContextConfig::default(),
189        retry: RetryConfig::default(),
190        compaction: CompactionConfig::default(),
191        tools: run_tools,
192    };
193
194    let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
195        checkpoint_runtime: checkpoint_runtime.clone(),
196    })];
197
198    let compactor = PassthroughCompactionEngine;
199    let run_context = build_run_context(session_id, run_id);
200
201    let run_result = run_agent(
202        run_context,
203        state.inference.as_ref(),
204        &agent_config,
205        initial_messages,
206        user_message,
207        tool_executor.as_ref(),
208        &hooks,
209        core_event_tx,
210        command_rx,
211        cancel,
212        &compactor,
213    )
214    .await;
215
216    periodic_checkpoint_cancel.cancel();
217    let _ = periodic_task.await;
218
219    // Shut down sandbox container if one was started
220    if let Some(sandbox) = sandbox {
221        sandbox.shutdown().await;
222    }
223
224    state.clear_pending_tools(session_id, run_id).await;
225
226    match &run_result {
227        Ok(result) => {
228            checkpoint_runtime.update_messages(&result.messages).await;
229            checkpoint_runtime
230                .persist_snapshot()
231                .await
232                .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
233        }
234        Err(_) => {
235            let _ = checkpoint_runtime.persist_snapshot().await;
236        }
237    }
238
239    let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
240
241    run_result
242        .map(|_| ())
243        .map_err(|error| format!("Agent run failed: {error}"))
244}
245
246async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
247    match &event {
248        AgentEvent::ToolCallsProposed { tool_calls, .. } => {
249            state
250                .set_pending_tools(session_id, run_id, tool_calls.clone())
251                .await;
252        }
253        AgentEvent::TurnCompleted { .. }
254        | AgentEvent::RunCompleted { .. }
255        | AgentEvent::RunError { .. } => {
256            state.clear_pending_tools(session_id, run_id).await;
257        }
258        _ => {}
259    }
260
261    state.events.publish(session_id, Some(run_id), event).await;
262}
263
264#[derive(Clone)]
265struct ServerToolExecutor {
266    state: AppState,
267}
268
269#[async_trait]
270impl ToolExecutor for ServerToolExecutor {
271    async fn execute_tool_call(
272        &self,
273        run: &AgentRunContext,
274        tool_call: &ProposedToolCall,
275        cancel: &CancellationToken,
276    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
277        Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
278    }
279}
280
281/// Tool executor that routes calls through a per-session sandboxed MCP client.
282#[derive(Clone)]
283struct SandboxedToolExecutor {
284    mcp_client: Arc<McpClient>,
285}
286
287#[async_trait]
288impl ToolExecutor for SandboxedToolExecutor {
289    async fn execute_tool_call(
290        &self,
291        run: &AgentRunContext,
292        tool_call: &ProposedToolCall,
293        cancel: &CancellationToken,
294    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
295        Ok(execute_mcp_tool_call_with_client(
296            &self.mcp_client,
297            run.session_id,
298            run.run_id,
299            tool_call,
300            cancel,
301        )
302        .await)
303    }
304}
305
306struct CheckpointRuntime {
307    state: AppState,
308    session_id: Uuid,
309    run_id: Uuid,
310    active_model: stakai::Model,
311    inner: Mutex<CheckpointRuntimeInner>,
312}
313
314struct CheckpointRuntimeInner {
315    parent_checkpoint_id: Option<Uuid>,
316    latest_messages: Vec<Message>,
317    last_persisted_signature: Option<String>,
318    dirty: bool,
319}
320
321impl CheckpointRuntime {
322    fn new(
323        state: AppState,
324        session_id: Uuid,
325        run_id: Uuid,
326        active_model: stakai::Model,
327        parent_checkpoint_id: Option<Uuid>,
328        latest_messages: Vec<Message>,
329    ) -> Self {
330        Self {
331            state,
332            session_id,
333            run_id,
334            active_model,
335            inner: Mutex::new(CheckpointRuntimeInner {
336                parent_checkpoint_id,
337                latest_messages,
338                last_persisted_signature: None,
339                dirty: true,
340            }),
341        }
342    }
343
344    async fn update_messages(&self, messages: &[Message]) {
345        let mut guard = self.inner.lock().await;
346        guard.latest_messages = messages.to_vec();
347        guard.dirty = true;
348    }
349
350    async fn persist_snapshot(&self) -> Result<Uuid, String> {
351        let mut guard = self.inner.lock().await;
352        self.persist_if_needed(&mut guard).await
353    }
354
355    async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
356        if !guard.dirty
357            && let Some(checkpoint_id) = guard.parent_checkpoint_id
358        {
359            return Ok(checkpoint_id);
360        }
361
362        let signature = checkpoint_signature(&guard.latest_messages)?;
363        let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
364        let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
365
366        if !should_persist {
367            guard.dirty = false;
368            if let Some(checkpoint_id) = guard.parent_checkpoint_id {
369                return Ok(checkpoint_id);
370            }
371        }
372
373        let checkpoint_id = persist_checkpoint(
374            &self.state,
375            self.session_id,
376            self.run_id,
377            &self.active_model,
378            guard.parent_checkpoint_id,
379            &guard.latest_messages,
380        )
381        .await?;
382
383        guard.parent_checkpoint_id = Some(checkpoint_id);
384        guard.last_persisted_signature = Some(signature);
385        guard.dirty = false;
386
387        Ok(checkpoint_id)
388    }
389}
390
391struct ServerCheckpointHook {
392    checkpoint_runtime: Arc<CheckpointRuntime>,
393}
394
395#[async_trait]
396impl AgentHook for ServerCheckpointHook {
397    async fn before_inference(
398        &self,
399        _run: &AgentRunContext,
400        messages: &[Message],
401        _model: &stakai::Model,
402    ) -> Result<(), stakpak_agent_core::AgentError> {
403        self.checkpoint_runtime.update_messages(messages).await;
404        Ok(())
405    }
406
407    async fn after_inference(
408        &self,
409        _run: &AgentRunContext,
410        messages: &[Message],
411        _model: &stakai::Model,
412    ) -> Result<(), stakpak_agent_core::AgentError> {
413        self.checkpoint_runtime.update_messages(messages).await;
414        Ok(())
415    }
416
417    async fn after_tool_execution(
418        &self,
419        _run: &AgentRunContext,
420        _tool_call: &ProposedToolCall,
421        messages: &[Message],
422    ) -> Result<(), stakpak_agent_core::AgentError> {
423        self.checkpoint_runtime.update_messages(messages).await;
424        Ok(())
425    }
426
427    async fn on_error(
428        &self,
429        _run: &AgentRunContext,
430        _error: &stakpak_agent_core::AgentError,
431        messages: &[Message],
432    ) -> Result<(), stakpak_agent_core::AgentError> {
433        self.checkpoint_runtime.update_messages(messages).await;
434        let _ = self.checkpoint_runtime.persist_snapshot().await;
435        Ok(())
436    }
437}
438
439async fn execute_mcp_tool_call(
440    state: &AppState,
441    session_id: Uuid,
442    run_id: Uuid,
443    tool_call: &ProposedToolCall,
444    cancel: &CancellationToken,
445) -> ToolExecutionResult {
446    let Some(mcp_client) = state.mcp_client.as_ref() else {
447        return ToolExecutionResult::Completed {
448            result: "MCP client is not initialized".to_string(),
449            is_error: true,
450        };
451    };
452
453    execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
454}
455
456async fn execute_mcp_tool_call_with_client(
457    mcp_client: &McpClient,
458    session_id: Uuid,
459    run_id: Uuid,
460    tool_call: &ProposedToolCall,
461    cancel: &CancellationToken,
462) -> ToolExecutionResult {
463    let metadata = Some(serde_json::Map::from_iter([
464        (
465            "session_id".to_string(),
466            serde_json::Value::String(session_id.to_string()),
467        ),
468        (
469            "run_id".to_string(),
470            serde_json::Value::String(run_id.to_string()),
471        ),
472        (
473            "tool_call_id".to_string(),
474            serde_json::Value::String(tool_call.id.clone()),
475        ),
476    ]));
477
478    let arguments = match &tool_call.arguments {
479        serde_json::Value::Object(map) => Some(map.clone()),
480        serde_json::Value::Null => None,
481        other => Some(serde_json::Map::from_iter([(
482            "input".to_string(),
483            other.clone(),
484        )])),
485    };
486
487    let request_handle = match stakpak_mcp_client::call_tool(
488        mcp_client,
489        CallToolRequestParam {
490            name: tool_call.name.clone().into(),
491            arguments,
492        },
493        metadata,
494    )
495    .await
496    {
497        Ok(handle) => handle,
498        Err(error) => {
499            return ToolExecutionResult::Completed {
500                result: format!("MCP tool call failed: {error}"),
501                is_error: true,
502            };
503        }
504    };
505
506    let peer_for_cancel = request_handle.peer.clone();
507    let request_id = request_handle.id.clone();
508
509    tokio::select! {
510        _ = cancel.cancelled() => {
511            let notification = CancelledNotification {
512                method: CancelledNotificationMethod,
513                params: CancelledNotificationParam {
514                    request_id,
515                    reason: Some("user cancel".to_string()),
516                },
517                extensions: Default::default(),
518            };
519
520            let _ = peer_for_cancel.send_notification(notification.into()).await;
521            ToolExecutionResult::Cancelled
522        }
523        server_result = request_handle.await_response() => {
524            match server_result {
525                Ok(ServerResult::CallToolResult(result)) => {
526                    ToolExecutionResult::Completed {
527                        result: render_call_tool_result(&result),
528                        is_error: result.is_error.unwrap_or(false),
529                    }
530                }
531                Ok(_) => ToolExecutionResult::Completed {
532                    result: "Unexpected MCP response type".to_string(),
533                    is_error: true,
534                },
535                Err(error) => ToolExecutionResult::Completed {
536                    result: format!("MCP tool execution error: {error}"),
537                    is_error: true,
538                },
539            }
540        }
541    }
542}
543
544fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
545    let rendered = result
546        .content
547        .iter()
548        .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
549        .collect::<Vec<_>>()
550        .join("\n");
551
552    if !rendered.is_empty() {
553        return sanitize_text_output(&rendered);
554    }
555
556    if result.content.is_empty() {
557        return "<empty tool result>".to_string();
558    }
559
560    "<non-text tool result omitted for safety>".to_string()
561}
562
563fn checkpoint_signature(messages: &[Message]) -> Result<String, String> {
564    serde_json::to_string(messages)
565        .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
566}
567
568async fn persist_checkpoint(
569    state: &AppState,
570    session_id: Uuid,
571    run_id: Uuid,
572    active_model: &stakai::Model,
573    parent_id: Option<Uuid>,
574    messages: &[Message],
575) -> Result<Uuid, String> {
576    // TODO(ahmed): Migrate server/session checkpoint storage to `Vec<stakai::Message>` directly
577    // and remove the ChatMessage adapter conversion (`message_bridge::stakai_to_chat`).
578    let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages));
579
580    if let Some(parent_id) = parent_id {
581        request = request.with_parent(parent_id);
582    }
583
584    let checkpoint = state
585        .session_store
586        .create_checkpoint(session_id, &request)
587        .await
588        .map_err(|error| error.to_string())?;
589
590    let envelope = build_checkpoint_envelope(
591        run_id,
592        messages.to_vec(),
593        json!({
594            "session_id": session_id.to_string(),
595            "checkpoint_id": checkpoint.id.to_string(),
596            (ACTIVE_MODEL_METADATA_KEY): format!("{}/{}", active_model.provider, active_model.id),
597        }),
598    );
599
600    state
601        .checkpoint_store
602        .save_latest(session_id, &envelope)
603        .await
604        .map_err(|error| {
605            format!(
606                "Failed to persist checkpoint envelope for session {}: {}",
607                session_id, error
608            )
609        })?;
610
611    Ok(checkpoint.id)
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use rmcp::model::{CallToolResult, Content};
618    use serde_json::json;
619    use stakai::{Message, Role};
620
621    #[test]
622    fn run_id_is_not_regenerated_when_building_run_context() {
623        let session_id = Uuid::new_v4();
624        let run_id = Uuid::new_v4();
625
626        let run_context = build_run_context(session_id, run_id);
627
628        assert_eq!(run_context.session_id, session_id);
629        assert_eq!(run_context.run_id, run_id);
630    }
631
632    #[test]
633    fn checkpoint_envelope_carries_same_run_id() {
634        let run_id = Uuid::new_v4();
635        let envelope = build_checkpoint_envelope(
636            run_id,
637            vec![Message::new(Role::User, "hello")],
638            json!({"turn": 1}),
639        );
640
641        assert_eq!(envelope.run_id, Some(run_id));
642    }
643
644    #[test]
645    fn render_call_tool_result_sanitizes_text_blocks() {
646        let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
647
648        assert_eq!(render_call_tool_result(&result), "okdone");
649    }
650
651    #[test]
652    fn render_call_tool_result_omits_non_text_blocks() {
653        let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
654
655        assert_eq!(
656            render_call_tool_result(&result),
657            "<non-text tool result omitted for safety>"
658        );
659    }
660
661    #[test]
662    fn checkpoint_signature_changes_when_messages_change() {
663        let messages_a = vec![Message::new(Role::User, "hello")];
664        let messages_b = vec![
665            Message::new(Role::User, "hello"),
666            Message::new(Role::Assistant, "hi"),
667        ];
668
669        let sig_a = checkpoint_signature(&messages_a)
670            .unwrap_or_else(|error| panic!("signature failed: {error}"));
671        let sig_b = checkpoint_signature(&messages_b)
672            .unwrap_or_else(|error| panic!("signature failed: {error}"));
673
674        assert_ne!(sig_a, sig_b);
675    }
676}