1use crate::{
2 message_bridge,
3 sandbox::{SandboxConfig, SandboxedMcpServer},
4 state::AppState,
5 types::SessionHandle,
6};
7use async_trait::async_trait;
8use rmcp::model::{
9 CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
10 CancelledNotificationParam, ServerResult,
11};
12use serde_json::json;
13use stakai::Message;
14use stakpak_agent_core::{
15 AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, CheckpointEnvelopeV1,
16 CompactionConfig, ContextConfig, PassthroughCompactionEngine, ProposedToolCall, RetryConfig,
17 ToolExecutionResult, ToolExecutor, run_agent,
18};
19use stakpak_api::CreateCheckpointRequest;
20use stakpak_mcp_client::McpClient;
21use stakpak_shared::utils::sanitize_text_output;
22use std::sync::Arc;
23use tokio::sync::{Mutex, mpsc};
24use tokio_util::sync::CancellationToken;
25use uuid::Uuid;
26
27const MAX_TURNS: usize = 64;
28const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
29pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
30
31pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
32 AgentRunContext { run_id, session_id }
33}
34
35pub fn build_checkpoint_envelope(
36 run_id: Uuid,
37 messages: Vec<stakai::Message>,
38 metadata: serde_json::Value,
39) -> CheckpointEnvelopeV1 {
40 CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
41}
42
43pub fn spawn_session_actor(
44 state: AppState,
45 session_id: Uuid,
46 run_id: Uuid,
47 model: stakai::Model,
48 user_message: Message,
49 sandbox_config: Option<SandboxConfig>,
50) -> Result<SessionHandle, String> {
51 let (command_tx, command_rx) = mpsc::channel(128);
52 let cancel = CancellationToken::new();
53
54 let handle = SessionHandle::new(command_tx, cancel.clone());
55
56 let state_for_task = state.clone();
57 tokio::spawn(async move {
58 let actor_result = run_session_actor(
59 state_for_task.clone(),
60 session_id,
61 run_id,
62 model,
63 user_message,
64 command_rx,
65 cancel,
66 sandbox_config,
67 )
68 .await;
69
70 let finish_result = actor_result.map(|_| ());
71 let _ = state_for_task
72 .run_manager
73 .mark_run_finished(session_id, run_id, finish_result)
74 .await;
75 });
76
77 Ok(handle)
78}
79
80#[allow(clippy::too_many_arguments)]
81async fn run_session_actor(
82 state: AppState,
83 session_id: Uuid,
84 run_id: Uuid,
85 model: stakai::Model,
86 user_message: Message,
87 command_rx: mpsc::Receiver<AgentCommand>,
88 cancel: CancellationToken,
89 sandbox_config: Option<SandboxConfig>,
90) -> Result<(), String> {
91 let active_checkpoint = state
92 .session_store
93 .get_active_checkpoint(session_id)
94 .await
95 .ok();
96 let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
97
98 let initial_messages = match state.checkpoint_store.load_latest(session_id).await {
99 Ok(Some(envelope)) => envelope.messages,
100 Ok(None) => active_checkpoint
101 .map(|checkpoint| message_bridge::chat_to_stakai(checkpoint.state.messages))
102 .unwrap_or_default(),
103 Err(error) => {
104 return Err(format!("Failed to load checkpoint envelope: {error}"));
105 }
106 };
107
108 let mut baseline_messages = initial_messages.clone();
109 baseline_messages.push(user_message.clone());
110
111 let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
112 state.clone(),
113 session_id,
114 run_id,
115 model.clone(),
116 parent_checkpoint_id,
117 baseline_messages,
118 ));
119
120 checkpoint_runtime
121 .persist_snapshot()
122 .await
123 .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
124
125 let periodic_checkpoint_cancel = CancellationToken::new();
126 let periodic_checkpoint_runtime = checkpoint_runtime.clone();
127 let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
128 let periodic_task = tokio::spawn(async move {
129 let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
130 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
131
132 loop {
133 tokio::select! {
134 _ = periodic_checkpoint_cancel_task.cancelled() => break,
135 _ = interval.tick() => {
136 let _ = periodic_checkpoint_runtime.persist_snapshot().await;
137 }
138 }
139 }
140 });
141
142 let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
143
144 let event_state = state.clone();
145 let event_forwarder = tokio::spawn(async move {
146 while let Some(event) = core_event_rx.recv().await {
147 handle_core_event(&event_state, session_id, run_id, event).await;
148 }
149 });
150
151 let sandbox = if let Some(sandbox_config) = sandbox_config {
154 tracing::info!(session_id = %session_id, image = %sandbox_config.image, "Spawning sandbox container for session");
155 Some(
156 SandboxedMcpServer::spawn(&sandbox_config)
157 .await
158 .map_err(|e| format!("Failed to start sandbox for session {session_id}: {e}"))?,
159 )
160 } else {
161 None
162 };
163
164 let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
165 if let Some(ref sandbox) = sandbox {
166 (
167 sandbox.tools.clone(),
168 Box::new(SandboxedToolExecutor {
169 mcp_client: sandbox.client.clone(),
170 }),
171 )
172 } else {
173 (
174 state.current_mcp_tools().await,
175 Box::new(ServerToolExecutor {
176 state: state.clone(),
177 }),
178 )
179 };
180
181 let agent_config = AgentConfig {
182 model,
183 system_prompt: String::new(),
184 max_turns: MAX_TURNS,
185 max_output_tokens: 0,
186 provider_options: None,
187 tool_approval: state.tool_approval_policy.clone(),
188 context: ContextConfig::default(),
189 retry: RetryConfig::default(),
190 compaction: CompactionConfig::default(),
191 tools: run_tools,
192 };
193
194 let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
195 checkpoint_runtime: checkpoint_runtime.clone(),
196 })];
197
198 let compactor = PassthroughCompactionEngine;
199 let run_context = build_run_context(session_id, run_id);
200
201 let run_result = run_agent(
202 run_context,
203 state.inference.as_ref(),
204 &agent_config,
205 initial_messages,
206 user_message,
207 tool_executor.as_ref(),
208 &hooks,
209 core_event_tx,
210 command_rx,
211 cancel,
212 &compactor,
213 )
214 .await;
215
216 periodic_checkpoint_cancel.cancel();
217 let _ = periodic_task.await;
218
219 if let Some(sandbox) = sandbox {
221 sandbox.shutdown().await;
222 }
223
224 state.clear_pending_tools(session_id, run_id).await;
225
226 match &run_result {
227 Ok(result) => {
228 checkpoint_runtime.update_messages(&result.messages).await;
229 checkpoint_runtime
230 .persist_snapshot()
231 .await
232 .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
233 }
234 Err(_) => {
235 let _ = checkpoint_runtime.persist_snapshot().await;
236 }
237 }
238
239 let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
240
241 run_result
242 .map(|_| ())
243 .map_err(|error| format!("Agent run failed: {error}"))
244}
245
246async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
247 match &event {
248 AgentEvent::ToolCallsProposed { tool_calls, .. } => {
249 state
250 .set_pending_tools(session_id, run_id, tool_calls.clone())
251 .await;
252 }
253 AgentEvent::TurnCompleted { .. }
254 | AgentEvent::RunCompleted { .. }
255 | AgentEvent::RunError { .. } => {
256 state.clear_pending_tools(session_id, run_id).await;
257 }
258 _ => {}
259 }
260
261 state.events.publish(session_id, Some(run_id), event).await;
262}
263
264#[derive(Clone)]
265struct ServerToolExecutor {
266 state: AppState,
267}
268
269#[async_trait]
270impl ToolExecutor for ServerToolExecutor {
271 async fn execute_tool_call(
272 &self,
273 run: &AgentRunContext,
274 tool_call: &ProposedToolCall,
275 cancel: &CancellationToken,
276 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
277 Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
278 }
279}
280
281#[derive(Clone)]
283struct SandboxedToolExecutor {
284 mcp_client: Arc<McpClient>,
285}
286
287#[async_trait]
288impl ToolExecutor for SandboxedToolExecutor {
289 async fn execute_tool_call(
290 &self,
291 run: &AgentRunContext,
292 tool_call: &ProposedToolCall,
293 cancel: &CancellationToken,
294 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
295 Ok(execute_mcp_tool_call_with_client(
296 &self.mcp_client,
297 run.session_id,
298 run.run_id,
299 tool_call,
300 cancel,
301 )
302 .await)
303 }
304}
305
306struct CheckpointRuntime {
307 state: AppState,
308 session_id: Uuid,
309 run_id: Uuid,
310 active_model: stakai::Model,
311 inner: Mutex<CheckpointRuntimeInner>,
312}
313
314struct CheckpointRuntimeInner {
315 parent_checkpoint_id: Option<Uuid>,
316 latest_messages: Vec<Message>,
317 last_persisted_signature: Option<String>,
318 dirty: bool,
319}
320
321impl CheckpointRuntime {
322 fn new(
323 state: AppState,
324 session_id: Uuid,
325 run_id: Uuid,
326 active_model: stakai::Model,
327 parent_checkpoint_id: Option<Uuid>,
328 latest_messages: Vec<Message>,
329 ) -> Self {
330 Self {
331 state,
332 session_id,
333 run_id,
334 active_model,
335 inner: Mutex::new(CheckpointRuntimeInner {
336 parent_checkpoint_id,
337 latest_messages,
338 last_persisted_signature: None,
339 dirty: true,
340 }),
341 }
342 }
343
344 async fn update_messages(&self, messages: &[Message]) {
345 let mut guard = self.inner.lock().await;
346 guard.latest_messages = messages.to_vec();
347 guard.dirty = true;
348 }
349
350 async fn persist_snapshot(&self) -> Result<Uuid, String> {
351 let mut guard = self.inner.lock().await;
352 self.persist_if_needed(&mut guard).await
353 }
354
355 async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
356 if !guard.dirty
357 && let Some(checkpoint_id) = guard.parent_checkpoint_id
358 {
359 return Ok(checkpoint_id);
360 }
361
362 let signature = checkpoint_signature(&guard.latest_messages)?;
363 let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
364 let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
365
366 if !should_persist {
367 guard.dirty = false;
368 if let Some(checkpoint_id) = guard.parent_checkpoint_id {
369 return Ok(checkpoint_id);
370 }
371 }
372
373 let checkpoint_id = persist_checkpoint(
374 &self.state,
375 self.session_id,
376 self.run_id,
377 &self.active_model,
378 guard.parent_checkpoint_id,
379 &guard.latest_messages,
380 )
381 .await?;
382
383 guard.parent_checkpoint_id = Some(checkpoint_id);
384 guard.last_persisted_signature = Some(signature);
385 guard.dirty = false;
386
387 Ok(checkpoint_id)
388 }
389}
390
391struct ServerCheckpointHook {
392 checkpoint_runtime: Arc<CheckpointRuntime>,
393}
394
395#[async_trait]
396impl AgentHook for ServerCheckpointHook {
397 async fn before_inference(
398 &self,
399 _run: &AgentRunContext,
400 messages: &[Message],
401 _model: &stakai::Model,
402 ) -> Result<(), stakpak_agent_core::AgentError> {
403 self.checkpoint_runtime.update_messages(messages).await;
404 Ok(())
405 }
406
407 async fn after_inference(
408 &self,
409 _run: &AgentRunContext,
410 messages: &[Message],
411 _model: &stakai::Model,
412 ) -> Result<(), stakpak_agent_core::AgentError> {
413 self.checkpoint_runtime.update_messages(messages).await;
414 Ok(())
415 }
416
417 async fn after_tool_execution(
418 &self,
419 _run: &AgentRunContext,
420 _tool_call: &ProposedToolCall,
421 messages: &[Message],
422 ) -> Result<(), stakpak_agent_core::AgentError> {
423 self.checkpoint_runtime.update_messages(messages).await;
424 Ok(())
425 }
426
427 async fn on_error(
428 &self,
429 _run: &AgentRunContext,
430 _error: &stakpak_agent_core::AgentError,
431 messages: &[Message],
432 ) -> Result<(), stakpak_agent_core::AgentError> {
433 self.checkpoint_runtime.update_messages(messages).await;
434 let _ = self.checkpoint_runtime.persist_snapshot().await;
435 Ok(())
436 }
437}
438
439async fn execute_mcp_tool_call(
440 state: &AppState,
441 session_id: Uuid,
442 run_id: Uuid,
443 tool_call: &ProposedToolCall,
444 cancel: &CancellationToken,
445) -> ToolExecutionResult {
446 let Some(mcp_client) = state.mcp_client.as_ref() else {
447 return ToolExecutionResult::Completed {
448 result: "MCP client is not initialized".to_string(),
449 is_error: true,
450 };
451 };
452
453 execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
454}
455
456async fn execute_mcp_tool_call_with_client(
457 mcp_client: &McpClient,
458 session_id: Uuid,
459 run_id: Uuid,
460 tool_call: &ProposedToolCall,
461 cancel: &CancellationToken,
462) -> ToolExecutionResult {
463 let metadata = Some(serde_json::Map::from_iter([
464 (
465 "session_id".to_string(),
466 serde_json::Value::String(session_id.to_string()),
467 ),
468 (
469 "run_id".to_string(),
470 serde_json::Value::String(run_id.to_string()),
471 ),
472 (
473 "tool_call_id".to_string(),
474 serde_json::Value::String(tool_call.id.clone()),
475 ),
476 ]));
477
478 let arguments = match &tool_call.arguments {
479 serde_json::Value::Object(map) => Some(map.clone()),
480 serde_json::Value::Null => None,
481 other => Some(serde_json::Map::from_iter([(
482 "input".to_string(),
483 other.clone(),
484 )])),
485 };
486
487 let request_handle = match stakpak_mcp_client::call_tool(
488 mcp_client,
489 CallToolRequestParam {
490 name: tool_call.name.clone().into(),
491 arguments,
492 },
493 metadata,
494 )
495 .await
496 {
497 Ok(handle) => handle,
498 Err(error) => {
499 return ToolExecutionResult::Completed {
500 result: format!("MCP tool call failed: {error}"),
501 is_error: true,
502 };
503 }
504 };
505
506 let peer_for_cancel = request_handle.peer.clone();
507 let request_id = request_handle.id.clone();
508
509 tokio::select! {
510 _ = cancel.cancelled() => {
511 let notification = CancelledNotification {
512 method: CancelledNotificationMethod,
513 params: CancelledNotificationParam {
514 request_id,
515 reason: Some("user cancel".to_string()),
516 },
517 extensions: Default::default(),
518 };
519
520 let _ = peer_for_cancel.send_notification(notification.into()).await;
521 ToolExecutionResult::Cancelled
522 }
523 server_result = request_handle.await_response() => {
524 match server_result {
525 Ok(ServerResult::CallToolResult(result)) => {
526 ToolExecutionResult::Completed {
527 result: render_call_tool_result(&result),
528 is_error: result.is_error.unwrap_or(false),
529 }
530 }
531 Ok(_) => ToolExecutionResult::Completed {
532 result: "Unexpected MCP response type".to_string(),
533 is_error: true,
534 },
535 Err(error) => ToolExecutionResult::Completed {
536 result: format!("MCP tool execution error: {error}"),
537 is_error: true,
538 },
539 }
540 }
541 }
542}
543
544fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
545 let rendered = result
546 .content
547 .iter()
548 .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
549 .collect::<Vec<_>>()
550 .join("\n");
551
552 if !rendered.is_empty() {
553 return sanitize_text_output(&rendered);
554 }
555
556 if result.content.is_empty() {
557 return "<empty tool result>".to_string();
558 }
559
560 "<non-text tool result omitted for safety>".to_string()
561}
562
563fn checkpoint_signature(messages: &[Message]) -> Result<String, String> {
564 serde_json::to_string(messages)
565 .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
566}
567
568async fn persist_checkpoint(
569 state: &AppState,
570 session_id: Uuid,
571 run_id: Uuid,
572 active_model: &stakai::Model,
573 parent_id: Option<Uuid>,
574 messages: &[Message],
575) -> Result<Uuid, String> {
576 let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages));
579
580 if let Some(parent_id) = parent_id {
581 request = request.with_parent(parent_id);
582 }
583
584 let checkpoint = state
585 .session_store
586 .create_checkpoint(session_id, &request)
587 .await
588 .map_err(|error| error.to_string())?;
589
590 let envelope = build_checkpoint_envelope(
591 run_id,
592 messages.to_vec(),
593 json!({
594 "session_id": session_id.to_string(),
595 "checkpoint_id": checkpoint.id.to_string(),
596 (ACTIVE_MODEL_METADATA_KEY): format!("{}/{}", active_model.provider, active_model.id),
597 }),
598 );
599
600 state
601 .checkpoint_store
602 .save_latest(session_id, &envelope)
603 .await
604 .map_err(|error| {
605 format!(
606 "Failed to persist checkpoint envelope for session {}: {}",
607 session_id, error
608 )
609 })?;
610
611 Ok(checkpoint.id)
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use rmcp::model::{CallToolResult, Content};
618 use serde_json::json;
619 use stakai::{Message, Role};
620
621 #[test]
622 fn run_id_is_not_regenerated_when_building_run_context() {
623 let session_id = Uuid::new_v4();
624 let run_id = Uuid::new_v4();
625
626 let run_context = build_run_context(session_id, run_id);
627
628 assert_eq!(run_context.session_id, session_id);
629 assert_eq!(run_context.run_id, run_id);
630 }
631
632 #[test]
633 fn checkpoint_envelope_carries_same_run_id() {
634 let run_id = Uuid::new_v4();
635 let envelope = build_checkpoint_envelope(
636 run_id,
637 vec![Message::new(Role::User, "hello")],
638 json!({"turn": 1}),
639 );
640
641 assert_eq!(envelope.run_id, Some(run_id));
642 }
643
644 #[test]
645 fn render_call_tool_result_sanitizes_text_blocks() {
646 let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
647
648 assert_eq!(render_call_tool_result(&result), "okdone");
649 }
650
651 #[test]
652 fn render_call_tool_result_omits_non_text_blocks() {
653 let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
654
655 assert_eq!(
656 render_call_tool_result(&result),
657 "<non-text tool result omitted for safety>"
658 );
659 }
660
661 #[test]
662 fn checkpoint_signature_changes_when_messages_change() {
663 let messages_a = vec![Message::new(Role::User, "hello")];
664 let messages_b = vec![
665 Message::new(Role::User, "hello"),
666 Message::new(Role::Assistant, "hi"),
667 ];
668
669 let sig_a = checkpoint_signature(&messages_a)
670 .unwrap_or_else(|error| panic!("signature failed: {error}"));
671 let sig_b = checkpoint_signature(&messages_b)
672 .unwrap_or_else(|error| panic!("signature failed: {error}"));
673
674 assert_ne!(sig_a, sig_b);
675 }
676}