steer_grpc/grpc/
client_adapter.rs

1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11    convert_app_command_to_client_message, proto_to_mcp_server_info, proto_to_message,
12    server_event_to_app_event, session_tool_config_to_proto, tool_approval_policy_to_proto,
13    workspace_config_to_proto,
14};
15use crate::grpc::error::GrpcError;
16
17type GrpcResult<T> = std::result::Result<T, GrpcError>;
18
19use steer_core::app::conversation::Message;
20use steer_core::app::io::{AppCommandSink, AppEventSource};
21use steer_core::app::{AppCommand, AppEvent};
22use steer_core::session::{McpServerInfo, SessionConfig};
23use steer_proto::agent::v1::{
24    self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
25    GetMcpServersRequest, GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState,
26    StreamSessionRequest, SubscribeRequest, agent_service_client::AgentServiceClient,
27    stream_session_request::Message as StreamSessionRequestType,
28};
29
30/// Adapter that bridges TUI's AppCommand/AppEvent interface with gRPC streaming
31pub struct AgentClient {
32    client: Mutex<AgentServiceClient<Channel>>,
33    session_id: Mutex<Option<String>>,
34    command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
35    event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
36    stream_handle: Mutex<Option<JoinHandle<()>>>,
37}
38
39impl AgentClient {
40    /// Connect to a gRPC server
41    pub async fn connect(addr: &str) -> GrpcResult<Self> {
42        info!("Connecting to gRPC server at {}", addr);
43
44        let client = AgentServiceClient::connect(addr.to_string()).await?;
45
46        info!("Successfully connected to gRPC server");
47
48        Ok(Self {
49            client: Mutex::new(client),
50            session_id: Mutex::new(None),
51            command_tx: Mutex::new(None),
52            stream_handle: Mutex::new(None),
53            event_rx: Mutex::new(None),
54        })
55    }
56
57    /// Create client from an existing channel (for in-memory connections)
58    pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
59        info!("Creating gRPC client from provided channel");
60
61        let client = AgentServiceClient::new(channel);
62
63        Ok(Self {
64            client: Mutex::new(client),
65            session_id: Mutex::new(None),
66            command_tx: Mutex::new(None),
67            stream_handle: Mutex::new(None),
68            event_rx: Mutex::new(None),
69        })
70    }
71
72    /// Create a new session on the server
73    pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
74        debug!("Creating new session with gRPC server");
75
76        let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
77        let workspace_config = workspace_config_to_proto(&config.workspace);
78        let tool_config = session_tool_config_to_proto(&config.tool_config);
79
80        let request = Request::new(CreateSessionRequest {
81            tool_policy: Some(tool_policy),
82            metadata: config.metadata,
83            tool_config: Some(tool_config),
84            workspace_config: Some(workspace_config),
85            system_prompt: config.system_prompt,
86        });
87
88        let response = self
89            .client
90            .lock()
91            .await
92            .create_session(request)
93            .await
94            .map_err(Box::new)?;
95        let response = response.into_inner();
96        let session = response
97            .session
98            .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
99
100        *self.session_id.lock().await = Some(session.id.clone());
101
102        info!("Created session: {}", session.id);
103        Ok(session.id)
104    }
105
106    /// Activate (load) an existing dormant session and get its state
107    pub async fn activate_session(
108        &self,
109        session_id: String,
110    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
111        info!("Activating remote session: {}", session_id);
112
113        let mut stream = self
114            .client
115            .lock()
116            .await
117            .activate_session(proto::ActivateSessionRequest {
118                session_id: session_id.clone(),
119            })
120            .await
121            .map_err(Box::new)?
122            .into_inner();
123
124        let mut messages = Vec::new();
125        let mut approved_tools = Vec::new();
126
127        while let Some(response) = stream
128            .message()
129            .await
130            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
131        {
132            match response.chunk {
133                Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
134                    match proto_to_message(proto_msg) {
135                        Ok(msg) => messages.push(msg),
136                        Err(e) => return Err(GrpcError::ConversionError(e)),
137                    }
138                }
139                Some(proto::activate_session_response::Chunk::Footer(footer)) => {
140                    approved_tools = footer.approved_tools;
141                }
142                None => {}
143            }
144        }
145
146        *self.session_id.lock().await = Some(session_id);
147        Ok((messages, approved_tools))
148    }
149
150    /// Start bidirectional streaming with the server
151    pub async fn start_streaming(&self) -> GrpcResult<()> {
152        let session_id = self
153            .session_id
154            .lock()
155            .await
156            .as_ref()
157            .cloned()
158            .ok_or_else(|| GrpcError::InvalidSessionState {
159                reason: "No session ID - call create_session or activate_session first".to_string(),
160            })?;
161
162        debug!("Starting bidirectional stream for session: {}", session_id);
163
164        // Create channels for command and event communication
165        let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
166        let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
167
168        // Create the bidirectional stream
169        let outbound_stream = ReceiverStream::new(cmd_rx);
170        let request = Request::new(outbound_stream);
171
172        let response = self
173            .client
174            .lock()
175            .await
176            .stream_session(request)
177            .await
178            .map_err(Box::new)?;
179        let mut inbound_stream = response.into_inner();
180
181        // Send initial subscribe message
182        let subscribe_msg = StreamSessionRequest {
183            session_id: session_id.clone(),
184            message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
185                event_types: vec![], // Subscribe to all events
186                since_sequence: None,
187            })),
188        };
189
190        cmd_tx
191            .send(subscribe_msg)
192            .await
193            .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
194
195        // Spawn task to handle incoming server events
196        let session_id_clone = session_id.clone();
197        let stream_handle = tokio::spawn(async move {
198            info!(
199                "Started event stream handler for session: {}",
200                session_id_clone
201            );
202
203            while let Some(result) = inbound_stream.message().await.transpose() {
204                match result {
205                    Ok(server_event) => {
206                        debug!(
207                            "Received server event: sequence {}",
208                            server_event.sequence_num
209                        );
210
211                        match server_event_to_app_event(server_event) {
212                            Ok(app_event) => {
213                                if let Err(e) = evt_tx.send(app_event).await {
214                                    warn!("Failed to forward event to TUI: {}", e);
215                                    break;
216                                }
217                            }
218                            Err(e) => {
219                                error!("Failed to convert server event: {}", e);
220                                // Continue processing other events instead of breaking
221                            }
222                        }
223                    }
224                    Err(e) => {
225                        error!("gRPC stream error: {}", e);
226                        break;
227                    }
228                }
229            }
230
231            info!(
232                "Event stream handler ended for session: {}",
233                session_id_clone
234            );
235        });
236
237        // Store the handles
238        *self.command_tx.lock().await = Some(cmd_tx);
239        *self.stream_handle.lock().await = Some(stream_handle);
240        // store receiver
241        *self.event_rx.lock().await = Some(evt_rx);
242
243        info!(
244            "Bidirectional streaming started for session: {}",
245            session_id
246        );
247        Ok(())
248    }
249
250    /// Send a command to the server
251    pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
252        let session_id = self
253            .session_id
254            .lock()
255            .await
256            .as_ref()
257            .cloned()
258            .ok_or_else(|| GrpcError::InvalidSessionState {
259                reason: "No active session".to_string(),
260            })?;
261
262        let command_tx = self
263            .command_tx
264            .lock()
265            .await
266            .as_ref()
267            .cloned()
268            .ok_or_else(|| GrpcError::InvalidSessionState {
269                reason: "Streaming not started - call start_streaming first".to_string(),
270            })?;
271
272        let message = convert_app_command_to_client_message(command, &session_id)?;
273
274        if let Some(message) = message {
275            command_tx.send(message).await.map_err(|_| {
276                GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
277            })?;
278        }
279
280        Ok(())
281    }
282
283    /// Get the current session ID
284    pub async fn session_id(&self) -> Option<String> {
285        self.session_id.lock().await.clone()
286    }
287
288    /// List sessions on the remote server
289    pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
290        debug!("Listing sessions from gRPC server");
291
292        let request = Request::new(ListSessionsRequest {
293            filter: None,
294            page_size: None,
295            page_token: None,
296        });
297
298        let response = self
299            .client
300            .lock()
301            .await
302            .list_sessions(request)
303            .await
304            .map_err(Box::new)?;
305        let sessions_response = response.into_inner();
306
307        Ok(sessions_response.sessions)
308    }
309
310    /// Get session details from the remote server
311    pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
312        debug!("Getting session {} from gRPC server", session_id);
313
314        let request = Request::new(GetSessionRequest {
315            session_id: session_id.to_string(),
316        });
317
318        let mut stream = self
319            .client
320            .lock()
321            .await
322            .get_session(request)
323            .await
324            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
325            .into_inner();
326
327        let mut header = None;
328        let mut messages = Vec::new();
329        let mut tool_calls = std::collections::HashMap::new();
330        let mut footer = None;
331
332        while let Some(response) = stream
333            .message()
334            .await
335            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
336        {
337            match response.chunk {
338                Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
339                Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
340                Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
341                    if let Some(value) = tc.value {
342                        tool_calls.insert(tc.key, value);
343                    }
344                }
345                Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
346                None => {}
347            }
348        }
349
350        match (header, footer) {
351            (Some(h), Some(f)) => Ok(Some(SessionState {
352                id: h.id,
353                created_at: h.created_at,
354                updated_at: h.updated_at,
355                config: h.config,
356                messages,
357                tool_calls,
358                approved_tools: f.approved_tools,
359                last_event_sequence: f.last_event_sequence,
360                metadata: f.metadata,
361            })),
362            _ => Ok(None),
363        }
364    }
365
366    /// Delete a session on the remote server
367    pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
368        debug!("Deleting session {} from gRPC server", session_id);
369
370        let request = Request::new(DeleteSessionRequest {
371            session_id: session_id.to_string(),
372        });
373
374        match self.client.lock().await.delete_session(request).await {
375            Ok(_) => {
376                info!("Successfully deleted session: {}", session_id);
377                Ok(true)
378            }
379            Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
380            Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
381        }
382    }
383
384    /// Get the current conversation for a session
385    pub async fn get_conversation(
386        &self,
387        session_id: &str,
388    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
389        info!(
390            "Client adapter getting conversation for session: {}",
391            session_id
392        );
393
394        let mut stream = self
395            .client
396            .lock()
397            .await
398            .get_conversation(GetConversationRequest {
399                session_id: session_id.to_string(),
400            })
401            .await
402            .map_err(Box::new)?
403            .into_inner();
404
405        let mut messages = Vec::new();
406        let mut approved_tools = Vec::new();
407
408        while let Some(response) = stream
409            .message()
410            .await
411            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
412        {
413            match response.chunk {
414                Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
415                    match proto_to_message(proto_msg) {
416                        Ok(msg) => messages.push(msg),
417                        Err(e) => {
418                            warn!("Failed to convert message: {}", e);
419                            return Err(GrpcError::ConversionError(e));
420                        }
421                    }
422                }
423                Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
424                    approved_tools = footer.approved_tools;
425                }
426                None => {}
427            }
428        }
429
430        info!(
431            "Successfully converted {} messages from GetConversation response",
432            messages.len()
433        );
434
435        Ok((messages, approved_tools))
436    }
437
438    /// Shutdown the adapter and clean up resources
439    pub async fn shutdown(self) {
440        if let Some(handle) = self.stream_handle.lock().await.take() {
441            handle.abort();
442            let _ = handle.await;
443        }
444
445        if let Some(session_id) = &*self.session_id.lock().await {
446            info!("GrpcClientAdapter shut down for session: {}", session_id);
447        }
448    }
449
450    pub async fn get_mcp_servers(&self) -> GrpcResult<Vec<McpServerInfo>> {
451        let session_id = self
452            .session_id
453            .lock()
454            .await
455            .as_ref()
456            .cloned()
457            .ok_or_else(|| GrpcError::InvalidSessionState {
458                reason: "No active session".to_string(),
459            })?;
460
461        let request = Request::new(GetMcpServersRequest {
462            session_id: session_id.clone(),
463        });
464
465        let response = self
466            .client
467            .lock()
468            .await
469            .get_mcp_servers(request)
470            .await
471            .map_err(Box::new)?;
472
473        let servers = response
474            .into_inner()
475            .servers
476            .into_iter()
477            .filter_map(|s| proto_to_mcp_server_info(s).ok())
478            .collect();
479
480        Ok(servers)
481    }
482}
483
484#[async_trait]
485impl AppCommandSink for AgentClient {
486    async fn send_command(&self, command: AppCommand) -> Result<()> {
487        self.send_command(command)
488            .await
489            .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
490    }
491}
492
493#[async_trait]
494impl AppEventSource for AgentClient {
495    async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
496        // This is a blocking operation in a trait that doesn't support async
497        // We need to use block_on here
498        self.event_rx.lock().await.take().expect(
499            "Event receiver already taken - GrpcClientAdapter only supports single subscription",
500        )
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::grpc::conversions::tool_approval_policy_to_proto;
508    use steer_core::session::ToolApprovalPolicy;
509    use steer_proto::agent::v1::tool_approval_policy::Policy;
510
511    #[test]
512    fn test_convert_tool_approval_policy() {
513        let policy = ToolApprovalPolicy::AlwaysAsk;
514        let proto_policy = tool_approval_policy_to_proto(&policy);
515        assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
516
517        let mut tools = std::collections::HashSet::new();
518        tools.insert("bash".to_string());
519        let policy = ToolApprovalPolicy::PreApproved { tools };
520        let proto_policy = tool_approval_policy_to_proto(&policy);
521        assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
522    }
523
524    #[test]
525    fn test_convert_app_command_to_client_message() {
526        let session_id = "test-session";
527
528        let command = AppCommand::ProcessUserInput("Hello".to_string());
529        let result = convert_app_command_to_client_message(command, session_id).unwrap();
530        assert!(result.is_some());
531
532        let command = AppCommand::Shutdown;
533        let result = convert_app_command_to_client_message(command, session_id).unwrap();
534        assert!(result.is_none());
535    }
536}