Skip to main content

steer_core/session/
state.rs

1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15/// State of an MCP server connection
16#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18    /// Currently attempting to connect
19    Connecting,
20    /// Successfully connected
21    Connected {
22        /// Names of tools available from this server
23        tool_names: Vec<String>,
24    },
25    /// Gracefully disconnected
26    Disconnected {
27        /// Optional reason for disconnection
28        reason: Option<String>,
29    },
30    /// Failed to connect
31    Failed {
32        /// Error message describing the failure
33        error: String,
34    },
35}
36
37/// Information about an MCP server
38#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40    /// The configured server name
41    pub server_name: String,
42    /// The transport configuration
43    pub transport: McpTransport,
44    /// Current connection state
45    pub state: McpConnectionState,
46    /// Timestamp when this state was last updated
47    pub last_updated: DateTime<Utc>,
48}
49
50/// Defines the primary execution environment for a session's workspace
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54    Local {
55        path: PathBuf,
56    },
57    Remote {
58        agent_address: String,
59        auth: Option<RemoteAuth>,
60    },
61}
62
63impl WorkspaceConfig {
64    pub fn get_path(&self) -> Option<String> {
65        match self {
66            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68        }
69    }
70
71    /// Convert to steer_workspace::WorkspaceConfig
72    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73        match self {
74            WorkspaceConfig::Local { path } => {
75                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76            }
77            WorkspaceConfig::Remote {
78                agent_address,
79                auth,
80            } => steer_workspace::WorkspaceConfig::Remote {
81                address: agent_address.clone(),
82                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83            },
84        }
85    }
86}
87
88impl Default for WorkspaceConfig {
89    fn default() -> Self {
90        Self::Local {
91            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92        }
93    }
94}
95
96/// Complete session representation
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99    pub id: String,
100    pub created_at: DateTime<Utc>,
101    pub updated_at: DateTime<Utc>,
102    pub config: SessionConfig,
103    pub state: SessionState,
104}
105
106impl Session {
107    pub fn new(id: String, config: SessionConfig) -> Self {
108        let now = Utc::now();
109        Self {
110            id,
111            created_at: now,
112            updated_at: now,
113            config,
114            state: SessionState::default(),
115        }
116    }
117
118    pub fn update_timestamp(&mut self) {
119        self.updated_at = Utc::now();
120    }
121
122    /// Check if session has any recent activity
123    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124        let cutoff = Utc::now() - threshold;
125        self.updated_at > cutoff
126    }
127
128    /// Build a workspace from this session's configuration
129    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131    }
132}
133
134/// Session configuration - immutable once created
135#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137    pub workspace: WorkspaceConfig,
138    #[serde(default)]
139    pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140    #[serde(default)]
141    pub workspace_id: Option<crate::workspace::WorkspaceId>,
142    #[serde(default)]
143    pub repo_ref: Option<crate::workspace::RepoRef>,
144    #[serde(default)]
145    pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146    #[serde(default)]
147    pub workspace_name: Option<String>,
148    pub tool_config: SessionToolConfig,
149    /// Optional custom system prompt to use for the session. If `None`, Steer will
150    /// fall back to its built-in default prompt.
151    pub system_prompt: Option<String>,
152    /// Primary agent mode for this session. Defaults to "normal" if unset.
153    #[serde(default)]
154    pub primary_agent_id: Option<String>,
155    /// User-controlled policy overrides that apply on top of the primary agent base policy.
156    #[serde(default = "SessionPolicyOverrides::empty")]
157    pub policy_overrides: SessionPolicyOverrides,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub title: Option<String>,
160    pub metadata: HashMap<String, String>,
161    pub default_model: ModelId,
162    #[serde(default)]
163    pub auto_compaction: AutoCompactionConfig,
164}
165
166impl SessionConfig {
167    /// Build a BackendRegistry from MCP server configurations.
168    /// Returns the registry and a map of MCP server connection states.
169    pub async fn build_registry(
170        &self,
171    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
172        let mut registry = BackendRegistry::new();
173        let mut mcp_servers = HashMap::new();
174
175        for backend_config in &self.tool_config.backends {
176            let BackendConfig::Mcp {
177                server_name,
178                transport,
179                tool_filter,
180            } = backend_config;
181
182            tracing::info!(
183                "Attempting to initialize MCP backend '{}' with transport: {:?}",
184                server_name,
185                transport
186            );
187
188            let mut server_info = McpServerInfo {
189                server_name: server_name.clone(),
190                transport: transport.clone(),
191                state: McpConnectionState::Connecting,
192                last_updated: Utc::now(),
193            };
194
195            match crate::tools::McpBackend::new(
196                server_name.clone(),
197                transport.clone(),
198                tool_filter.clone(),
199            )
200            .await
201            {
202                Ok(mcp_backend) => {
203                    let tool_names = mcp_backend.supported_tools().await;
204                    let tool_count = tool_names.len();
205                    tracing::info!(
206                        "Successfully initialized MCP backend '{}' with {} tools",
207                        server_name,
208                        tool_count
209                    );
210                    server_info.state = McpConnectionState::Connected { tool_names };
211                    server_info.last_updated = Utc::now();
212                    registry
213                        .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
214                        .await;
215                }
216                Err(e) => {
217                    tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
218                    server_info.state = McpConnectionState::Failed {
219                        error: e.to_string(),
220                    };
221                    server_info.last_updated = Utc::now();
222                }
223            }
224
225            mcp_servers.insert(server_name.clone(), server_info);
226        }
227
228        Ok((registry, mcp_servers))
229    }
230
231    /// Filter tools based on visibility settings
232    pub fn filter_tools_by_visibility(
233        &self,
234        tools: Vec<steer_tools::ToolSchema>,
235    ) -> Vec<steer_tools::ToolSchema> {
236        match &self.tool_config.visibility {
237            ToolVisibility::All => tools,
238            ToolVisibility::ReadOnly => {
239                let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
240                    .iter()
241                    .map(|name| (*name).to_string())
242                    .collect();
243
244                tools
245                    .into_iter()
246                    .filter(|schema| read_only_names.contains(&schema.name))
247                    .collect()
248            }
249            ToolVisibility::Whitelist(allowed) => tools
250                .into_iter()
251                .filter(|schema| allowed.contains(&schema.name))
252                .collect(),
253            ToolVisibility::Blacklist(blocked) => tools
254                .into_iter()
255                .filter(|schema| !blocked.contains(&schema.name))
256                .collect(),
257        }
258    }
259
260    /// Minimal read-only configuration
261    #[cfg(test)]
262    pub fn read_only(default_model: ModelId) -> Self {
263        Self {
264            workspace: WorkspaceConfig::Local {
265                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
266            },
267            workspace_ref: None,
268            workspace_id: None,
269            repo_ref: None,
270            parent_session_id: None,
271            workspace_name: None,
272            tool_config: SessionToolConfig::read_only(),
273            system_prompt: None,
274            primary_agent_id: None,
275            policy_overrides: SessionPolicyOverrides::empty(),
276            title: None,
277            metadata: HashMap::new(),
278            default_model,
279            auto_compaction: AutoCompactionConfig::default(),
280        }
281    }
282}
283
284/// Configuration for automatic context-window compaction.
285#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
286pub struct AutoCompactionConfig {
287    pub enabled: bool,
288    pub threshold_percent: u32,
289}
290
291impl Default for AutoCompactionConfig {
292    fn default() -> Self {
293        Self {
294            enabled: true,
295            threshold_percent: 90,
296        }
297    }
298}
299
300impl AutoCompactionConfig {
301    pub fn threshold_ratio(&self) -> f64 {
302        f64::from(self.threshold_percent) / 100.0
303    }
304}
305
306/// User-controlled policy overrides applied on top of a primary agent base policy.
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct SessionPolicyOverrides {
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub default_model: Option<ModelId>,
311    #[serde(default, skip_serializing_if = "Option::is_none")]
312    pub tool_visibility: Option<ToolVisibility>,
313    #[serde(default = "ToolApprovalPolicyOverrides::empty")]
314    pub approval_policy: ToolApprovalPolicyOverrides,
315}
316
317impl SessionPolicyOverrides {
318    pub fn empty() -> Self {
319        Self {
320            default_model: None,
321            tool_visibility: None,
322            approval_policy: ToolApprovalPolicyOverrides::empty(),
323        }
324    }
325
326    pub fn is_empty(&self) -> bool {
327        self.default_model.is_none()
328            && self.tool_visibility.is_none()
329            && self.approval_policy.is_empty()
330    }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
334pub struct ToolApprovalPolicyOverrides {
335    #[serde(default = "ApprovalRulesOverrides::empty")]
336    pub preapproved: ApprovalRulesOverrides,
337}
338
339impl ToolApprovalPolicyOverrides {
340    pub fn empty() -> Self {
341        Self {
342            preapproved: ApprovalRulesOverrides::empty(),
343        }
344    }
345
346    pub fn is_empty(&self) -> bool {
347        self.preapproved.is_empty()
348    }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
352pub struct ApprovalRulesOverrides {
353    #[serde(default)]
354    pub tools: HashSet<String>,
355    #[serde(default)]
356    pub per_tool: HashMap<String, ToolRuleOverrides>,
357}
358
359impl ApprovalRulesOverrides {
360    pub fn empty() -> Self {
361        Self {
362            tools: HashSet::new(),
363            per_tool: HashMap::new(),
364        }
365    }
366
367    pub fn is_empty(&self) -> bool {
368        self.tools.is_empty() && self.per_tool.is_empty()
369    }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
373#[serde(tag = "type", rename_all = "snake_case")]
374pub enum ToolRuleOverrides {
375    Bash { patterns: Vec<String> },
376    DispatchAgent { agent_patterns: Vec<String> },
377}
378
379/// Tool visibility configuration - controls which tools are shown to the AI agent
380#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
381#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
382pub enum ToolVisibility {
383    /// Show all registered tools to the AI
384    #[default]
385    All,
386
387    /// Only show read-only tools to the AI
388    ReadOnly,
389
390    /// Show only specific tools to the AI (whitelist)
391    Whitelist(HashSet<String>),
392
393    /// Hide specific tools from the AI (blacklist)
394    Blacklist(HashSet<String>),
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum ToolDecision {
399    Allow,
400    Ask,
401    Deny,
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
405#[serde(rename_all = "snake_case")]
406pub enum UnapprovedBehavior {
407    #[default]
408    Prompt,
409    Deny,
410    Allow,
411}
412
413#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum ToolRule {
416    Bash { patterns: Vec<String> },
417    DispatchAgent { agent_patterns: Vec<String> },
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
421pub struct ApprovalRules {
422    #[serde(default)]
423    pub tools: HashSet<String>,
424    #[serde(default)]
425    pub per_tool: HashMap<String, ToolRule>,
426}
427
428impl ApprovalRules {
429    pub fn is_empty(&self) -> bool {
430        self.tools.is_empty() && self.per_tool.is_empty()
431    }
432
433    pub fn bash_patterns(&self) -> Option<&[String]> {
434        self.per_tool.get("bash").and_then(|rule| match rule {
435            ToolRule::Bash { patterns } => Some(patterns.as_slice()),
436            ToolRule::DispatchAgent { .. } => None,
437        })
438    }
439
440    pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
441        self.per_tool
442            .get("dispatch_agent")
443            .and_then(|rule| match rule {
444                ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
445                ToolRule::Bash { .. } => None,
446            })
447    }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
451pub struct ToolApprovalPolicy {
452    pub default_behavior: UnapprovedBehavior,
453    #[serde(default)]
454    pub preapproved: ApprovalRules,
455}
456
457impl Default for ToolApprovalPolicy {
458    fn default() -> Self {
459        Self {
460            default_behavior: UnapprovedBehavior::Prompt,
461            preapproved: ApprovalRules {
462                tools: READ_ONLY_TOOL_NAMES
463                    .iter()
464                    .map(|name| (*name).to_string())
465                    .collect(),
466                per_tool: HashMap::new(),
467            },
468        }
469    }
470}
471
472impl ToolApprovalPolicy {
473    pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
474        if self.preapproved.tools.contains(tool_name) {
475            ToolDecision::Allow
476        } else {
477            match self.default_behavior {
478                UnapprovedBehavior::Prompt => ToolDecision::Ask,
479                UnapprovedBehavior::Deny => ToolDecision::Deny,
480                UnapprovedBehavior::Allow => ToolDecision::Allow,
481            }
482        }
483    }
484
485    pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
486        let Some(patterns) = self.preapproved.bash_patterns() else {
487            return false;
488        };
489        patterns.iter().any(|pattern| {
490            if pattern == command {
491                return true;
492            }
493            glob::Pattern::new(pattern)
494                .map(|glob| glob.matches(command))
495                .unwrap_or(false)
496        })
497    }
498
499    pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
500        let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
501            return false;
502        };
503        patterns.iter().any(|pattern| {
504            if pattern == agent_id {
505                return true;
506            }
507            glob::Pattern::new(pattern)
508                .map(|glob| glob.matches(agent_id))
509                .unwrap_or(false)
510        })
511    }
512
513    pub fn pre_approved_tools(&self) -> &HashSet<String> {
514        &self.preapproved.tools
515    }
516}
517
518impl ToolApprovalPolicyOverrides {
519    pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
520        let mut merged = base.clone();
521
522        if !self.preapproved.tools.is_empty() {
523            merged
524                .preapproved
525                .tools
526                .extend(self.preapproved.tools.iter().cloned());
527        }
528
529        for (tool_name, override_rule) in &self.preapproved.per_tool {
530            let base_rule = merged.preapproved.per_tool.get(tool_name);
531            let merged_rule = merge_tool_rule_override(base_rule, override_rule);
532            merged
533                .preapproved
534                .per_tool
535                .insert(tool_name.clone(), merged_rule);
536        }
537
538        merged
539    }
540}
541
542fn merge_tool_rule_override(
543    base: Option<&ToolRule>,
544    override_rule: &ToolRuleOverrides,
545) -> ToolRule {
546    match (base, override_rule) {
547        (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
548            ToolRule::Bash {
549                patterns: merge_patterns(patterns, extra),
550            }
551        }
552        (
553            Some(ToolRule::DispatchAgent { agent_patterns }),
554            ToolRuleOverrides::DispatchAgent {
555                agent_patterns: extra,
556            },
557        ) => ToolRule::DispatchAgent {
558            agent_patterns: merge_patterns(agent_patterns, extra),
559        },
560        (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
561            patterns: patterns.clone(),
562        },
563        (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
564            agent_patterns: agent_patterns.clone(),
565        },
566    }
567}
568
569fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
570    let mut merged = base.to_vec();
571    for pattern in extra {
572        if !merged.iter().any(|existing| existing == pattern) {
573            merged.push(pattern.clone());
574        }
575    }
576    merged
577}
578
579/// Authentication configuration for remote backends
580#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
581pub enum RemoteAuth {
582    Bearer { token: String },
583    ApiKey { key: String },
584}
585
586impl RemoteAuth {
587    /// Convert to steer_workspace RemoteAuth type
588    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
589        match self {
590            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
591            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
592        }
593    }
594}
595
596/// Tool filtering configuration for backends
597#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
598#[serde(rename_all = "snake_case")]
599#[derive(Default)]
600pub enum ToolFilter {
601    /// Include all available tools
602    #[default]
603    All,
604    /// Include only the specified tools
605    Include(Vec<String>),
606    /// Include all tools except the specified ones
607    Exclude(Vec<String>),
608}
609
610/// Configuration for MCP server backends
611#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
612#[serde(tag = "type", rename_all = "snake_case")]
613pub enum BackendConfig {
614    Mcp {
615        server_name: String,
616        transport: McpTransport,
617        tool_filter: ToolFilter,
618    },
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
622pub struct SessionToolConfig {
623    pub backends: Vec<BackendConfig>,
624    pub visibility: ToolVisibility,
625    pub approval_policy: ToolApprovalPolicy,
626    pub metadata: HashMap<String, String>,
627}
628
629impl Default for SessionToolConfig {
630    fn default() -> Self {
631        Self {
632            backends: Vec::new(),
633            visibility: ToolVisibility::All,
634            approval_policy: ToolApprovalPolicy::default(),
635            metadata: HashMap::new(),
636        }
637    }
638}
639
640impl SessionToolConfig {
641    pub fn read_only() -> Self {
642        Self {
643            backends: Vec::new(),
644            visibility: ToolVisibility::ReadOnly,
645            approval_policy: ToolApprovalPolicy::default(),
646            metadata: HashMap::new(),
647        }
648    }
649}
650
651/// Mutable session state that changes during execution
652#[derive(Debug, Clone, Serialize, Deserialize, Default)]
653pub struct SessionState {
654    /// Conversation messages
655    pub messages: Vec<Message>,
656
657    /// Tool call tracking
658    pub tool_calls: HashMap<String, ToolCallState>,
659
660    /// Tools that have been approved for this session
661    pub approved_tools: HashSet<String>,
662
663    /// Bash commands that have been approved for this session (dynamically added)
664    #[serde(default)]
665    pub approved_bash_patterns: HashSet<String>,
666
667    /// Last processed event sequence number for replay
668    pub last_event_sequence: u64,
669
670    /// Additional runtime metadata
671    pub metadata: HashMap<String, String>,
672
673    /// The ID of the currently active message (head of selected branch)
674    /// None means use last message semantics for backward compatibility
675    #[serde(default, skip_serializing_if = "Option::is_none")]
676    pub active_message_id: Option<String>,
677
678    /// Status of MCP server connections
679    /// This is a transient field that is rebuilt on session activation
680    #[serde(default, skip_serializing, skip_deserializing)]
681    pub mcp_servers: HashMap<String, McpServerInfo>,
682}
683
684impl SessionState {
685    /// Add a message to the conversation
686    pub fn add_message(&mut self, message: Message) {
687        self.messages.push(message);
688    }
689
690    /// Get the number of messages in the conversation
691    pub fn message_count(&self) -> usize {
692        self.messages.len()
693    }
694
695    /// Get the last message in the conversation
696    pub fn last_message(&self) -> Option<&Message> {
697        self.messages.last()
698    }
699
700    /// Add a tool call to tracking
701    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
702        let state = ToolCallState {
703            tool_call: tool_call.clone(),
704            status: ToolCallStatus::PendingApproval,
705            started_at: None,
706            completed_at: None,
707            result: None,
708        };
709        self.tool_calls.insert(tool_call.id, state);
710    }
711
712    /// Update tool call status
713    pub fn update_tool_call_status(
714        &mut self,
715        tool_call_id: &str,
716        status: ToolCallStatus,
717    ) -> std::result::Result<(), String> {
718        let tool_call = self
719            .tool_calls
720            .get_mut(tool_call_id)
721            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
722
723        // Update timestamps based on status changes
724        match (&tool_call.status, &status) {
725            (_, ToolCallStatus::Executing) => {
726                tool_call.started_at = Some(Utc::now());
727            }
728            (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
729                tool_call.completed_at = Some(Utc::now());
730            }
731            _ => {}
732        }
733
734        tool_call.status = status;
735        Ok(())
736    }
737
738    /// Approve a tool for future use
739    pub fn approve_tool(&mut self, tool_name: String) {
740        self.approved_tools.insert(tool_name);
741    }
742
743    /// Check if a tool is approved
744    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
745        self.approved_tools.contains(tool_name)
746    }
747
748    /// Validate internal consistency
749    pub fn validate(&self) -> std::result::Result<(), String> {
750        // Check that all tool calls referenced in messages exist
751        for message in &self.messages {
752            let tool_calls = Self::extract_tool_calls_from_message(message);
753            if !tool_calls.is_empty() {
754                for tool_call_id in tool_calls {
755                    if !self.tool_calls.contains_key(&tool_call_id) {
756                        return Err(format!(
757                            "Message references unknown tool call: {tool_call_id}"
758                        ));
759                    }
760                }
761            }
762        }
763
764        Ok(())
765    }
766
767    /// Extract tool call IDs from a message
768    fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
769        let mut tool_call_ids = Vec::new();
770
771        match &message.data {
772            MessageData::Assistant { content, .. } => {
773                for c in content {
774                    if let crate::app::conversation::AssistantContent::ToolCall {
775                        tool_call, ..
776                    } = c
777                    {
778                        tool_call_ids.push(tool_call.id.clone());
779                    }
780                }
781            }
782            MessageData::Tool { tool_use_id, .. } => {
783                tool_call_ids.push(tool_use_id.clone());
784            }
785            MessageData::User { .. } => {}
786        }
787
788        tool_call_ids
789    }
790}
791
792/// Tool call state tracking
793#[derive(Debug, Clone, Serialize, Deserialize)]
794pub struct ToolCallState {
795    pub tool_call: ToolCall,
796    pub status: ToolCallStatus,
797    pub started_at: Option<DateTime<Utc>>,
798    pub completed_at: Option<DateTime<Utc>>,
799    pub result: Option<ToolResult>,
800}
801
802impl ToolCallState {
803    pub fn is_pending(&self) -> bool {
804        matches!(self.status, ToolCallStatus::PendingApproval)
805    }
806
807    pub fn is_complete(&self) -> bool {
808        matches!(
809            self.status,
810            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
811        )
812    }
813
814    pub fn duration(&self) -> Option<chrono::Duration> {
815        match (self.started_at, self.completed_at) {
816            (Some(start), Some(end)) => Some(end - start),
817            _ => None,
818        }
819    }
820}
821
822/// Tool call execution status
823#[derive(Debug, Clone, Serialize, Deserialize)]
824#[serde(tag = "status", rename_all = "snake_case")]
825pub enum ToolCallStatus {
826    PendingApproval,
827    Approved,
828    Denied,
829    Executing,
830    Completed,
831    Failed { error: String },
832}
833
834impl ToolCallStatus {
835    pub fn is_terminal(&self) -> bool {
836        matches!(
837            self,
838            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
839        )
840    }
841}
842
843/// Tool execution statistics
844#[derive(Debug, Clone, Serialize, Deserialize)]
845pub struct ToolExecutionStats {
846    #[serde(skip_serializing_if = "Option::is_none")]
847    pub output: Option<String>, // Legacy string output
848    #[serde(skip_serializing_if = "Option::is_none")]
849    pub json_output: Option<serde_json::Value>, // Typed JSON output
850    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
851    pub success: bool,
852    pub execution_time_ms: u64,
853    pub metadata: HashMap<String, String>,
854}
855
856impl ToolExecutionStats {
857    pub fn success(output: String, execution_time_ms: u64) -> Self {
858        Self {
859            output: Some(output),
860            json_output: None,
861            result_type: None,
862            success: true,
863            execution_time_ms,
864            metadata: HashMap::new(),
865        }
866    }
867
868    pub fn success_typed(
869        json_output: serde_json::Value,
870        result_type: String,
871        execution_time_ms: u64,
872    ) -> Self {
873        Self {
874            output: None,
875            json_output: Some(json_output),
876            result_type: Some(result_type),
877            success: true,
878            execution_time_ms,
879            metadata: HashMap::new(),
880        }
881    }
882
883    pub fn failure(error: String, execution_time_ms: u64) -> Self {
884        Self {
885            output: Some(error),
886            json_output: None,
887            result_type: None,
888            success: false,
889            execution_time_ms,
890            metadata: HashMap::new(),
891        }
892    }
893
894    pub fn with_metadata(mut self, key: String, value: String) -> Self {
895        self.metadata.insert(key, value);
896        self
897    }
898}
899
900/// Session metadata for listing and filtering
901#[derive(Debug, Clone, Serialize, Deserialize)]
902pub struct SessionInfo {
903    pub id: String,
904    pub created_at: DateTime<Utc>,
905    pub updated_at: DateTime<Utc>,
906    /// The last known model used in this session
907    pub last_model: Option<ModelId>,
908    pub message_count: usize,
909    pub title: Option<String>,
910    pub metadata: HashMap<String, String>,
911}
912
913impl From<&Session> for SessionInfo {
914    fn from(session: &Session) -> Self {
915        Self {
916            id: session.id.clone(),
917            created_at: session.created_at,
918            updated_at: session.updated_at,
919            last_model: None, // TODO: Track last model used from events
920            message_count: session.state.message_count(),
921            title: session.config.title.clone(),
922            metadata: session.config.metadata.clone(),
923        }
924    }
925}
926
927#[cfg(test)]
928mod tests {
929    use super::*;
930    use crate::app::conversation::{Message, MessageData, UserContent};
931    use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
932    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
933    use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
934    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
935
936    #[test]
937    fn test_session_creation() {
938        let config = SessionConfig {
939            workspace: WorkspaceConfig::Local {
940                path: PathBuf::from("/test/path"),
941            },
942            workspace_ref: None,
943            workspace_id: None,
944            repo_ref: None,
945            parent_session_id: None,
946            workspace_name: None,
947            tool_config: SessionToolConfig::default(),
948            system_prompt: None,
949            primary_agent_id: None,
950            policy_overrides: SessionPolicyOverrides::empty(),
951            title: None,
952            metadata: HashMap::new(),
953            default_model: test_model(),
954            auto_compaction: AutoCompactionConfig::default(),
955        };
956        let session = Session::new("test-session".to_string(), config.clone());
957
958        assert_eq!(session.id, "test-session");
959        assert_eq!(
960            session
961                .config
962                .tool_config
963                .approval_policy
964                .tool_decision("any_tool"),
965            ToolDecision::Ask
966        );
967        assert_eq!(session.state.message_count(), 0);
968    }
969
970    #[test]
971    fn test_tool_approval_policy_prompt_unapproved() {
972        let policy = ToolApprovalPolicy {
973            default_behavior: UnapprovedBehavior::Prompt,
974            preapproved: ApprovalRules {
975                tools: ["read_file", "list_files"]
976                    .iter()
977                    .map(|s| (*s).to_string())
978                    .collect(),
979                per_tool: HashMap::new(),
980            },
981        };
982
983        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
984        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
985    }
986
987    #[test]
988    fn test_tool_approval_policy_deny_unapproved() {
989        let policy = ToolApprovalPolicy {
990            default_behavior: UnapprovedBehavior::Deny,
991            preapproved: ApprovalRules {
992                tools: ["read_file", "list_files"]
993                    .iter()
994                    .map(|s| (*s).to_string())
995                    .collect(),
996                per_tool: HashMap::new(),
997            },
998        };
999
1000        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1001        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
1002    }
1003
1004    #[test]
1005    fn test_tool_approval_policy_default() {
1006        let policy = ToolApprovalPolicy::default();
1007
1008        assert_eq!(
1009            policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1010            ToolDecision::Allow
1011        );
1012        assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1013    }
1014
1015    #[test]
1016    fn test_tool_approval_policy_allow_unapproved() {
1017        let policy = ToolApprovalPolicy {
1018            default_behavior: UnapprovedBehavior::Allow,
1019            preapproved: ApprovalRules {
1020                tools: ["read_file", "list_files"]
1021                    .iter()
1022                    .map(|s| (*s).to_string())
1023                    .collect(),
1024                per_tool: HashMap::new(),
1025            },
1026        };
1027
1028        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1029        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1030    }
1031
1032    #[test]
1033    fn test_tool_approval_policy_overrides_union_rules() {
1034        let base_policy = ToolApprovalPolicy {
1035            default_behavior: UnapprovedBehavior::Prompt,
1036            preapproved: ApprovalRules {
1037                tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1038                per_tool: [
1039                    (
1040                        BASH_TOOL_NAME.to_string(),
1041                        ToolRule::Bash {
1042                            patterns: vec!["git status".to_string()],
1043                        },
1044                    ),
1045                    (
1046                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1047                        ToolRule::DispatchAgent {
1048                            agent_patterns: vec!["explore".to_string()],
1049                        },
1050                    ),
1051                ]
1052                .into_iter()
1053                .collect(),
1054            },
1055        };
1056
1057        let overrides = ToolApprovalPolicyOverrides {
1058            preapproved: ApprovalRulesOverrides {
1059                tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1060                per_tool: [
1061                    (
1062                        BASH_TOOL_NAME.to_string(),
1063                        ToolRuleOverrides::Bash {
1064                            patterns: vec!["git log".to_string()],
1065                        },
1066                    ),
1067                    (
1068                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1069                        ToolRuleOverrides::DispatchAgent {
1070                            agent_patterns: vec!["review".to_string()],
1071                        },
1072                    ),
1073                ]
1074                .into_iter()
1075                .collect(),
1076            },
1077        };
1078
1079        let merged = overrides.apply_to(&base_policy);
1080
1081        assert_eq!(merged.default_behavior, UnapprovedBehavior::Prompt);
1082        assert!(merged.preapproved.tools.contains("read_file"));
1083        assert!(merged.preapproved.tools.contains("write_file"));
1084
1085        let bash_patterns = match merged
1086            .preapproved
1087            .per_tool
1088            .get(BASH_TOOL_NAME)
1089            .expect("bash rule")
1090        {
1091            ToolRule::Bash { patterns } => patterns,
1092            ToolRule::DispatchAgent { .. } => {
1093                panic!("Unexpected bash rule: dispatch agent")
1094            }
1095        };
1096        assert!(bash_patterns.contains(&"git status".to_string()));
1097        assert!(bash_patterns.contains(&"git log".to_string()));
1098        assert_eq!(bash_patterns.len(), 2);
1099
1100        let agent_patterns = match merged
1101            .preapproved
1102            .per_tool
1103            .get(DISPATCH_AGENT_TOOL_NAME)
1104            .expect("dispatch_agent rule")
1105        {
1106            ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1107            ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1108        };
1109        assert!(agent_patterns.contains(&"explore".to_string()));
1110        assert!(agent_patterns.contains(&"review".to_string()));
1111        assert_eq!(agent_patterns.len(), 2);
1112    }
1113
1114    #[test]
1115    fn test_bash_pattern_matching() {
1116        let policy = ToolApprovalPolicy {
1117            default_behavior: UnapprovedBehavior::Prompt,
1118            preapproved: ApprovalRules {
1119                tools: HashSet::new(),
1120                per_tool: [(
1121                    "bash".to_string(),
1122                    ToolRule::Bash {
1123                        patterns: vec![
1124                            "git status".to_string(),
1125                            "git log*".to_string(),
1126                            "git * --oneline".to_string(),
1127                            "ls -?a*".to_string(),
1128                            "cargo build*".to_string(),
1129                        ],
1130                    },
1131                )]
1132                .into_iter()
1133                .collect(),
1134            },
1135        };
1136
1137        assert!(policy.is_bash_pattern_preapproved("git status"));
1138        assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1139        assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1140        assert!(policy.is_bash_pattern_preapproved("ls -la"));
1141        assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1142        assert!(!policy.is_bash_pattern_preapproved("git commit"));
1143        assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1144        assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1145    }
1146
1147    #[test]
1148    fn test_dispatch_agent_pattern_matching() {
1149        let policy = ToolApprovalPolicy {
1150            default_behavior: UnapprovedBehavior::Prompt,
1151            preapproved: ApprovalRules {
1152                tools: HashSet::new(),
1153                per_tool: [(
1154                    "dispatch_agent".to_string(),
1155                    ToolRule::DispatchAgent {
1156                        agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1157                    },
1158                )]
1159                .into_iter()
1160                .collect(),
1161            },
1162        };
1163
1164        assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1165        assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1166        assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1167    }
1168
1169    #[test]
1170    fn test_session_state_validation() {
1171        let mut state = SessionState::default();
1172
1173        // Valid empty state
1174        assert!(state.validate().is_ok());
1175
1176        // Add a message
1177        let message = Message {
1178            data: MessageData::User {
1179                content: vec![UserContent::Text {
1180                    text: "Hello".to_string(),
1181                }],
1182            },
1183            timestamp: 123_456_789,
1184            id: "msg1".to_string(),
1185            parent_message_id: None,
1186        };
1187        state.add_message(message);
1188
1189        assert!(state.validate().is_ok());
1190        assert_eq!(state.message_count(), 1);
1191    }
1192
1193    #[test]
1194    fn test_tool_call_state_tracking() {
1195        let mut state = SessionState::default();
1196
1197        let tool_call = ToolCall {
1198            id: "tool1".to_string(),
1199            name: "read_file".to_string(),
1200            parameters: serde_json::json!({"path": "/test.txt"}),
1201        };
1202
1203        state.add_tool_call(tool_call.clone());
1204        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1205
1206        state
1207            .update_tool_call_status("tool1", ToolCallStatus::Executing)
1208            .unwrap();
1209        let tool_state = state.tool_calls.get("tool1").unwrap();
1210        assert!(tool_state.started_at.is_some());
1211        assert!(!tool_state.is_complete());
1212
1213        state
1214            .update_tool_call_status("tool1", ToolCallStatus::Completed)
1215            .unwrap();
1216        let tool_state = state.tool_calls.get("tool1").unwrap();
1217        assert!(tool_state.completed_at.is_some());
1218        assert!(tool_state.is_complete());
1219    }
1220
1221    #[test]
1222    fn test_session_tool_config_default() {
1223        let config = SessionToolConfig::default();
1224        assert!(config.backends.is_empty());
1225    }
1226
1227    #[test]
1228    fn test_tool_filter_exclude() {
1229        let excluded =
1230            ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1231
1232        if let ToolFilter::Exclude(tools) = &excluded {
1233            assert_eq!(tools.len(), 2);
1234            assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1235            assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1236        } else {
1237            panic!("Expected ToolFilter::Exclude");
1238        }
1239    }
1240
1241    #[test]
1242    fn test_session_tool_config_read_only() {
1243        let config = SessionToolConfig::read_only();
1244        assert_eq!(config.backends.len(), 0);
1245        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1246        assert_eq!(
1247            config.approval_policy.default_behavior,
1248            UnapprovedBehavior::Prompt
1249        );
1250    }
1251
1252    #[tokio::test]
1253    async fn test_session_config_build_registry_no_default_backends() {
1254        // Test that BackendRegistry only contains user-configured backends.
1255        // Builtin tools (dispatch_agent, web_fetch) are now in ToolRegistry,
1256        // not BackendRegistry.
1257        let config = SessionConfig {
1258            workspace: WorkspaceConfig::Local {
1259                path: PathBuf::from("/test/path"),
1260            },
1261            workspace_ref: None,
1262            workspace_id: None,
1263            repo_ref: None,
1264            parent_session_id: None,
1265            workspace_name: None,
1266            tool_config: SessionToolConfig::default(), // No backends configured
1267            system_prompt: None,
1268            primary_agent_id: None,
1269            policy_overrides: SessionPolicyOverrides::empty(),
1270            title: None,
1271            metadata: HashMap::new(),
1272            default_model: test_model(),
1273            auto_compaction: AutoCompactionConfig::default(),
1274        };
1275
1276        let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1277        let schemas = registry.get_tool_schemas().await;
1278
1279        assert!(
1280            schemas.is_empty(),
1281            "BackendRegistry should be empty with default config; got: {:?}",
1282            schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1283        );
1284    }
1285
1286    // Test removed: workspace tools are no longer in the registry
1287
1288    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
1289
1290    // Test removed: workspace backend no longer exists in the registry
1291
1292    #[test]
1293    fn test_mcp_status_tracking() {
1294        // Test that MCP server info is properly tracked in session state
1295        let mut session_state = SessionState::default();
1296
1297        // Add some MCP server info
1298        let mcp_info = McpServerInfo {
1299            server_name: "test-server".to_string(),
1300            transport: crate::tools::McpTransport::Stdio {
1301                command: "python".to_string(),
1302                args: vec!["-m".to_string(), "test_server".to_string()],
1303            },
1304            state: McpConnectionState::Connected {
1305                tool_names: vec![
1306                    "tool1".to_string(),
1307                    "tool2".to_string(),
1308                    "tool3".to_string(),
1309                    "tool4".to_string(),
1310                    "tool5".to_string(),
1311                ],
1312            },
1313            last_updated: Utc::now(),
1314        };
1315
1316        session_state
1317            .mcp_servers
1318            .insert("test-server".to_string(), mcp_info.clone());
1319
1320        // Verify it's stored
1321        assert_eq!(session_state.mcp_servers.len(), 1);
1322        let stored = session_state.mcp_servers.get("test-server").unwrap();
1323        assert_eq!(stored.server_name, "test-server");
1324        assert!(matches!(
1325            stored.state,
1326            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1327        ));
1328
1329        // Test failed connection
1330        let failed_info = McpServerInfo {
1331            server_name: "failed-server".to_string(),
1332            transport: crate::tools::McpTransport::Tcp {
1333                host: "localhost".to_string(),
1334                port: 9999,
1335            },
1336            state: McpConnectionState::Failed {
1337                error: "Connection refused".to_string(),
1338            },
1339            last_updated: Utc::now(),
1340        };
1341
1342        session_state
1343            .mcp_servers
1344            .insert("failed-server".to_string(), failed_info);
1345        assert_eq!(session_state.mcp_servers.len(), 2);
1346    }
1347
1348    #[tokio::test]
1349    async fn test_mcp_server_tracking_in_build_registry() {
1350        // Create a session config with both good and bad MCP servers
1351        let mut config = SessionConfig::read_only(test_model());
1352
1353        // This one should fail (invalid transport)
1354        config.tool_config.backends.push(BackendConfig::Mcp {
1355            server_name: "bad-server".to_string(),
1356            transport: crate::tools::McpTransport::Tcp {
1357                host: "nonexistent.invalid".to_string(),
1358                port: 12345,
1359            },
1360            tool_filter: ToolFilter::All,
1361        });
1362
1363        // This one would succeed if we had a real server running
1364        config.tool_config.backends.push(BackendConfig::Mcp {
1365            server_name: "good-server".to_string(),
1366            transport: crate::tools::McpTransport::Stdio {
1367                command: "echo".to_string(),
1368                args: vec!["test".to_string()],
1369            },
1370            tool_filter: ToolFilter::All,
1371        });
1372
1373        let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1374
1375        // Should have tracked both servers
1376        assert_eq!(mcp_servers.len(), 2);
1377
1378        // Check the bad server
1379        let bad_server = mcp_servers.get("bad-server").unwrap();
1380        assert_eq!(bad_server.server_name, "bad-server");
1381        assert!(matches!(
1382            bad_server.state,
1383            McpConnectionState::Failed { .. }
1384        ));
1385
1386        // Check the good server (will also fail in tests since echo isn't an MCP server)
1387        let good_server = mcp_servers.get("good-server").unwrap();
1388        assert_eq!(good_server.server_name, "good-server");
1389        assert!(matches!(
1390            good_server.state,
1391            McpConnectionState::Failed { .. }
1392        ));
1393    }
1394
1395    #[test]
1396    fn test_backend_config_mcp_variant() {
1397        let mcp_config = BackendConfig::Mcp {
1398            server_name: "test-mcp".to_string(),
1399            transport: crate::tools::McpTransport::Stdio {
1400                command: "python".to_string(),
1401                args: vec!["-m".to_string(), "test_server".to_string()],
1402            },
1403            tool_filter: ToolFilter::All,
1404        };
1405
1406        let BackendConfig::Mcp {
1407            server_name,
1408            transport,
1409            ..
1410        } = mcp_config;
1411
1412        assert_eq!(server_name, "test-mcp");
1413        if let crate::tools::McpTransport::Stdio { command, args } = transport {
1414            assert_eq!(command, "python");
1415            assert_eq!(args.len(), 2);
1416        } else {
1417            panic!("Expected Stdio transport");
1418        }
1419    }
1420}