1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11 convert_app_command_to_client_message, proto_to_message, server_event_to_app_event,
12 session_tool_config_to_proto, tool_approval_policy_to_proto, workspace_config_to_proto,
13};
14use crate::grpc::error::GrpcError;
15
16type GrpcResult<T> = std::result::Result<T, GrpcError>;
17
18use steer_core::app::conversation::Message;
19use steer_core::app::io::{AppCommandSink, AppEventSource};
20use steer_core::app::{AppCommand, AppEvent};
21use steer_core::session::SessionConfig;
22use steer_proto::agent::v1::{
23 self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
24 GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState, StreamSessionRequest,
25 SubscribeRequest, agent_service_client::AgentServiceClient,
26 stream_session_request::Message as StreamSessionRequestType,
27};
28
29pub struct GrpcClientAdapter {
31 client: Mutex<AgentServiceClient<Channel>>,
32 session_id: Mutex<Option<String>>,
33 command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
34 event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
35 stream_handle: Mutex<Option<JoinHandle<()>>>,
36}
37
38impl GrpcClientAdapter {
39 pub async fn connect(addr: &str) -> GrpcResult<Self> {
41 info!("Connecting to gRPC server at {}", addr);
42
43 let client = AgentServiceClient::connect(addr.to_string()).await?;
44
45 info!("Successfully connected to gRPC server");
46
47 Ok(Self {
48 client: Mutex::new(client),
49 session_id: Mutex::new(None),
50 command_tx: Mutex::new(None),
51 stream_handle: Mutex::new(None),
52 event_rx: Mutex::new(None),
53 })
54 }
55
56 pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
58 info!("Creating gRPC client from provided channel");
59
60 let client = AgentServiceClient::new(channel);
61
62 Ok(Self {
63 client: Mutex::new(client),
64 session_id: Mutex::new(None),
65 command_tx: Mutex::new(None),
66 stream_handle: Mutex::new(None),
67 event_rx: Mutex::new(None),
68 })
69 }
70
71 pub async fn local(default_model: steer_core::api::Model) -> GrpcResult<Self> {
73 use crate::local_server::setup_local_grpc;
74 let (channel, _server_handle) = setup_local_grpc(default_model, None).await?;
75 Self::from_channel(channel).await
76 }
77
78 pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
80 debug!("Creating new session with gRPC server");
81
82 let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
83 let workspace_config = workspace_config_to_proto(&config.workspace);
84 let tool_config = session_tool_config_to_proto(&config.tool_config);
85
86 let request = Request::new(CreateSessionRequest {
87 tool_policy: Some(tool_policy),
88 metadata: config.metadata,
89 tool_config: Some(tool_config),
90 workspace_config: Some(workspace_config),
91 system_prompt: config.system_prompt,
92 });
93
94 let response = self
95 .client
96 .lock()
97 .await
98 .create_session(request)
99 .await
100 .map_err(Box::new)?;
101 let response = response.into_inner();
102 let session = response
103 .session
104 .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
105
106 *self.session_id.lock().await = Some(session.id.clone());
107
108 info!("Created session: {}", session.id);
109 Ok(session.id)
110 }
111
112 pub async fn activate_session(
114 &self,
115 session_id: String,
116 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
117 info!("Activating remote session: {}", session_id);
118
119 let mut stream = self
120 .client
121 .lock()
122 .await
123 .activate_session(proto::ActivateSessionRequest {
124 session_id: session_id.clone(),
125 })
126 .await
127 .map_err(Box::new)?
128 .into_inner();
129
130 let mut messages = Vec::new();
131 let mut approved_tools = Vec::new();
132
133 while let Some(response) = stream
134 .message()
135 .await
136 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
137 {
138 match response.chunk {
139 Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
140 match proto_to_message(proto_msg) {
141 Ok(msg) => messages.push(msg),
142 Err(e) => return Err(GrpcError::ConversionError(e)),
143 }
144 }
145 Some(proto::activate_session_response::Chunk::Footer(footer)) => {
146 approved_tools = footer.approved_tools;
147 }
148 None => {}
149 }
150 }
151
152 *self.session_id.lock().await = Some(session_id);
153 Ok((messages, approved_tools))
154 }
155
156 pub async fn start_streaming(&self) -> GrpcResult<()> {
158 let session_id = self
159 .session_id
160 .lock()
161 .await
162 .as_ref()
163 .cloned()
164 .ok_or_else(|| GrpcError::InvalidSessionState {
165 reason: "No session ID - call create_session or activate_session first".to_string(),
166 })?;
167
168 debug!("Starting bidirectional stream for session: {}", session_id);
169
170 let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
172 let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
173
174 let outbound_stream = ReceiverStream::new(cmd_rx);
176 let request = Request::new(outbound_stream);
177
178 let response = self
179 .client
180 .lock()
181 .await
182 .stream_session(request)
183 .await
184 .map_err(Box::new)?;
185 let mut inbound_stream = response.into_inner();
186
187 let subscribe_msg = StreamSessionRequest {
189 session_id: session_id.clone(),
190 message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
191 event_types: vec![], since_sequence: None,
193 })),
194 };
195
196 cmd_tx
197 .send(subscribe_msg)
198 .await
199 .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
200
201 let session_id_clone = session_id.clone();
203 let stream_handle = tokio::spawn(async move {
204 info!(
205 "Started event stream handler for session: {}",
206 session_id_clone
207 );
208
209 while let Some(result) = inbound_stream.message().await.transpose() {
210 match result {
211 Ok(server_event) => {
212 debug!(
213 "Received server event: sequence {}",
214 server_event.sequence_num
215 );
216
217 match server_event_to_app_event(server_event) {
218 Ok(app_event) => {
219 if let Err(e) = evt_tx.send(app_event).await {
220 warn!("Failed to forward event to TUI: {}", e);
221 break;
222 }
223 }
224 Err(e) => {
225 error!("Failed to convert server event: {}", e);
226 }
228 }
229 }
230 Err(e) => {
231 error!("gRPC stream error: {}", e);
232 break;
233 }
234 }
235 }
236
237 info!(
238 "Event stream handler ended for session: {}",
239 session_id_clone
240 );
241 });
242
243 *self.command_tx.lock().await = Some(cmd_tx);
245 *self.stream_handle.lock().await = Some(stream_handle);
246 *self.event_rx.lock().await = Some(evt_rx);
248
249 info!(
250 "Bidirectional streaming started for session: {}",
251 session_id
252 );
253 Ok(())
254 }
255
256 pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
258 let session_id = self
259 .session_id
260 .lock()
261 .await
262 .as_ref()
263 .cloned()
264 .ok_or_else(|| GrpcError::InvalidSessionState {
265 reason: "No active session".to_string(),
266 })?;
267
268 let command_tx = self
269 .command_tx
270 .lock()
271 .await
272 .as_ref()
273 .cloned()
274 .ok_or_else(|| GrpcError::InvalidSessionState {
275 reason: "Streaming not started - call start_streaming first".to_string(),
276 })?;
277
278 let message = convert_app_command_to_client_message(command, &session_id)?;
279
280 if let Some(message) = message {
281 command_tx.send(message).await.map_err(|_| {
282 GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
283 })?;
284 }
285
286 Ok(())
287 }
288
289 pub async fn session_id(&self) -> Option<String> {
291 self.session_id.lock().await.clone()
292 }
293
294 pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
296 debug!("Listing sessions from gRPC server");
297
298 let request = Request::new(ListSessionsRequest {
299 filter: None,
300 page_size: None,
301 page_token: None,
302 });
303
304 let response = self
305 .client
306 .lock()
307 .await
308 .list_sessions(request)
309 .await
310 .map_err(Box::new)?;
311 let sessions_response = response.into_inner();
312
313 Ok(sessions_response.sessions)
314 }
315
316 pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
318 debug!("Getting session {} from gRPC server", session_id);
319
320 let request = Request::new(GetSessionRequest {
321 session_id: session_id.to_string(),
322 });
323
324 let mut stream = self
325 .client
326 .lock()
327 .await
328 .get_session(request)
329 .await
330 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
331 .into_inner();
332
333 let mut header = None;
334 let mut messages = Vec::new();
335 let mut tool_calls = std::collections::HashMap::new();
336 let mut footer = None;
337
338 while let Some(response) = stream
339 .message()
340 .await
341 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
342 {
343 match response.chunk {
344 Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
345 Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
346 Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
347 if let Some(value) = tc.value {
348 tool_calls.insert(tc.key, value);
349 }
350 }
351 Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
352 None => {}
353 }
354 }
355
356 match (header, footer) {
357 (Some(h), Some(f)) => Ok(Some(SessionState {
358 id: h.id,
359 created_at: h.created_at,
360 updated_at: h.updated_at,
361 config: h.config,
362 messages,
363 tool_calls,
364 approved_tools: f.approved_tools,
365 last_event_sequence: f.last_event_sequence,
366 metadata: f.metadata,
367 })),
368 _ => Ok(None),
369 }
370 }
371
372 pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
374 debug!("Deleting session {} from gRPC server", session_id);
375
376 let request = Request::new(DeleteSessionRequest {
377 session_id: session_id.to_string(),
378 });
379
380 match self.client.lock().await.delete_session(request).await {
381 Ok(_) => {
382 info!("Successfully deleted session: {}", session_id);
383 Ok(true)
384 }
385 Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
386 Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
387 }
388 }
389
390 pub async fn get_conversation(
392 &self,
393 session_id: &str,
394 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
395 info!(
396 "Client adapter getting conversation for session: {}",
397 session_id
398 );
399
400 let mut stream = self
401 .client
402 .lock()
403 .await
404 .get_conversation(GetConversationRequest {
405 session_id: session_id.to_string(),
406 })
407 .await
408 .map_err(Box::new)?
409 .into_inner();
410
411 let mut messages = Vec::new();
412 let mut approved_tools = Vec::new();
413
414 while let Some(response) = stream
415 .message()
416 .await
417 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
418 {
419 match response.chunk {
420 Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
421 match proto_to_message(proto_msg) {
422 Ok(msg) => messages.push(msg),
423 Err(e) => {
424 warn!("Failed to convert message: {}", e);
425 return Err(GrpcError::ConversionError(e));
426 }
427 }
428 }
429 Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
430 approved_tools = footer.approved_tools;
431 }
432 None => {}
433 }
434 }
435
436 info!(
437 "Successfully converted {} messages from GetConversation response",
438 messages.len()
439 );
440
441 Ok((messages, approved_tools))
442 }
443
444 pub async fn shutdown(self) {
446 if let Some(handle) = self.stream_handle.lock().await.take() {
447 handle.abort();
448 let _ = handle.await;
449 }
450
451 if let Some(session_id) = &*self.session_id.lock().await {
452 info!("GrpcClientAdapter shut down for session: {}", session_id);
453 }
454 }
455}
456
457#[async_trait]
458impl AppCommandSink for GrpcClientAdapter {
459 async fn send_command(&self, command: AppCommand) -> Result<()> {
460 self.send_command(command)
461 .await
462 .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
463 }
464}
465
466#[async_trait]
467impl AppEventSource for GrpcClientAdapter {
468 async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
469 self.event_rx.lock().await.take().expect(
472 "Event receiver already taken - GrpcClientAdapter only supports single subscription",
473 )
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::grpc::conversions::tool_approval_policy_to_proto;
481 use steer_core::session::ToolApprovalPolicy;
482 use steer_proto::agent::v1::tool_approval_policy::Policy;
483
484 #[test]
485 fn test_convert_tool_approval_policy() {
486 let policy = ToolApprovalPolicy::AlwaysAsk;
487 let proto_policy = tool_approval_policy_to_proto(&policy);
488 assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
489
490 let mut tools = std::collections::HashSet::new();
491 tools.insert("bash".to_string());
492 let policy = ToolApprovalPolicy::PreApproved { tools };
493 let proto_policy = tool_approval_policy_to_proto(&policy);
494 assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
495 }
496
497 #[test]
498 fn test_convert_app_command_to_client_message() {
499 let session_id = "test-session";
500
501 let command = AppCommand::ProcessUserInput("Hello".to_string());
502 let result = convert_app_command_to_client_message(command, session_id).unwrap();
503 assert!(result.is_some());
504
505 let command = AppCommand::Shutdown;
506 let result = convert_app_command_to_client_message(command, session_id).unwrap();
507 assert!(result.is_none());
508 }
509}