Skip to main content

steer_grpc/grpc/
client_adapter.rs

1use tokio::sync::{Mutex, mpsc};
2use tokio::task::JoinHandle;
3use tonic::Request;
4use tonic::transport::Channel;
5use tracing::{debug, error, info, warn};
6
7use crate::client_api::{
8    ClientEvent, CreateSessionParams, ProviderAuthStatus, ProviderInfo, StartAuthResponse,
9};
10use crate::grpc::conversions::{
11    model_to_proto, proto_to_client_event, proto_to_mcp_server_info, proto_to_message,
12    proto_to_provider_auth_status, proto_to_provider_info, proto_to_repo_info,
13    proto_to_start_auth_response, proto_to_workspace_info, proto_to_workspace_status,
14    session_policy_overrides_to_proto, session_tool_config_to_proto, workspace_config_to_proto,
15};
16use crate::grpc::error::GrpcError;
17
18type GrpcResult<T> = std::result::Result<T, GrpcError>;
19
20use steer_core::app::conversation::Message;
21use steer_core::session::McpServerInfo;
22use steer_proto::agent::v1::{
23    self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
24    GetDefaultModelRequest, GetMcpServersRequest, GetSessionRequest, GetWorkspaceStatusRequest,
25    ListReposRequest, ListSessionsRequest, ListWorkspacesRequest, ResolveRepoRequest, SessionInfo,
26    SessionState, agent_service_client::AgentServiceClient,
27};
28
29pub struct AgentClient {
30    client: Mutex<AgentServiceClient<Channel>>,
31    session_id: Mutex<Option<String>>,
32    client_event_tx: mpsc::Sender<ClientEvent>,
33    client_event_rx: Mutex<Option<mpsc::Receiver<ClientEvent>>>,
34    stream_handle: Mutex<Option<JoinHandle<()>>>,
35}
36
37impl AgentClient {
38    pub async fn connect(addr: &str) -> GrpcResult<Self> {
39        info!("Connecting to gRPC server at {}", addr);
40
41        let client = AgentServiceClient::connect(addr.to_string()).await?;
42
43        info!("Successfully connected to gRPC server");
44
45        let (client_event_tx, client_event_rx) = mpsc::channel::<ClientEvent>(100);
46
47        Ok(Self {
48            client: Mutex::new(client),
49            session_id: Mutex::new(None),
50            client_event_tx,
51            client_event_rx: Mutex::new(Some(client_event_rx)),
52            stream_handle: Mutex::new(None),
53        })
54    }
55
56    pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
57        info!("Creating gRPC client from provided channel");
58
59        let client = AgentServiceClient::new(channel);
60        let (client_event_tx, client_event_rx) = mpsc::channel::<ClientEvent>(100);
61
62        Ok(Self {
63            client: Mutex::new(client),
64            session_id: Mutex::new(None),
65            client_event_tx,
66            client_event_rx: Mutex::new(Some(client_event_rx)),
67            stream_handle: Mutex::new(None),
68        })
69    }
70
71    pub async fn local(default_model: steer_core::config::model::ModelId) -> GrpcResult<Self> {
72        use crate::local_server::setup_local_grpc;
73        let (channel, _server_handle) = setup_local_grpc(default_model, None, None).await?;
74        Self::from_channel(channel).await
75    }
76
77    pub async fn create_session(&self, params: CreateSessionParams) -> GrpcResult<String> {
78        debug!("Creating new session with gRPC server");
79
80        let workspace_config = workspace_config_to_proto(&params.workspace);
81        let tool_config = session_tool_config_to_proto(&params.tool_config);
82
83        let request = Request::new(CreateSessionRequest {
84            metadata: params.metadata,
85            tool_config: Some(tool_config),
86            workspace_config: Some(workspace_config),
87            default_model: Some(model_to_proto(params.default_model)),
88            primary_agent_id: params.primary_agent_id,
89            policy_overrides: Some(session_policy_overrides_to_proto(&params.policy_overrides)),
90            auto_compaction: None,
91        });
92
93        let response = self
94            .client
95            .lock()
96            .await
97            .create_session(request)
98            .await
99            .map_err(Box::new)?;
100        let response = response.into_inner();
101        let session = response
102            .session
103            .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
104
105        *self.session_id.lock().await = Some(session.id.clone());
106
107        info!("Created session: {}", session.id);
108        Ok(session.id)
109    }
110
111    pub async fn resume_session(
112        &self,
113        session_id: &str,
114    ) -> GrpcResult<(Vec<Message>, Vec<String>, Vec<String>)> {
115        let result = self.get_conversation(session_id).await?;
116        *self.session_id.lock().await = Some(session_id.to_string());
117        info!("Resumed session: {}", session_id);
118        Ok(result)
119    }
120
121    pub async fn subscribe_session_events(&self) -> GrpcResult<()> {
122        let session_id = self
123            .session_id
124            .lock()
125            .await
126            .as_ref()
127            .cloned()
128            .ok_or_else(|| GrpcError::InvalidSessionState {
129                reason: "No active session - call create_session or resume_session first"
130                    .to_string(),
131            })?;
132
133        debug!("Subscribing to session events for session: {}", session_id);
134
135        if let Some(handle) = self.stream_handle.lock().await.take() {
136            handle.abort();
137            let _ = handle.await;
138        }
139
140        let evt_tx = self.client_event_tx.clone();
141
142        let request = Request::new(proto::SubscribeSessionEventsRequest {
143            session_id: session_id.clone(),
144            since_sequence: None,
145        });
146
147        let mut inbound_stream = self
148            .client
149            .lock()
150            .await
151            .subscribe_session_events(request)
152            .await
153            .map_err(Box::new)?
154            .into_inner();
155
156        let session_id_clone = session_id.clone();
157        let stream_handle = tokio::spawn(async move {
158            info!(
159                "Started event subscription handler for session: {}",
160                session_id_clone
161            );
162
163            while let Some(result) = inbound_stream.message().await.transpose() {
164                match result {
165                    Ok(server_event) => match proto_to_client_event(server_event) {
166                        Ok(Some(client_event)) => {
167                            if let Err(e) = evt_tx.send(client_event).await {
168                                warn!("Failed to forward client event: {}", e);
169                                break;
170                            }
171                        }
172                        Ok(None) => {}
173                        Err(e) => {
174                            error!("Failed to convert server event: {}", e);
175                        }
176                    },
177                    Err(e) => {
178                        error!("gRPC stream error: {}", e);
179                        break;
180                    }
181                }
182            }
183
184            info!(
185                "Event subscription handler ended for session: {}",
186                session_id_clone
187            );
188        });
189
190        *self.stream_handle.lock().await = Some(stream_handle);
191
192        info!("Event subscription started for session: {}", session_id);
193        Ok(())
194    }
195
196    pub async fn send_message(
197        &self,
198        message: String,
199        model: steer_core::config::model::ModelId,
200    ) -> GrpcResult<()> {
201        self.send_content_message(
202            vec![crate::client_api::UserContent::Text { text: message }],
203            model,
204        )
205        .await
206    }
207
208    pub async fn send_content_message(
209        &self,
210        content: Vec<crate::client_api::UserContent>,
211        model: steer_core::config::model::ModelId,
212    ) -> GrpcResult<()> {
213        let session_id = self
214            .session_id
215            .lock()
216            .await
217            .as_ref()
218            .cloned()
219            .ok_or_else(|| GrpcError::InvalidSessionState {
220                reason: "No active session".to_string(),
221            })?;
222
223        let fallback_text = content
224            .iter()
225            .filter_map(|item| match item {
226                crate::client_api::UserContent::Text { text } => Some(text.as_str()),
227                _ => None,
228            })
229            .collect::<Vec<_>>()
230            .join("\n");
231
232        let proto_content: Vec<proto::UserContent> = content
233            .into_iter()
234            .map(|item| {
235                let content = match item {
236                    crate::client_api::UserContent::Text { text } => {
237                        Some(proto::user_content::Content::Text(text))
238                    }
239                    crate::client_api::UserContent::CommandExecution {
240                        command,
241                        stdout,
242                        stderr,
243                        exit_code,
244                    } => Some(proto::user_content::Content::CommandExecution(
245                        proto::CommandExecution {
246                            command,
247                            stdout,
248                            stderr,
249                            exit_code,
250                        },
251                    )),
252                    crate::client_api::UserContent::Image { image } => {
253                        let source = match image.source {
254                            crate::client_api::ImageSource::SessionFile { relative_path } => {
255                                Some(proto::image_content::Source::SessionFile(
256                                    proto::SessionFileSource { relative_path },
257                                ))
258                            }
259                            crate::client_api::ImageSource::DataUrl { data_url } => {
260                                Some(proto::image_content::Source::DataUrl(
261                                    proto::DataUrlSource { data_url },
262                                ))
263                            }
264                            crate::client_api::ImageSource::Url { url } => {
265                                Some(proto::image_content::Source::Url(proto::UrlSource { url }))
266                            }
267                        };
268
269                        Some(proto::user_content::Content::Image(proto::ImageContent {
270                            mime_type: image.mime_type,
271                            source,
272                            width: image.width,
273                            height: image.height,
274                            bytes: image.bytes,
275                            sha256: image.sha256,
276                        }))
277                    }
278                };
279                proto::UserContent { content }
280            })
281            .collect();
282
283        let steer_core::config::model::ModelId { provider, id } = model;
284        let request = Request::new(proto::SendMessageRequest {
285            session_id,
286            message: fallback_text,
287            content: proto_content,
288            model: Some(proto::ModelSpec {
289                provider_id: provider.storage_key(),
290                model_id: id,
291            }),
292        });
293
294        self.client
295            .lock()
296            .await
297            .send_message(request)
298            .await
299            .map_err(Box::new)?;
300
301        Ok(())
302    }
303
304    pub async fn edit_message(
305        &self,
306        message_id: String,
307        content: Vec<crate::client_api::UserContent>,
308        model: steer_core::config::model::ModelId,
309    ) -> GrpcResult<()> {
310        let session_id = self
311            .session_id
312            .lock()
313            .await
314            .as_ref()
315            .cloned()
316            .ok_or_else(|| GrpcError::InvalidSessionState {
317                reason: "No active session".to_string(),
318            })?;
319
320        let fallback_text = content
321            .iter()
322            .filter_map(|item| match item {
323                crate::client_api::UserContent::Text { text } => Some(text.as_str()),
324                _ => None,
325            })
326            .collect::<Vec<_>>()
327            .join("\n");
328
329        let proto_content: Vec<proto::UserContent> = content
330            .into_iter()
331            .map(|item| {
332                let content = match item {
333                    crate::client_api::UserContent::Text { text } => {
334                        Some(proto::user_content::Content::Text(text))
335                    }
336                    crate::client_api::UserContent::CommandExecution {
337                        command,
338                        stdout,
339                        stderr,
340                        exit_code,
341                    } => Some(proto::user_content::Content::CommandExecution(
342                        proto::CommandExecution {
343                            command,
344                            stdout,
345                            stderr,
346                            exit_code,
347                        },
348                    )),
349                    crate::client_api::UserContent::Image { image } => {
350                        let source = match image.source {
351                            crate::client_api::ImageSource::SessionFile { relative_path } => {
352                                Some(proto::image_content::Source::SessionFile(
353                                    proto::SessionFileSource { relative_path },
354                                ))
355                            }
356                            crate::client_api::ImageSource::DataUrl { data_url } => {
357                                Some(proto::image_content::Source::DataUrl(
358                                    proto::DataUrlSource { data_url },
359                                ))
360                            }
361                            crate::client_api::ImageSource::Url { url } => {
362                                Some(proto::image_content::Source::Url(proto::UrlSource { url }))
363                            }
364                        };
365
366                        Some(proto::user_content::Content::Image(proto::ImageContent {
367                            mime_type: image.mime_type,
368                            source,
369                            width: image.width,
370                            height: image.height,
371                            bytes: image.bytes,
372                            sha256: image.sha256,
373                        }))
374                    }
375                };
376                proto::UserContent { content }
377            })
378            .collect();
379
380        let steer_core::config::model::ModelId { provider, id } = model;
381        let request = Request::new(proto::EditMessageRequest {
382            session_id,
383            message_id,
384            new_content: fallback_text,
385            content: proto_content,
386            model: Some(proto::ModelSpec {
387                provider_id: provider.storage_key(),
388                model_id: id,
389            }),
390        });
391
392        self.client
393            .lock()
394            .await
395            .edit_message(request)
396            .await
397            .map_err(Box::new)?;
398
399        Ok(())
400    }
401
402    pub async fn approve_tool(
403        &self,
404        tool_call_id: String,
405        decision: crate::client_api::ApprovalDecision,
406    ) -> GrpcResult<()> {
407        let session_id = self
408            .session_id
409            .lock()
410            .await
411            .as_ref()
412            .cloned()
413            .ok_or_else(|| GrpcError::InvalidSessionState {
414                reason: "No active session".to_string(),
415            })?;
416
417        use crate::client_api::ApprovalDecision;
418        use proto::approval_decision::DecisionType;
419
420        let decision_type = match decision {
421            ApprovalDecision::Deny => DecisionType::Deny(true),
422            ApprovalDecision::Once => DecisionType::Once(true),
423            ApprovalDecision::AlwaysTool => DecisionType::AlwaysTool(true),
424            ApprovalDecision::AlwaysBashPattern(pattern) => {
425                DecisionType::AlwaysBashPattern(pattern)
426            }
427        };
428
429        let request = Request::new(proto::ApproveToolRequest {
430            session_id,
431            tool_call_id,
432            decision: Some(proto::ApprovalDecision {
433                decision_type: Some(decision_type),
434            }),
435        });
436
437        self.client
438            .lock()
439            .await
440            .approve_tool(request)
441            .await
442            .map_err(Box::new)?;
443
444        Ok(())
445    }
446
447    pub async fn switch_primary_agent(&self, primary_agent_id: String) -> GrpcResult<()> {
448        let session_id = self
449            .session_id
450            .lock()
451            .await
452            .as_ref()
453            .cloned()
454            .ok_or_else(|| GrpcError::InvalidSessionState {
455                reason: "No active session".to_string(),
456            })?;
457
458        let request = Request::new(proto::SwitchPrimaryAgentRequest {
459            session_id,
460            primary_agent_id,
461        });
462
463        self.client
464            .lock()
465            .await
466            .switch_primary_agent(request)
467            .await
468            .map_err(Box::new)?;
469
470        Ok(())
471    }
472
473    pub async fn cancel_operation(&self) -> GrpcResult<()> {
474        let session_id = self
475            .session_id
476            .lock()
477            .await
478            .as_ref()
479            .cloned()
480            .ok_or_else(|| GrpcError::InvalidSessionState {
481                reason: "No active session".to_string(),
482            })?;
483
484        let request = Request::new(proto::CancelOperationRequest { session_id });
485
486        self.client
487            .lock()
488            .await
489            .cancel_operation(request)
490            .await
491            .map_err(Box::new)?;
492
493        Ok(())
494    }
495
496    pub async fn compact_session(
497        &self,
498        model: steer_core::config::model::ModelId,
499    ) -> GrpcResult<()> {
500        let session_id = self
501            .session_id
502            .lock()
503            .await
504            .as_ref()
505            .cloned()
506            .ok_or_else(|| GrpcError::InvalidSessionState {
507                reason: "No active session".to_string(),
508            })?;
509
510        let request = Request::new(proto::CompactSessionRequest {
511            session_id,
512            model: Some(model_to_proto(model)),
513        });
514
515        self.client
516            .lock()
517            .await
518            .compact_session(request)
519            .await
520            .map_err(Box::new)?;
521
522        Ok(())
523    }
524
525    pub async fn execute_bash_command(&self, command: String) -> GrpcResult<()> {
526        let session_id = self
527            .session_id
528            .lock()
529            .await
530            .as_ref()
531            .cloned()
532            .ok_or_else(|| GrpcError::InvalidSessionState {
533                reason: "No active session".to_string(),
534            })?;
535
536        let request = Request::new(proto::ExecuteBashCommandRequest {
537            session_id,
538            command,
539        });
540
541        self.client
542            .lock()
543            .await
544            .execute_bash_command(request)
545            .await
546            .map_err(Box::new)?;
547
548        Ok(())
549    }
550
551    pub async fn dequeue_queued_item(&self) -> GrpcResult<()> {
552        let session_id = self
553            .session_id
554            .lock()
555            .await
556            .as_ref()
557            .cloned()
558            .ok_or_else(|| GrpcError::InvalidSessionState {
559                reason: "No active session".to_string(),
560            })?;
561
562        let request = Request::new(proto::DequeueQueuedItemRequest { session_id });
563
564        self.client
565            .lock()
566            .await
567            .dequeue_queued_item(request)
568            .await
569            .map_err(Box::new)?;
570
571        Ok(())
572    }
573
574    pub async fn subscribe_client_events(&self) -> GrpcResult<mpsc::Receiver<ClientEvent>> {
575        let mut guard = self.client_event_rx.lock().await;
576        if let Some(receiver) = guard.take() {
577            Ok(receiver)
578        } else {
579            let reason = "Client events already subscribed".to_string();
580            warn!("{reason}");
581            Err(GrpcError::InvalidSessionState { reason })
582        }
583    }
584
585    pub async fn session_id(&self) -> Option<String> {
586        self.session_id.lock().await.clone()
587    }
588
589    pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
590        debug!("Listing sessions from gRPC server");
591
592        let request = Request::new(ListSessionsRequest {
593            filter: None,
594            page_size: None,
595            page_token: None,
596        });
597
598        let response = self
599            .client
600            .lock()
601            .await
602            .list_sessions(request)
603            .await
604            .map_err(Box::new)?;
605        let sessions_response = response.into_inner();
606
607        Ok(sessions_response.sessions)
608    }
609
610    pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
611        debug!("Getting session {} from gRPC server", session_id);
612
613        let request = Request::new(GetSessionRequest {
614            session_id: session_id.to_string(),
615        });
616
617        let mut stream = self
618            .client
619            .lock()
620            .await
621            .get_session(request)
622            .await
623            .map_err(GrpcError::from)?
624            .into_inner();
625
626        let mut header = None;
627        let mut messages = Vec::new();
628        let mut footer = None;
629
630        while let Some(response) = stream.message().await.map_err(GrpcError::from)? {
631            match response.chunk {
632                Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
633                Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
634                Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
635                None => {}
636            }
637        }
638
639        match (header, footer) {
640            (Some(h), Some(f)) => Ok(Some(SessionState {
641                id: h.id,
642                created_at: h.created_at,
643                updated_at: h.updated_at,
644                config: h.config,
645                messages,
646                approved_tools: f.approved_tools,
647                last_event_sequence: h.last_event_sequence,
648                metadata: f.metadata,
649            })),
650            _ => Ok(None),
651        }
652    }
653
654    pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
655        debug!("Deleting session {} from gRPC server", session_id);
656
657        let request = Request::new(DeleteSessionRequest {
658            session_id: session_id.to_string(),
659        });
660
661        match self.client.lock().await.delete_session(request).await {
662            Ok(_) => {
663                info!("Successfully deleted session: {}", session_id);
664                Ok(true)
665            }
666            Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
667            Err(e) => Err(GrpcError::from(e)),
668        }
669    }
670
671    pub async fn get_conversation(
672        &self,
673        session_id: &str,
674    ) -> GrpcResult<(Vec<Message>, Vec<String>, Vec<String>)> {
675        info!(
676            "Client adapter getting conversation for session: {}",
677            session_id
678        );
679
680        let mut stream = self
681            .client
682            .lock()
683            .await
684            .get_conversation(GetConversationRequest {
685                session_id: session_id.to_string(),
686            })
687            .await
688            .map_err(Box::new)?
689            .into_inner();
690
691        let mut messages = Vec::new();
692        let mut approved_tools = Vec::new();
693        let mut compaction_summary_ids = Vec::new();
694
695        while let Some(response) = stream.message().await.map_err(GrpcError::from)? {
696            match response.chunk {
697                Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
698                    match proto_to_message(proto_msg) {
699                        Ok(msg) => messages.push(msg),
700                        Err(e) => {
701                            warn!("Failed to convert message: {}", e);
702                            return Err(GrpcError::ConversionError(e));
703                        }
704                    }
705                }
706                Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
707                    approved_tools = footer.approved_tools;
708                    compaction_summary_ids = footer.compaction_summary_ids;
709                }
710                None => {}
711            }
712        }
713
714        info!(
715            "Successfully converted {} messages from GetConversation response",
716            messages.len()
717        );
718
719        Ok((messages, approved_tools, compaction_summary_ids))
720    }
721
722    pub async fn shutdown(self) {
723        if let Some(handle) = self.stream_handle.lock().await.take() {
724            handle.abort();
725            let _ = handle.await;
726        }
727
728        if let Some(session_id) = &*self.session_id.lock().await {
729            info!("GrpcClientAdapter shut down for session: {}", session_id);
730        }
731    }
732
733    pub async fn get_mcp_servers(&self) -> GrpcResult<Vec<McpServerInfo>> {
734        let session_id = self
735            .session_id
736            .lock()
737            .await
738            .as_ref()
739            .cloned()
740            .ok_or_else(|| GrpcError::InvalidSessionState {
741                reason: "No active session".to_string(),
742            })?;
743
744        let request = Request::new(GetMcpServersRequest {
745            session_id: session_id.clone(),
746        });
747
748        let response = self
749            .client
750            .lock()
751            .await
752            .get_mcp_servers(request)
753            .await
754            .map_err(Box::new)?;
755
756        let servers = response
757            .into_inner()
758            .servers
759            .into_iter()
760            .filter_map(|s| proto_to_mcp_server_info(s).ok())
761            .collect();
762
763        Ok(servers)
764    }
765
766    pub async fn resolve_model(
767        &self,
768        input: &str,
769    ) -> GrpcResult<steer_core::config::model::ModelId> {
770        let request = Request::new(proto::ResolveModelRequest {
771            input: input.to_string(),
772        });
773
774        let response = self
775            .client
776            .lock()
777            .await
778            .resolve_model(request)
779            .await
780            .map_err(Box::new)?;
781
782        let inner = response.into_inner();
783        let model_spec = inner.model.ok_or_else(|| GrpcError::InvalidSessionState {
784            reason: format!("Server returned no model for input '{input}'"),
785        })?;
786
787        let provider_id: steer_core::config::provider::ProviderId =
788            serde_json::from_value(serde_json::Value::String(model_spec.provider_id.clone()))
789                .map_err(|_| GrpcError::InvalidSessionState {
790                    reason: format!(
791                        "Invalid provider ID from server: {}",
792                        model_spec.provider_id
793                    ),
794                })?;
795
796        Ok(steer_core::config::model::ModelId::new(
797            provider_id,
798            model_spec.model_id,
799        ))
800    }
801
802    pub async fn get_default_model(&self) -> GrpcResult<steer_core::config::model::ModelId> {
803        let request = Request::new(GetDefaultModelRequest {});
804
805        let response = self
806            .client
807            .lock()
808            .await
809            .get_default_model(request)
810            .await
811            .map_err(Box::new)?;
812
813        let inner = response.into_inner();
814        let model_spec = inner.model.ok_or_else(|| GrpcError::InvalidSessionState {
815            reason: "Server returned no default model".to_string(),
816        })?;
817
818        let provider_id: steer_core::config::provider::ProviderId =
819            serde_json::from_value(serde_json::Value::String(model_spec.provider_id.clone()))
820                .map_err(|_| GrpcError::InvalidSessionState {
821                    reason: format!(
822                        "Invalid provider ID from server: {}",
823                        model_spec.provider_id
824                    ),
825                })?;
826
827        Ok(steer_core::config::model::ModelId::new(
828            provider_id,
829            model_spec.model_id,
830        ))
831    }
832
833    pub async fn list_providers(&self) -> GrpcResult<Vec<ProviderInfo>> {
834        let request = Request::new(proto::ListProvidersRequest {});
835        let response = self
836            .client
837            .lock()
838            .await
839            .list_providers(request)
840            .await
841            .map_err(Box::new)?;
842        let providers = response
843            .into_inner()
844            .providers
845            .into_iter()
846            .map(proto_to_provider_info)
847            .collect::<Result<Vec<_>, _>>()?;
848        Ok(providers)
849    }
850
851    pub async fn get_provider_auth_status(
852        &self,
853        provider_id: Option<String>,
854    ) -> GrpcResult<Vec<ProviderAuthStatus>> {
855        let request = Request::new(proto::GetProviderAuthStatusRequest { provider_id });
856        let response = self
857            .client
858            .lock()
859            .await
860            .get_provider_auth_status(request)
861            .await
862            .map_err(Box::new)?;
863        let statuses = response
864            .into_inner()
865            .statuses
866            .into_iter()
867            .map(proto_to_provider_auth_status)
868            .collect::<Result<Vec<_>, _>>()?;
869        Ok(statuses)
870    }
871
872    pub async fn start_auth(&self, provider_id: String) -> GrpcResult<StartAuthResponse> {
873        let request = Request::new(proto::StartAuthRequest { provider_id });
874        let response = self
875            .client
876            .lock()
877            .await
878            .start_auth(request)
879            .await
880            .map_err(Box::new)?;
881        proto_to_start_auth_response(response.into_inner()).map_err(GrpcError::from)
882    }
883
884    pub async fn send_auth_input(
885        &self,
886        flow_id: String,
887        input: String,
888    ) -> GrpcResult<crate::client_api::AuthProgress> {
889        let request = Request::new(proto::SendAuthInputRequest { flow_id, input });
890        let response = self
891            .client
892            .lock()
893            .await
894            .send_auth_input(request)
895            .await
896            .map_err(Box::new)?;
897        let progress =
898            response
899                .into_inner()
900                .progress
901                .ok_or_else(|| GrpcError::InvalidSessionState {
902                    reason: "Missing auth progress in response".to_string(),
903                })?;
904        crate::grpc::conversions::proto_to_auth_progress(progress).map_err(GrpcError::from)
905    }
906
907    pub async fn get_auth_progress(
908        &self,
909        flow_id: String,
910    ) -> GrpcResult<crate::client_api::AuthProgress> {
911        let request = Request::new(proto::GetAuthProgressRequest { flow_id });
912        let response = self
913            .client
914            .lock()
915            .await
916            .get_auth_progress(request)
917            .await
918            .map_err(Box::new)?;
919        let progress =
920            response
921                .into_inner()
922                .progress
923                .ok_or_else(|| GrpcError::InvalidSessionState {
924                    reason: "Missing auth progress in response".to_string(),
925                })?;
926        crate::grpc::conversions::proto_to_auth_progress(progress).map_err(GrpcError::from)
927    }
928
929    pub async fn cancel_auth(&self, flow_id: String) -> GrpcResult<()> {
930        let request = Request::new(proto::CancelAuthRequest { flow_id });
931        self.client
932            .lock()
933            .await
934            .cancel_auth(request)
935            .await
936            .map_err(Box::new)?;
937        Ok(())
938    }
939
940    pub async fn list_models(
941        &self,
942        provider_id: Option<String>,
943    ) -> GrpcResult<Vec<proto::ProviderModel>> {
944        let request = Request::new(proto::ListModelsRequest { provider_id });
945
946        let response = self
947            .client
948            .lock()
949            .await
950            .list_models(request)
951            .await
952            .map_err(Box::new)?;
953
954        Ok(response.into_inner().models)
955    }
956
957    pub async fn list_workspace_files(&self) -> GrpcResult<Vec<String>> {
958        let session_id = self
959            .session_id
960            .lock()
961            .await
962            .as_ref()
963            .cloned()
964            .ok_or_else(|| GrpcError::InvalidSessionState {
965                reason: "No active session".to_string(),
966            })?;
967
968        let request = Request::new(proto::ListFilesRequest {
969            session_id,
970            query: String::new(),
971            max_results: 0,
972        });
973
974        let mut stream = self
975            .client
976            .lock()
977            .await
978            .list_files(request)
979            .await
980            .map_err(Box::new)?
981            .into_inner();
982
983        let mut all_files = Vec::new();
984        while let Some(response) = stream.message().await.map_err(Box::new)? {
985            all_files.extend(response.paths);
986        }
987
988        Ok(all_files)
989    }
990
991    pub async fn list_workspaces(
992        &self,
993        environment_id: Option<String>,
994    ) -> GrpcResult<Vec<steer_workspace::WorkspaceInfo>> {
995        let request = Request::new(ListWorkspacesRequest {
996            environment_id: environment_id.unwrap_or_default(),
997        });
998        let response = self
999            .client
1000            .lock()
1001            .await
1002            .list_workspaces(request)
1003            .await
1004            .map_err(Box::new)?;
1005
1006        let workspaces = response
1007            .into_inner()
1008            .workspaces
1009            .into_iter()
1010            .map(proto_to_workspace_info)
1011            .collect::<Result<Vec<_>, _>>()?;
1012
1013        Ok(workspaces)
1014    }
1015
1016    pub async fn list_repos(
1017        &self,
1018        environment_id: Option<String>,
1019    ) -> GrpcResult<Vec<steer_workspace::RepoInfo>> {
1020        let request = Request::new(ListReposRequest {
1021            environment_id: environment_id.unwrap_or_default(),
1022        });
1023        let response = self
1024            .client
1025            .lock()
1026            .await
1027            .list_repos(request)
1028            .await
1029            .map_err(Box::new)?;
1030
1031        let repos = response
1032            .into_inner()
1033            .repos
1034            .into_iter()
1035            .map(proto_to_repo_info)
1036            .collect::<Result<Vec<_>, _>>()?;
1037
1038        Ok(repos)
1039    }
1040
1041    pub async fn resolve_repo(
1042        &self,
1043        environment_id: Option<String>,
1044        path: String,
1045    ) -> GrpcResult<steer_workspace::RepoInfo> {
1046        let request = Request::new(ResolveRepoRequest {
1047            environment_id: environment_id.unwrap_or_default(),
1048            path,
1049        });
1050        let response = self
1051            .client
1052            .lock()
1053            .await
1054            .resolve_repo(request)
1055            .await
1056            .map_err(Box::new)?;
1057
1058        let repo = response
1059            .into_inner()
1060            .repo
1061            .ok_or_else(|| GrpcError::InvalidSessionState {
1062                reason: "Repo missing from response".to_string(),
1063            })?;
1064
1065        Ok(proto_to_repo_info(repo)?)
1066    }
1067
1068    pub async fn get_workspace_status(
1069        &self,
1070        workspace_id: &str,
1071    ) -> GrpcResult<steer_workspace::WorkspaceStatus> {
1072        let request = Request::new(GetWorkspaceStatusRequest {
1073            workspace_id: workspace_id.to_string(),
1074        });
1075
1076        let response = self
1077            .client
1078            .lock()
1079            .await
1080            .get_workspace_status(request)
1081            .await
1082            .map_err(Box::new)?;
1083
1084        let status =
1085            response
1086                .into_inner()
1087                .status
1088                .ok_or_else(|| GrpcError::InvalidSessionState {
1089                    reason: "Workspace status missing from response".to_string(),
1090                })?;
1091
1092        Ok(proto_to_workspace_status(status)?)
1093    }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098    use crate::grpc::conversions::tool_approval_policy_to_proto;
1099    use steer_core::session::{ApprovalRules, ToolApprovalPolicy, UnapprovedBehavior};
1100    use steer_proto::agent::v1::UnapprovedBehavior as ProtoBehavior;
1101
1102    #[test]
1103    fn test_convert_tool_approval_policy() {
1104        let policy = ToolApprovalPolicy::default();
1105        let proto_policy = tool_approval_policy_to_proto(&policy);
1106        assert_eq!(proto_policy.default_behavior, ProtoBehavior::Prompt as i32);
1107        assert!(proto_policy.preapproved.is_some());
1108
1109        let mut tools = std::collections::HashSet::new();
1110        tools.insert("bash".to_string());
1111        let policy = ToolApprovalPolicy {
1112            default_behavior: UnapprovedBehavior::Deny,
1113            preapproved: ApprovalRules {
1114                tools,
1115                per_tool: std::collections::HashMap::new(),
1116            },
1117        };
1118        let proto_policy = tool_approval_policy_to_proto(&policy);
1119        assert_eq!(proto_policy.default_behavior, ProtoBehavior::Deny as i32);
1120        let preapproved = proto_policy.preapproved.unwrap();
1121        assert!(preapproved.tools.contains(&"bash".to_string()));
1122
1123        let policy = ToolApprovalPolicy {
1124            default_behavior: UnapprovedBehavior::Allow,
1125            preapproved: ApprovalRules::default(),
1126        };
1127        let proto_policy = tool_approval_policy_to_proto(&policy);
1128        assert_eq!(proto_policy.default_behavior, ProtoBehavior::Allow as i32);
1129    }
1130}