steer_grpc/grpc/
session_manager_ext.rs

1use crate::grpc::conversions::{
2    message_to_proto, proto_to_tool_config, proto_to_workspace_config, session_config_to_proto,
3};
4use crate::grpc::error::GrpcError;
5use steer_core::session::{SessionConfig, SessionManager};
6use steer_proto::agent::v1 as proto;
7
8/// Extension trait for SessionManager that adds gRPC-specific functionality
9#[async_trait::async_trait]
10pub trait SessionManagerExt {
11    /// Get session state as protobuf SessionState
12    async fn get_session_proto(
13        &self,
14        session_id: &str,
15    ) -> Result<Option<proto::SessionState>, GrpcError>;
16
17    /// Create session for gRPC
18    async fn create_session_grpc(
19        &self,
20        config: proto::CreateSessionRequest,
21        app_config: steer_core::app::AppConfig,
22    ) -> Result<(String, proto::SessionInfo), GrpcError>;
23}
24
25#[async_trait::async_trait]
26impl SessionManagerExt for SessionManager {
27    async fn get_session_proto(
28        &self,
29        session_id: &str,
30    ) -> Result<Option<proto::SessionState>, GrpcError> {
31        match self.store().get_session(session_id).await {
32            Ok(Some(session)) => {
33                // Use the conversion function to convert config
34                let config = session_config_to_proto(&session.config);
35
36                // Convert internal SessionState to protobuf SessionState
37                let proto_state = proto::SessionState {
38                    id: session_id.to_string(),
39                    created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
40                        session.created_at,
41                    ))),
42                    updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
43                        session.updated_at,
44                    ))),
45                    config: Some(config),
46                    messages: session
47                        .state
48                        .messages
49                        .into_iter()
50                        .map(message_to_proto)
51                        .collect::<Result<Vec<_>, _>>()
52                        .map_err(GrpcError::ConversionError)?,
53                    tool_calls: std::collections::HashMap::new(), // TODO: Convert tool calls
54                    approved_tools: session.state.approved_tools.into_iter().collect(),
55                    last_event_sequence: session.state.last_event_sequence,
56                    metadata: session.state.metadata,
57                };
58                Ok(Some(proto_state))
59            }
60            Ok(None) => Ok(None),
61            Err(e) => Err(GrpcError::CoreError(e.into())),
62        }
63    }
64
65    async fn create_session_grpc(
66        &self,
67        config: proto::CreateSessionRequest,
68        app_config: steer_core::app::AppConfig,
69    ) -> Result<(String, proto::SessionInfo), GrpcError> {
70        use steer_core::session::ToolApprovalPolicy;
71
72        // Convert protobuf config to internal SessionConfig
73        let tool_policy = config
74            .tool_policy
75            .map(|policy| match policy.policy {
76                Some(proto::tool_approval_policy::Policy::AlwaysAsk(_)) => {
77                    ToolApprovalPolicy::AlwaysAsk
78                }
79                Some(proto::tool_approval_policy::Policy::PreApproved(pre_approved)) => {
80                    ToolApprovalPolicy::PreApproved {
81                        tools: pre_approved.tools.into_iter().collect(),
82                    }
83                }
84                Some(proto::tool_approval_policy::Policy::Mixed(mixed)) => {
85                    ToolApprovalPolicy::Mixed {
86                        pre_approved: mixed.pre_approved_tools.into_iter().collect(),
87                        ask_for_others: mixed.ask_for_others,
88                    }
89                }
90                None => ToolApprovalPolicy::AlwaysAsk,
91            })
92            .unwrap_or(ToolApprovalPolicy::AlwaysAsk);
93
94        let mut tool_config = config
95            .tool_config
96            .map(proto_to_tool_config)
97            .unwrap_or_default();
98
99        // Set the approval policy in the tool config
100        tool_config.approval_policy = tool_policy;
101
102        let workspace_config = config
103            .workspace_config
104            .map(proto_to_workspace_config)
105            .unwrap_or_default();
106
107        let session_config = SessionConfig {
108            workspace: workspace_config,
109            tool_config,
110            system_prompt: config.system_prompt,
111            metadata: config.metadata,
112        };
113
114        let (session_id, _command_tx) = self.create_session(session_config, app_config).await?;
115
116        // Get session info for response
117        let session_info =
118            self.get_session(&session_id)
119                .await?
120                .ok_or_else(|| GrpcError::SessionNotFound {
121                    session_id: session_id.clone(),
122                })?;
123
124        let proto_info = proto::SessionInfo {
125            id: session_info.id,
126            created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
127                session_info.created_at,
128            ))),
129            updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
130                session_info.updated_at,
131            ))),
132            status: proto::SessionStatus::Active as i32,
133            metadata: Some(proto::SessionMetadata {
134                labels: session_info.metadata,
135                annotations: std::collections::HashMap::new(),
136            }),
137        };
138
139        Ok((session_id, proto_info))
140    }
141}