Skip to main content

stakpak_server/
session_actor.rs

1use crate::{message_bridge, state::AppState, types::SessionHandle};
2use async_trait::async_trait;
3use rmcp::model::{
4    CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
5    CancelledNotificationParam, ServerResult,
6};
7use serde_json::json;
8use stakai::Message;
9use stakpak_agent_core::{
10    AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, CheckpointEnvelopeV1,
11    CompactionConfig, ContextConfig, PassthroughCompactionEngine, ProposedToolCall, RetryConfig,
12    ToolExecutionResult, ToolExecutor, run_agent,
13};
14use stakpak_api::CreateCheckpointRequest;
15use stakpak_shared::utils::sanitize_text_output;
16use std::sync::Arc;
17use tokio::sync::{Mutex, mpsc};
18use tokio_util::sync::CancellationToken;
19use uuid::Uuid;
20
21const MAX_TURNS: usize = 64;
22const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
23pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
24
25pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
26    AgentRunContext { run_id, session_id }
27}
28
29pub fn build_checkpoint_envelope(
30    run_id: Uuid,
31    messages: Vec<stakai::Message>,
32    metadata: serde_json::Value,
33) -> CheckpointEnvelopeV1 {
34    CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
35}
36
37pub fn spawn_session_actor(
38    state: AppState,
39    session_id: Uuid,
40    run_id: Uuid,
41    model: stakai::Model,
42    user_message: Message,
43) -> Result<SessionHandle, String> {
44    let (command_tx, command_rx) = mpsc::channel(128);
45    let cancel = CancellationToken::new();
46
47    let handle = SessionHandle::new(command_tx, cancel.clone());
48
49    let state_for_task = state.clone();
50    tokio::spawn(async move {
51        let actor_result = run_session_actor(
52            state_for_task.clone(),
53            session_id,
54            run_id,
55            model,
56            user_message,
57            command_rx,
58            cancel,
59        )
60        .await;
61
62        let finish_result = actor_result.map(|_| ());
63        let _ = state_for_task
64            .run_manager
65            .mark_run_finished(session_id, run_id, finish_result)
66            .await;
67    });
68
69    Ok(handle)
70}
71
72async fn run_session_actor(
73    state: AppState,
74    session_id: Uuid,
75    run_id: Uuid,
76    model: stakai::Model,
77    user_message: Message,
78    command_rx: mpsc::Receiver<AgentCommand>,
79    cancel: CancellationToken,
80) -> Result<(), String> {
81    let active_checkpoint = state
82        .session_store
83        .get_active_checkpoint(session_id)
84        .await
85        .ok();
86    let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
87
88    let initial_messages = match state.checkpoint_store.load_latest(session_id).await {
89        Ok(Some(envelope)) => envelope.messages,
90        Ok(None) => active_checkpoint
91            .map(|checkpoint| message_bridge::chat_to_stakai(checkpoint.state.messages))
92            .unwrap_or_default(),
93        Err(error) => {
94            return Err(format!("Failed to load checkpoint envelope: {error}"));
95        }
96    };
97
98    let mut baseline_messages = initial_messages.clone();
99    baseline_messages.push(user_message.clone());
100
101    let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
102        state.clone(),
103        session_id,
104        run_id,
105        model.clone(),
106        parent_checkpoint_id,
107        baseline_messages,
108    ));
109
110    checkpoint_runtime
111        .persist_snapshot()
112        .await
113        .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
114
115    let periodic_checkpoint_cancel = CancellationToken::new();
116    let periodic_checkpoint_runtime = checkpoint_runtime.clone();
117    let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
118    let periodic_task = tokio::spawn(async move {
119        let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
120        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
121
122        loop {
123            tokio::select! {
124                _ = periodic_checkpoint_cancel_task.cancelled() => break,
125                _ = interval.tick() => {
126                    let _ = periodic_checkpoint_runtime.persist_snapshot().await;
127                }
128            }
129        }
130    });
131
132    let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
133
134    let event_state = state.clone();
135    let event_forwarder = tokio::spawn(async move {
136        while let Some(event) = core_event_rx.recv().await {
137            handle_core_event(&event_state, session_id, run_id, event).await;
138        }
139    });
140
141    let run_tools = state.current_mcp_tools().await;
142
143    let agent_config = AgentConfig {
144        model,
145        system_prompt: String::new(),
146        max_turns: MAX_TURNS,
147        max_output_tokens: 0,
148        provider_options: None,
149        tool_approval: state.tool_approval_policy.clone(),
150        context: ContextConfig::default(),
151        retry: RetryConfig::default(),
152        compaction: CompactionConfig::default(),
153        tools: run_tools,
154    };
155
156    let tool_executor = ServerToolExecutor {
157        state: state.clone(),
158    };
159
160    let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
161        checkpoint_runtime: checkpoint_runtime.clone(),
162    })];
163
164    let compactor = PassthroughCompactionEngine;
165    let run_context = build_run_context(session_id, run_id);
166
167    let run_result = run_agent(
168        run_context,
169        state.inference.as_ref(),
170        &agent_config,
171        initial_messages,
172        user_message,
173        &tool_executor,
174        &hooks,
175        core_event_tx,
176        command_rx,
177        cancel,
178        &compactor,
179    )
180    .await;
181
182    periodic_checkpoint_cancel.cancel();
183    let _ = periodic_task.await;
184
185    state.clear_pending_tools(session_id, run_id).await;
186
187    match &run_result {
188        Ok(result) => {
189            checkpoint_runtime.update_messages(&result.messages).await;
190            checkpoint_runtime
191                .persist_snapshot()
192                .await
193                .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
194        }
195        Err(_) => {
196            let _ = checkpoint_runtime.persist_snapshot().await;
197        }
198    }
199
200    let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
201
202    run_result
203        .map(|_| ())
204        .map_err(|error| format!("Agent run failed: {error}"))
205}
206
207async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
208    match &event {
209        AgentEvent::ToolCallsProposed { tool_calls, .. } => {
210            state
211                .set_pending_tools(session_id, run_id, tool_calls.clone())
212                .await;
213        }
214        AgentEvent::TurnCompleted { .. }
215        | AgentEvent::RunCompleted { .. }
216        | AgentEvent::RunError { .. } => {
217            state.clear_pending_tools(session_id, run_id).await;
218        }
219        _ => {}
220    }
221
222    state.events.publish(session_id, Some(run_id), event).await;
223}
224
225#[derive(Clone)]
226struct ServerToolExecutor {
227    state: AppState,
228}
229
230#[async_trait]
231impl ToolExecutor for ServerToolExecutor {
232    async fn execute_tool_call(
233        &self,
234        run: &AgentRunContext,
235        tool_call: &ProposedToolCall,
236        cancel: &CancellationToken,
237    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
238        Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
239    }
240}
241
242struct CheckpointRuntime {
243    state: AppState,
244    session_id: Uuid,
245    run_id: Uuid,
246    active_model: stakai::Model,
247    inner: Mutex<CheckpointRuntimeInner>,
248}
249
250struct CheckpointRuntimeInner {
251    parent_checkpoint_id: Option<Uuid>,
252    latest_messages: Vec<Message>,
253    last_persisted_signature: Option<String>,
254    dirty: bool,
255}
256
257impl CheckpointRuntime {
258    fn new(
259        state: AppState,
260        session_id: Uuid,
261        run_id: Uuid,
262        active_model: stakai::Model,
263        parent_checkpoint_id: Option<Uuid>,
264        latest_messages: Vec<Message>,
265    ) -> Self {
266        Self {
267            state,
268            session_id,
269            run_id,
270            active_model,
271            inner: Mutex::new(CheckpointRuntimeInner {
272                parent_checkpoint_id,
273                latest_messages,
274                last_persisted_signature: None,
275                dirty: true,
276            }),
277        }
278    }
279
280    async fn update_messages(&self, messages: &[Message]) {
281        let mut guard = self.inner.lock().await;
282        guard.latest_messages = messages.to_vec();
283        guard.dirty = true;
284    }
285
286    async fn persist_snapshot(&self) -> Result<Uuid, String> {
287        let mut guard = self.inner.lock().await;
288        self.persist_if_needed(&mut guard).await
289    }
290
291    async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
292        if !guard.dirty
293            && let Some(checkpoint_id) = guard.parent_checkpoint_id
294        {
295            return Ok(checkpoint_id);
296        }
297
298        let signature = checkpoint_signature(&guard.latest_messages)?;
299        let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
300        let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
301
302        if !should_persist {
303            guard.dirty = false;
304            if let Some(checkpoint_id) = guard.parent_checkpoint_id {
305                return Ok(checkpoint_id);
306            }
307        }
308
309        let checkpoint_id = persist_checkpoint(
310            &self.state,
311            self.session_id,
312            self.run_id,
313            &self.active_model,
314            guard.parent_checkpoint_id,
315            &guard.latest_messages,
316        )
317        .await?;
318
319        guard.parent_checkpoint_id = Some(checkpoint_id);
320        guard.last_persisted_signature = Some(signature);
321        guard.dirty = false;
322
323        Ok(checkpoint_id)
324    }
325}
326
327struct ServerCheckpointHook {
328    checkpoint_runtime: Arc<CheckpointRuntime>,
329}
330
331#[async_trait]
332impl AgentHook for ServerCheckpointHook {
333    async fn before_inference(
334        &self,
335        _run: &AgentRunContext,
336        messages: &[Message],
337        _model: &stakai::Model,
338    ) -> Result<(), stakpak_agent_core::AgentError> {
339        self.checkpoint_runtime.update_messages(messages).await;
340        Ok(())
341    }
342
343    async fn after_inference(
344        &self,
345        _run: &AgentRunContext,
346        messages: &[Message],
347        _model: &stakai::Model,
348    ) -> Result<(), stakpak_agent_core::AgentError> {
349        self.checkpoint_runtime.update_messages(messages).await;
350        Ok(())
351    }
352
353    async fn after_tool_execution(
354        &self,
355        _run: &AgentRunContext,
356        _tool_call: &ProposedToolCall,
357        messages: &[Message],
358    ) -> Result<(), stakpak_agent_core::AgentError> {
359        self.checkpoint_runtime.update_messages(messages).await;
360        Ok(())
361    }
362
363    async fn on_error(
364        &self,
365        _run: &AgentRunContext,
366        _error: &stakpak_agent_core::AgentError,
367        messages: &[Message],
368    ) -> Result<(), stakpak_agent_core::AgentError> {
369        self.checkpoint_runtime.update_messages(messages).await;
370        let _ = self.checkpoint_runtime.persist_snapshot().await;
371        Ok(())
372    }
373}
374
375async fn execute_mcp_tool_call(
376    state: &AppState,
377    session_id: Uuid,
378    run_id: Uuid,
379    tool_call: &ProposedToolCall,
380    cancel: &CancellationToken,
381) -> ToolExecutionResult {
382    let Some(mcp_client) = state.mcp_client.as_ref() else {
383        return ToolExecutionResult::Completed {
384            result: "MCP client is not initialized".to_string(),
385            is_error: true,
386        };
387    };
388
389    let metadata = Some(serde_json::Map::from_iter([
390        (
391            "session_id".to_string(),
392            serde_json::Value::String(session_id.to_string()),
393        ),
394        (
395            "run_id".to_string(),
396            serde_json::Value::String(run_id.to_string()),
397        ),
398        (
399            "tool_call_id".to_string(),
400            serde_json::Value::String(tool_call.id.clone()),
401        ),
402    ]));
403
404    let arguments = match &tool_call.arguments {
405        serde_json::Value::Object(map) => Some(map.clone()),
406        serde_json::Value::Null => None,
407        other => Some(serde_json::Map::from_iter([(
408            "input".to_string(),
409            other.clone(),
410        )])),
411    };
412
413    let request_handle = match stakpak_mcp_client::call_tool(
414        mcp_client,
415        CallToolRequestParam {
416            name: tool_call.name.clone().into(),
417            arguments,
418        },
419        metadata,
420    )
421    .await
422    {
423        Ok(handle) => handle,
424        Err(error) => {
425            return ToolExecutionResult::Completed {
426                result: format!("MCP tool call failed: {error}"),
427                is_error: true,
428            };
429        }
430    };
431
432    let peer_for_cancel = request_handle.peer.clone();
433    let request_id = request_handle.id.clone();
434
435    tokio::select! {
436        _ = cancel.cancelled() => {
437            let notification = CancelledNotification {
438                method: CancelledNotificationMethod,
439                params: CancelledNotificationParam {
440                    request_id,
441                    reason: Some("user cancel".to_string()),
442                },
443                extensions: Default::default(),
444            };
445
446            let _ = peer_for_cancel.send_notification(notification.into()).await;
447            ToolExecutionResult::Cancelled
448        }
449        server_result = request_handle.await_response() => {
450            match server_result {
451                Ok(ServerResult::CallToolResult(result)) => {
452                    ToolExecutionResult::Completed {
453                        result: render_call_tool_result(&result),
454                        is_error: result.is_error.unwrap_or(false),
455                    }
456                }
457                Ok(_) => ToolExecutionResult::Completed {
458                    result: "Unexpected MCP response type".to_string(),
459                    is_error: true,
460                },
461                Err(error) => ToolExecutionResult::Completed {
462                    result: format!("MCP tool execution error: {error}"),
463                    is_error: true,
464                },
465            }
466        }
467    }
468}
469
470fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
471    let rendered = result
472        .content
473        .iter()
474        .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
475        .collect::<Vec<_>>()
476        .join("\n");
477
478    if !rendered.is_empty() {
479        return sanitize_text_output(&rendered);
480    }
481
482    if result.content.is_empty() {
483        return "<empty tool result>".to_string();
484    }
485
486    "<non-text tool result omitted for safety>".to_string()
487}
488
489fn checkpoint_signature(messages: &[Message]) -> Result<String, String> {
490    serde_json::to_string(messages)
491        .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
492}
493
494async fn persist_checkpoint(
495    state: &AppState,
496    session_id: Uuid,
497    run_id: Uuid,
498    active_model: &stakai::Model,
499    parent_id: Option<Uuid>,
500    messages: &[Message],
501) -> Result<Uuid, String> {
502    // TODO(ahmed): Migrate server/session checkpoint storage to `Vec<stakai::Message>` directly
503    // and remove the ChatMessage adapter conversion (`message_bridge::stakai_to_chat`).
504    let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages));
505
506    if let Some(parent_id) = parent_id {
507        request = request.with_parent(parent_id);
508    }
509
510    let checkpoint = state
511        .session_store
512        .create_checkpoint(session_id, &request)
513        .await
514        .map_err(|error| error.to_string())?;
515
516    let envelope = build_checkpoint_envelope(
517        run_id,
518        messages.to_vec(),
519        json!({
520            "session_id": session_id.to_string(),
521            "checkpoint_id": checkpoint.id.to_string(),
522            (ACTIVE_MODEL_METADATA_KEY): format!("{}/{}", active_model.provider, active_model.id),
523        }),
524    );
525
526    state
527        .checkpoint_store
528        .save_latest(session_id, &envelope)
529        .await
530        .map_err(|error| {
531            format!(
532                "Failed to persist checkpoint envelope for session {}: {}",
533                session_id, error
534            )
535        })?;
536
537    Ok(checkpoint.id)
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use rmcp::model::{CallToolResult, Content};
544    use serde_json::json;
545    use stakai::{Message, Role};
546
547    #[test]
548    fn run_id_is_not_regenerated_when_building_run_context() {
549        let session_id = Uuid::new_v4();
550        let run_id = Uuid::new_v4();
551
552        let run_context = build_run_context(session_id, run_id);
553
554        assert_eq!(run_context.session_id, session_id);
555        assert_eq!(run_context.run_id, run_id);
556    }
557
558    #[test]
559    fn checkpoint_envelope_carries_same_run_id() {
560        let run_id = Uuid::new_v4();
561        let envelope = build_checkpoint_envelope(
562            run_id,
563            vec![Message::new(Role::User, "hello")],
564            json!({"turn": 1}),
565        );
566
567        assert_eq!(envelope.run_id, Some(run_id));
568    }
569
570    #[test]
571    fn render_call_tool_result_sanitizes_text_blocks() {
572        let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
573
574        assert_eq!(render_call_tool_result(&result), "okdone");
575    }
576
577    #[test]
578    fn render_call_tool_result_omits_non_text_blocks() {
579        let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
580
581        assert_eq!(
582            render_call_tool_result(&result),
583            "<non-text tool result omitted for safety>"
584        );
585    }
586
587    #[test]
588    fn checkpoint_signature_changes_when_messages_change() {
589        let messages_a = vec![Message::new(Role::User, "hello")];
590        let messages_b = vec![
591            Message::new(Role::User, "hello"),
592            Message::new(Role::Assistant, "hi"),
593        ];
594
595        let sig_a = checkpoint_signature(&messages_a)
596            .unwrap_or_else(|error| panic!("signature failed: {error}"));
597        let sig_b = checkpoint_signature(&messages_b)
598            .unwrap_or_else(|error| panic!("signature failed: {error}"));
599
600        assert_ne!(sig_a, sig_b);
601    }
602}