steer_grpc/grpc/
client_adapter.rs

1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11    convert_app_command_to_client_message, proto_to_message, server_event_to_app_event,
12    session_tool_config_to_proto, tool_approval_policy_to_proto, workspace_config_to_proto,
13};
14use crate::grpc::error::GrpcError;
15
16type GrpcResult<T> = std::result::Result<T, GrpcError>;
17
18use steer_core::app::conversation::Message;
19use steer_core::app::io::{AppCommandSink, AppEventSource};
20use steer_core::app::{AppCommand, AppEvent};
21use steer_core::session::SessionConfig;
22use steer_proto::agent::v1::{
23    self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
24    GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState, StreamSessionRequest,
25    SubscribeRequest, agent_service_client::AgentServiceClient,
26    stream_session_request::Message as StreamSessionRequestType,
27};
28
29/// Adapter that bridges TUI's AppCommand/AppEvent interface with gRPC streaming
30pub struct GrpcClientAdapter {
31    client: Mutex<AgentServiceClient<Channel>>,
32    session_id: Mutex<Option<String>>,
33    command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
34    event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
35    stream_handle: Mutex<Option<JoinHandle<()>>>,
36}
37
38impl GrpcClientAdapter {
39    /// Connect to a gRPC server
40    pub async fn connect(addr: &str) -> GrpcResult<Self> {
41        info!("Connecting to gRPC server at {}", addr);
42
43        let client = AgentServiceClient::connect(addr.to_string()).await?;
44
45        info!("Successfully connected to gRPC server");
46
47        Ok(Self {
48            client: Mutex::new(client),
49            session_id: Mutex::new(None),
50            command_tx: Mutex::new(None),
51            stream_handle: Mutex::new(None),
52            event_rx: Mutex::new(None),
53        })
54    }
55
56    /// Create client from an existing channel (for in-memory connections)
57    pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
58        info!("Creating gRPC client from provided channel");
59
60        let client = AgentServiceClient::new(channel);
61
62        Ok(Self {
63            client: Mutex::new(client),
64            session_id: Mutex::new(None),
65            command_tx: Mutex::new(None),
66            stream_handle: Mutex::new(None),
67            event_rx: Mutex::new(None),
68        })
69    }
70
71    /// Convenience constructor: spin up a localhost gRPC server and return a ready client.
72    pub async fn local(default_model: steer_core::api::Model) -> GrpcResult<Self> {
73        use crate::local_server::setup_local_grpc;
74        let (channel, _server_handle) = setup_local_grpc(default_model, None).await?;
75        Self::from_channel(channel).await
76    }
77
78    /// Create a new session on the server
79    pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
80        debug!("Creating new session with gRPC server");
81
82        let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
83        let workspace_config = workspace_config_to_proto(&config.workspace);
84        let tool_config = session_tool_config_to_proto(&config.tool_config);
85
86        let request = Request::new(CreateSessionRequest {
87            tool_policy: Some(tool_policy),
88            metadata: config.metadata,
89            tool_config: Some(tool_config),
90            workspace_config: Some(workspace_config),
91            system_prompt: config.system_prompt,
92        });
93
94        let response = self
95            .client
96            .lock()
97            .await
98            .create_session(request)
99            .await
100            .map_err(Box::new)?;
101        let response = response.into_inner();
102        let session = response
103            .session
104            .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
105
106        *self.session_id.lock().await = Some(session.id.clone());
107
108        info!("Created session: {}", session.id);
109        Ok(session.id)
110    }
111
112    /// Activate (load) an existing dormant session and get its state
113    pub async fn activate_session(
114        &self,
115        session_id: String,
116    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
117        info!("Activating remote session: {}", session_id);
118
119        let mut stream = self
120            .client
121            .lock()
122            .await
123            .activate_session(proto::ActivateSessionRequest {
124                session_id: session_id.clone(),
125            })
126            .await
127            .map_err(Box::new)?
128            .into_inner();
129
130        let mut messages = Vec::new();
131        let mut approved_tools = Vec::new();
132
133        while let Some(response) = stream
134            .message()
135            .await
136            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
137        {
138            match response.chunk {
139                Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
140                    match proto_to_message(proto_msg) {
141                        Ok(msg) => messages.push(msg),
142                        Err(e) => return Err(GrpcError::ConversionError(e)),
143                    }
144                }
145                Some(proto::activate_session_response::Chunk::Footer(footer)) => {
146                    approved_tools = footer.approved_tools;
147                }
148                None => {}
149            }
150        }
151
152        *self.session_id.lock().await = Some(session_id);
153        Ok((messages, approved_tools))
154    }
155
156    /// Start bidirectional streaming with the server
157    pub async fn start_streaming(&self) -> GrpcResult<()> {
158        let session_id = self
159            .session_id
160            .lock()
161            .await
162            .as_ref()
163            .cloned()
164            .ok_or_else(|| GrpcError::InvalidSessionState {
165                reason: "No session ID - call create_session or activate_session first".to_string(),
166            })?;
167
168        debug!("Starting bidirectional stream for session: {}", session_id);
169
170        // Create channels for command and event communication
171        let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
172        let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
173
174        // Create the bidirectional stream
175        let outbound_stream = ReceiverStream::new(cmd_rx);
176        let request = Request::new(outbound_stream);
177
178        let response = self
179            .client
180            .lock()
181            .await
182            .stream_session(request)
183            .await
184            .map_err(Box::new)?;
185        let mut inbound_stream = response.into_inner();
186
187        // Send initial subscribe message
188        let subscribe_msg = StreamSessionRequest {
189            session_id: session_id.clone(),
190            message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
191                event_types: vec![], // Subscribe to all events
192                since_sequence: None,
193            })),
194        };
195
196        cmd_tx
197            .send(subscribe_msg)
198            .await
199            .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
200
201        // Spawn task to handle incoming server events
202        let session_id_clone = session_id.clone();
203        let stream_handle = tokio::spawn(async move {
204            info!(
205                "Started event stream handler for session: {}",
206                session_id_clone
207            );
208
209            while let Some(result) = inbound_stream.message().await.transpose() {
210                match result {
211                    Ok(server_event) => {
212                        debug!(
213                            "Received server event: sequence {}",
214                            server_event.sequence_num
215                        );
216
217                        match server_event_to_app_event(server_event) {
218                            Ok(app_event) => {
219                                if let Err(e) = evt_tx.send(app_event).await {
220                                    warn!("Failed to forward event to TUI: {}", e);
221                                    break;
222                                }
223                            }
224                            Err(e) => {
225                                error!("Failed to convert server event: {}", e);
226                                // Continue processing other events instead of breaking
227                            }
228                        }
229                    }
230                    Err(e) => {
231                        error!("gRPC stream error: {}", e);
232                        break;
233                    }
234                }
235            }
236
237            info!(
238                "Event stream handler ended for session: {}",
239                session_id_clone
240            );
241        });
242
243        // Store the handles
244        *self.command_tx.lock().await = Some(cmd_tx);
245        *self.stream_handle.lock().await = Some(stream_handle);
246        // store receiver
247        *self.event_rx.lock().await = Some(evt_rx);
248
249        info!(
250            "Bidirectional streaming started for session: {}",
251            session_id
252        );
253        Ok(())
254    }
255
256    /// Send a command to the server
257    pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
258        let session_id = self
259            .session_id
260            .lock()
261            .await
262            .as_ref()
263            .cloned()
264            .ok_or_else(|| GrpcError::InvalidSessionState {
265                reason: "No active session".to_string(),
266            })?;
267
268        let command_tx = self
269            .command_tx
270            .lock()
271            .await
272            .as_ref()
273            .cloned()
274            .ok_or_else(|| GrpcError::InvalidSessionState {
275                reason: "Streaming not started - call start_streaming first".to_string(),
276            })?;
277
278        let message = convert_app_command_to_client_message(command, &session_id)?;
279
280        if let Some(message) = message {
281            command_tx.send(message).await.map_err(|_| {
282                GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
283            })?;
284        }
285
286        Ok(())
287    }
288
289    /// Get the current session ID
290    pub async fn session_id(&self) -> Option<String> {
291        self.session_id.lock().await.clone()
292    }
293
294    /// List sessions on the remote server
295    pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
296        debug!("Listing sessions from gRPC server");
297
298        let request = Request::new(ListSessionsRequest {
299            filter: None,
300            page_size: None,
301            page_token: None,
302        });
303
304        let response = self
305            .client
306            .lock()
307            .await
308            .list_sessions(request)
309            .await
310            .map_err(Box::new)?;
311        let sessions_response = response.into_inner();
312
313        Ok(sessions_response.sessions)
314    }
315
316    /// Get session details from the remote server
317    pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
318        debug!("Getting session {} from gRPC server", session_id);
319
320        let request = Request::new(GetSessionRequest {
321            session_id: session_id.to_string(),
322        });
323
324        let mut stream = self
325            .client
326            .lock()
327            .await
328            .get_session(request)
329            .await
330            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
331            .into_inner();
332
333        let mut header = None;
334        let mut messages = Vec::new();
335        let mut tool_calls = std::collections::HashMap::new();
336        let mut footer = None;
337
338        while let Some(response) = stream
339            .message()
340            .await
341            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
342        {
343            match response.chunk {
344                Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
345                Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
346                Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
347                    if let Some(value) = tc.value {
348                        tool_calls.insert(tc.key, value);
349                    }
350                }
351                Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
352                None => {}
353            }
354        }
355
356        match (header, footer) {
357            (Some(h), Some(f)) => Ok(Some(SessionState {
358                id: h.id,
359                created_at: h.created_at,
360                updated_at: h.updated_at,
361                config: h.config,
362                messages,
363                tool_calls,
364                approved_tools: f.approved_tools,
365                last_event_sequence: f.last_event_sequence,
366                metadata: f.metadata,
367            })),
368            _ => Ok(None),
369        }
370    }
371
372    /// Delete a session on the remote server
373    pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
374        debug!("Deleting session {} from gRPC server", session_id);
375
376        let request = Request::new(DeleteSessionRequest {
377            session_id: session_id.to_string(),
378        });
379
380        match self.client.lock().await.delete_session(request).await {
381            Ok(_) => {
382                info!("Successfully deleted session: {}", session_id);
383                Ok(true)
384            }
385            Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
386            Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
387        }
388    }
389
390    /// Get the current conversation for a session
391    pub async fn get_conversation(
392        &self,
393        session_id: &str,
394    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
395        info!(
396            "Client adapter getting conversation for session: {}",
397            session_id
398        );
399
400        let mut stream = self
401            .client
402            .lock()
403            .await
404            .get_conversation(GetConversationRequest {
405                session_id: session_id.to_string(),
406            })
407            .await
408            .map_err(Box::new)?
409            .into_inner();
410
411        let mut messages = Vec::new();
412        let mut approved_tools = Vec::new();
413
414        while let Some(response) = stream
415            .message()
416            .await
417            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
418        {
419            match response.chunk {
420                Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
421                    match proto_to_message(proto_msg) {
422                        Ok(msg) => messages.push(msg),
423                        Err(e) => {
424                            warn!("Failed to convert message: {}", e);
425                            return Err(GrpcError::ConversionError(e));
426                        }
427                    }
428                }
429                Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
430                    approved_tools = footer.approved_tools;
431                }
432                None => {}
433            }
434        }
435
436        info!(
437            "Successfully converted {} messages from GetConversation response",
438            messages.len()
439        );
440
441        Ok((messages, approved_tools))
442    }
443
444    /// Shutdown the adapter and clean up resources
445    pub async fn shutdown(self) {
446        if let Some(handle) = self.stream_handle.lock().await.take() {
447            handle.abort();
448            let _ = handle.await;
449        }
450
451        if let Some(session_id) = &*self.session_id.lock().await {
452            info!("GrpcClientAdapter shut down for session: {}", session_id);
453        }
454    }
455}
456
457#[async_trait]
458impl AppCommandSink for GrpcClientAdapter {
459    async fn send_command(&self, command: AppCommand) -> Result<()> {
460        self.send_command(command)
461            .await
462            .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
463    }
464}
465
466#[async_trait]
467impl AppEventSource for GrpcClientAdapter {
468    async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
469        // This is a blocking operation in a trait that doesn't support async
470        // We need to use block_on here
471        self.event_rx.lock().await.take().expect(
472            "Event receiver already taken - GrpcClientAdapter only supports single subscription",
473        )
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::grpc::conversions::tool_approval_policy_to_proto;
481    use steer_core::session::ToolApprovalPolicy;
482    use steer_proto::agent::v1::tool_approval_policy::Policy;
483
484    #[test]
485    fn test_convert_tool_approval_policy() {
486        let policy = ToolApprovalPolicy::AlwaysAsk;
487        let proto_policy = tool_approval_policy_to_proto(&policy);
488        assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
489
490        let mut tools = std::collections::HashSet::new();
491        tools.insert("bash".to_string());
492        let policy = ToolApprovalPolicy::PreApproved { tools };
493        let proto_policy = tool_approval_policy_to_proto(&policy);
494        assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
495    }
496
497    #[test]
498    fn test_convert_app_command_to_client_message() {
499        let session_id = "test-session";
500
501        let command = AppCommand::ProcessUserInput("Hello".to_string());
502        let result = convert_app_command_to_client_message(command, session_id).unwrap();
503        assert!(result.is_some());
504
505        let command = AppCommand::Shutdown;
506        let result = convert_app_command_to_client_message(command, session_id).unwrap();
507        assert!(result.is_none());
508    }
509}