1use crate::{message_bridge, state::AppState, types::SessionHandle};
2use async_trait::async_trait;
3use rmcp::model::{
4 CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
5 CancelledNotificationParam, ServerResult,
6};
7use serde_json::json;
8use stakai::Message;
9use stakpak_agent_core::{
10 AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, CheckpointEnvelopeV1,
11 CompactionConfig, ContextConfig, PassthroughCompactionEngine, ProposedToolCall, RetryConfig,
12 ToolExecutionResult, ToolExecutor, run_agent,
13};
14use stakpak_api::CreateCheckpointRequest;
15use stakpak_shared::utils::sanitize_text_output;
16use std::sync::Arc;
17use tokio::sync::{Mutex, mpsc};
18use tokio_util::sync::CancellationToken;
19use uuid::Uuid;
20
21const MAX_TURNS: usize = 64;
22const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
23pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
24
25pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
26 AgentRunContext { run_id, session_id }
27}
28
29pub fn build_checkpoint_envelope(
30 run_id: Uuid,
31 messages: Vec<stakai::Message>,
32 metadata: serde_json::Value,
33) -> CheckpointEnvelopeV1 {
34 CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
35}
36
37pub fn spawn_session_actor(
38 state: AppState,
39 session_id: Uuid,
40 run_id: Uuid,
41 model: stakai::Model,
42 user_message: Message,
43) -> Result<SessionHandle, String> {
44 let (command_tx, command_rx) = mpsc::channel(128);
45 let cancel = CancellationToken::new();
46
47 let handle = SessionHandle::new(command_tx, cancel.clone());
48
49 let state_for_task = state.clone();
50 tokio::spawn(async move {
51 let actor_result = run_session_actor(
52 state_for_task.clone(),
53 session_id,
54 run_id,
55 model,
56 user_message,
57 command_rx,
58 cancel,
59 )
60 .await;
61
62 let finish_result = actor_result.map(|_| ());
63 let _ = state_for_task
64 .run_manager
65 .mark_run_finished(session_id, run_id, finish_result)
66 .await;
67 });
68
69 Ok(handle)
70}
71
72async fn run_session_actor(
73 state: AppState,
74 session_id: Uuid,
75 run_id: Uuid,
76 model: stakai::Model,
77 user_message: Message,
78 command_rx: mpsc::Receiver<AgentCommand>,
79 cancel: CancellationToken,
80) -> Result<(), String> {
81 let active_checkpoint = state
82 .session_store
83 .get_active_checkpoint(session_id)
84 .await
85 .ok();
86 let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
87
88 let initial_messages = match state.checkpoint_store.load_latest(session_id).await {
89 Ok(Some(envelope)) => envelope.messages,
90 Ok(None) => active_checkpoint
91 .map(|checkpoint| message_bridge::chat_to_stakai(checkpoint.state.messages))
92 .unwrap_or_default(),
93 Err(error) => {
94 return Err(format!("Failed to load checkpoint envelope: {error}"));
95 }
96 };
97
98 let mut baseline_messages = initial_messages.clone();
99 baseline_messages.push(user_message.clone());
100
101 let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
102 state.clone(),
103 session_id,
104 run_id,
105 model.clone(),
106 parent_checkpoint_id,
107 baseline_messages,
108 ));
109
110 checkpoint_runtime
111 .persist_snapshot()
112 .await
113 .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
114
115 let periodic_checkpoint_cancel = CancellationToken::new();
116 let periodic_checkpoint_runtime = checkpoint_runtime.clone();
117 let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
118 let periodic_task = tokio::spawn(async move {
119 let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
120 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
121
122 loop {
123 tokio::select! {
124 _ = periodic_checkpoint_cancel_task.cancelled() => break,
125 _ = interval.tick() => {
126 let _ = periodic_checkpoint_runtime.persist_snapshot().await;
127 }
128 }
129 }
130 });
131
132 let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
133
134 let event_state = state.clone();
135 let event_forwarder = tokio::spawn(async move {
136 while let Some(event) = core_event_rx.recv().await {
137 handle_core_event(&event_state, session_id, run_id, event).await;
138 }
139 });
140
141 let run_tools = state.current_mcp_tools().await;
142
143 let agent_config = AgentConfig {
144 model,
145 system_prompt: String::new(),
146 max_turns: MAX_TURNS,
147 max_output_tokens: 0,
148 provider_options: None,
149 tool_approval: state.tool_approval_policy.clone(),
150 context: ContextConfig::default(),
151 retry: RetryConfig::default(),
152 compaction: CompactionConfig::default(),
153 tools: run_tools,
154 };
155
156 let tool_executor = ServerToolExecutor {
157 state: state.clone(),
158 };
159
160 let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
161 checkpoint_runtime: checkpoint_runtime.clone(),
162 })];
163
164 let compactor = PassthroughCompactionEngine;
165 let run_context = build_run_context(session_id, run_id);
166
167 let run_result = run_agent(
168 run_context,
169 state.inference.as_ref(),
170 &agent_config,
171 initial_messages,
172 user_message,
173 &tool_executor,
174 &hooks,
175 core_event_tx,
176 command_rx,
177 cancel,
178 &compactor,
179 )
180 .await;
181
182 periodic_checkpoint_cancel.cancel();
183 let _ = periodic_task.await;
184
185 state.clear_pending_tools(session_id, run_id).await;
186
187 match &run_result {
188 Ok(result) => {
189 checkpoint_runtime.update_messages(&result.messages).await;
190 checkpoint_runtime
191 .persist_snapshot()
192 .await
193 .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
194 }
195 Err(_) => {
196 let _ = checkpoint_runtime.persist_snapshot().await;
197 }
198 }
199
200 let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
201
202 run_result
203 .map(|_| ())
204 .map_err(|error| format!("Agent run failed: {error}"))
205}
206
207async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
208 match &event {
209 AgentEvent::ToolCallsProposed { tool_calls, .. } => {
210 state
211 .set_pending_tools(session_id, run_id, tool_calls.clone())
212 .await;
213 }
214 AgentEvent::TurnCompleted { .. }
215 | AgentEvent::RunCompleted { .. }
216 | AgentEvent::RunError { .. } => {
217 state.clear_pending_tools(session_id, run_id).await;
218 }
219 _ => {}
220 }
221
222 state.events.publish(session_id, Some(run_id), event).await;
223}
224
225#[derive(Clone)]
226struct ServerToolExecutor {
227 state: AppState,
228}
229
230#[async_trait]
231impl ToolExecutor for ServerToolExecutor {
232 async fn execute_tool_call(
233 &self,
234 run: &AgentRunContext,
235 tool_call: &ProposedToolCall,
236 cancel: &CancellationToken,
237 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
238 Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
239 }
240}
241
242struct CheckpointRuntime {
243 state: AppState,
244 session_id: Uuid,
245 run_id: Uuid,
246 active_model: stakai::Model,
247 inner: Mutex<CheckpointRuntimeInner>,
248}
249
250struct CheckpointRuntimeInner {
251 parent_checkpoint_id: Option<Uuid>,
252 latest_messages: Vec<Message>,
253 last_persisted_signature: Option<String>,
254 dirty: bool,
255}
256
257impl CheckpointRuntime {
258 fn new(
259 state: AppState,
260 session_id: Uuid,
261 run_id: Uuid,
262 active_model: stakai::Model,
263 parent_checkpoint_id: Option<Uuid>,
264 latest_messages: Vec<Message>,
265 ) -> Self {
266 Self {
267 state,
268 session_id,
269 run_id,
270 active_model,
271 inner: Mutex::new(CheckpointRuntimeInner {
272 parent_checkpoint_id,
273 latest_messages,
274 last_persisted_signature: None,
275 dirty: true,
276 }),
277 }
278 }
279
280 async fn update_messages(&self, messages: &[Message]) {
281 let mut guard = self.inner.lock().await;
282 guard.latest_messages = messages.to_vec();
283 guard.dirty = true;
284 }
285
286 async fn persist_snapshot(&self) -> Result<Uuid, String> {
287 let mut guard = self.inner.lock().await;
288 self.persist_if_needed(&mut guard).await
289 }
290
291 async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
292 if !guard.dirty
293 && let Some(checkpoint_id) = guard.parent_checkpoint_id
294 {
295 return Ok(checkpoint_id);
296 }
297
298 let signature = checkpoint_signature(&guard.latest_messages)?;
299 let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
300 let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
301
302 if !should_persist {
303 guard.dirty = false;
304 if let Some(checkpoint_id) = guard.parent_checkpoint_id {
305 return Ok(checkpoint_id);
306 }
307 }
308
309 let checkpoint_id = persist_checkpoint(
310 &self.state,
311 self.session_id,
312 self.run_id,
313 &self.active_model,
314 guard.parent_checkpoint_id,
315 &guard.latest_messages,
316 )
317 .await?;
318
319 guard.parent_checkpoint_id = Some(checkpoint_id);
320 guard.last_persisted_signature = Some(signature);
321 guard.dirty = false;
322
323 Ok(checkpoint_id)
324 }
325}
326
327struct ServerCheckpointHook {
328 checkpoint_runtime: Arc<CheckpointRuntime>,
329}
330
331#[async_trait]
332impl AgentHook for ServerCheckpointHook {
333 async fn before_inference(
334 &self,
335 _run: &AgentRunContext,
336 messages: &[Message],
337 _model: &stakai::Model,
338 ) -> Result<(), stakpak_agent_core::AgentError> {
339 self.checkpoint_runtime.update_messages(messages).await;
340 Ok(())
341 }
342
343 async fn after_inference(
344 &self,
345 _run: &AgentRunContext,
346 messages: &[Message],
347 _model: &stakai::Model,
348 ) -> Result<(), stakpak_agent_core::AgentError> {
349 self.checkpoint_runtime.update_messages(messages).await;
350 Ok(())
351 }
352
353 async fn after_tool_execution(
354 &self,
355 _run: &AgentRunContext,
356 _tool_call: &ProposedToolCall,
357 messages: &[Message],
358 ) -> Result<(), stakpak_agent_core::AgentError> {
359 self.checkpoint_runtime.update_messages(messages).await;
360 Ok(())
361 }
362
363 async fn on_error(
364 &self,
365 _run: &AgentRunContext,
366 _error: &stakpak_agent_core::AgentError,
367 messages: &[Message],
368 ) -> Result<(), stakpak_agent_core::AgentError> {
369 self.checkpoint_runtime.update_messages(messages).await;
370 let _ = self.checkpoint_runtime.persist_snapshot().await;
371 Ok(())
372 }
373}
374
375async fn execute_mcp_tool_call(
376 state: &AppState,
377 session_id: Uuid,
378 run_id: Uuid,
379 tool_call: &ProposedToolCall,
380 cancel: &CancellationToken,
381) -> ToolExecutionResult {
382 let Some(mcp_client) = state.mcp_client.as_ref() else {
383 return ToolExecutionResult::Completed {
384 result: "MCP client is not initialized".to_string(),
385 is_error: true,
386 };
387 };
388
389 let metadata = Some(serde_json::Map::from_iter([
390 (
391 "session_id".to_string(),
392 serde_json::Value::String(session_id.to_string()),
393 ),
394 (
395 "run_id".to_string(),
396 serde_json::Value::String(run_id.to_string()),
397 ),
398 (
399 "tool_call_id".to_string(),
400 serde_json::Value::String(tool_call.id.clone()),
401 ),
402 ]));
403
404 let arguments = match &tool_call.arguments {
405 serde_json::Value::Object(map) => Some(map.clone()),
406 serde_json::Value::Null => None,
407 other => Some(serde_json::Map::from_iter([(
408 "input".to_string(),
409 other.clone(),
410 )])),
411 };
412
413 let request_handle = match stakpak_mcp_client::call_tool(
414 mcp_client,
415 CallToolRequestParam {
416 name: tool_call.name.clone().into(),
417 arguments,
418 },
419 metadata,
420 )
421 .await
422 {
423 Ok(handle) => handle,
424 Err(error) => {
425 return ToolExecutionResult::Completed {
426 result: format!("MCP tool call failed: {error}"),
427 is_error: true,
428 };
429 }
430 };
431
432 let peer_for_cancel = request_handle.peer.clone();
433 let request_id = request_handle.id.clone();
434
435 tokio::select! {
436 _ = cancel.cancelled() => {
437 let notification = CancelledNotification {
438 method: CancelledNotificationMethod,
439 params: CancelledNotificationParam {
440 request_id,
441 reason: Some("user cancel".to_string()),
442 },
443 extensions: Default::default(),
444 };
445
446 let _ = peer_for_cancel.send_notification(notification.into()).await;
447 ToolExecutionResult::Cancelled
448 }
449 server_result = request_handle.await_response() => {
450 match server_result {
451 Ok(ServerResult::CallToolResult(result)) => {
452 ToolExecutionResult::Completed {
453 result: render_call_tool_result(&result),
454 is_error: result.is_error.unwrap_or(false),
455 }
456 }
457 Ok(_) => ToolExecutionResult::Completed {
458 result: "Unexpected MCP response type".to_string(),
459 is_error: true,
460 },
461 Err(error) => ToolExecutionResult::Completed {
462 result: format!("MCP tool execution error: {error}"),
463 is_error: true,
464 },
465 }
466 }
467 }
468}
469
470fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
471 let rendered = result
472 .content
473 .iter()
474 .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
475 .collect::<Vec<_>>()
476 .join("\n");
477
478 if !rendered.is_empty() {
479 return sanitize_text_output(&rendered);
480 }
481
482 if result.content.is_empty() {
483 return "<empty tool result>".to_string();
484 }
485
486 "<non-text tool result omitted for safety>".to_string()
487}
488
489fn checkpoint_signature(messages: &[Message]) -> Result<String, String> {
490 serde_json::to_string(messages)
491 .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
492}
493
494async fn persist_checkpoint(
495 state: &AppState,
496 session_id: Uuid,
497 run_id: Uuid,
498 active_model: &stakai::Model,
499 parent_id: Option<Uuid>,
500 messages: &[Message],
501) -> Result<Uuid, String> {
502 let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages));
505
506 if let Some(parent_id) = parent_id {
507 request = request.with_parent(parent_id);
508 }
509
510 let checkpoint = state
511 .session_store
512 .create_checkpoint(session_id, &request)
513 .await
514 .map_err(|error| error.to_string())?;
515
516 let envelope = build_checkpoint_envelope(
517 run_id,
518 messages.to_vec(),
519 json!({
520 "session_id": session_id.to_string(),
521 "checkpoint_id": checkpoint.id.to_string(),
522 (ACTIVE_MODEL_METADATA_KEY): format!("{}/{}", active_model.provider, active_model.id),
523 }),
524 );
525
526 state
527 .checkpoint_store
528 .save_latest(session_id, &envelope)
529 .await
530 .map_err(|error| {
531 format!(
532 "Failed to persist checkpoint envelope for session {}: {}",
533 session_id, error
534 )
535 })?;
536
537 Ok(checkpoint.id)
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use rmcp::model::{CallToolResult, Content};
544 use serde_json::json;
545 use stakai::{Message, Role};
546
547 #[test]
548 fn run_id_is_not_regenerated_when_building_run_context() {
549 let session_id = Uuid::new_v4();
550 let run_id = Uuid::new_v4();
551
552 let run_context = build_run_context(session_id, run_id);
553
554 assert_eq!(run_context.session_id, session_id);
555 assert_eq!(run_context.run_id, run_id);
556 }
557
558 #[test]
559 fn checkpoint_envelope_carries_same_run_id() {
560 let run_id = Uuid::new_v4();
561 let envelope = build_checkpoint_envelope(
562 run_id,
563 vec![Message::new(Role::User, "hello")],
564 json!({"turn": 1}),
565 );
566
567 assert_eq!(envelope.run_id, Some(run_id));
568 }
569
570 #[test]
571 fn render_call_tool_result_sanitizes_text_blocks() {
572 let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
573
574 assert_eq!(render_call_tool_result(&result), "okdone");
575 }
576
577 #[test]
578 fn render_call_tool_result_omits_non_text_blocks() {
579 let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
580
581 assert_eq!(
582 render_call_tool_result(&result),
583 "<non-text tool result omitted for safety>"
584 );
585 }
586
587 #[test]
588 fn checkpoint_signature_changes_when_messages_change() {
589 let messages_a = vec![Message::new(Role::User, "hello")];
590 let messages_b = vec![
591 Message::new(Role::User, "hello"),
592 Message::new(Role::Assistant, "hi"),
593 ];
594
595 let sig_a = checkpoint_signature(&messages_a)
596 .unwrap_or_else(|error| panic!("signature failed: {error}"));
597 let sig_b = checkpoint_signature(&messages_b)
598 .unwrap_or_else(|error| panic!("signature failed: {error}"));
599
600 assert_ne!(sig_a, sig_b);
601 }
602}