1use crate::grpc::conversions::{
2 environment_descriptor_to_proto, message_to_proto, model_to_proto, proto_to_model,
3 proto_to_session_policy_overrides, proto_to_tool_config, proto_to_workspace_config,
4 repo_info_to_proto, session_event_to_proto, stream_delta_to_proto, workspace_info_to_proto,
5 workspace_status_to_proto,
6};
7use std::cmp::Ordering as CmpOrdering;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant};
12use steer_core::app::domain::runtime::{RuntimeError, RuntimeHandle};
13use steer_core::app::domain::session::{SessionCatalog, SessionFilter};
14use steer_core::app::domain::types::SessionId;
15use steer_core::auth::api_key::ApiKeyAuthFlow;
16use steer_core::auth::{
17 AuthFlowWrapper, AuthMethod, AuthSource, DynAuthenticationFlow, ModelId as AuthModelId,
18 ModelVisibilityPolicy, ProviderId as AuthProviderId,
19};
20use steer_core::session::state::SessionConfig;
21use steer_proto::agent::v1::{
22 self as proto, ApproveToolRequest, ApproveToolResponse, CancelOperationRequest,
23 CancelOperationResponse, CompactSessionRequest, CompactSessionResponse, CreateSessionRequest,
24 CreateSessionResponse, DeleteSessionRequest, DeleteSessionResponse, DequeueQueuedItemRequest,
25 DequeueQueuedItemResponse, EditMessageRequest, EditMessageResponse, ExecuteBashCommandRequest,
26 ExecuteBashCommandResponse, GetConversationFooter, GetConversationRequest,
27 GetConversationResponse, GetMcpServersRequest, GetMcpServersResponse, GetSessionRequest,
28 GetSessionResponse, ListFilesRequest, ListFilesResponse, ListModelsRequest, ListModelsResponse,
29 ListProvidersRequest, ListProvidersResponse, ListSessionsRequest, ListSessionsResponse,
30 Operation, OperationStatus, OperationType, SendMessageRequest, SendMessageResponse,
31 SessionEvent, SessionInfo, SessionStateFooter, SessionStateHeader,
32 SubscribeSessionEventsRequest, SwitchPrimaryAgentRequest, SwitchPrimaryAgentResponse,
33 agent_service_server, get_conversation_response, get_session_response,
34};
35use steer_workspace::{EnvironmentManager, RepoManager, WorkspaceManager};
36use tokio::sync::{Mutex, broadcast, mpsc};
37use tokio_stream::wrappers::ReceiverStream;
38use tonic::{Request, Response, Status};
39use tracing::{debug, error, info, warn};
40use uuid::Uuid;
41
42pub struct RuntimeAgentService {
43 runtime: RuntimeHandle,
44 catalog: Arc<dyn SessionCatalog>,
45 model_registry: Arc<steer_core::model_registry::ModelRegistry>,
46 provider_registry: Arc<steer_core::auth::ProviderRegistry>,
47 llm_config_provider: steer_core::config::LlmConfigProvider,
48 environment_manager: Arc<dyn EnvironmentManager>,
49 workspace_manager: Arc<dyn WorkspaceManager>,
50 repo_manager: Arc<dyn RepoManager>,
51 auth_flow_manager: Arc<AuthFlowManager>,
52}
53
54const AUTH_FLOW_TTL: Duration = Duration::from_secs(10 * 60);
55
56struct AuthFlowEntry {
57 flow: Arc<dyn DynAuthenticationFlow>,
58 state: Box<dyn std::any::Any + Send + Sync>,
59 last_updated: Instant,
60}
61
62#[derive(Default)]
63struct AuthFlowManager {
64 flows: Mutex<HashMap<String, AuthFlowEntry>>,
65}
66
67impl AuthFlowManager {
68 fn new() -> Self {
69 Self::default()
70 }
71
72 async fn insert(&self, flow_id: String, entry: AuthFlowEntry) {
73 let mut flows = self.flows.lock().await;
74 flows.insert(flow_id, entry);
75 }
76
77 async fn take(&self, flow_id: &str) -> Option<AuthFlowEntry> {
78 let mut flows = self.flows.lock().await;
79 flows.remove(flow_id)
80 }
81
82 async fn cleanup(&self) {
83 let mut flows = self.flows.lock().await;
84 flows.retain(|_, entry| entry.last_updated.elapsed() <= AUTH_FLOW_TTL);
85 }
86}
87
88pub struct RuntimeAgentDeps {
89 pub runtime: RuntimeHandle,
90 pub catalog: Arc<dyn SessionCatalog>,
91 pub llm_config_provider: steer_core::config::LlmConfigProvider,
92 pub model_registry: Arc<steer_core::model_registry::ModelRegistry>,
93 pub provider_registry: Arc<steer_core::auth::ProviderRegistry>,
94 pub environment_manager: Arc<dyn EnvironmentManager>,
95 pub workspace_manager: Arc<dyn WorkspaceManager>,
96 pub repo_manager: Arc<dyn RepoManager>,
97}
98
99impl RuntimeAgentService {
100 pub fn new(deps: RuntimeAgentDeps) -> Self {
101 Self {
102 runtime: deps.runtime,
103 catalog: deps.catalog,
104 llm_config_provider: deps.llm_config_provider,
105 model_registry: deps.model_registry,
106 provider_registry: deps.provider_registry,
107 environment_manager: deps.environment_manager,
108 workspace_manager: deps.workspace_manager,
109 repo_manager: deps.repo_manager,
110 auth_flow_manager: Arc::new(AuthFlowManager::new()),
111 }
112 }
113
114 #[expect(clippy::result_large_err)]
115 fn parse_session_id(session_id: &str) -> Result<SessionId, Status> {
116 Uuid::parse_str(session_id)
117 .map(SessionId::from)
118 .map_err(|_| Status::invalid_argument(format!("Invalid session ID: {session_id}")))
119 }
120 fn parse_environment_id(
121 environment_id: &str,
122 ) -> Result<steer_workspace::EnvironmentId, Status> {
123 if environment_id.is_empty() {
124 return Ok(steer_workspace::EnvironmentId::local());
125 }
126 let id = Uuid::parse_str(environment_id).map_err(|_| {
127 Status::invalid_argument(format!("Invalid environment ID: {environment_id}"))
128 })?;
129 Ok(steer_workspace::EnvironmentId::from_uuid(id))
130 }
131
132 fn parse_workspace_id(workspace_id: &str) -> Result<steer_workspace::WorkspaceId, Status> {
133 let id = Uuid::parse_str(workspace_id).map_err(|_| {
134 Status::invalid_argument(format!("Invalid workspace ID: {workspace_id}"))
135 })?;
136 Ok(steer_workspace::WorkspaceId::from_uuid(id))
137 }
138
139 fn parse_repo_id(repo_id: &str) -> Result<steer_workspace::RepoId, Status> {
140 let id = Uuid::parse_str(repo_id)
141 .map_err(|_| Status::invalid_argument(format!("Invalid repo ID: {repo_id}")))?;
142 Ok(steer_workspace::RepoId::from_uuid(id))
143 }
144
145 fn select_default_model(&self) -> steer_core::config::model::ModelId {
146 let builtin_default = steer_core::config::model::builtin::default_model();
147
148 if let Some(config) = self.model_registry.get(&builtin_default)
149 && config.recommended
150 {
151 return builtin_default;
152 }
153
154 let mut recommended: Vec<_> = self.model_registry.recommended().collect();
155 if recommended.is_empty() {
156 return builtin_default;
157 }
158
159 recommended.sort_by(|a, b| {
160 let provider_cmp = a.provider.as_str().cmp(b.provider.as_str());
161 if provider_cmp == CmpOrdering::Equal {
162 a.id.cmp(&b.id)
163 } else {
164 provider_cmp
165 }
166 });
167
168 let chosen = recommended[0];
169 steer_core::config::model::ModelId::new(chosen.provider.clone(), chosen.id.clone())
170 }
171
172 fn workspace_manager_error_to_status(err: steer_workspace::WorkspaceManagerError) -> Status {
173 match err {
174 steer_workspace::WorkspaceManagerError::NotFound(msg) => Status::not_found(msg),
175 steer_workspace::WorkspaceManagerError::NotSupported(msg) => {
176 Status::failed_precondition(msg)
177 }
178 steer_workspace::WorkspaceManagerError::InvalidRequest(msg) => {
179 Status::invalid_argument(msg)
180 }
181 steer_workspace::WorkspaceManagerError::Io(msg)
182 | steer_workspace::WorkspaceManagerError::Other(msg) => Status::internal(msg),
183 }
184 }
185
186 fn environment_manager_error_to_status(
187 err: steer_workspace::EnvironmentManagerError,
188 ) -> Status {
189 match err {
190 steer_workspace::EnvironmentManagerError::NotFound(msg) => Status::not_found(msg),
191 steer_workspace::EnvironmentManagerError::NotSupported(msg) => {
192 Status::failed_precondition(msg)
193 }
194 steer_workspace::EnvironmentManagerError::InvalidRequest(msg) => {
195 Status::invalid_argument(msg)
196 }
197 steer_workspace::EnvironmentManagerError::Io(msg)
198 | steer_workspace::EnvironmentManagerError::Other(msg) => Status::internal(msg),
199 }
200 }
201
202 fn create_auth_flow(
203 &self,
204 provider_id: &steer_core::config::provider::ProviderId,
205 ) -> Result<(Arc<dyn DynAuthenticationFlow>, AuthMethod), Status> {
206 let provider_cfg = self.provider_registry.get(provider_id).ok_or_else(|| {
207 Status::not_found(format!("Unknown provider: {}", provider_id.as_str()))
208 })?;
209 let provider_name = provider_cfg.name.clone();
210 let auth_storage = self.llm_config_provider.auth_storage().clone();
211
212 if let Some(plugin) = self.llm_config_provider.plugin_registry().get(provider_id)
213 && let Some(flow) = plugin.create_flow(auth_storage.clone())
214 {
215 let methods = flow.available_methods();
216 let method = if methods.contains(&AuthMethod::OAuth) {
217 AuthMethod::OAuth
218 } else if methods.contains(&AuthMethod::ApiKey) {
219 AuthMethod::ApiKey
220 } else {
221 return Err(Status::failed_precondition(format!(
222 "No supported auth methods for provider {}",
223 provider_id.as_str()
224 )));
225 };
226 return Ok((Arc::from(flow), method));
227 }
228
229 let flow = AuthFlowWrapper::new(ApiKeyAuthFlow::new(
230 auth_storage,
231 provider_id.clone(),
232 provider_name,
233 ));
234 Ok((Arc::new(flow), AuthMethod::ApiKey))
235 }
236}
237
238#[tonic::async_trait]
239impl agent_service_server::AgentService for RuntimeAgentService {
240 type SubscribeSessionEventsStream = ReceiverStream<Result<SessionEvent, Status>>;
241 type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
242 type GetSessionStream =
243 std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
244 type GetConversationStream = std::pin::Pin<
245 Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
246 >;
247
248 async fn subscribe_session_events(
249 &self,
250 request: Request<SubscribeSessionEventsRequest>,
251 ) -> Result<Response<Self::SubscribeSessionEventsStream>, Status> {
252 let req = request.into_inner();
253 let session_id = Self::parse_session_id(&req.session_id)?;
254
255 if let Err(e) = self.runtime.resume_session(session_id).await
256 && !matches!(e, RuntimeError::SessionNotFound { .. })
257 {
258 error!("Failed to resume session {}: {}", session_id, e);
259 }
260
261 let subscription = self
262 .runtime
263 .subscribe_events(session_id)
264 .await
265 .map_err(|e| Status::internal(format!("Failed to subscribe: {e}")))?;
266
267 let delta_subscription = self
268 .runtime
269 .subscribe_deltas(session_id)
270 .await
271 .map_err(|e| Status::internal(format!("Failed to subscribe to deltas: {e}")))?;
272
273 let (tx, rx) = mpsc::channel(100);
274 let last_sequence = Arc::new(AtomicU64::new(req.since_sequence.unwrap_or(0)));
275 let delta_sequence = Arc::new(AtomicU64::new(0));
276
277 let mut min_live_seq = req.since_sequence.map(|seq| seq.saturating_add(1));
278
279 if let Some(after_seq) = req.since_sequence {
280 match self.runtime.load_events_after(session_id, after_seq).await {
281 Ok(events) => {
282 let mut last_seq = after_seq;
283 for (seq, event) in events {
284 last_seq = last_seq.max(seq);
285 let proto_event = match session_event_to_proto(event, seq) {
286 Ok(event) => event,
287 Err(e) => {
288 warn!("Failed to convert session replay event: {}", e);
289 continue;
290 }
291 };
292
293 if proto_event.event.is_none() {
294 continue;
295 }
296
297 if let Err(e) = tx.send(Ok(proto_event)).await {
298 warn!("Failed to send replay event to client: {}", e);
299 break;
300 }
301 }
302 min_live_seq = Some(last_seq.saturating_add(1));
303 last_sequence.store(last_seq, Ordering::Relaxed);
304 }
305 Err(e) => {
306 warn!("Failed to load replay events: {}", e);
307 }
308 }
309 }
310
311 let event_tx = tx.clone();
312 let last_sequence_events = last_sequence.clone();
313 let delta_sequence_counter = delta_sequence.clone();
314 let min_live_seq = min_live_seq;
315 tokio::spawn(async move {
316 async fn send_delta(
317 delta: steer_core::app::domain::delta::StreamDelta,
318 tx: &mpsc::Sender<Result<proto::SessionEvent, Status>>,
319 last_sequence: &Arc<AtomicU64>,
320 delta_sequence: &Arc<AtomicU64>,
321 ) -> Result<(), ()> {
322 let sequence_num = last_sequence.load(Ordering::Relaxed);
323 let delta_sequence = delta_sequence.fetch_add(1, Ordering::Relaxed);
324 let proto_event = match stream_delta_to_proto(delta, sequence_num, delta_sequence) {
325 Ok(event) => event,
326 Err(e) => {
327 warn!("Failed to convert stream delta: {}", e);
328 return Ok(());
329 }
330 };
331
332 if let Err(e) = tx.send(Ok(proto_event)).await {
333 warn!("Failed to send delta to client: {}", e);
334 return Err(());
335 }
336
337 Ok(())
338 }
339
340 let mut subscription = subscription;
341 let mut delta_rx = delta_subscription;
342 let mut events_closed = false;
343 let mut deltas_closed = false;
344
345 loop {
346 if events_closed && deltas_closed {
347 break;
348 }
349
350 tokio::select! {
351 envelope = subscription.recv(), if !events_closed => {
352 match envelope {
353 Some(envelope) => {
354 loop {
355 match delta_rx.try_recv() {
356 Ok(delta) => {
357 if send_delta(
358 delta,
359 &event_tx,
360 &last_sequence_events,
361 &delta_sequence_counter,
362 )
363 .await
364 .is_err()
365 {
366 return;
367 }
368 }
369 Err(broadcast::error::TryRecvError::Empty) => break,
370 Err(broadcast::error::TryRecvError::Lagged(skipped)) => {
371 warn!("Delta subscription lagged by {} messages", skipped);
372 }
373 Err(broadcast::error::TryRecvError::Closed) => {
374 deltas_closed = true;
375 break;
376 }
377 }
378 }
379
380 if let Some(min_seq) = min_live_seq
381 && envelope.seq < min_seq {
382 continue;
383 }
384
385 let proto_event = match session_event_to_proto(envelope.event, envelope.seq) {
386 Ok(event) => event,
387 Err(e) => {
388 warn!("Failed to convert session event: {}", e);
389 continue;
390 }
391 };
392
393 if proto_event.event.is_none() {
394 continue;
395 }
396
397 if let Err(e) = event_tx.send(Ok(proto_event)).await {
398 warn!("Failed to send event to client: {}", e);
399 break;
400 }
401 last_sequence_events.store(envelope.seq, Ordering::Relaxed);
402 }
403 None => {
404 events_closed = true;
405 }
406 }
407 }
408 delta = delta_rx.recv(), if !deltas_closed => {
409 match delta {
410 Ok(delta) => {
411 if send_delta(
412 delta,
413 &event_tx,
414 &last_sequence_events,
415 &delta_sequence_counter,
416 )
417 .await
418 .is_err()
419 {
420 break;
421 }
422 }
423 Err(broadcast::error::RecvError::Lagged(skipped)) => {
424 warn!("Delta subscription lagged by {} messages", skipped);
425 }
426 Err(broadcast::error::RecvError::Closed) => {
427 deltas_closed = true;
428 }
429 }
430 }
431 }
432 }
433 debug!("Event forwarding task ended for session: {}", session_id);
434 });
435
436 Ok(Response::new(ReceiverStream::new(rx)))
437 }
438
439 async fn create_session(
440 &self,
441 request: Request<CreateSessionRequest>,
442 ) -> Result<Response<CreateSessionResponse>, Status> {
443 let req = request.into_inner();
444
445 let default_model_spec = req
446 .default_model
447 .ok_or_else(|| Status::invalid_argument("Missing required default_model"))?;
448 let default_model = proto_to_model(&default_model_spec)
449 .map_err(|e| Status::invalid_argument(format!("Invalid default_model: {e}")))?;
450
451 let tool_config = req
452 .tool_config
453 .map(proto_to_tool_config)
454 .unwrap_or_default();
455
456 let workspace_config = req
457 .workspace_config
458 .map(proto_to_workspace_config)
459 .unwrap_or_default();
460
461 let policy_overrides = proto_to_session_policy_overrides(req.policy_overrides);
462
463 let mut workspace_id = None;
464 let mut workspace_ref = None;
465 let mut repo_ref = None;
466 let parent_session_id = None;
467 let mut workspace_name = None;
468
469 if repo_ref.is_none()
470 && let steer_core::session::state::WorkspaceConfig::Local { path } = &workspace_config
471 {
472 match self
473 .repo_manager
474 .resolve_repo(steer_workspace::EnvironmentId::local(), path)
475 .await
476 {
477 Ok(repo_info) => {
478 repo_ref = Some(steer_workspace::RepoRef {
479 environment_id: repo_info.environment_id,
480 repo_id: repo_info.repo_id,
481 root_path: repo_info.root_path.clone(),
482 vcs_kind: repo_info.vcs_kind,
483 });
484 workspace_name = repo_info
485 .root_path
486 .file_name()
487 .map(|n| n.to_string_lossy().into_owned());
488 }
489 Err(err) => {
490 warn!("Failed to resolve repo for session: {}", err);
491 }
492 }
493 }
494
495 if workspace_id.is_none()
496 && workspace_ref.is_none()
497 && let steer_core::session::state::WorkspaceConfig::Local { path } = &workspace_config
498 && let Ok(info) = self.workspace_manager.resolve_workspace(path).await
499 {
500 workspace_id = Some(info.workspace_id);
501 workspace_ref = Some(steer_workspace::WorkspaceRef {
502 environment_id: info.environment_id,
503 workspace_id: info.workspace_id,
504 repo_id: info.repo_id,
505 });
506 workspace_name.clone_from(&info.name);
507 }
508
509 let session_config = SessionConfig {
510 workspace: workspace_config,
511 workspace_ref,
512 workspace_id,
513 repo_ref,
514 parent_session_id,
515 workspace_name,
516 tool_config,
517 system_prompt: None,
518 primary_agent_id: req.primary_agent_id,
519 policy_overrides,
520 metadata: req.metadata,
521 default_model,
522 };
523
524 match self.runtime.create_session(session_config.clone()).await {
525 Ok(session_id) => {
526 if let Err(e) = self
527 .catalog
528 .update_session_catalog(session_id, Some(&session_config), false, None)
529 .await
530 {
531 error!("Failed to update session catalog: {}", e);
532 return Err(Status::internal(format!(
533 "Failed to update session catalog: {e}"
534 )));
535 }
536
537 let session_info = SessionInfo {
538 id: session_id.to_string(),
539 created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
540 updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
541 status: proto::SessionStatus::Active as i32,
542 metadata: None,
543 };
544 Ok(Response::new(CreateSessionResponse {
545 session: Some(session_info),
546 }))
547 }
548 Err(e) => {
549 error!("Failed to create session: {}", e);
550 Err(Status::internal(format!("Failed to create session: {e}")))
551 }
552 }
553 }
554
555 async fn list_sessions(
556 &self,
557 _request: Request<ListSessionsRequest>,
558 ) -> Result<Response<ListSessionsResponse>, Status> {
559 let filter = SessionFilter::default();
560
561 match self.catalog.list_sessions(filter).await {
562 Ok(sessions) => {
563 let proto_sessions = sessions
564 .into_iter()
565 .map(|s| SessionInfo {
566 id: s.id.to_string(),
567 created_at: Some(prost_types::Timestamp::from(
568 std::time::SystemTime::from(s.created_at),
569 )),
570 updated_at: Some(prost_types::Timestamp::from(
571 std::time::SystemTime::from(s.updated_at),
572 )),
573 status: proto::SessionStatus::Active as i32,
574 metadata: None,
575 })
576 .collect();
577
578 Ok(Response::new(ListSessionsResponse {
579 sessions: proto_sessions,
580 next_page_token: None,
581 }))
582 }
583 Err(e) => {
584 error!("Failed to list sessions: {}", e);
585 Err(Status::internal(format!("Failed to list sessions: {e}")))
586 }
587 }
588 }
589
590 async fn get_session(
591 &self,
592 request: Request<GetSessionRequest>,
593 ) -> Result<Response<Self::GetSessionStream>, Status> {
594 let req = request.into_inner();
595 let session_id = Self::parse_session_id(&req.session_id)?;
596 let runtime = self.runtime.clone();
597 let catalog = self.catalog.clone();
598
599 let stream = async_stream::try_stream! {
600 if let Err(e) = runtime.resume_session(session_id).await
601 && matches!(e, RuntimeError::SessionNotFound { .. }) {
602 Err(Status::not_found(format!("Session not found: {session_id}")))?;
603 return;
604 }
605
606 let state = runtime.get_session_state(session_id).await
607 .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
608
609 let config = catalog.get_session_config(session_id).await
610 .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?;
611
612 yield GetSessionResponse {
613 chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
614 id: session_id.to_string(),
615 created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
616 updated_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
617 config: config.map(|c| crate::grpc::conversions::session_config_to_proto(&c)),
618 last_event_sequence: state.event_sequence,
619 })),
620 };
621
622 for message in state.message_graph.messages {
623 let proto_msg = message_to_proto(message)
624 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
625 yield GetSessionResponse {
626 chunk: Some(get_session_response::Chunk::Message(proto_msg)),
627 };
628 }
629
630 yield GetSessionResponse {
631 chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
632 approved_tools: state.approved_tools.into_iter().collect(),
633 metadata: std::collections::HashMap::new(),
634 queued_head: state
635 .queued_work
636 .front()
637 .map(|item| match item {
638 steer_core::app::domain::state::QueuedWorkItem::UserMessage(message) => {
639 proto::QueuedWorkItem {
640 kind: proto::queued_work_item::Kind::UserMessage as i32,
641 content: message.text.as_str().to_string(),
642 model: Some(model_to_proto(message.model.clone())),
643 queued_at: message.queued_at,
644 op_id: message.op_id.to_string(),
645 message_id: message.message_id.to_string(),
646 }
647 }
648 steer_core::app::domain::state::QueuedWorkItem::DirectBash(command) => {
649 proto::QueuedWorkItem {
650 kind: proto::queued_work_item::Kind::DirectBash as i32,
651 content: command.command.clone(),
652 model: None,
653 queued_at: command.queued_at,
654 op_id: command.op_id.to_string(),
655 message_id: command.message_id.to_string(),
656 }
657 }
658 }),
659 queued_count: state.queued_work.len() as u32,
660 })),
661 };
662 };
663
664 Ok(Response::new(Box::pin(stream)))
665 }
666
667 async fn delete_session(
668 &self,
669 request: Request<DeleteSessionRequest>,
670 ) -> Result<Response<DeleteSessionResponse>, Status> {
671 let req = request.into_inner();
672 let session_id = Self::parse_session_id(&req.session_id)?;
673
674 match self.runtime.delete_session(session_id).await {
675 Ok(()) => Ok(Response::new(DeleteSessionResponse {})),
676 Err(RuntimeError::SessionNotFound { .. }) => Err(Status::not_found(format!(
677 "Session not found: {}",
678 req.session_id
679 ))),
680 Err(e) => {
681 error!("Failed to delete session: {}", e);
682 Err(Status::internal(format!("Failed to delete session: {e}")))
683 }
684 }
685 }
686
687 async fn get_conversation(
688 &self,
689 request: Request<GetConversationRequest>,
690 ) -> Result<Response<Self::GetConversationStream>, Status> {
691 let req = request.into_inner();
692 let session_id = Self::parse_session_id(&req.session_id)?;
693 let runtime = self.runtime.clone();
694
695 info!("GetConversation called for session: {}", session_id);
696
697 let stream = async_stream::try_stream! {
698 if let Err(e) = runtime.resume_session(session_id).await
699 && matches!(e, RuntimeError::SessionNotFound { .. }) {
700 Err(Status::not_found(format!("Session not found: {session_id}")))?;
701 return;
702 }
703
704 let state = runtime.get_session_state(session_id).await
705 .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
706
707 info!(
708 "Found session state with {} messages and {} approved tools",
709 state.message_graph.messages.len(),
710 state.approved_tools.len()
711 );
712
713 for msg in state.message_graph.messages {
714 let proto_msg = message_to_proto(msg)
715 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
716 yield GetConversationResponse {
717 chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
718 };
719 }
720
721 yield GetConversationResponse {
722 chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
723 approved_tools: state.approved_tools.into_iter().collect(),
724 })),
725 };
726 };
727
728 Ok(Response::new(Box::pin(stream)))
729 }
730
731 async fn send_message(
732 &self,
733 request: Request<SendMessageRequest>,
734 ) -> Result<Response<SendMessageResponse>, Status> {
735 let req = request.into_inner();
736 let session_id = Self::parse_session_id(&req.session_id)?;
737
738 let model = if let Some(model_spec) = req.model {
739 proto_to_model(&model_spec)
740 .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?
741 } else {
742 let config = self
743 .catalog
744 .get_session_config(session_id)
745 .await
746 .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
747 .ok_or_else(|| Status::not_found("Session config not found"))?;
748 config.default_model
749 };
750
751 match self
752 .runtime
753 .submit_user_input(session_id, req.message, model)
754 .await
755 {
756 Ok(op_id) => Ok(Response::new(SendMessageResponse {
757 operation: Some(Operation {
758 id: op_id.to_string(),
759 session_id: session_id.to_string(),
760 r#type: OperationType::SendMessage as i32,
761 status: OperationStatus::Running as i32,
762 created_at: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
763 completed_at: None,
764 metadata: std::collections::HashMap::new(),
765 }),
766 })),
767 Err(RuntimeError::InvalidInput { message }) => Err(Status::invalid_argument(message)),
768 Err(e) => {
769 error!("Failed to send message: {}", e);
770 Err(Status::internal(format!("Failed to send message: {e}")))
771 }
772 }
773 }
774
775 async fn edit_message(
776 &self,
777 request: Request<EditMessageRequest>,
778 ) -> Result<Response<EditMessageResponse>, Status> {
779 let req = request.into_inner();
780 let session_id = Self::parse_session_id(&req.session_id)?;
781
782 let model = if let Some(model_spec) = req.model {
783 proto_to_model(&model_spec)
784 .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?
785 } else {
786 let config = self
787 .catalog
788 .get_session_config(session_id)
789 .await
790 .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
791 .ok_or_else(|| Status::not_found("Session config not found"))?;
792 config.default_model
793 };
794
795 self.runtime
796 .submit_edited_message(session_id, req.message_id, req.new_content, model)
797 .await
798 .map_err(|e| match e {
799 RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
800 other => Status::internal(format!("Failed to edit message: {other}")),
801 })?;
802
803 Ok(Response::new(EditMessageResponse {}))
804 }
805
806 async fn dequeue_queued_item(
807 &self,
808 request: Request<DequeueQueuedItemRequest>,
809 ) -> Result<Response<DequeueQueuedItemResponse>, Status> {
810 let req = request.into_inner();
811 let session_id = Self::parse_session_id(&req.session_id)?;
812
813 self.runtime
814 .submit_dequeue_queued_item(session_id)
815 .await
816 .map_err(|e| match e {
817 RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
818 other => Status::internal(format!("Failed to dequeue queued item: {other}")),
819 })?;
820
821 Ok(Response::new(DequeueQueuedItemResponse {}))
822 }
823
824 async fn approve_tool(
825 &self,
826 request: Request<ApproveToolRequest>,
827 ) -> Result<Response<ApproveToolResponse>, Status> {
828 let req = request.into_inner();
829 let session_id = Self::parse_session_id(&req.session_id)?;
830
831 let request_id = Uuid::parse_str(&req.tool_call_id)
832 .map(steer_core::app::domain::types::RequestId::from)
833 .map_err(|_| Status::invalid_argument("Invalid tool call ID"))?;
834
835 let (approved, remember) = match req.decision {
836 Some(decision) => match decision.decision_type {
837 Some(proto::approval_decision::DecisionType::Deny(_)) => (false, None),
838 Some(proto::approval_decision::DecisionType::Once(_)) => (true, None),
839 Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => (
840 true,
841 Some(steer_core::app::domain::action::ApprovalMemory::PendingTool),
842 ),
843 Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)) => (
844 true,
845 Some(steer_core::app::domain::action::ApprovalMemory::BashPattern(pattern)),
846 ),
847 None => {
848 return Err(Status::invalid_argument("Invalid approval decision"));
849 }
850 },
851 None => {
852 return Err(Status::invalid_argument("Missing approval decision"));
853 }
854 };
855
856 match self
857 .runtime
858 .submit_tool_approval(session_id, request_id, approved, remember)
859 .await
860 {
861 Ok(()) => Ok(Response::new(ApproveToolResponse {})),
862 Err(e) => {
863 error!("Failed to approve tool: {}", e);
864 Err(Status::internal(format!("Failed to approve tool: {e}")))
865 }
866 }
867 }
868
869 async fn switch_primary_agent(
870 &self,
871 request: Request<SwitchPrimaryAgentRequest>,
872 ) -> Result<Response<SwitchPrimaryAgentResponse>, Status> {
873 let req = request.into_inner();
874 let session_id = Self::parse_session_id(&req.session_id)?;
875
876 self.runtime
877 .switch_primary_agent(session_id, req.primary_agent_id)
878 .await
879 .map_err(|e| match e {
880 RuntimeError::InvalidInput { message } => {
881 if message.contains("operation is active") {
882 Status::failed_precondition(message)
883 } else {
884 Status::invalid_argument(message)
885 }
886 }
887 other => Status::internal(format!("Failed to switch primary agent: {other}")),
888 })?;
889
890 Ok(Response::new(SwitchPrimaryAgentResponse {}))
891 }
892
893 async fn cancel_operation(
894 &self,
895 request: Request<CancelOperationRequest>,
896 ) -> Result<Response<CancelOperationResponse>, Status> {
897 let req = request.into_inner();
898 let session_id = Self::parse_session_id(&req.session_id)?;
899
900 match self.runtime.cancel_operation(session_id, None).await {
901 Ok(()) => Ok(Response::new(CancelOperationResponse {})),
902 Err(e) => {
903 error!("Failed to cancel operation: {}", e);
904 Err(Status::internal(format!("Failed to cancel operation: {e}")))
905 }
906 }
907 }
908
909 async fn compact_session(
910 &self,
911 request: Request<CompactSessionRequest>,
912 ) -> Result<Response<CompactSessionResponse>, Status> {
913 let req = request.into_inner();
914 let session_id = Self::parse_session_id(&req.session_id)?;
915 let model_spec = req
916 .model
917 .ok_or_else(|| Status::invalid_argument("Missing model spec"))?;
918 let model = proto_to_model(&model_spec)
919 .map_err(|e| Status::invalid_argument(format!("Invalid model spec: {e}")))?;
920
921 self.runtime
922 .compact_session(session_id, model)
923 .await
924 .map_err(|e| match e {
925 RuntimeError::InvalidInput { message } => Status::failed_precondition(message),
926 other => Status::internal(format!("Failed to compact session: {other}")),
927 })?;
928
929 Ok(Response::new(CompactSessionResponse {}))
930 }
931
932 async fn execute_bash_command(
933 &self,
934 request: Request<ExecuteBashCommandRequest>,
935 ) -> Result<Response<ExecuteBashCommandResponse>, Status> {
936 let req = request.into_inner();
937 let session_id = Self::parse_session_id(&req.session_id)?;
938
939 self.runtime
940 .execute_bash_command(session_id, req.command)
941 .await
942 .map_err(|e| Status::internal(format!("Failed to execute bash command: {e}")))?;
943
944 Ok(Response::new(ExecuteBashCommandResponse {}))
945 }
946
947 async fn list_files(
948 &self,
949 request: Request<ListFilesRequest>,
950 ) -> Result<Response<Self::ListFilesStream>, Status> {
951 let req = request.into_inner();
952 let session_id = Self::parse_session_id(&req.session_id)?;
953
954 debug!("ListFiles called for session: {}", session_id);
955
956 let config = self
957 .catalog
958 .get_session_config(session_id)
959 .await
960 .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?
961 .ok_or_else(|| Status::not_found(format!("Session not found: {session_id}")))?;
962
963 let workspace =
964 steer_core::workspace::create_workspace(&config.workspace.to_workspace_config())
965 .await
966 .map_err(|e| Status::internal(format!("Failed to create workspace: {e}")))?;
967
968 let (tx, rx) = mpsc::channel(100);
969
970 let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
971 let query = if req.query.is_empty() {
972 None
973 } else {
974 Some(req.query.as_str())
975 };
976
977 let max_results = if req.max_results == 0 {
978 None
979 } else {
980 Some(req.max_results as usize)
981 };
982
983 match workspace.list_files(query, max_results).await {
984 Ok(files) => {
985 for chunk in files.chunks(1000) {
986 let response = ListFilesResponse {
987 paths: chunk.to_vec(),
988 };
989
990 if let Err(e) = tx.send(Ok(response)).await {
991 warn!("Failed to send file list chunk: {}", e);
992 break;
993 }
994 }
995 }
996 Err(e) => {
997 error!("Failed to list files: {}", e);
998 let _ = tx
999 .send(Err(Status::internal(format!("Failed to list files: {e}"))))
1000 .await;
1001 }
1002 }
1003 });
1004
1005 Ok(Response::new(ReceiverStream::new(rx)))
1006 }
1007
1008 async fn get_mcp_servers(
1009 &self,
1010 request: Request<GetMcpServersRequest>,
1011 ) -> Result<Response<GetMcpServersResponse>, Status> {
1012 let req = request.into_inner();
1013 let session_id = Self::parse_session_id(&req.session_id)?;
1014
1015 debug!("GetMcpServers called for session: {}", session_id);
1016
1017 let state = self
1018 .runtime
1019 .get_session_state(session_id)
1020 .await
1021 .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?;
1022
1023 let config = self
1024 .catalog
1025 .get_session_config(session_id)
1026 .await
1027 .map_err(|e| Status::internal(format!("Failed to get session config: {e}")))?;
1028
1029 let transport_map: std::collections::HashMap<String, &steer_core::tools::McpTransport> =
1030 config
1031 .as_ref()
1032 .map(|c| {
1033 c.tool_config
1034 .backends
1035 .iter()
1036 .map(|b| {
1037 let steer_core::session::state::BackendConfig::Mcp {
1038 server_name,
1039 transport,
1040 ..
1041 } = b;
1042 (server_name.clone(), transport)
1043 })
1044 .collect()
1045 })
1046 .unwrap_or_default();
1047
1048 let servers: Vec<proto::McpServerInfo> = state
1049 .mcp_servers
1050 .into_iter()
1051 .map(|(name, mcp_state)| {
1052 use crate::grpc::conversions::mcp_transport_to_proto;
1053 use steer_core::app::domain::action::McpServerState;
1054
1055 let state = match mcp_state {
1056 McpServerState::Connecting => proto::McpConnectionState {
1057 state: Some(proto::mcp_connection_state::State::Connecting(
1058 proto::McpConnecting {},
1059 )),
1060 },
1061 McpServerState::Connected { tools } => {
1062 let tool_names = tools.iter().map(|t| t.name.clone()).collect();
1063 proto::McpConnectionState {
1064 state: Some(proto::mcp_connection_state::State::Connected(
1065 proto::McpConnected { tool_names },
1066 )),
1067 }
1068 }
1069 McpServerState::Disconnected { error } => {
1070 let error_msg = error.unwrap_or_else(|| "Disconnected".to_string());
1071 proto::McpConnectionState {
1072 state: Some(proto::mcp_connection_state::State::Failed(
1073 proto::McpFailed { error: error_msg },
1074 )),
1075 }
1076 }
1077 McpServerState::Failed { error } => proto::McpConnectionState {
1078 state: Some(proto::mcp_connection_state::State::Failed(
1079 proto::McpFailed { error },
1080 )),
1081 },
1082 };
1083
1084 proto::McpServerInfo {
1085 server_name: name.clone(),
1086 transport: transport_map.get(&name).map(|t| mcp_transport_to_proto(t)),
1087 state: Some(state),
1088 last_updated: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
1089 }
1090 })
1091 .collect();
1092
1093 Ok(Response::new(GetMcpServersResponse { servers }))
1094 }
1095
1096 async fn list_providers(
1097 &self,
1098 _request: Request<ListProvidersRequest>,
1099 ) -> Result<Response<ListProvidersResponse>, Status> {
1100 let providers = self
1101 .provider_registry
1102 .all()
1103 .map(|p| proto::ProviderInfo {
1104 id: p.id.storage_key(),
1105 name: p.name.clone(),
1106 })
1107 .collect();
1108
1109 Ok(Response::new(ListProvidersResponse { providers }))
1110 }
1111
1112 async fn list_models(
1113 &self,
1114 request: Request<ListModelsRequest>,
1115 ) -> Result<Response<ListModelsResponse>, Status> {
1116 let req = request.into_inner();
1117
1118 let mut auth_sources: HashMap<steer_core::config::provider::ProviderId, AuthSource> =
1119 HashMap::new();
1120 let mut visibility_policies: HashMap<
1121 steer_core::config::provider::ProviderId,
1122 Option<Arc<dyn ModelVisibilityPolicy>>,
1123 > = HashMap::new();
1124
1125 let mut all_models = Vec::new();
1126
1127 for model in self.model_registry.recommended() {
1128 if let Some(ref provider_id) = req.provider_id
1129 && model.provider.storage_key() != *provider_id
1130 {
1131 continue;
1132 }
1133
1134 let provider_id = model.provider.clone();
1135
1136 let auth_source = if let Some(source) = auth_sources.get(&provider_id) {
1137 source.clone()
1138 } else {
1139 let source = match self
1140 .llm_config_provider
1141 .resolve_auth_source(&provider_id)
1142 .await
1143 {
1144 Ok(source) => source,
1145 Err(err) => {
1146 warn!(
1147 "Failed to resolve auth source for provider {}: {err}",
1148 provider_id.as_str()
1149 );
1150 AuthSource::None
1151 }
1152 };
1153 auth_sources.insert(provider_id.clone(), source.clone());
1154 source
1155 };
1156
1157 let policy = visibility_policies
1158 .entry(provider_id.clone())
1159 .or_insert_with(|| {
1160 self.llm_config_provider
1161 .plugin_registry()
1162 .get(&provider_id)
1163 .and_then(|plugin| plugin.model_visibility().map(Arc::from))
1164 });
1165
1166 if let Some(policy) = policy {
1167 let auth_model_id = AuthModelId {
1168 provider_id: AuthProviderId(provider_id.as_str().to_string()),
1169 model_id: model.id.clone(),
1170 };
1171 if !policy.allow_model(&auth_model_id, &auth_source) {
1172 continue;
1173 }
1174 }
1175
1176 all_models.push(proto::ProviderModel {
1177 provider_id: model.provider.storage_key(),
1178 model_id: model.id.clone(),
1179 display_name: model
1180 .display_name
1181 .clone()
1182 .unwrap_or_else(|| model.id.clone()),
1183 supports_thinking: model
1184 .parameters
1185 .as_ref()
1186 .and_then(|p| p.thinking_config.as_ref())
1187 .is_some_and(|tc| tc.enabled),
1188 aliases: model.aliases.clone(),
1189 });
1190 }
1191
1192 Ok(Response::new(ListModelsResponse { models: all_models }))
1193 }
1194
1195 async fn get_provider_auth_status(
1196 &self,
1197 request: Request<proto::GetProviderAuthStatusRequest>,
1198 ) -> Result<Response<proto::GetProviderAuthStatusResponse>, Status> {
1199 let req = request.into_inner();
1200
1201 let mut statuses = Vec::new();
1202 for p in self.provider_registry.all() {
1203 if let Some(ref filter) = req.provider_id
1204 && &p.id.storage_key() != filter
1205 {
1206 continue;
1207 }
1208 let auth_source = self
1209 .llm_config_provider
1210 .resolve_auth_source(&p.id)
1211 .await
1212 .map_err(|e| Status::internal(format!("auth lookup failed: {e}")))?;
1213 let auth_source = crate::grpc::conversions::auth_source_to_proto(auth_source);
1214 statuses.push(proto::ProviderAuthStatus {
1215 provider_id: p.id.storage_key(),
1216 auth_source: Some(auth_source),
1217 });
1218 }
1219
1220 Ok(Response::new(proto::GetProviderAuthStatusResponse {
1221 statuses,
1222 }))
1223 }
1224
1225 async fn start_auth(
1226 &self,
1227 request: Request<proto::StartAuthRequest>,
1228 ) -> Result<Response<proto::StartAuthResponse>, Status> {
1229 self.auth_flow_manager.cleanup().await;
1230 let req = request.into_inner();
1231 let provider_id = steer_core::config::provider::ProviderId(req.provider_id);
1232
1233 let (flow, method) = self.create_auth_flow(&provider_id)?;
1234 let state = flow
1235 .start_auth(method)
1236 .await
1237 .map_err(|e| Status::internal(format!("auth start failed: {e}")))?;
1238 let progress = flow
1239 .get_initial_progress(&state, method)
1240 .await
1241 .map_err(|e| Status::internal(format!("auth progress failed: {e}")))?;
1242
1243 let flow_id = Uuid::new_v4().to_string();
1244 self.auth_flow_manager
1245 .insert(
1246 flow_id.clone(),
1247 AuthFlowEntry {
1248 flow,
1249 state,
1250 last_updated: Instant::now(),
1251 },
1252 )
1253 .await;
1254
1255 Ok(Response::new(proto::StartAuthResponse {
1256 flow_id,
1257 progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1258 }))
1259 }
1260
1261 async fn send_auth_input(
1262 &self,
1263 request: Request<proto::SendAuthInputRequest>,
1264 ) -> Result<Response<proto::SendAuthInputResponse>, Status> {
1265 self.auth_flow_manager.cleanup().await;
1266 let req = request.into_inner();
1267 let flow_id = req.flow_id.clone();
1268
1269 let mut entry = self
1270 .auth_flow_manager
1271 .take(&flow_id)
1272 .await
1273 .ok_or_else(|| Status::not_found("Auth flow not found"))?;
1274
1275 let progress = entry
1276 .flow
1277 .handle_input(&mut entry.state, &req.input)
1278 .await
1279 .map_err(|e| Status::internal(format!("auth input failed: {e}")))?;
1280
1281 let done = matches!(
1282 progress,
1283 steer_core::auth::AuthProgress::Complete | steer_core::auth::AuthProgress::Error(_)
1284 );
1285
1286 if !done {
1287 entry.last_updated = Instant::now();
1288 self.auth_flow_manager.insert(flow_id, entry).await;
1289 }
1290
1291 Ok(Response::new(proto::SendAuthInputResponse {
1292 progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1293 }))
1294 }
1295
1296 async fn get_auth_progress(
1297 &self,
1298 request: Request<proto::GetAuthProgressRequest>,
1299 ) -> Result<Response<proto::GetAuthProgressResponse>, Status> {
1300 self.auth_flow_manager.cleanup().await;
1301 let req = request.into_inner();
1302 let flow_id = req.flow_id.clone();
1303
1304 let mut entry = self
1305 .auth_flow_manager
1306 .take(&flow_id)
1307 .await
1308 .ok_or_else(|| Status::not_found("Auth flow not found"))?;
1309
1310 let progress = entry
1311 .flow
1312 .handle_input(&mut entry.state, "")
1313 .await
1314 .map_err(|e| Status::internal(format!("auth progress failed: {e}")))?;
1315
1316 let done = matches!(
1317 progress,
1318 steer_core::auth::AuthProgress::Complete | steer_core::auth::AuthProgress::Error(_)
1319 );
1320
1321 if !done {
1322 entry.last_updated = Instant::now();
1323 self.auth_flow_manager.insert(flow_id, entry).await;
1324 }
1325
1326 Ok(Response::new(proto::GetAuthProgressResponse {
1327 progress: Some(crate::grpc::conversions::auth_progress_to_proto(progress)),
1328 }))
1329 }
1330
1331 async fn cancel_auth(
1332 &self,
1333 request: Request<proto::CancelAuthRequest>,
1334 ) -> Result<Response<proto::CancelAuthResponse>, Status> {
1335 self.auth_flow_manager.cleanup().await;
1336 let req = request.into_inner();
1337 let flow_id = req.flow_id;
1338
1339 let _ = self.auth_flow_manager.take(&flow_id).await;
1340
1341 Ok(Response::new(proto::CancelAuthResponse {}))
1342 }
1343
1344 async fn resolve_model(
1345 &self,
1346 request: Request<proto::ResolveModelRequest>,
1347 ) -> Result<Response<proto::ResolveModelResponse>, Status> {
1348 let req = request.into_inner();
1349
1350 match self.model_registry.resolve(&req.input) {
1351 Ok(model_id) => {
1352 let steer_core::config::model::ModelId { provider, id } = model_id;
1353 let model_spec = proto::ModelSpec {
1354 provider_id: provider.storage_key(),
1355 model_id: id,
1356 };
1357 Ok(Response::new(proto::ResolveModelResponse {
1358 model: Some(model_spec),
1359 }))
1360 }
1361 Err(e) => Err(Status::not_found(format!(
1362 "Failed to resolve model '{}': {}",
1363 req.input, e
1364 ))),
1365 }
1366 }
1367
1368 async fn get_default_model(
1369 &self,
1370 _request: Request<proto::GetDefaultModelRequest>,
1371 ) -> Result<Response<proto::GetDefaultModelResponse>, Status> {
1372 let model = self.select_default_model();
1373 Ok(Response::new(proto::GetDefaultModelResponse {
1374 model: Some(model_to_proto(model)),
1375 }))
1376 }
1377
1378 async fn create_workspace(
1379 &self,
1380 request: Request<proto::CreateWorkspaceRequest>,
1381 ) -> Result<Response<proto::CreateWorkspaceResponse>, Status> {
1382 let req = request.into_inner();
1383 let repo_id = Self::parse_repo_id(&req.repo_id)?;
1384 let parent_workspace_id = match req.parent_workspace_id {
1385 Some(value) => Some(Self::parse_workspace_id(&value)?),
1386 None => None,
1387 };
1388
1389 let strategy = match proto::WorkspaceCreateStrategy::try_from(req.strategy) {
1390 Ok(proto::WorkspaceCreateStrategy::JjWorkspace) => {
1391 steer_workspace::WorkspaceCreateStrategy::JjWorkspace
1392 }
1393 Ok(proto::WorkspaceCreateStrategy::GitWorktree) => {
1394 steer_workspace::WorkspaceCreateStrategy::GitWorktree
1395 }
1396 _ => {
1397 return Err(Status::invalid_argument(
1398 "Unsupported workspace create strategy",
1399 ));
1400 }
1401 };
1402
1403 let request = steer_workspace::CreateWorkspaceRequest {
1404 repo_id,
1405 name: req.name,
1406 parent_workspace_id,
1407 strategy,
1408 };
1409
1410 let workspace = self
1411 .workspace_manager
1412 .create_workspace(request)
1413 .await
1414 .map_err(Self::workspace_manager_error_to_status)?;
1415
1416 Ok(Response::new(proto::CreateWorkspaceResponse {
1417 workspace: Some(workspace_info_to_proto(&workspace)),
1418 }))
1419 }
1420
1421 async fn resolve_repo(
1422 &self,
1423 request: Request<proto::ResolveRepoRequest>,
1424 ) -> Result<Response<proto::ResolveRepoResponse>, Status> {
1425 let req = request.into_inner();
1426 let environment_id = Self::parse_environment_id(&req.environment_id)?;
1427 let repo = self
1428 .repo_manager
1429 .resolve_repo(environment_id, std::path::Path::new(&req.path))
1430 .await
1431 .map_err(Self::workspace_manager_error_to_status)?;
1432
1433 Ok(Response::new(proto::ResolveRepoResponse {
1434 repo: Some(repo_info_to_proto(&repo)),
1435 }))
1436 }
1437
1438 async fn list_repos(
1439 &self,
1440 request: Request<proto::ListReposRequest>,
1441 ) -> Result<Response<proto::ListReposResponse>, Status> {
1442 let req = request.into_inner();
1443 let environment_id = Self::parse_environment_id(&req.environment_id)?;
1444 let repos = self
1445 .repo_manager
1446 .list_repos(environment_id)
1447 .await
1448 .map_err(Self::workspace_manager_error_to_status)?;
1449
1450 Ok(Response::new(proto::ListReposResponse {
1451 repos: repos.iter().map(repo_info_to_proto).collect(),
1452 }))
1453 }
1454
1455 async fn list_workspaces(
1456 &self,
1457 request: Request<proto::ListWorkspacesRequest>,
1458 ) -> Result<Response<proto::ListWorkspacesResponse>, Status> {
1459 let req = request.into_inner();
1460 let environment_id = Self::parse_environment_id(&req.environment_id)?;
1461
1462 let workspaces = self
1463 .workspace_manager
1464 .list_workspaces(steer_workspace::ListWorkspacesRequest { environment_id })
1465 .await
1466 .map_err(Self::workspace_manager_error_to_status)?;
1467
1468 Ok(Response::new(proto::ListWorkspacesResponse {
1469 workspaces: workspaces.iter().map(workspace_info_to_proto).collect(),
1470 }))
1471 }
1472
1473 async fn get_workspace_status(
1474 &self,
1475 request: Request<proto::GetWorkspaceStatusRequest>,
1476 ) -> Result<Response<proto::GetWorkspaceStatusResponse>, Status> {
1477 let req = request.into_inner();
1478 let workspace_id = Self::parse_workspace_id(&req.workspace_id)?;
1479
1480 let status = self
1481 .workspace_manager
1482 .get_workspace_status(workspace_id)
1483 .await
1484 .map_err(Self::workspace_manager_error_to_status)?;
1485
1486 Ok(Response::new(proto::GetWorkspaceStatusResponse {
1487 status: Some(workspace_status_to_proto(&status)),
1488 }))
1489 }
1490
1491 async fn delete_workspace(
1492 &self,
1493 request: Request<proto::DeleteWorkspaceRequest>,
1494 ) -> Result<Response<proto::DeleteWorkspaceResponse>, Status> {
1495 let req = request.into_inner();
1496 let workspace_id = Self::parse_workspace_id(&req.workspace_id)?;
1497
1498 self.workspace_manager
1499 .delete_workspace(steer_workspace::DeleteWorkspaceRequest { workspace_id })
1500 .await
1501 .map_err(Self::workspace_manager_error_to_status)?;
1502
1503 Ok(Response::new(proto::DeleteWorkspaceResponse {}))
1504 }
1505
1506 async fn create_environment(
1507 &self,
1508 request: Request<proto::CreateEnvironmentRequest>,
1509 ) -> Result<Response<proto::CreateEnvironmentResponse>, Status> {
1510 let req = request.into_inner();
1511 let request = steer_workspace::CreateEnvironmentRequest {
1512 root: req.root_path.map(std::path::PathBuf::from),
1513 name: req.name,
1514 };
1515
1516 let env = self
1517 .environment_manager
1518 .create_environment(request)
1519 .await
1520 .map_err(Self::environment_manager_error_to_status)?;
1521
1522 Ok(Response::new(proto::CreateEnvironmentResponse {
1523 environment: Some(environment_descriptor_to_proto(&env)),
1524 }))
1525 }
1526
1527 async fn get_environment(
1528 &self,
1529 request: Request<proto::GetEnvironmentRequest>,
1530 ) -> Result<Response<proto::GetEnvironmentResponse>, Status> {
1531 let req = request.into_inner();
1532 let environment_id = Self::parse_environment_id(&req.environment_id)?;
1533
1534 let env = self
1535 .environment_manager
1536 .get_environment(environment_id)
1537 .await
1538 .map_err(Self::environment_manager_error_to_status)?;
1539
1540 Ok(Response::new(proto::GetEnvironmentResponse {
1541 environment: Some(environment_descriptor_to_proto(&env)),
1542 }))
1543 }
1544
1545 async fn delete_environment(
1546 &self,
1547 request: Request<proto::DeleteEnvironmentRequest>,
1548 ) -> Result<Response<proto::DeleteEnvironmentResponse>, Status> {
1549 let req = request.into_inner();
1550 let environment_id = Self::parse_environment_id(&req.environment_id)?;
1551 let policy = match proto::EnvironmentDeletePolicy::try_from(req.policy) {
1552 Ok(proto::EnvironmentDeletePolicy::Soft) => {
1553 steer_workspace::EnvironmentDeletePolicy::Soft
1554 }
1555 Ok(proto::EnvironmentDeletePolicy::Hard) => {
1556 steer_workspace::EnvironmentDeletePolicy::Hard
1557 }
1558 _ => steer_workspace::EnvironmentDeletePolicy::Hard,
1559 };
1560
1561 self.environment_manager
1562 .delete_environment(environment_id, policy)
1563 .await
1564 .map_err(Self::environment_manager_error_to_status)?;
1565
1566 Ok(Response::new(proto::DeleteEnvironmentResponse {}))
1567 }
1568}