steer_grpc/grpc/
session_manager_ext.rs1use crate::grpc::conversions::{
2 message_to_proto, proto_to_tool_config, proto_to_workspace_config, session_config_to_proto,
3};
4use crate::grpc::error::GrpcError;
5use steer_core::session::{SessionConfig, SessionManager};
6use steer_proto::agent::v1 as proto;
7
8#[async_trait::async_trait]
10pub trait SessionManagerExt {
11 async fn get_session_proto(
13 &self,
14 session_id: &str,
15 ) -> Result<Option<proto::SessionState>, GrpcError>;
16
17 async fn create_session_grpc(
19 &self,
20 config: proto::CreateSessionRequest,
21 app_config: steer_core::app::AppConfig,
22 ) -> Result<(String, proto::SessionInfo), GrpcError>;
23}
24
25#[async_trait::async_trait]
26impl SessionManagerExt for SessionManager {
27 async fn get_session_proto(
28 &self,
29 session_id: &str,
30 ) -> Result<Option<proto::SessionState>, GrpcError> {
31 match self.store().get_session(session_id).await {
32 Ok(Some(session)) => {
33 let config = session_config_to_proto(&session.config);
35
36 let proto_state = proto::SessionState {
38 id: session_id.to_string(),
39 created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
40 session.created_at,
41 ))),
42 updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
43 session.updated_at,
44 ))),
45 config: Some(config),
46 messages: session
47 .state
48 .messages
49 .into_iter()
50 .map(message_to_proto)
51 .collect::<Result<Vec<_>, _>>()
52 .map_err(GrpcError::ConversionError)?,
53 tool_calls: std::collections::HashMap::new(), approved_tools: session.state.approved_tools.into_iter().collect(),
55 last_event_sequence: session.state.last_event_sequence,
56 metadata: session.state.metadata,
57 };
58 Ok(Some(proto_state))
59 }
60 Ok(None) => Ok(None),
61 Err(e) => Err(GrpcError::CoreError(e.into())),
62 }
63 }
64
65 async fn create_session_grpc(
66 &self,
67 config: proto::CreateSessionRequest,
68 app_config: steer_core::app::AppConfig,
69 ) -> Result<(String, proto::SessionInfo), GrpcError> {
70 use steer_core::session::ToolApprovalPolicy;
71
72 let tool_policy = config
74 .tool_policy
75 .map(|policy| match policy.policy {
76 Some(proto::tool_approval_policy::Policy::AlwaysAsk(_)) => {
77 ToolApprovalPolicy::AlwaysAsk
78 }
79 Some(proto::tool_approval_policy::Policy::PreApproved(pre_approved)) => {
80 ToolApprovalPolicy::PreApproved {
81 tools: pre_approved.tools.into_iter().collect(),
82 }
83 }
84 Some(proto::tool_approval_policy::Policy::Mixed(mixed)) => {
85 ToolApprovalPolicy::Mixed {
86 pre_approved: mixed.pre_approved_tools.into_iter().collect(),
87 ask_for_others: mixed.ask_for_others,
88 }
89 }
90 None => ToolApprovalPolicy::AlwaysAsk,
91 })
92 .unwrap_or(ToolApprovalPolicy::AlwaysAsk);
93
94 let mut tool_config = config
95 .tool_config
96 .map(proto_to_tool_config)
97 .unwrap_or_default();
98
99 tool_config.approval_policy = tool_policy;
101
102 let workspace_config = config
103 .workspace_config
104 .map(proto_to_workspace_config)
105 .unwrap_or_default();
106
107 let session_config = SessionConfig {
108 workspace: workspace_config,
109 tool_config,
110 system_prompt: config.system_prompt,
111 metadata: config.metadata,
112 };
113
114 let (session_id, _command_tx) = self.create_session(session_config, app_config).await?;
115
116 let session_info =
118 self.get_session(&session_id)
119 .await?
120 .ok_or_else(|| GrpcError::SessionNotFound {
121 session_id: session_id.clone(),
122 })?;
123
124 let proto_info = proto::SessionInfo {
125 id: session_info.id,
126 created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
127 session_info.created_at,
128 ))),
129 updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::from(
130 session_info.updated_at,
131 ))),
132 status: proto::SessionStatus::Active as i32,
133 metadata: Some(proto::SessionMetadata {
134 labels: session_info.metadata,
135 annotations: std::collections::HashMap::new(),
136 }),
137 };
138
139 Ok((session_id, proto_info))
140 }
141}