Skip to main content

steer_grpc/grpc/
runtime_server.rs

1use crate::grpc::conversions::{
2    environment_descriptor_to_proto, message_to_proto, model_to_proto, proto_to_model,
3    proto_to_session_policy_overrides, proto_to_tool_config, proto_to_workspace_config,
4    repo_info_to_proto, session_event_to_proto, stream_delta_to_proto, workspace_info_to_proto,
5    workspace_status_to_proto,
6};
7use std::cmp::Ordering as CmpOrdering;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12use steer_core::app::conversation::UserContent;
13use steer_core::app::domain::runtime::{RuntimeError, RuntimeHandle};
14use steer_core::app::domain::session::{SessionCatalog, SessionFilter};
15use steer_core::app::domain::types::SessionId;
16use steer_core::auth::api_key::ApiKeyAuthFlow;
17use steer_core::auth::{
18    AuthFlowWrapper, AuthMethod, AuthSource, DynAuthenticationFlow, ModelId as AuthModelId,
19    ModelVisibilityPolicy, ProviderId as AuthProviderId,
20};
21use steer_core::session::state::SessionConfig;
22use steer_proto::agent::v1::{
23    self as proto, ApproveToolRequest, ApproveToolResponse, CancelOperationRequest,
24    CancelOperationResponse, CompactSessionRequest, CompactSessionResponse, CreateSessionRequest,
25    CreateSessionResponse, DeleteSessionRequest, DeleteSessionResponse, DequeueQueuedItemRequest,
26    DequeueQueuedItemResponse, EditMessageRequest, EditMessageResponse, ExecuteBashCommandRequest,
27    ExecuteBashCommandResponse, GetConversationFooter, GetConversationRequest,
28    GetConversationResponse, GetMcpServersRequest, GetMcpServersResponse, GetSessionRequest,
29    GetSessionResponse, ListFilesRequest, ListFilesResponse, ListModelsRequest, ListModelsResponse,
30    ListProvidersRequest, ListProvidersResponse, ListSessionsRequest, ListSessionsResponse,
31    Operation, OperationStatus, OperationType, SendMessageRequest, SendMessageResponse,
32    SessionEvent, SessionInfo, SessionStateFooter, SessionStateHeader,
33    SubscribeSessionEventsRequest, SwitchPrimaryAgentRequest, SwitchPrimaryAgentResponse,
34    agent_service_server, get_conversation_response, get_session_response,
35};
36use steer_workspace::{EnvironmentManager, RepoManager, WorkspaceManager};
37use tokio::sync::{Mutex, broadcast, mpsc};
38use tokio_stream::wrappers::ReceiverStream;
39use tonic::{Request, Response, Status};
40use tracing::{debug, error, info, warn};
41use uuid::Uuid;
42
43pub struct RuntimeAgentService {
44    runtime: RuntimeHandle,
45    catalog: Arc<dyn SessionCatalog>,
46    model_registry: Arc<steer_core::model_registry::ModelRegistry>,
47    provider_registry: Arc<steer_core::auth::ProviderRegistry>,
48    llm_config_provider: steer_core::config::LlmConfigProvider,
49    environment_manager: Arc<dyn EnvironmentManager>,
50    workspace_manager: Arc<dyn WorkspaceManager>,
51    repo_manager: Arc<dyn RepoManager>,
52    auth_flow_manager: Arc<AuthFlowManager>,
53}
54
55const AUTH_FLOW_TTL: Duration = Duration::from_secs(10 * 60);
56
57struct AuthFlowEntry {
58    flow: Arc<dyn DynAuthenticationFlow>,
59    state: Box<dyn std::any::Any + Send + Sync>,
60    last_updated: Instant,
61}
62
63#[derive(Default)]
64struct AuthFlowManager {
65    flows: Mutex<HashMap<String, AuthFlowEntry>>,
66}
67
68impl AuthFlowManager {
69    fn new() -> Self {
70        Self::default()
71    }
72
73    async fn insert(&self, flow_id: String, entry: AuthFlowEntry) {
74        let mut flows = self.flows.lock().await;
75        flows.insert(flow_id, entry);
76    }
77
78    async fn take(&self, flow_id: &str) -> Option<AuthFlowEntry> {
79        let mut flows = self.flows.lock().await;
80        flows.remove(flow_id)
81    }
82
83    async fn cleanup(&self) {
84        let mut flows = self.flows.lock().await;
85        flows.retain(|_, entry| entry.last_updated.elapsed() <= AUTH_FLOW_TTL);
86    }
87}
88
89pub struct RuntimeAgentDeps {
90    pub runtime: RuntimeHandle,
91    pub catalog: Arc<dyn SessionCatalog>,
92    pub llm_config_provider: steer_core::config::LlmConfigProvider,
93    pub model_registry: Arc<steer_core::model_registry::ModelRegistry>,
94    pub provider_registry: Arc<steer_core::auth::ProviderRegistry>,
95    pub environment_manager: Arc<dyn EnvironmentManager>,
96    pub workspace_manager: Arc<dyn WorkspaceManager>,
97    pub repo_manager: Arc<dyn RepoManager>,
98}
99
100impl RuntimeAgentService {
101    pub fn new(deps: RuntimeAgentDeps) -> Self {
102        Self {
103            runtime: deps.runtime,
104            catalog: deps.catalog,
105            llm_config_provider: deps.llm_config_provider,
106            model_registry: deps.model_registry,
107            provider_registry: deps.provider_registry,
108            environment_manager: deps.environment_manager,
109            workspace_manager: deps.workspace_manager,
110            repo_manager: deps.repo_manager,
111            auth_flow_manager: Arc::new(AuthFlowManager::new()),
112        }
113    }
114
115    #[expect(clippy::result_large_err)]
116    fn parse_session_id(session_id: &str) -> Result<SessionId, Status> {
117        Uuid::parse_str(session_id)
118            .map(SessionId::from)
119            .map_err(|_| Status::invalid_argument(format!("Invalid session ID: {session_id}")))
120    }
121    fn parse_environment_id(
122        environment_id: &str,
123    ) -> Result<steer_workspace::EnvironmentId, Status> {
124        if environment_id.is_empty() {
125            return Ok(steer_workspace::EnvironmentId::local());
126        }
127        let id = Uuid::parse_str(environment_id).map_err(|_| {
128            Status::invalid_argument(format!("Invalid environment ID: {environment_id}"))
129        })?;
130        Ok(steer_workspace::EnvironmentId::from_uuid(id))
131    }
132
133    fn parse_workspace_id(workspace_id: &str) -> Result<steer_workspace::WorkspaceId, Status> {
134        let id = Uuid::parse_str(workspace_id).map_err(|_| {
135            Status::invalid_argument(format!("Invalid workspace ID: {workspace_id}"))
136        })?;
137        Ok(steer_workspace::WorkspaceId::from_uuid(id))
138    }
139
140    fn parse_repo_id(repo_id: &str) -> Result<steer_workspace::RepoId, Status> {
141        let id = Uuid::parse_str(repo_id)
142            .map_err(|_| Status::invalid_argument(format!("Invalid repo ID: {repo_id}")))?;
143        Ok(steer_workspace::RepoId::from_uuid(id))
144    }
145
146    fn select_default_model(&self) -> steer_core::config::model::ModelId {
147        let builtin_default = steer_core::config::model::builtin::default_model();
148
149        if let Some(config) = self.model_registry.get(&builtin_default)
150            && config.recommended
151        {
152            return builtin_default;
153        }
154
155        let mut recommended: Vec<_> = self.model_registry.recommended().collect();
156        if recommended.is_empty() {
157            return builtin_default;
158        }
159
160        recommended.sort_by(|a, b| {
161            let provider_cmp = a.provider.as_str().cmp(b.provider.as_str());
162            if provider_cmp == CmpOrdering::Equal {
163                a.id.cmp(&b.id)
164            } else {
165                provider_cmp
166            }
167        });
168
169        let chosen = recommended[0];
170        steer_core::config::model::ModelId::new(chosen.provider.clone(), chosen.id.clone())
171    }
172
173    fn workspace_manager_error_to_status(err: steer_workspace::WorkspaceManagerError) -> Status {
174        match err {
175            steer_workspace::WorkspaceManagerError::NotFound(msg) => Status::not_found(msg),
176            steer_workspace::WorkspaceManagerError::NotSupported(msg) => {
177                Status::failed_precondition(msg)
178            }
179            steer_workspace::WorkspaceManagerError::InvalidRequest(msg) => {
180                Status::invalid_argument(msg)
181            }
182            steer_workspace::WorkspaceManagerError::Io(msg)
183            | steer_workspace::WorkspaceManagerError::Other(msg) => Status::internal(msg),
184        }
185    }
186
187    fn environment_manager_error_to_status(
188        err: steer_workspace::EnvironmentManagerError,
189    ) -> Status {
190        match err {
191            steer_workspace::EnvironmentManagerError::NotFound(msg) => Status::not_found(msg),
192            steer_workspace::EnvironmentManagerError::NotSupported(msg) => {
193                Status::failed_precondition(msg)
194            }
195            steer_workspace::EnvironmentManagerError::InvalidRequest(msg) => {
196                Status::invalid_argument(msg)
197            }
198            steer_workspace::EnvironmentManagerError::Io(msg)
199            | steer_workspace::EnvironmentManagerError::Other(msg) => Status::internal(msg),
200        }
201    }
202
203    fn create_auth_flow(
204        &self,
205        provider_id: &steer_core::config::provider::ProviderId,
206    ) -> Result<(Arc<dyn DynAuthenticationFlow>, AuthMethod), Status> {
207        let provider_cfg = self.provider_registry.get(provider_id).ok_or_else(|| {
208            Status::not_found(format!("Unknown provider: {}", provider_id.as_str()))
209        })?;
210        let provider_name = provider_cfg.name.clone();
211        let auth_storage = self.llm_config_provider.auth_storage().clone();
212
213        if let Some(plugin) = self.llm_config_provider.plugin_registry().get(provider_id)
214            && let Some(flow) = plugin.create_flow(auth_storage.clone())
215        {
216            let methods = flow.available_methods();
217            let method = if methods.contains(&AuthMethod::OAuth) {
218                AuthMethod::OAuth
219            } else if methods.contains(&AuthMethod::ApiKey) {
220                AuthMethod::ApiKey
221            } else {
222                return Err(Status::failed_precondition(format!(
223                    "No supported auth methods for provider {}",
224                    provider_id.as_str()
225                )));
226            };
227            return Ok((Arc::from(flow), method));
228        }
229
230        let flow = AuthFlowWrapper::new(ApiKeyAuthFlow::new(
231            auth_storage,
232            provider_id.clone(),
233            provider_name,
234        ));
235        Ok((Arc::new(flow), AuthMethod::ApiKey))
236    }
237}
238
239#[tonic::async_trait]
240impl agent_service_server::AgentService for RuntimeAgentService {
241    type SubscribeSessionEventsStream = ReceiverStream<Result<SessionEvent, Status>>;
242    type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
243    type GetSessionStream =
244        std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
245    type GetConversationStream = std::pin::Pin<
246        Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
247    >;
248
249    async fn subscribe_session_events(
250        &self,
251        request: Request<SubscribeSessionEventsRequest>,
252    ) -> Result<Response<Self::SubscribeSessionEventsStream>, Status> {
253        let req = request.into_inner();
254        let session_id = Self::parse_session_id(&req.session_id)?;
255
256        if let Err(e) = self.runtime.resume_session(session_id).await
257            && !matches!(e, RuntimeError::SessionNotFound { .. })
258        {
259            error!("Failed to resume session {}: {}", session_id, e);
260        }
261
262        let subscription = self
263            .runtime
264            .subscribe_events(session_id)
265            .await
266            .map_err(|e| Status::internal(format!("Failed to subscribe: {e}")))?;
267
268        let delta_subscription = self
269            .runtime
270            .subscribe_deltas(session_id)
271            .await
272            .map_err(|e| Status::internal(format!("Failed to subscribe to deltas: {e}")))?;
273
274        let (tx, rx) = mpsc::channel(100);
275        let last_sequence = Arc::new(AtomicU64::new(req.since_sequence.unwrap_or(0)));
276        let delta_sequence = Arc::new(AtomicU64::new(0));
277
278        let mut min_live_seq = req.since_sequence.map(|seq| seq.saturating_add(1));
279
280        if let Some(after_seq) = req.since_sequence {
281            match self.runtime.load_events_after(session_id, after_seq).await {
282                Ok(events) => {
283                    let mut last_seq = after_seq;
284                    for (seq, event) in events {
285                        last_seq = last_seq.max(seq);
286                        let proto_event = match session_event_to_proto(event, seq) {
287                            Ok(event) => event,
288                            Err(e) => {
289                                warn!("Failed to convert session replay event: {}", e);
290                                continue;
291                            }
292                        };
293
294                        if proto_event.event.is_none() {
295                            continue;
296                        }
297
298                        if let Err(e) = tx.send(Ok(proto_event)).await {
299                            warn!("Failed to send replay event to client: {}", e);
300                            break;
301                        }
302                    }
303                    min_live_seq = Some(last_seq.saturating_add(1));
304                    last_sequence.store(last_seq, Ordering::Relaxed);
305                }
306                Err(e) => {
307                    warn!("Failed to load replay events: {}", e);
308                }
309            }
310        }
311
312        let event_tx = tx.clone();
313        let last_sequence_events = last_sequence.clone();
314        let delta_sequence_counter = delta_sequence.clone();
315        let min_live_seq = min_live_seq;
316        tokio::spawn(async move {
317            async fn send_delta(
318                delta: steer_core::app::domain::delta::StreamDelta,
319                tx: &mpsc::Sender<Result<proto::SessionEvent, Status>>,
320                last_sequence: &Arc<AtomicU64>,
321                delta_sequence: &Arc<AtomicU64>,
322            ) -> Result<(), ()> {
323                let sequence_num = last_sequence.load(Ordering::Relaxed);
324                let delta_sequence = delta_sequence.fetch_add(1, Ordering::Relaxed);
325                let proto_event = match stream_delta_to_proto(delta, sequence_num, delta_sequence) {
326                    Ok(event) => event,
327                    Err(e) => {
328                        warn!("Failed to convert stream delta: {}", e);
329                        return Ok(());
330                    }
331                };
332
333                if let Err(e) = tx.send(Ok(proto_event)).await {
334                    warn!("Failed to send delta to client: {}", e);
335                    return Err(());
336                }
337
338                Ok(())
339            }
340
341            let mut subscription = subscription;
342            let mut delta_rx = delta_subscription;
343            let mut events_closed = false;
344            let mut deltas_closed = false;
345
346            loop {
347                if events_closed && deltas_closed {
348                    break;
349                }
350
351                tokio::select! {
352                    envelope = subscription.recv(), if !events_closed => {
353                        match envelope {
354                            Some(envelope) => {
355                                loop {
356                                    match delta_rx.try_recv() {
357                                        Ok(delta) => {
358                                            if send_delta(
359                                                delta,
360                                                &event_tx,
361                                                &last_sequence_events,
362                                                &delta_sequence_counter,
363                                            )
364                                            .await
365                                            .is_err()
366                                            {
367                                                return;
368                                            }
369                                        }
370                                        Err(broadcast::error::TryRecvError::Empty) => break,
371                                        Err(broadcast::error::TryRecvError::Lagged(skipped)) => {
372                                            warn!("Delta subscription lagged by {} messages", skipped);
373                                        }
374                                        Err(broadcast::error::TryRecvError::Closed) => {
375                                            deltas_closed = true;
376                                            break;
377                                        }
378                                    }
379                                }
380
381                                if let Some(min_seq) = min_live_seq
382                                    && envelope.seq < min_seq {
383                                        continue;
384                                    }
385
386                                let proto_event = match session_event_to_proto(envelope.event, envelope.seq) {
387                                    Ok(event) => event,
388                                    Err(e) => {
389                                        warn!("Failed to convert session event: {}", e);
390                                        continue;
391                                    }
392                                };
393
394                                if proto_event.event.is_none() {
395                                    continue;
396                                }
397
398                                if let Err(e) = event_tx.send(Ok(proto_event)).await {
399                                    warn!("Failed to send event to client: {}", e);
400                                    break;
401                                }
402                                last_sequence_events.store(envelope.seq, Ordering::Relaxed);
403                            }
404                            None => {
405                                events_closed = true;
406                            }
407                        }
408                    }
409                    delta = delta_rx.recv(), if !deltas_closed => {
410                        match delta {
411                            Ok(delta) => {
412                                if send_delta(
413                                    delta,
414                                    &event_tx,
415                                    &last_sequence_events,
416                                    &delta_sequence_counter,
417                                )
418                                .await
419                                .is_err()
420                                {
421                                    break;
422                                }
423                            }
424                            Err(broadcast::error::RecvError::Lagged(skipped)) => {
425                                warn!("Delta subscription lagged by {} messages", skipped);
426                            }
427                            Err(broadcast::error::RecvError::Closed) => {
428                                deltas_closed = true;
429                            }
430                        }
431                    }
432                }
433            }
434            debug!("Event forwarding task ended for session: {}", session_id);
435        });
436
437        Ok(Response::new(ReceiverStream::new(rx)))
438    }
439
440    async fn create_session(
441        &self,
442        request: Request<CreateSessionRequest>,
443    ) -> Result<Response<CreateSessionResponse>, Status> {
444        let req = request.into_inner();
445
446        let default_model_spec = req
447            .default_model
448            .ok_or_else(|| Status::invalid_argument("Missing required default_model"))?;
449        let default_model = proto_to_model(&default_model_spec)
450            .map_err(|e| Status::invalid_argument(format!("Invalid default_model: {e}")))?;
451
452        let tool_config = req
453            .tool_config
454            .map(proto_to_tool_config)
455            .unwrap_or_default();
456
457        let workspace_config = req
458            .workspace_config
459            .map(proto_to_workspace_config)
460            .unwrap_or_default();
461
462        let policy_overrides = proto_to_session_policy_overrides(req.policy_overrides);
463
464        let mut workspace_id = None;
465        let mut workspace_ref = None;
466        let mut repo_ref = None;
467        let parent_session_id = None;
468        let mut workspace_name = None;
469
470        if repo_ref.is_none()
471            && let steer_core::session::state::WorkspaceConfig::Local { path } = &workspace_config
472        {
473            match self
474                .repo_manager
475                .resolve_repo(steer_workspace::EnvironmentId::local(), path)
476                .await
477            {
478                Ok(repo_info) => {
479                    repo_ref = Some(steer_workspace::RepoRef {
480                        environment_id: repo_info.environment_id,
481                        repo_id: repo_info.repo_id,
482                        root_path: repo_info.root_path.clone(),
483                        vcs_kind: repo_info.vcs_kind,
484                    });
485                    workspace_name = repo_info
486                        .root_path
487                        .file_name()
488                        .map(|n| n.to_string_lossy().into_owned());
489                }
490                Err(err) => {
491                    warn!("Failed to resolve repo for session: {}", err);
492                }
493            }
494        }
495
496        if workspace_id.is_none()
497            && workspace_ref.is_none()
498            && let steer_core::session::state::WorkspaceConfig::Local { path } = &workspace_config
499            && let Ok(info) = self.workspace_manager.resolve_workspace(path).await
500        {
501            workspace_id = Some(info.workspace_id);
502            workspace_ref = Some(steer_workspace::WorkspaceRef {
503                environment_id: info.environment_id,
504                workspace_id: info.workspace_id,
505                repo_id: info.repo_id,
506            });
507            workspace_name.clone_from(&info.name);
508        }
509
510        let session_config = SessionConfig {
511            workspace: workspace_config,
512            workspace_ref,
513            workspace_id,
514            repo_ref,
515            parent_session_id,
516            workspace_name,
517            tool_config,
518            system_prompt: None,
519            primary_agent_id: req.primary_agent_id,
520            policy_overrides,
521            metadata: req.metadata,
522            default_model,
523        };
524
525        match self.runtime.create_session(session_config.clone()).await {
526            Ok(session_id) => {
527                if let Err(e) = self
528                    .catalog
529                    .update_session_catalog(session_id, Some(&session_config), false, None)
530                    .await
531                {
532                    error!("Failed to update session catalog: {}", e);
533                    return Err(Status::internal(format!(
534                        "Failed to update session catalog: {e}"
535                    )));
536                }
537
538                let session_info = SessionInfo {
539                    id: session_id.to_string(),
540                    created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
541                    updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
542                    status: proto::SessionStatus::Active as i32,
543                    metadata: None,
544                };
545                Ok(Response::new(CreateSessionResponse {
546                    session: Some(session_info),
547                }))
548            }
549            Err(e) => {
550                error!("Failed to create session: {}", e);
551                Err(Status::internal(format!("Failed to create session: {e}")))
552            }
553        }
554    }
555
556    async fn list_sessions(
557        &self,
558        _request: Request<ListSessionsRequest>,
559    ) -> Result<Response<ListSessionsResponse>, Status> {
560        let filter = SessionFilter::default();
561
562        match self.catalog.list_sessions(filter).await {
563            Ok(sessions) => {
564                let proto_sessions = sessions
565                    .into_iter()
566                    .map(|s| SessionInfo {
567                        id: s.id.to_string(),
568                        created_at: Some(prost_types::Timestamp::from(
569                            std::time::SystemTime::from(s.created_at),
570                        )),
571                        updated_at: Some(prost_types::Timestamp::from(
572                            std::time::SystemTime::from(s.updated_at),
573                        )),
574                        status: proto::SessionStatus::Active as i32,
575                        metadata: None,
576                    })
577                    .collect();
578
579                Ok(Response::new(ListSessionsResponse {
580                    sessions: proto_sessions,
581                    next_page_token: None,
582                }))
583            }
584            Err(e) => {
585                error!("Failed to list sessions: {}", e);
586                Err(Status::internal(format!("Failed to list sessions: {e}")))
587            }
588        }
589    }
590
591    async fn get_session(
592        &self,
593        request: Request<GetSessionRequest>,
594    ) -> Result<Response<Self::GetSessionStream>, Status> {
595        let req = request.into_inner();
596        let session_id = Self::parse_session_id(&req.session_id)?;
597        let runtime = self.runtime.clone();
598        let catalog = self.catalog.clone();
599
600        let stream = async_stream::try_stream! {
601            if let Err(e) = runtime.resume_session(session_id).await
602                && matches!(e, RuntimeError::SessionNotFound { .. }) {
603                    Err(Status::not_found(format!("Session not found: {session_id}")))?;
604                    return;
605                }
606
607            let state = runtime.get_session_state(session_id).await
608                .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
609
610            let config = catalog.get_session_config(session_id).await
611                .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?;
612
613            yield GetSessionResponse {
614                chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
615                    id: session_id.to_string(),
616                    created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
617                    updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
618                    config: config.map(|c| crate::grpc::conversions::session_config_to_proto(&c)),
619                    last_event_sequence: state.event_sequence,
620                })),
621            };
622
623            for message in state.message_graph.messages {
624                let proto_msg = message_to_proto(message)
625                    .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
626                yield GetSessionResponse {
627                    chunk: Some(get_session_response::Chunk::Message(proto_msg)),
628                };
629            }
630
631            yield GetSessionResponse {
632                chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
633                    approved_tools: state.approved_tools.into_iter().collect(),
634                    metadata: std::collections::HashMap::new(),
635                    queued_head: state
636                        .queued_work
637                        .front()
638                        .map(|item| match item {
639                            steer_core::app::domain::state::QueuedWorkItem::UserMessage(message) => {
640                                proto::QueuedWorkItem {
641                                    kind: proto::queued_work_item::Kind::UserMessage as i32,
642                                    content: message
643                                        .content
644                                        .iter()
645                                        .filter_map(|item| match item {
646                                            UserContent::Text { text } => Some(text.as_str()),
647                                            _ => None,
648                                        })
649                                        .collect::<Vec<_>>()
650                                        .join("\n"),
651                                    model: Some(model_to_proto(message.model.clone())),
652                                    queued_at: message.queued_at,
653                                    op_id: message.op_id.to_string(),
654                                    message_id: message.message_id.to_string(),
655                                    attachment_count: message
656                                        .content
657                                        .iter()
658                                        .filter(|item| matches!(item, UserContent::Image { .. }))
659                                        .count() as u32,
660                                }
661                            }
662                            steer_core::app::domain::state::QueuedWorkItem::DirectBash(command) => {
663                                proto::QueuedWorkItem {
664                                    kind: proto::queued_work_item::Kind::DirectBash as i32,
665                                    content: command.command.clone(),
666                                    model: None,
667                                    queued_at: command.queued_at,
668                                    op_id: command.op_id.to_string(),
669                                    message_id: command.message_id.to_string(),
670                                    attachment_count: 0,
671                                }
672                            }
673                        }),
674                    queued_count: state.queued_work.len() as u32,
675                })),
676            };
677        };
678
679        Ok(Response::new(Box::pin(stream)))
680    }
681
682    async fn delete_session(
683        &self,
684        request: Request<DeleteSessionRequest>,
685    ) -> Result<Response<DeleteSessionResponse>, Status> {
686        let req = request.into_inner();
687        let session_id = Self::parse_session_id(&req.session_id)?;
688
689        match self.runtime.delete_session(session_id).await {
690            Ok(()) => Ok(Response::new(DeleteSessionResponse {})),
691            Err(RuntimeError::SessionNotFound { .. }) => Err(Status::not_found(format!(
692                "Session not found: {}",
693                req.session_id
694            ))),
695            Err(e) => {
696                error!("Failed to delete session: {}", e);
697                Err(Status::internal(format!("Failed to delete session: {e}")))
698            }
699        }
700    }
701
702    async fn get_conversation(
703        &self,
704        request: Request<GetConversationRequest>,
705    ) -> Result<Response<Self::GetConversationStream>, Status> {
706        let req = request.into_inner();
707        let session_id = Self::parse_session_id(&req.session_id)?;
708        let runtime = self.runtime.clone();
709
710        info!("GetConversation called for session: {}", session_id);
711
712        let stream = async_stream::try_stream! {
713            if let Err(e) = runtime.resume_session(session_id).await
714                && matches!(e, RuntimeError::SessionNotFound { .. }) {
715                    Err(Status::not_found(format!("Session not found: {session_id}")))?;
716                    return;
717                }
718
719            let state = runtime.get_session_state(session_id).await
720                .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
721
722            info!(
723                "Found session state with {} messages and {} approved tools",
724                state.message_graph.messages.len(),
725                state.approved_tools.len()
726            );
727
728            for msg in state.message_graph.messages {
729                let proto_msg = message_to_proto(msg)
730                    .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
731                yield GetConversationResponse {
732                    chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
733                };
734            }
735
736            yield GetConversationResponse {
737                chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
738                    approved_tools: state.approved_tools.into_iter().collect(),
739                })),
740            };
741        };
742
743        Ok(Response::new(Box::pin(stream)))
744    }
745
746    async fn send_message(
747        &self,
748        request: Request<SendMessageRequest>,
749    ) -> Result<Response<SendMessageResponse>, Status> {
750        let req = request.into_inner();
751        let session_id = Self::parse_session_id(&req.session_id)?;
752
753        let model = if let Some(model_spec) = req.model {
754            proto_to_model(&model_spec)
755                .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?
756        } else {
757            let config = self
758                .catalog
759                .get_session_config(session_id)
760                .await
761                .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
762                .ok_or_else(|| Status::not_found("Session config not found"))?;
763            config.default_model
764        };
765
766        let user_content: Vec<UserContent> = req
767            .content
768            .into_iter()
769            .filter_map(|item| match item.content {
770                Some(proto::user_content::Content::Text(text)) => Some(UserContent::Text { text }),
771                Some(proto::user_content::Content::CommandExecution(cmd)) => {
772                    Some(UserContent::CommandExecution {
773                        command: cmd.command,
774                        stdout: cmd.stdout,
775                        stderr: cmd.stderr,
776                        exit_code: cmd.exit_code,
777                    })
778                }
779                Some(proto::user_content::Content::Image(image)) => {
780                    let source = image.source.map(|source| match source {
781                        proto::image_content::Source::SessionFile(file) => {
782                            steer_core::app::conversation::ImageSource::SessionFile {
783                                relative_path: file.relative_path,
784                            }
785                        }
786                        proto::image_content::Source::DataUrl(data_url) => {
787                            steer_core::app::conversation::ImageSource::DataUrl {
788                                data_url: data_url.data_url,
789                            }
790                        }
791                        proto::image_content::Source::Url(url) => {
792                            steer_core::app::conversation::ImageSource::Url { url: url.url }
793                        }
794                    });
795
796                    source.map(|source| UserContent::Image {
797                        image: steer_core::app::conversation::ImageContent {
798                            mime_type: image.mime_type,
799                            source,
800                            width: image.width,
801                            height: image.height,
802                            bytes: image.bytes,
803                            sha256: image.sha256,
804                        },
805                    })
806                }
807                None => None,
808            })
809            .collect();
810
811        let fallback_text = req.message;
812        let content = if user_content.is_empty() {
813            vec![UserContent::Text {
814                text: fallback_text,
815            }]
816        } else {
817            user_content
818        };
819
820        let has_text = content
821            .iter()
822            .any(|item| matches!(item, UserContent::Text { text } if !text.trim().is_empty()));
823
824        let has_non_text = content
825            .iter()
826            .any(|item| !matches!(item, UserContent::Text { .. }));
827
828        if !has_text && !has_non_text {
829            return Err(Status::invalid_argument("Input text cannot be empty"));
830        }
831
832        match self
833            .runtime
834            .submit_user_input(session_id, content, model)
835            .await
836        {
837            Ok(op_id) => Ok(Response::new(SendMessageResponse {
838                operation: Some(Operation {
839                    id: op_id.to_string(),
840                    session_id: session_id.to_string(),
841                    r#type: OperationType::SendMessage as i32,
842                    status: OperationStatus::Running as i32,
843                    created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
844                    completed_at: None,
845                    metadata: std::collections::HashMap::new(),
846                }),
847            })),
848            Err(RuntimeError::InvalidInput { message }) => Err(Status::invalid_argument(message)),
849            Err(e) => {
850                error!("Failed to send message: {}", e);
851                Err(Status::internal(format!("Failed to send message: {e}")))
852            }
853        }
854    }
855
856    async fn edit_message(
857        &self,
858        request: Request<EditMessageRequest>,
859    ) -> Result<Response<EditMessageResponse>, Status> {
860        let req = request.into_inner();
861        let session_id = Self::parse_session_id(&req.session_id)?;
862
863        let model = if let Some(model_spec) = req.model {
864            proto_to_model(&model_spec)
865                .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?
866        } else {
867            let config = self
868                .catalog
869                .get_session_config(session_id)
870                .await
871                .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
872                .ok_or_else(|| Status::not_found("Session config not found"))?;
873            config.default_model
874        };
875
876        let user_content: Vec<UserContent> = req
877            .content
878            .into_iter()
879            .filter_map(|item| match item.content {
880                Some(proto::user_content::Content::Text(text)) => Some(UserContent::Text { text }),
881                Some(proto::user_content::Content::CommandExecution(cmd)) => {
882                    Some(UserContent::CommandExecution {
883                        command: cmd.command,
884                        stdout: cmd.stdout,
885                        stderr: cmd.stderr,
886                        exit_code: cmd.exit_code,
887                    })
888                }
889                Some(proto::user_content::Content::Image(image)) => {
890                    let source = image.source.map(|source| match source {
891                        proto::image_content::Source::SessionFile(file) => {
892                            steer_core::app::conversation::ImageSource::SessionFile {
893                                relative_path: file.relative_path,
894                            }
895                        }
896                        proto::image_content::Source::DataUrl(data_url) => {
897                            steer_core::app::conversation::ImageSource::DataUrl {
898                                data_url: data_url.data_url,
899                            }
900                        }
901                        proto::image_content::Source::Url(url) => {
902                            steer_core::app::conversation::ImageSource::Url { url: url.url }
903                        }
904                    });
905
906                    source.map(|source| UserContent::Image {
907                        image: steer_core::app::conversation::ImageContent {
908                            mime_type: image.mime_type,
909                            source,
910                            width: image.width,
911                            height: image.height,
912                            bytes: image.bytes,
913                            sha256: image.sha256,
914                        },
915                    })
916                }
917                None => None,
918            })
919            .collect();
920
921        let content = if user_content.is_empty() {
922            vec![UserContent::Text {
923                text: req.new_content,
924            }]
925        } else {
926            user_content
927        };
928
929        self.runtime
930            .submit_edited_message(session_id, req.message_id, content, model)
931            .await
932            .map_err(|e| match e {
933                RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
934                other => Status::internal(format!("Failed to edit message: {other}")),
935            })?;
936
937        Ok(Response::new(EditMessageResponse {}))
938    }
939
940    async fn dequeue_queued_item(
941        &self,
942        request: Request<DequeueQueuedItemRequest>,
943    ) -> Result<Response<DequeueQueuedItemResponse>, Status> {
944        let req = request.into_inner();
945        let session_id = Self::parse_session_id(&req.session_id)?;
946
947        self.runtime
948            .submit_dequeue_queued_item(session_id)
949            .await
950            .map_err(|e| match e {
951                RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
952                other => Status::internal(format!("Failed to dequeue queued item: {other}")),
953            })?;
954
955        Ok(Response::new(DequeueQueuedItemResponse {}))
956    }
957
958    async fn approve_tool(
959        &self,
960        request: Request<ApproveToolRequest>,
961    ) -> Result<Response<ApproveToolResponse>, Status> {
962        let req = request.into_inner();
963        let session_id = Self::parse_session_id(&req.session_id)?;
964
965        let request_id = Uuid::parse_str(&req.tool_call_id)
966            .map(steer_core::app::domain::types::RequestId::from)
967            .map_err(|_| Status::invalid_argument("Invalid tool call ID"))?;
968
969        let (approved, remember) = match req.decision {
970            Some(decision) => match decision.decision_type {
971                Some(proto::approval_decision::DecisionType::Deny(_)) => (false, None),
972                Some(proto::approval_decision::DecisionType::Once(_)) => (true, None),
973                Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => (
974                    true,
975                    Some(steer_core::app::domain::action::ApprovalMemory::PendingTool),
976                ),
977                Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)) => (
978                    true,
979                    Some(steer_core::app::domain::action::ApprovalMemory::BashPattern(pattern)),
980                ),
981                None => {
982                    return Err(Status::invalid_argument("Invalid approval decision"));
983                }
984            },
985            None => {
986                return Err(Status::invalid_argument("Missing approval decision"));
987            }
988        };
989
990        match self
991            .runtime
992            .submit_tool_approval(session_id, request_id, approved, remember)
993            .await
994        {
995            Ok(()) => Ok(Response::new(ApproveToolResponse {})),
996            Err(e) => {
997                error!("Failed to approve tool: {}", e);
998                Err(Status::internal(format!("Failed to approve tool: {e}")))
999            }
1000        }
1001    }
1002
1003    async fn switch_primary_agent(
1004        &self,
1005        request: Request<SwitchPrimaryAgentRequest>,
1006    ) -> Result<Response<SwitchPrimaryAgentResponse>, Status> {
1007        let req = request.into_inner();
1008        let session_id = Self::parse_session_id(&req.session_id)?;
1009
1010        self.runtime
1011            .switch_primary_agent(session_id, req.primary_agent_id)
1012            .await
1013            .map_err(|e| match e {
1014                RuntimeError::InvalidInput { message } => {
1015                    if message.contains("operation is active") {
1016                        Status::failed_precondition(message)
1017                    } else {
1018                        Status::invalid_argument(message)
1019                    }
1020                }
1021                other => Status::internal(format!("Failed to switch primary agent: {other}")),
1022            })?;
1023
1024        Ok(Response::new(SwitchPrimaryAgentResponse {}))
1025    }
1026
1027    async fn cancel_operation(
1028        &self,
1029        request: Request<CancelOperationRequest>,
1030    ) -> Result<Response<CancelOperationResponse>, Status> {
1031        let req = request.into_inner();
1032        let session_id = Self::parse_session_id(&req.session_id)?;
1033
1034        match self.runtime.cancel_operation(session_id, None).await {
1035            Ok(()) => Ok(Response::new(CancelOperationResponse {})),
1036            Err(e) => {
1037                error!("Failed to cancel operation: {}", e);
1038                Err(Status::internal(format!("Failed to cancel operation: {e}")))
1039            }
1040        }
1041    }
1042
1043    async fn compact_session(
1044        &self,
1045        request: Request<CompactSessionRequest>,
1046    ) -> Result<Response<CompactSessionResponse>, Status> {
1047        let req = request.into_inner();
1048        let session_id = Self::parse_session_id(&req.session_id)?;
1049        let model_spec = req
1050            .model
1051            .ok_or_else(|| Status::invalid_argument("Missing model spec"))?;
1052        let model = proto_to_model(&model_spec)
1053            .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?;
1054
1055        self.runtime
1056            .compact_session(session_id, model)
1057            .await
1058            .map_err(|e| match e {
1059                RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
1060                other => Status::internal(format!("Failed to compact session: {other}")),
1061            })?;
1062
1063        Ok(Response::new(CompactSessionResponse {}))
1064    }
1065
1066    async fn execute_bash_command(
1067        &self,
1068        request: Request<ExecuteBashCommandRequest>,
1069    ) -> Result<Response<ExecuteBashCommandResponse>, Status> {
1070        let req = request.into_inner();
1071        let session_id = Self::parse_session_id(&req.session_id)?;
1072
1073        self.runtime
1074            .execute_bash_command(session_id, req.command)
1075            .await
1076            .map_err(|e| Status::internal(format!("Failed to execute bash command: {e}")))?;
1077
1078        Ok(Response::new(ExecuteBashCommandResponse {}))
1079    }
1080
1081    async fn list_files(
1082        &self,
1083        request: Request<ListFilesRequest>,
1084    ) -> Result<Response<Self::ListFilesStream>, Status> {
1085        let req = request.into_inner();
1086        let session_id = Self::parse_session_id(&req.session_id)?;
1087
1088        debug!("ListFiles called for session: {}", session_id);
1089
1090        let config = self
1091            .catalog
1092            .get_session_config(session_id)
1093            .await
1094            .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
1095            .ok_or_else(|| Status::not_found(format!("Session not found: {session_id}")))?;
1096
1097        let workspace =
1098            steer_core::workspace::create_workspace(&config.workspace.to_workspace_config())
1099                .await
1100                .map_err(|e| Status::internal(format!("Failed to create workspace: {e}")))?;
1101
1102        let (tx, rx) = mpsc::channel(100);
1103
1104        let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
1105            let query = if req.query.is_empty() {
1106                None
1107            } else {
1108                Some(req.query.as_str())
1109            };
1110
1111            let max_results = if req.max_results == 0 {
1112                None
1113            } else {
1114                Some(req.max_results as usize)
1115            };
1116
1117            match workspace.list_files(query, max_results).await {
1118                Ok(files) => {
1119                    for chunk in files.chunks(1000) {
1120                        let response = ListFilesResponse {
1121                            paths: chunk.to_vec(),
1122                        };
1123
1124                        if let Err(e) = tx.send(Ok(response)).await {
1125                            warn!("Failed to send file list chunk: {}", e);
1126                            break;
1127                        }
1128                    }
1129                }
1130                Err(e) => {
1131                    error!("Failed to list files: {}", e);
1132                    let _ = tx
1133                        .send(Err(Status::internal(format!("Failed to list files: {e}"))))
1134                        .await;
1135                }
1136            }
1137        });
1138
1139        Ok(Response::new(ReceiverStream::new(rx)))
1140    }
1141
1142    async fn get_mcp_servers(
1143        &self,
1144        request: Request<GetMcpServersRequest>,
1145    ) -> Result<Response<GetMcpServersResponse>, Status> {
1146        let req = request.into_inner();
1147        let session_id = Self::parse_session_id(&req.session_id)?;
1148
1149        debug!("GetMcpServers called for session: {}", session_id);
1150
1151        let state = self
1152            .runtime
1153            .get_session_state(session_id)
1154            .await
1155            .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
1156
1157        let config = self
1158            .catalog
1159            .get_session_config(session_id)
1160            .await
1161            .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?;
1162
1163        let transport_map: std::collections::HashMap<String, &steer_core::tools::McpTransport> =
1164            config
1165                .as_ref()
1166                .map(|c| {
1167                    c.tool_config
1168                        .backends
1169                        .iter()
1170                        .map(|b| {
1171                            let steer_core::session::state::BackendConfig::Mcp {
1172                                server_name,
1173                                transport,
1174                                ..
1175                            } = b;
1176                            (server_name.clone(), transport)
1177                        })
1178                        .collect()
1179                })
1180                .unwrap_or_default();
1181
1182        let servers: Vec<proto::McpServerInfo> = state
1183            .mcp_servers
1184            .into_iter()
1185            .map(|(name, mcp_state)| {
1186                use crate::grpc::conversions::mcp_transport_to_proto;
1187                use steer_core::app::domain::action::McpServerState;
1188
1189                let state = match mcp_state {
1190                    McpServerState::Connecting => proto::McpConnectionState {
1191                        state: Some(proto::mcp_connection_state::State::Connecting(
1192                            proto::McpConnecting {},
1193                        )),
1194                    },
1195                    McpServerState::Connected { tools } => {
1196                        let tool_names = tools.iter().map(|t| t.name.clone()).collect();
1197                        proto::McpConnectionState {
1198                            state: Some(proto::mcp_connection_state::State::Connected(
1199                                proto::McpConnected { tool_names },
1200                            )),
1201                        }
1202                    }
1203                    McpServerState::Disconnected { error } => {
1204                        let error_msg = error.unwrap_or_else(|| "Disconnected".to_string());
1205                        proto::McpConnectionState {
1206                            state: Some(proto::mcp_connection_state::State::Failed(
1207                                proto::McpFailed { error: error_msg },
1208                            )),
1209                        }
1210                    }
1211                    McpServerState::Failed { error } => proto::McpConnectionState {
1212                        state: Some(proto::mcp_connection_state::State::Failed(
1213                            proto::McpFailed { error },
1214                        )),
1215                    },
1216                };
1217
1218                proto::McpServerInfo {
1219                    server_name: name.clone(),
1220                    transport: transport_map.get(&name).map(|t| mcp_transport_to_proto(t)),
1221                    state: Some(state),
1222                    last_updated: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
1223                }
1224            })
1225            .collect();
1226
1227        Ok(Response::new(GetMcpServersResponse { servers }))
1228    }
1229
1230    async fn list_providers(
1231        &self,
1232        _request: Request<ListProvidersRequest>,
1233    ) -> Result<Response<ListProvidersResponse>, Status> {
1234        let providers = self
1235            .provider_registry
1236            .all()
1237            .map(|p| proto::ProviderInfo {
1238                id: p.id.storage_key(),
1239                name: p.name.clone(),
1240            })
1241            .collect();
1242
1243        Ok(Response::new(ListProvidersResponse { providers }))
1244    }
1245
1246    async fn list_models(
1247        &self,
1248        request: Request<ListModelsRequest>,
1249    ) -> Result<Response<ListModelsResponse>, Status> {
1250        let req = request.into_inner();
1251
1252        let mut auth_sources: HashMap<steer_core::config::provider::ProviderId, AuthSource> =
1253            HashMap::new();
1254        let mut visibility_policies: HashMap<
1255            steer_core::config::provider::ProviderId,
1256            Option<Arc<dyn ModelVisibilityPolicy>>,
1257        > = HashMap::new();
1258
1259        let mut all_models = Vec::new();
1260
1261        for model in self.model_registry.recommended() {
1262            if let Some(ref provider_id) = req.provider_id
1263                && model.provider.storage_key() != *provider_id
1264            {
1265                continue;
1266            }
1267
1268            let provider_id = model.provider.clone();
1269
1270            let auth_source = if let Some(source) = auth_sources.get(&provider_id) {
1271                source.clone()
1272            } else {
1273                let source = match self
1274                    .llm_config_provider
1275                    .resolve_auth_source(&provider_id)
1276                    .await
1277                {
1278                    Ok(source) => source,
1279                    Err(err) => {
1280                        warn!(
1281                            "Failed to resolve auth source for provider {}: {err}",
1282                            provider_id.as_str()
1283                        );
1284                        AuthSource::None
1285                    }
1286                };
1287                auth_sources.insert(provider_id.clone(), source.clone());
1288                source
1289            };
1290
1291            let policy = visibility_policies
1292                .entry(provider_id.clone())
1293                .or_insert_with(|| {
1294                    self.llm_config_provider
1295                        .plugin_registry()
1296                        .get(&provider_id)
1297                        .and_then(|plugin| plugin.model_visibility().map(Arc::from))
1298                });
1299
1300            if let Some(policy) = policy {
1301                let auth_model_id = AuthModelId {
1302                    provider_id: AuthProviderId(provider_id.as_str().to_string()),
1303                    model_id: model.id.clone(),
1304                };
1305                if !policy.allow_model(&auth_model_id, &auth_source) {
1306                    continue;
1307                }
1308            }
1309
1310            all_models.push(proto::ProviderModel {
1311                provider_id: model.provider.storage_key(),
1312                model_id: model.id.clone(),
1313                display_name: model
1314                    .display_name
1315                    .clone()
1316                    .unwrap_or_else(|| model.id.clone()),
1317                supports_thinking: model
1318                    .parameters
1319                    .as_ref()
1320                    .and_then(|p| p.thinking_config.as_ref())
1321                    .is_some_and(|tc| tc.enabled),
1322                aliases: model.aliases.clone(),
1323            });
1324        }
1325
1326        Ok(Response::new(ListModelsResponse { models: all_models }))
1327    }
1328
1329    async fn get_provider_auth_status(
1330        &self,
1331        request: Request<proto::GetProviderAuthStatusRequest>,
1332    ) -> Result<Response<proto::GetProviderAuthStatusResponse>, Status> {
1333        let req = request.into_inner();
1334
1335        let mut statuses = Vec::new();
1336        for p in self.provider_registry.all() {
1337            if let Some(ref filter) = req.provider_id
1338                && &p.id.storage_key() != filter
1339            {
1340                continue;
1341            }
1342            let auth_source = self
1343                .llm_config_provider
1344                .resolve_auth_source(&p.id)
1345                .await
1346                .map_err(|e| Status::internal(format!("auth lookup failed: {e}")))?;
1347            let auth_source = crate::grpc::conversions::auth_source_to_proto(auth_source);
1348            statuses.push(proto::ProviderAuthStatus {
1349                provider_id: p.id.storage_key(),
1350                auth_source: Some(auth_source),
1351            });
1352        }
1353
1354        Ok(Response::new(proto::GetProviderAuthStatusResponse {
1355            statuses,
1356        }))
1357    }
1358
1359    async fn start_auth(
1360        &self,
1361        request: Request<proto::StartAuthRequest>,
1362    ) -> Result<Response<proto::StartAuthResponse>, Status> {
1363        self.auth_flow_manager.cleanup().await;
1364        let req = request.into_inner();
1365        let provider_id = steer_core::config::provider::ProviderId(req.provider_id);
1366
1367        let (flow, method) = self.create_auth_flow(&provider_id)?;
1368        let state = flow
1369            .start_auth(method)
1370            .await
1371            .map_err(|e| Status::internal(format!("auth start failed: {e}")))?;
1372        let progress = flow
1373            .get_initial_progress(&state, method)
1374            .await
1375            .map_err(|e| Status::internal(format!("auth progress failed: {e}")))?;
1376
1377        let flow_id = Uuid::new_v4().to_string();
1378        self.auth_flow_manager
1379            .insert(
1380                flow_id.clone(),
1381                AuthFlowEntry {
1382                    flow,
1383                    state,
1384                    last_updated: Instant::now(),
1385                },
1386            )
1387            .await;
1388
1389        Ok(Response::new(proto::StartAuthResponse {
1390            flow_id,
1391            progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1392        }))
1393    }
1394
1395    async fn send_auth_input(
1396        &self,
1397        request: Request<proto::SendAuthInputRequest>,
1398    ) -> Result<Response<proto::SendAuthInputResponse>, Status> {
1399        self.auth_flow_manager.cleanup().await;
1400        let req = request.into_inner();
1401        let flow_id = req.flow_id.clone();
1402
1403        let mut entry = self
1404            .auth_flow_manager
1405            .take(&flow_id)
1406            .await
1407            .ok_or_else(|| Status::not_found("Auth flow not found"))?;
1408
1409        let progress = entry
1410            .flow
1411            .handle_input(&mut entry.state, &req.input)
1412            .await
1413            .map_err(|e| Status::internal(format!("auth input failed: {e}")))?;
1414
1415        let done = matches!(
1416            progress,
1417            steer_core::auth::AuthProgress::Complete | steer_core::auth::AuthProgress::Error(_)
1418        );
1419
1420        if !done {
1421            entry.last_updated = Instant::now();
1422            self.auth_flow_manager.insert(flow_id, entry).await;
1423        }
1424
1425        Ok(Response::new(proto::SendAuthInputResponse {
1426            progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1427        }))
1428    }
1429
1430    async fn get_auth_progress(
1431        &self,
1432        request: Request<proto::GetAuthProgressRequest>,
1433    ) -> Result<Response<proto::GetAuthProgressResponse>, Status> {
1434        self.auth_flow_manager.cleanup().await;
1435        let req = request.into_inner();
1436        let flow_id = req.flow_id.clone();
1437
1438        let mut entry = self
1439            .auth_flow_manager
1440            .take(&flow_id)
1441            .await
1442            .ok_or_else(|| Status::not_found("Auth flow not found"))?;
1443
1444        let progress = entry
1445            .flow
1446            .handle_input(&mut entry.state, "")
1447            .await
1448            .map_err(|e| Status::internal(format!("auth progress failed: {e}")))?;
1449
1450        let done = matches!(
1451            progress,
1452            steer_core::auth::AuthProgress::Complete | steer_core::auth::AuthProgress::Error(_)
1453        );
1454
1455        if !done {
1456            entry.last_updated = Instant::now();
1457            self.auth_flow_manager.insert(flow_id, entry).await;
1458        }
1459
1460        Ok(Response::new(proto::GetAuthProgressResponse {
1461            progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1462        }))
1463    }
1464
1465    async fn cancel_auth(
1466        &self,
1467        request: Request<proto::CancelAuthRequest>,
1468    ) -> Result<Response<proto::CancelAuthResponse>, Status> {
1469        self.auth_flow_manager.cleanup().await;
1470        let req = request.into_inner();
1471        let flow_id = req.flow_id;
1472
1473        let _ = self.auth_flow_manager.take(&flow_id).await;
1474
1475        Ok(Response::new(proto::CancelAuthResponse {}))
1476    }
1477
1478    async fn resolve_model(
1479        &self,
1480        request: Request<proto::ResolveModelRequest>,
1481    ) -> Result<Response<proto::ResolveModelResponse>, Status> {
1482        let req = request.into_inner();
1483
1484        match self.model_registry.resolve(&req.input) {
1485            Ok(model_id) => {
1486                let steer_core::config::model::ModelId { provider, id } = model_id;
1487                let model_spec = proto::ModelSpec {
1488                    provider_id: provider.storage_key(),
1489                    model_id: id,
1490                };
1491                Ok(Response::new(proto::ResolveModelResponse {
1492                    model: Some(model_spec),
1493                }))
1494            }
1495            Err(e) => Err(Status::not_found(format!(
1496                "Failed to resolve model '{}': {}",
1497                req.input, e
1498            ))),
1499        }
1500    }
1501
1502    async fn get_default_model(
1503        &self,
1504        _request: Request<proto::GetDefaultModelRequest>,
1505    ) -> Result<Response<proto::GetDefaultModelResponse>, Status> {
1506        let model = self.select_default_model();
1507        Ok(Response::new(proto::GetDefaultModelResponse {
1508            model: Some(model_to_proto(model)),
1509        }))
1510    }
1511
1512    async fn create_workspace(
1513        &self,
1514        request: Request<proto::CreateWorkspaceRequest>,
1515    ) -> Result<Response<proto::CreateWorkspaceResponse>, Status> {
1516        let req = request.into_inner();
1517        let repo_id = Self::parse_repo_id(&req.repo_id)?;
1518        let parent_workspace_id = match req.parent_workspace_id {
1519            Some(value) => Some(Self::parse_workspace_id(&value)?),
1520            None => None,
1521        };
1522
1523        let strategy = match proto::WorkspaceCreateStrategy::try_from(req.strategy) {
1524            Ok(proto::WorkspaceCreateStrategy::JjWorkspace) => {
1525                steer_workspace::WorkspaceCreateStrategy::JjWorkspace
1526            }
1527            Ok(proto::WorkspaceCreateStrategy::GitWorktree) => {
1528                steer_workspace::WorkspaceCreateStrategy::GitWorktree
1529            }
1530            _ => {
1531                return Err(Status::invalid_argument(
1532                    "Unsupported workspace create strategy",
1533                ));
1534            }
1535        };
1536
1537        let request = steer_workspace::CreateWorkspaceRequest {
1538            repo_id,
1539            name: req.name,
1540            parent_workspace_id,
1541            strategy,
1542        };
1543
1544        let workspace = self
1545            .workspace_manager
1546            .create_workspace(request)
1547            .await
1548            .map_err(Self::workspace_manager_error_to_status)?;
1549
1550        Ok(Response::new(proto::CreateWorkspaceResponse {
1551            workspace: Some(workspace_info_to_proto(&workspace)),
1552        }))
1553    }
1554
1555    async fn resolve_repo(
1556        &self,
1557        request: Request<proto::ResolveRepoRequest>,
1558    ) -> Result<Response<proto::ResolveRepoResponse>, Status> {
1559        let req = request.into_inner();
1560        let environment_id = Self::parse_environment_id(&req.environment_id)?;
1561        let repo = self
1562            .repo_manager
1563            .resolve_repo(environment_id, std::path::Path::new(&req.path))
1564            .await
1565            .map_err(Self::workspace_manager_error_to_status)?;
1566
1567        Ok(Response::new(proto::ResolveRepoResponse {
1568            repo: Some(repo_info_to_proto(&repo)),
1569        }))
1570    }
1571
1572    async fn list_repos(
1573        &self,
1574        request: Request<proto::ListReposRequest>,
1575    ) -> Result<Response<proto::ListReposResponse>, Status> {
1576        let req = request.into_inner();
1577        let environment_id = Self::parse_environment_id(&req.environment_id)?;
1578        let repos = self
1579            .repo_manager
1580            .list_repos(environment_id)
1581            .await
1582            .map_err(Self::workspace_manager_error_to_status)?;
1583
1584        Ok(Response::new(proto::ListReposResponse {
1585            repos: repos.iter().map(repo_info_to_proto).collect(),
1586        }))
1587    }
1588
1589    async fn list_workspaces(
1590        &self,
1591        request: Request<proto::ListWorkspacesRequest>,
1592    ) -> Result<Response<proto::ListWorkspacesResponse>, Status> {
1593        let req = request.into_inner();
1594        let environment_id = Self::parse_environment_id(&req.environment_id)?;
1595
1596        let workspaces = self
1597            .workspace_manager
1598            .list_workspaces(steer_workspace::ListWorkspacesRequest { environment_id })
1599            .await
1600            .map_err(Self::workspace_manager_error_to_status)?;
1601
1602        Ok(Response::new(proto::ListWorkspacesResponse {
1603            workspaces: workspaces.iter().map(workspace_info_to_proto).collect(),
1604        }))
1605    }
1606
1607    async fn get_workspace_status(
1608        &self,
1609        request: Request<proto::GetWorkspaceStatusRequest>,
1610    ) -> Result<Response<proto::GetWorkspaceStatusResponse>, Status> {
1611        let req = request.into_inner();
1612        let workspace_id = Self::parse_workspace_id(&req.workspace_id)?;
1613
1614        let status = self
1615            .workspace_manager
1616            .get_workspace_status(workspace_id)
1617            .await
1618            .map_err(Self::workspace_manager_error_to_status)?;
1619
1620        Ok(Response::new(proto::GetWorkspaceStatusResponse {
1621            status: Some(workspace_status_to_proto(&status)),
1622        }))
1623    }
1624
1625    async fn delete_workspace(
1626        &self,
1627        request: Request<proto::DeleteWorkspaceRequest>,
1628    ) -> Result<Response<proto::DeleteWorkspaceResponse>, Status> {
1629        let req = request.into_inner();
1630        let workspace_id = Self::parse_workspace_id(&req.workspace_id)?;
1631
1632        self.workspace_manager
1633            .delete_workspace(steer_workspace::DeleteWorkspaceRequest { workspace_id })
1634            .await
1635            .map_err(Self::workspace_manager_error_to_status)?;
1636
1637        Ok(Response::new(proto::DeleteWorkspaceResponse {}))
1638    }
1639
1640    async fn create_environment(
1641        &self,
1642        request: Request<proto::CreateEnvironmentRequest>,
1643    ) -> Result<Response<proto::CreateEnvironmentResponse>, Status> {
1644        let req = request.into_inner();
1645        let request = steer_workspace::CreateEnvironmentRequest {
1646            root: req.root_path.map(std::path::PathBuf::from),
1647            name: req.name,
1648        };
1649
1650        let env = self
1651            .environment_manager
1652            .create_environment(request)
1653            .await
1654            .map_err(Self::environment_manager_error_to_status)?;
1655
1656        Ok(Response::new(proto::CreateEnvironmentResponse {
1657            environment: Some(environment_descriptor_to_proto(&env)),
1658        }))
1659    }
1660
1661    async fn get_environment(
1662        &self,
1663        request: Request<proto::GetEnvironmentRequest>,
1664    ) -> Result<Response<proto::GetEnvironmentResponse>, Status> {
1665        let req = request.into_inner();
1666        let environment_id = Self::parse_environment_id(&req.environment_id)?;
1667
1668        let env = self
1669            .environment_manager
1670            .get_environment(environment_id)
1671            .await
1672            .map_err(Self::environment_manager_error_to_status)?;
1673
1674        Ok(Response::new(proto::GetEnvironmentResponse {
1675            environment: Some(environment_descriptor_to_proto(&env)),
1676        }))
1677    }
1678
1679    async fn delete_environment(
1680        &self,
1681        request: Request<proto::DeleteEnvironmentRequest>,
1682    ) -> Result<Response<proto::DeleteEnvironmentResponse>, Status> {
1683        let req = request.into_inner();
1684        let environment_id = Self::parse_environment_id(&req.environment_id)?;
1685        let policy = match proto::EnvironmentDeletePolicy::try_from(req.policy) {
1686            Ok(proto::EnvironmentDeletePolicy::Soft) => {
1687                steer_workspace::EnvironmentDeletePolicy::Soft
1688            }
1689            Ok(proto::EnvironmentDeletePolicy::Hard) => {
1690                steer_workspace::EnvironmentDeletePolicy::Hard
1691            }
1692            _ => steer_workspace::EnvironmentDeletePolicy::Hard,
1693        };
1694
1695        self.environment_manager
1696            .delete_environment(environment_id, policy)
1697            .await
1698            .map_err(Self::environment_manager_error_to_status)?;
1699
1700        Ok(Response::new(proto::DeleteEnvironmentResponse {}))
1701    }
1702}