steer_grpc/grpc/
client_adapter.rs

1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11    convert_app_command_to_client_message, proto_to_mcp_server_info, proto_to_message,
12    server_event_to_app_event, session_tool_config_to_proto, tool_approval_policy_to_proto,
13    workspace_config_to_proto,
14};
15use crate::grpc::error::GrpcError;
16
17type GrpcResult<T> = std::result::Result<T, GrpcError>;
18
19use steer_core::app::conversation::Message;
20use steer_core::app::io::{AppCommandSink, AppEventSource};
21use steer_core::app::{AppCommand, AppEvent};
22use steer_core::session::{McpServerInfo, SessionConfig};
23use steer_proto::agent::v1::{
24    self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
25    GetMcpServersRequest, GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState,
26    StreamSessionRequest, SubscribeRequest, agent_service_client::AgentServiceClient,
27    stream_session_request::Message as StreamSessionRequestType,
28};
29
30/// Adapter that bridges TUI's AppCommand/AppEvent interface with gRPC streaming
31pub struct AgentClient {
32    client: Mutex<AgentServiceClient<Channel>>,
33    session_id: Mutex<Option<String>>,
34    command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
35    event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
36    stream_handle: Mutex<Option<JoinHandle<()>>>,
37}
38
39impl AgentClient {
40    /// Connect to a gRPC server
41    pub async fn connect(addr: &str) -> GrpcResult<Self> {
42        info!("Connecting to gRPC server at {}", addr);
43
44        let client = AgentServiceClient::connect(addr.to_string()).await?;
45
46        info!("Successfully connected to gRPC server");
47
48        Ok(Self {
49            client: Mutex::new(client),
50            session_id: Mutex::new(None),
51            command_tx: Mutex::new(None),
52            stream_handle: Mutex::new(None),
53            event_rx: Mutex::new(None),
54        })
55    }
56
57    /// Create client from an existing channel (for in-memory connections)
58    pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
59        info!("Creating gRPC client from provided channel");
60
61        let client = AgentServiceClient::new(channel);
62
63        Ok(Self {
64            client: Mutex::new(client),
65            session_id: Mutex::new(None),
66            command_tx: Mutex::new(None),
67            stream_handle: Mutex::new(None),
68            event_rx: Mutex::new(None),
69        })
70    }
71
72    /// Convenience constructor: spin up a localhost gRPC server and return a ready client.
73    pub async fn local(default_model: steer_core::config::model::ModelId) -> GrpcResult<Self> {
74        use crate::local_server::setup_local_grpc;
75        let (channel, _server_handle) = setup_local_grpc(default_model, None).await?;
76        Self::from_channel(channel).await
77    }
78
79    /// Create a new session on the server
80    pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
81        debug!("Creating new session with gRPC server");
82
83        let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
84        let workspace_config = workspace_config_to_proto(&config.workspace);
85        let tool_config = session_tool_config_to_proto(&config.tool_config);
86
87        let request = Request::new(CreateSessionRequest {
88            tool_policy: Some(tool_policy),
89            metadata: config.metadata,
90            tool_config: Some(tool_config),
91            workspace_config: Some(workspace_config),
92            system_prompt: config.system_prompt,
93        });
94
95        let response = self
96            .client
97            .lock()
98            .await
99            .create_session(request)
100            .await
101            .map_err(Box::new)?;
102        let response = response.into_inner();
103        let session = response
104            .session
105            .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
106
107        *self.session_id.lock().await = Some(session.id.clone());
108
109        info!("Created session: {}", session.id);
110        Ok(session.id)
111    }
112
113    /// Activate (load) an existing dormant session and get its state
114    pub async fn activate_session(
115        &self,
116        session_id: String,
117    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
118        info!("Activating remote session: {}", session_id);
119
120        let mut stream = self
121            .client
122            .lock()
123            .await
124            .activate_session(proto::ActivateSessionRequest {
125                session_id: session_id.clone(),
126            })
127            .await
128            .map_err(Box::new)?
129            .into_inner();
130
131        let mut messages = Vec::new();
132        let mut approved_tools = Vec::new();
133
134        while let Some(response) = stream
135            .message()
136            .await
137            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
138        {
139            match response.chunk {
140                Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
141                    match proto_to_message(proto_msg) {
142                        Ok(msg) => messages.push(msg),
143                        Err(e) => return Err(GrpcError::ConversionError(e)),
144                    }
145                }
146                Some(proto::activate_session_response::Chunk::Footer(footer)) => {
147                    approved_tools = footer.approved_tools;
148                }
149                None => {}
150            }
151        }
152
153        *self.session_id.lock().await = Some(session_id);
154        Ok((messages, approved_tools))
155    }
156
157    /// Start bidirectional streaming with the server
158    pub async fn start_streaming(&self) -> GrpcResult<()> {
159        let session_id = self
160            .session_id
161            .lock()
162            .await
163            .as_ref()
164            .cloned()
165            .ok_or_else(|| GrpcError::InvalidSessionState {
166                reason: "No session ID - call create_session or activate_session first".to_string(),
167            })?;
168
169        debug!("Starting bidirectional stream for session: {}", session_id);
170
171        // Create channels for command and event communication
172        let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
173        let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
174
175        // Create the bidirectional stream
176        let outbound_stream = ReceiverStream::new(cmd_rx);
177        let request = Request::new(outbound_stream);
178
179        let response = self
180            .client
181            .lock()
182            .await
183            .stream_session(request)
184            .await
185            .map_err(Box::new)?;
186        let mut inbound_stream = response.into_inner();
187
188        // Send initial subscribe message
189        let subscribe_msg = StreamSessionRequest {
190            session_id: session_id.clone(),
191            message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
192                event_types: vec![], // Subscribe to all events
193                since_sequence: None,
194            })),
195        };
196
197        cmd_tx
198            .send(subscribe_msg)
199            .await
200            .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
201
202        // Spawn task to handle incoming server events
203        let session_id_clone = session_id.clone();
204        let stream_handle = tokio::spawn(async move {
205            info!(
206                "Started event stream handler for session: {}",
207                session_id_clone
208            );
209
210            while let Some(result) = inbound_stream.message().await.transpose() {
211                match result {
212                    Ok(server_event) => {
213                        debug!(
214                            "Received server event: sequence {}",
215                            server_event.sequence_num
216                        );
217
218                        match server_event_to_app_event(server_event) {
219                            Ok(app_event) => {
220                                if let Err(e) = evt_tx.send(app_event).await {
221                                    warn!("Failed to forward event to TUI: {}", e);
222                                    break;
223                                }
224                            }
225                            Err(e) => {
226                                error!("Failed to convert server event: {}", e);
227                                // Continue processing other events instead of breaking
228                            }
229                        }
230                    }
231                    Err(e) => {
232                        error!("gRPC stream error: {}", e);
233                        break;
234                    }
235                }
236            }
237
238            info!(
239                "Event stream handler ended for session: {}",
240                session_id_clone
241            );
242        });
243
244        // Store the handles
245        *self.command_tx.lock().await = Some(cmd_tx);
246        *self.stream_handle.lock().await = Some(stream_handle);
247        // store receiver
248        *self.event_rx.lock().await = Some(evt_rx);
249
250        info!(
251            "Bidirectional streaming started for session: {}",
252            session_id
253        );
254        Ok(())
255    }
256
257    /// Send a command to the server
258    pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
259        let session_id = self
260            .session_id
261            .lock()
262            .await
263            .as_ref()
264            .cloned()
265            .ok_or_else(|| GrpcError::InvalidSessionState {
266                reason: "No active session".to_string(),
267            })?;
268
269        let command_tx = self
270            .command_tx
271            .lock()
272            .await
273            .as_ref()
274            .cloned()
275            .ok_or_else(|| GrpcError::InvalidSessionState {
276                reason: "Streaming not started - call start_streaming first".to_string(),
277            })?;
278
279        let message = convert_app_command_to_client_message(command, &session_id)?;
280
281        if let Some(message) = message {
282            command_tx.send(message).await.map_err(|_| {
283                GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
284            })?;
285        }
286
287        Ok(())
288    }
289
290    /// Get the current session ID
291    pub async fn session_id(&self) -> Option<String> {
292        self.session_id.lock().await.clone()
293    }
294
295    /// List sessions on the remote server
296    pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
297        debug!("Listing sessions from gRPC server");
298
299        let request = Request::new(ListSessionsRequest {
300            filter: None,
301            page_size: None,
302            page_token: None,
303        });
304
305        let response = self
306            .client
307            .lock()
308            .await
309            .list_sessions(request)
310            .await
311            .map_err(Box::new)?;
312        let sessions_response = response.into_inner();
313
314        Ok(sessions_response.sessions)
315    }
316
317    /// Get session details from the remote server
318    pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
319        debug!("Getting session {} from gRPC server", session_id);
320
321        let request = Request::new(GetSessionRequest {
322            session_id: session_id.to_string(),
323        });
324
325        let mut stream = self
326            .client
327            .lock()
328            .await
329            .get_session(request)
330            .await
331            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
332            .into_inner();
333
334        let mut header = None;
335        let mut messages = Vec::new();
336        let mut tool_calls = std::collections::HashMap::new();
337        let mut footer = None;
338
339        while let Some(response) = stream
340            .message()
341            .await
342            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
343        {
344            match response.chunk {
345                Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
346                Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
347                Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
348                    if let Some(value) = tc.value {
349                        tool_calls.insert(tc.key, value);
350                    }
351                }
352                Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
353                None => {}
354            }
355        }
356
357        match (header, footer) {
358            (Some(h), Some(f)) => Ok(Some(SessionState {
359                id: h.id,
360                created_at: h.created_at,
361                updated_at: h.updated_at,
362                config: h.config,
363                messages,
364                tool_calls,
365                approved_tools: f.approved_tools,
366                last_event_sequence: f.last_event_sequence,
367                metadata: f.metadata,
368            })),
369            _ => Ok(None),
370        }
371    }
372
373    /// Delete a session on the remote server
374    pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
375        debug!("Deleting session {} from gRPC server", session_id);
376
377        let request = Request::new(DeleteSessionRequest {
378            session_id: session_id.to_string(),
379        });
380
381        match self.client.lock().await.delete_session(request).await {
382            Ok(_) => {
383                info!("Successfully deleted session: {}", session_id);
384                Ok(true)
385            }
386            Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
387            Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
388        }
389    }
390
391    /// Get the current conversation for a session
392    pub async fn get_conversation(
393        &self,
394        session_id: &str,
395    ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
396        info!(
397            "Client adapter getting conversation for session: {}",
398            session_id
399        );
400
401        let mut stream = self
402            .client
403            .lock()
404            .await
405            .get_conversation(GetConversationRequest {
406                session_id: session_id.to_string(),
407            })
408            .await
409            .map_err(Box::new)?
410            .into_inner();
411
412        let mut messages = Vec::new();
413        let mut approved_tools = Vec::new();
414
415        while let Some(response) = stream
416            .message()
417            .await
418            .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
419        {
420            match response.chunk {
421                Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
422                    match proto_to_message(proto_msg) {
423                        Ok(msg) => messages.push(msg),
424                        Err(e) => {
425                            warn!("Failed to convert message: {}", e);
426                            return Err(GrpcError::ConversionError(e));
427                        }
428                    }
429                }
430                Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
431                    approved_tools = footer.approved_tools;
432                }
433                None => {}
434            }
435        }
436
437        info!(
438            "Successfully converted {} messages from GetConversation response",
439            messages.len()
440        );
441
442        Ok((messages, approved_tools))
443    }
444
445    /// Shutdown the adapter and clean up resources
446    pub async fn shutdown(self) {
447        if let Some(handle) = self.stream_handle.lock().await.take() {
448            handle.abort();
449            let _ = handle.await;
450        }
451
452        if let Some(session_id) = &*self.session_id.lock().await {
453            info!("GrpcClientAdapter shut down for session: {}", session_id);
454        }
455    }
456
457    pub async fn get_mcp_servers(&self) -> GrpcResult<Vec<McpServerInfo>> {
458        let session_id = self
459            .session_id
460            .lock()
461            .await
462            .as_ref()
463            .cloned()
464            .ok_or_else(|| GrpcError::InvalidSessionState {
465                reason: "No active session".to_string(),
466            })?;
467
468        let request = Request::new(GetMcpServersRequest {
469            session_id: session_id.clone(),
470        });
471
472        let response = self
473            .client
474            .lock()
475            .await
476            .get_mcp_servers(request)
477            .await
478            .map_err(Box::new)?;
479
480        let servers = response
481            .into_inner()
482            .servers
483            .into_iter()
484            .filter_map(|s| proto_to_mcp_server_info(s).ok())
485            .collect();
486
487        Ok(servers)
488    }
489
490    /// Resolve a model string (alias or provider/model) to a ModelId
491    pub async fn resolve_model(
492        &self,
493        input: &str,
494    ) -> GrpcResult<steer_core::config::model::ModelId> {
495        let request = Request::new(proto::ResolveModelRequest {
496            input: input.to_string(),
497        });
498
499        let response = self
500            .client
501            .lock()
502            .await
503            .resolve_model(request)
504            .await
505            .map_err(Box::new)?;
506
507        let inner = response.into_inner();
508        let model_spec = inner.model.ok_or_else(|| GrpcError::InvalidSessionState {
509            reason: format!("Server returned no model for input '{input}'"),
510        })?;
511
512        // Convert proto ModelSpec to core ModelId
513        // Try to deserialize the provider string using serde (same as ModelRegistry does)
514        let provider_id: steer_core::config::provider::ProviderId =
515            serde_json::from_value(serde_json::Value::String(model_spec.provider_id.clone()))
516                .map_err(|_| GrpcError::InvalidSessionState {
517                    reason: format!(
518                        "Invalid provider ID from server: {}",
519                        model_spec.provider_id
520                    ),
521                })?;
522
523        Ok((provider_id, model_spec.model_id))
524    }
525
526    /// List providers from server
527    pub async fn list_providers(&self) -> GrpcResult<Vec<proto::ProviderInfo>> {
528        let request = Request::new(proto::ListProvidersRequest {});
529        let response = self
530            .client
531            .lock()
532            .await
533            .list_providers(request)
534            .await
535            .map_err(Box::new)?;
536        Ok(response.into_inner().providers)
537    }
538
539    /// Get provider auth status from server
540    pub async fn get_provider_auth_status(
541        &self,
542        provider_id: Option<String>,
543    ) -> GrpcResult<Vec<proto::ProviderAuthStatus>> {
544        let request = Request::new(proto::GetProviderAuthStatusRequest { provider_id });
545        let response = self
546            .client
547            .lock()
548            .await
549            .get_provider_auth_status(request)
550            .await
551            .map_err(Box::new)?;
552        Ok(response.into_inner().statuses)
553    }
554
555    /// List available models (only recommended ones)
556    pub async fn list_models(
557        &self,
558        provider_id: Option<String>,
559    ) -> GrpcResult<Vec<proto::ProviderModel>> {
560        let request = Request::new(proto::ListModelsRequest { provider_id });
561
562        let response = self
563            .client
564            .lock()
565            .await
566            .list_models(request)
567            .await
568            .map_err(Box::new)?;
569
570        Ok(response.into_inner().models)
571    }
572}
573
574#[async_trait]
575impl AppCommandSink for AgentClient {
576    async fn send_command(&self, command: AppCommand) -> Result<()> {
577        self.send_command(command)
578            .await
579            .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
580    }
581}
582
583#[async_trait]
584impl AppEventSource for AgentClient {
585    async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
586        // This is a blocking operation in a trait that doesn't support async
587        // We need to use block_on here
588        self.event_rx.lock().await.take().expect(
589            "Event receiver already taken - GrpcClientAdapter only supports single subscription",
590        )
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use crate::grpc::conversions::tool_approval_policy_to_proto;
598    use steer_core::session::ToolApprovalPolicy;
599    use steer_proto::agent::v1::tool_approval_policy::Policy;
600
601    #[test]
602    fn test_convert_tool_approval_policy() {
603        let policy = ToolApprovalPolicy::AlwaysAsk;
604        let proto_policy = tool_approval_policy_to_proto(&policy);
605        assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
606
607        let mut tools = std::collections::HashSet::new();
608        tools.insert("bash".to_string());
609        let policy = ToolApprovalPolicy::PreApproved { tools };
610        let proto_policy = tool_approval_policy_to_proto(&policy);
611        assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
612    }
613
614    #[test]
615    fn test_convert_app_command_to_client_message() {
616        let session_id = "test-session";
617
618        let command = AppCommand::ProcessUserInput("Hello".to_string());
619        let result = convert_app_command_to_client_message(command, session_id).unwrap();
620        assert!(result.is_some());
621
622        let command = AppCommand::Shutdown;
623        let result = convert_app_command_to_client_message(command, session_id).unwrap();
624        assert!(result.is_none());
625    }
626}