1use crate::grpc::conversions::message_to_proto;
2use crate::grpc::session_manager_ext::SessionManagerExt;
3use std::sync::Arc;
4use steer_core::session::manager::SessionManager;
5use steer_proto::agent::v1::{self as proto, *};
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tonic::{Request, Response, Status, Streaming};
9use tracing::{debug, error, info, warn};
10
11pub struct AgentServiceImpl {
12 session_manager: Arc<SessionManager>,
13 llm_config_provider: steer_core::config::LlmConfigProvider,
14 model_registry: Arc<steer_core::model_registry::ModelRegistry>,
15 provider_registry: Arc<steer_core::auth::ProviderRegistry>,
16}
17
18impl AgentServiceImpl {
19 pub fn new(
20 session_manager: Arc<SessionManager>,
21 llm_config_provider: steer_core::config::LlmConfigProvider,
22 model_registry: Arc<steer_core::model_registry::ModelRegistry>,
23 provider_registry: Arc<steer_core::auth::ProviderRegistry>,
24 ) -> Self {
25 Self {
26 session_manager,
27 llm_config_provider,
28 model_registry,
29 provider_registry,
30 }
31 }
32}
33
34#[tonic::async_trait]
35impl agent_service_server::AgentService for AgentServiceImpl {
36 type StreamSessionStream = ReceiverStream<Result<StreamSessionResponse, Status>>;
37 type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
38 type GetSessionStream =
39 std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
40 type GetConversationStream = std::pin::Pin<
41 Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
42 >;
43 type ActivateSessionStream = std::pin::Pin<
44 Box<dyn futures::Stream<Item = Result<ActivateSessionResponse, Status>> + Send>,
45 >;
46
47 async fn stream_session(
48 &self,
49 request: Request<Streaming<StreamSessionRequest>>,
50 ) -> Result<Response<Self::StreamSessionStream>, Status> {
51 let mut client_stream = request.into_inner();
52 let (tx, rx) = mpsc::channel(100);
53
54 let session_manager = self.session_manager.clone();
56 let llm_config_provider = self.llm_config_provider.clone();
57 let model_registry = self.model_registry.clone();
58 let provider_registry = self.provider_registry.clone();
59
60 let _stream_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
61 let (session_id, mut event_rx) = if let Some(client_message_result) =
63 client_stream.message().await.transpose()
64 {
65 match client_message_result {
66 Ok(client_message) => {
67 let session_id = client_message.session_id.clone();
68
69 let receiver = match session_manager
71 .take_event_receiver(&client_message.session_id)
72 .await
73 {
74 Ok(receiver) => {
75 debug!("Session {} is already active, TUI should call GetConversation to retrieve history", session_id);
77 receiver
78 },
79 Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionNotActive { session_id })) => {
80 info!("Session {} not active, attempting to resume", session_id);
81
82 match try_resume_session(&session_manager, &session_id, &llm_config_provider, &model_registry, &provider_registry).await {
84 Ok(()) => {
85 match session_manager.take_event_receiver(&session_id).await {
87 Ok(receiver) => receiver,
88 Err(e) => {
89 error!("Failed to get event receiver after resuming session {}: {}", session_id, e);
90 let _ = tx
91 .send(Err(Status::internal(format!(
92 "Failed to establish stream after resuming session: {e}"
93 ))))
94 .await;
95 return;
96 }
97 }
98 }
99 Err(e) => {
100 error!("Failed to resume session {}: {}", session_id, e);
101 let _ = tx
102 .send(Err(e))
103 .await;
104 return;
105 }
106 }
107 }
108 Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionAlreadyHasListener { session_id })) => {
109 error!("Session already has an active stream: {}", session_id);
110 let _ = tx
111 .send(Err(Status::already_exists(format!(
112 "Session {session_id} already has an active stream"
113 ))))
114 .await;
115 return;
116 }
117 Err(e) => {
118 error!("Error taking event receiver: {}", e);
119 let _ = tx
120 .send(Err(Status::internal(format!(
121 "Error establishing stream: {e}"
122 ))))
123 .await;
124 return;
125 }
126 };
127
128 if let Err(e) =
130 handle_client_message(&session_manager, client_message).await
131 {
132 error!("Error handling first client message: {}", e);
133 let _ = tx
134 .send(Err(Status::internal(format!(
135 "Error processing message: {e}"
136 ))))
137 .await;
138 return;
139 }
140
141 (session_id, receiver)
142 }
143 Err(e) => {
144 error!("Error receiving first client message: {}", e);
145 let _ = tx.send(Err(Status::internal("Stream error"))).await;
146 return;
147 }
148 }
149 } else {
150 error!("No initial client message received");
151 let _ = tx.send(Err(Status::internal("No initial message"))).await;
152 return;
153 };
154
155 let mut event_sequence = 0u64;
156
157 if let Err(e) = session_manager
159 .increment_subscriber_count(&session_id)
160 .await
161 {
162 warn!(
163 "Failed to increment subscriber count for session {}: {}",
164 session_id, e
165 );
166 }
167
168 let tx_clone = tx.clone();
170 let session_id_clone = session_id.clone();
171 let event_task = tokio::spawn(async move {
172 while let Some(app_event) = event_rx.recv().await {
173 event_sequence += 1;
174 let server_event = match crate::grpc::conversions::app_event_to_server_event(
175 app_event,
176 event_sequence,
177 ) {
178 Ok(event) => event,
179 Err(e) => {
180 warn!("Failed to convert app event to server event: {}", e);
181 continue;
182 }
183 };
184
185 if let Err(e) = tx_clone.send(Ok(server_event)).await {
186 warn!("Failed to send event to client: {}", e);
187 break;
188 }
189 }
190 debug!(
191 "Event forwarding task ended for session: {}",
192 session_id_clone
193 );
194 });
195
196 while let Some(client_message_result) = client_stream.message().await.transpose() {
198 match client_message_result {
199 Ok(client_message) => {
200 if let Err(e) = session_manager.touch_session(&session_id).await {
202 warn!("Failed to touch session {}: {}", session_id, e);
203 }
204
205 if let Err(e) =
206 handle_client_message(&session_manager, client_message).await
207 {
208 error!("Error handling client message: {}", e);
209 let _ = tx
210 .send(Err(Status::internal(format!(
211 "Error processing message: {e}"
212 ))))
213 .await;
214 break;
215 }
216 }
217 Err(e) => {
218 error!("Error receiving client message: {}", e);
219 let _ = tx.send(Err(Status::internal("Stream error"))).await;
220 break;
221 }
222 }
223 }
224
225 event_task.abort();
227
228 if let Err(e) = session_manager
230 .decrement_subscriber_count(&session_id)
231 .await
232 {
233 warn!(
234 "Failed to decrement subscriber count for session {}: {}",
235 session_id, e
236 );
237 }
238
239 info!("Client stream ended for session: {}", session_id);
240
241 if let Err(e) = session_manager
243 .maybe_suspend_idle_session(&session_id)
244 .await
245 {
246 warn!("Failed to check/suspend idle session {}: {}", session_id, e);
247 }
248 });
249
250 Ok(Response::new(ReceiverStream::new(rx)))
251 }
252
253 async fn create_session(
254 &self,
255 request: Request<CreateSessionRequest>,
256 ) -> Result<Response<CreateSessionResponse>, Status> {
257 let req = request.into_inner();
258
259 let app_config = steer_core::app::AppConfig {
260 llm_config_provider: self.llm_config_provider.clone(),
261 model_registry: self.model_registry.clone(),
262 provider_registry: self.provider_registry.clone(),
263 };
264
265 match self
266 .session_manager
267 .create_session_grpc(req, app_config)
268 .await
269 {
270 Ok((_session_id, session_info)) => Ok(Response::new(CreateSessionResponse {
271 session: Some(session_info),
272 })),
273 Err(e) => {
274 error!("Failed to create session: {}", e);
275 Err(e.into())
276 }
277 }
278 }
279
280 async fn list_sessions(
281 &self,
282 request: Request<ListSessionsRequest>,
283 ) -> Result<Response<ListSessionsResponse>, Status> {
284 let _req = request.into_inner();
285
286 let filter = steer_core::session::SessionFilter::default();
288
289 match self.session_manager.list_sessions(filter).await {
290 Ok(sessions) => {
291 let proto_sessions = sessions
292 .into_iter()
293 .map(|session| SessionInfo {
294 id: session.id,
295 created_at: Some(prost_types::Timestamp::from(
296 std::time::SystemTime::from(session.created_at),
297 )),
298 updated_at: Some(prost_types::Timestamp::from(
299 std::time::SystemTime::from(session.updated_at),
300 )),
301 status: proto::SessionStatus::Active as i32,
302 metadata: Some(proto::SessionMetadata {
303 labels: session.metadata,
304 annotations: std::collections::HashMap::new(),
305 }),
306 })
307 .collect();
308
309 Ok(Response::new(ListSessionsResponse {
310 sessions: proto_sessions,
311 next_page_token: None,
312 }))
313 }
314 Err(e) => {
315 error!("Failed to list sessions: {}", e);
316 Err(Status::internal(format!("Failed to list sessions: {e}")))
317 }
318 }
319 }
320
321 async fn get_session(
322 &self,
323 request: Request<GetSessionRequest>,
324 ) -> Result<Response<Self::GetSessionStream>, Status> {
325 let req = request.into_inner();
326 let session_manager = self.session_manager.clone();
327
328 let stream = async_stream::try_stream! {
329 match session_manager.get_session_proto(&req.session_id).await {
330 Ok(Some(session_state)) => {
331 yield GetSessionResponse {
333 chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
334 id: session_state.id,
335 created_at: session_state.created_at,
336 updated_at: session_state.updated_at,
337 config: session_state.config,
338 })),
339 };
340
341 for message in session_state.messages {
343 yield GetSessionResponse {
344 chunk: Some(get_session_response::Chunk::Message(message)),
345 };
346 }
347
348 for (key, value) in session_state.tool_calls {
350 yield GetSessionResponse {
351 chunk: Some(get_session_response::Chunk::ToolCall(ToolCallStateEntry {
352 key,
353 value: Some(value),
354 })),
355 };
356 }
357
358 yield GetSessionResponse {
360 chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
361 approved_tools: session_state.approved_tools,
362 last_event_sequence: session_state.last_event_sequence,
363 metadata: session_state.metadata,
364 })),
365 };
366 }
367 Ok(None) => {
368 Err(Status::not_found(format!(
369 "Session not found: {}",
370 req.session_id
371 )))?;
372 }
373 Err(e) => {
374 error!("Failed to get session: {}", e);
375 Err(Status::internal(format!("Failed to get session: {e}")))?;
376 }
377 }
378 };
379
380 Ok(Response::new(Box::pin(stream)))
381 }
382
383 async fn delete_session(
384 &self,
385 request: Request<DeleteSessionRequest>,
386 ) -> Result<Response<DeleteSessionResponse>, Status> {
387 let req = request.into_inner();
388
389 match self.session_manager.delete_session(&req.session_id).await {
390 Ok(true) => Ok(Response::new(DeleteSessionResponse {})),
391 Ok(false) => Err(Status::not_found(format!(
392 "Session not found: {}",
393 req.session_id
394 ))),
395 Err(e) => {
396 error!("Failed to delete session: {}", e);
397 Err(Status::internal(format!("Failed to delete session: {e}")))
398 }
399 }
400 }
401
402 async fn get_conversation(
403 &self,
404 request: Request<GetConversationRequest>,
405 ) -> Result<Response<Self::GetConversationStream>, Status> {
406 let req = request.into_inner();
407 let session_manager = self.session_manager.clone();
408
409 info!("GetConversation called for session: {}", req.session_id);
410
411 let stream = async_stream::try_stream! {
412 match session_manager.get_session_state(&req.session_id).await {
413 Ok(Some(session_state)) => {
414 info!(
415 "Found session state with {} messages and {} approved tools",
416 session_state.messages.len(),
417 session_state.approved_tools.len()
418 );
419
420 for msg in session_state.messages {
422 let proto_msg = message_to_proto(msg.clone())
423 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
424 yield GetConversationResponse {
425 chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
426 };
427 }
428
429 yield GetConversationResponse {
431 chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
432 approved_tools: session_state.approved_tools.into_iter().collect(),
433 })),
434 };
435 }
436 Ok(None) => {
437 Err(Status::not_found(format!(
438 "Session not found: {}",
439 req.session_id
440 )))?;
441 }
442 Err(e) => {
443 error!("Failed to get session state: {}", e);
444 Err(Status::internal(format!("Failed to get session state: {e}")))?;
445 }
446 }
447 };
448
449 Ok(Response::new(Box::pin(stream)))
450 }
451
452 async fn send_message(
453 &self,
454 request: Request<SendMessageRequest>,
455 ) -> Result<Response<SendMessageResponse>, Status> {
456 let req = request.into_inner();
457
458 let app_command = steer_core::app::AppCommand::ProcessUserInput(req.message);
459
460 match self
461 .session_manager
462 .send_command(&req.session_id, app_command)
463 .await
464 {
465 Ok(()) => {
466 let operation_id = format!("op_{}", uuid::Uuid::new_v4());
468 Ok(Response::new(SendMessageResponse {
469 operation: Some(Operation {
470 id: operation_id,
471 session_id: req.session_id,
472 r#type: OperationType::SendMessage as i32,
473 status: OperationStatus::Running as i32,
474 created_at: Some(
475 prost_types::Timestamp::from(std::time::SystemTime::now()),
476 ),
477 completed_at: None,
478 metadata: std::collections::HashMap::new(),
479 }),
480 }))
481 }
482 Err(e) => {
483 error!("Failed to send message: {}", e);
484 Err(Status::internal(format!("Failed to send message: {e}")))
485 }
486 }
487 }
488
489 async fn approve_tool(
490 &self,
491 request: Request<ApproveToolRequest>,
492 ) -> Result<Response<ApproveToolResponse>, Status> {
493 let req = request.into_inner();
494
495 let approval = match req.decision {
496 Some(decision) => match decision {
497 proto::ApprovalDecision {
498 decision_type: Some(proto::approval_decision::DecisionType::Deny(true)),
499 } => steer_core::app::command::ApprovalType::Denied,
500 proto::ApprovalDecision {
501 decision_type: Some(proto::approval_decision::DecisionType::Once(true)),
502 } => steer_core::app::command::ApprovalType::Once,
503 proto::ApprovalDecision {
504 decision_type: Some(proto::approval_decision::DecisionType::AlwaysTool(true)),
505 } => steer_core::app::command::ApprovalType::AlwaysTool,
506 proto::ApprovalDecision {
507 decision_type:
508 Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)),
509 } => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
510 _ => {
511 return Err(Status::invalid_argument(
512 "Invalid approval decision enum value",
513 ));
514 }
515 },
516 None => {
517 return Err(Status::invalid_argument("Missing approval decision"));
518 }
519 };
520
521 let app_command = steer_core::app::AppCommand::HandleToolResponse {
522 id: req.tool_call_id,
523 approval,
524 };
525
526 match self
527 .session_manager
528 .send_command(&req.session_id, app_command)
529 .await
530 {
531 Ok(()) => Ok(Response::new(ApproveToolResponse {})),
532 Err(e) => {
533 error!("Failed to approve tool: {}", e);
534 Err(Status::internal(format!("Failed to approve tool: {e}")))
535 }
536 }
537 }
538
539 async fn activate_session(
540 &self,
541 request: Request<ActivateSessionRequest>,
542 ) -> Result<Response<Self::ActivateSessionStream>, Status> {
543 let req = request.into_inner();
544 let session_manager = self.session_manager.clone();
545 let llm_config_provider = self.llm_config_provider.clone();
546 let model_registry = self.model_registry.clone();
547 let provider_registry = self.provider_registry.clone();
548
549 info!("ActivateSession called for {}", req.session_id);
550
551 let stream = async_stream::try_stream! {
552 let state = if let Ok(Some(state)) = session_manager
554 .get_session_state(&req.session_id)
555 .await
556 {
557 state
558 } else {
559 let app_config = steer_core::app::AppConfig {
561 llm_config_provider: llm_config_provider.clone(),
562 model_registry: model_registry.clone(),
563 provider_registry: provider_registry.clone(),
564 };
565
566 session_manager
567 .resume_session(&req.session_id, app_config)
568 .await
569 .map_err(|e| Status::internal(format!("Failed to resume session: {e}")))?;
570
571 session_manager
573 .get_session_state(&req.session_id)
574 .await
575 .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?
576 .ok_or_else(|| Status::not_found(format!("Session not found: {}", req.session_id)))?
577 };
578
579 for msg in state.messages {
581 let proto_msg = message_to_proto(msg)
582 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
583 yield ActivateSessionResponse {
584 chunk: Some(activate_session_response::Chunk::Message(proto_msg)),
585 };
586 }
587
588 yield ActivateSessionResponse {
590 chunk: Some(activate_session_response::Chunk::Footer(ActivateSessionFooter {
591 approved_tools: state.approved_tools.into_iter().collect(),
592 })),
593 };
594 };
595
596 Ok(Response::new(Box::pin(stream)))
597 }
598
599 async fn cancel_operation(
600 &self,
601 request: Request<CancelOperationRequest>,
602 ) -> Result<Response<CancelOperationResponse>, Status> {
603 let req = request.into_inner();
604
605 let app_command = steer_core::app::AppCommand::CancelProcessing;
606
607 match self
608 .session_manager
609 .send_command(&req.session_id, app_command)
610 .await
611 {
612 Ok(()) => Ok(Response::new(CancelOperationResponse {})),
613 Err(e) => {
614 error!("Failed to cancel operation: {}", e);
615 Err(Status::internal(format!("Failed to cancel operation: {e}")))
616 }
617 }
618 }
619
620 async fn list_files(
621 &self,
622 request: Request<ListFilesRequest>,
623 ) -> Result<Response<Self::ListFilesStream>, Status> {
624 let req = request.into_inner();
625
626 debug!("ListFiles called for session: {}", req.session_id);
627
628 let workspace = match self
630 .session_manager
631 .get_session_workspace(&req.session_id)
632 .await
633 {
634 Ok(Some(workspace)) => workspace,
635 Ok(None) => {
636 return Err(Status::not_found(format!(
637 "Session not found: {}",
638 req.session_id
639 )));
640 }
641 Err(e) => {
642 error!("Failed to get session workspace: {}", e);
643 return Err(Status::internal(format!(
644 "Failed to get session workspace: {e}"
645 )));
646 }
647 };
648
649 let (tx, rx) = mpsc::channel(100);
651
652 let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
654 let query = if req.query.is_empty() {
656 None
657 } else {
658 Some(req.query.as_str())
659 };
660
661 let max_results = if req.max_results == 0 {
662 None
663 } else {
664 Some(req.max_results as usize)
665 };
666
667 match workspace.list_files(query, max_results).await {
668 Ok(files) => {
669 for chunk in files.chunks(1000) {
671 let response = ListFilesResponse {
672 paths: chunk.to_vec(),
673 };
674
675 if let Err(e) = tx.send(Ok(response)).await {
676 warn!("Failed to send file list chunk: {}", e);
677 break;
678 }
679 }
680 }
681 Err(e) => {
682 error!("Failed to list files: {}", e);
683 let _ = tx
684 .send(Err(Status::internal(format!("Failed to list files: {e}"))))
685 .await;
686 }
687 }
688 });
689
690 Ok(Response::new(ReceiverStream::new(rx)))
691 }
692
693 async fn get_mcp_servers(
694 &self,
695 request: Request<GetMcpServersRequest>,
696 ) -> Result<Response<GetMcpServersResponse>, Status> {
697 let req = request.into_inner();
698
699 debug!("GetMcpServers called for session: {}", req.session_id);
700
701 match self.session_manager.get_mcp_statuses(&req.session_id).await {
703 Ok(servers) => {
704 use crate::grpc::conversions::mcp_server_info_to_proto;
705
706 let proto_servers = servers.into_iter().map(mcp_server_info_to_proto).collect();
707
708 Ok(Response::new(GetMcpServersResponse {
709 servers: proto_servers,
710 }))
711 }
712 Err(e) => {
713 error!("Failed to get MCP server statuses: {}", e);
714 Err(Status::internal(format!(
715 "Failed to get MCP server statuses: {e}"
716 )))
717 }
718 }
719 }
720 async fn list_providers(
721 &self,
722 _request: Request<ListProvidersRequest>,
723 ) -> Result<Response<ListProvidersResponse>, Status> {
724 let providers = self
726 .provider_registry
727 .all()
728 .map(|p| proto::ProviderInfo {
729 id: p.id.storage_key(),
730 name: p.name.clone(),
731 auth_schemes: p
732 .auth_schemes
733 .iter()
734 .map(|s| match s {
735 steer_core::config::toml_types::AuthScheme::ApiKey => {
736 proto::ProviderAuthScheme::AuthSchemeApiKey as i32
737 }
738 steer_core::config::toml_types::AuthScheme::Oauth2 => {
739 proto::ProviderAuthScheme::AuthSchemeOauth2 as i32
740 }
741 })
742 .collect(),
743 })
744 .collect();
745
746 Ok(Response::new(ListProvidersResponse { providers }))
747 }
748
749 async fn list_models(
750 &self,
751 request: Request<ListModelsRequest>,
752 ) -> Result<Response<ListModelsResponse>, Status> {
753 let req = request.into_inner();
754
755 let model_registry = &self.model_registry;
757
758 let all_models: Vec<proto::ProviderModel> = model_registry
760 .recommended() .filter(|m| {
762 if let Some(ref provider_id) = req.provider_id {
763 m.provider.storage_key() == *provider_id
764 } else {
765 true
766 }
767 })
768 .map(|m| proto::ProviderModel {
769 provider_id: m.provider.storage_key(),
770 model_id: m.id.clone(),
771 display_name: m.display_name.clone().unwrap_or_else(|| m.id.clone()),
772 supports_thinking: m
773 .parameters
774 .as_ref()
775 .and_then(|p| p.thinking_config.as_ref())
776 .map(|tc| tc.enabled)
777 .unwrap_or(false),
778 aliases: m.aliases.clone(),
779 })
780 .collect();
781
782 Ok(Response::new(ListModelsResponse { models: all_models }))
783 }
784
785 async fn get_provider_auth_status(
786 &self,
787 request: Request<proto::GetProviderAuthStatusRequest>,
788 ) -> Result<Response<proto::GetProviderAuthStatusResponse>, Status> {
789 let req = request.into_inner();
790
791 let mut statuses = Vec::new();
792 for p in self.provider_registry.all() {
793 if let Some(ref filter) = req.provider_id {
794 if &p.id.storage_key() != filter {
795 continue;
796 }
797 }
798 let status = match self
799 .llm_config_provider
800 .get_auth_for_provider(&p.id)
801 .await
802 .map_err(|e| Status::internal(format!("auth lookup failed: {e}")))?
803 {
804 Some(steer_core::config::ApiAuth::OAuth) => {
805 proto::provider_auth_status::Status::AuthStatusOauth as i32
806 }
807 Some(steer_core::config::ApiAuth::Key(_)) => {
808 proto::provider_auth_status::Status::AuthStatusApiKey as i32
809 }
810 None => proto::provider_auth_status::Status::AuthStatusNone as i32,
811 };
812 statuses.push(proto::ProviderAuthStatus {
813 provider_id: p.id.storage_key(),
814 status,
815 });
816 }
817
818 Ok(Response::new(proto::GetProviderAuthStatusResponse {
819 statuses,
820 }))
821 }
822
823 async fn resolve_model(
824 &self,
825 request: Request<proto::ResolveModelRequest>,
826 ) -> Result<Response<proto::ResolveModelResponse>, Status> {
827 let req = request.into_inner();
828
829 match self.model_registry.resolve(&req.input) {
831 Ok(model_id) => {
832 let model_spec = proto::ModelSpec {
833 provider_id: model_id.0.storage_key(),
834 model_id: model_id.1,
835 };
836 Ok(Response::new(proto::ResolveModelResponse {
837 model: Some(model_spec),
838 }))
839 }
840 Err(e) => Err(Status::not_found(format!(
841 "Failed to resolve model '{}': {}",
842 req.input, e
843 ))),
844 }
845 }
846}
847
848async fn try_resume_session(
849 session_manager: &SessionManager,
850 session_id: &str,
851 llm_config_provider: &steer_core::config::LlmConfigProvider,
852 model_registry: &Arc<steer_core::model_registry::ModelRegistry>,
853 provider_registry: &Arc<steer_core::auth::ProviderRegistry>,
854) -> Result<(), Status> {
855 let app_config = steer_core::app::AppConfig {
856 llm_config_provider: llm_config_provider.clone(),
857 model_registry: model_registry.clone(),
858 provider_registry: provider_registry.clone(),
859 };
860
861 match session_manager.resume_session(session_id, app_config).await {
863 Ok(_command_tx) => {
864 info!("Successfully resumed session: {}", session_id);
865 Ok(())
867 }
868 Err(steer_core::error::Error::SessionManager(
869 steer_core::session::manager::SessionManagerError::CapacityExceeded { current, max },
870 )) => {
871 warn!(
872 "Cannot resume session {}: server at capacity ({}/{})",
873 session_id, current, max
874 );
875 Err(Status::resource_exhausted(format!(
876 "Server at maximum capacity ({current}/{max}). Cannot resume session."
877 )))
878 }
879 Err(e) => {
880 error!("Failed to resume session {}: {}", session_id, e);
881 Err(Status::internal(format!("Failed to resume session: {e}")))
882 }
883 }
884}
885
886async fn handle_client_message(
887 session_manager: &SessionManager,
888 client_message: StreamSessionRequest,
889) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
890 debug!(
891 "Handling client message for session: {}",
892 client_message.session_id
893 );
894
895 if let Some(message) = client_message.message {
896 match message {
897 stream_session_request::Message::SendMessage(send_msg) => {
898 let app_command = steer_core::app::AppCommand::ProcessUserInput(send_msg.message);
900
901 session_manager
902 .send_command(&client_message.session_id, app_command)
903 .await
904 .map_err(|e| format!("Failed to send message: {e}"))?;
905 }
906
907 stream_session_request::Message::ToolApproval(approval) => {
908 let approval_type = match approval.decision {
910 Some(decision) => match decision.decision_type {
911 Some(proto::approval_decision::DecisionType::Deny(_)) => {
912 steer_core::app::command::ApprovalType::Denied
913 }
914 Some(proto::approval_decision::DecisionType::Once(_)) => {
915 steer_core::app::command::ApprovalType::Once
916 }
917 Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => {
918 steer_core::app::command::ApprovalType::AlwaysTool
919 }
920 Some(proto::approval_decision::DecisionType::AlwaysBashPattern(
921 pattern,
922 )) => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
923 None => {
924 return Err(
925 "Invalid approval decision: no decision type specified".into()
926 );
927 }
928 },
929 None => {
930 return Err("Invalid approval decision: no decision provided".into());
931 }
932 };
933
934 let app_command = steer_core::app::AppCommand::HandleToolResponse {
935 id: approval.tool_call_id,
936 approval: approval_type,
937 };
938
939 session_manager
940 .send_command(&client_message.session_id, app_command)
941 .await
942 .map_err(|e| format!("Failed to approve tool: {e}"))?;
943 }
944
945 stream_session_request::Message::Cancel(_cancel) => {
946 let app_command = steer_core::app::AppCommand::CancelProcessing;
948
949 session_manager
950 .send_command(&client_message.session_id, app_command)
951 .await
952 .map_err(|e| format!("Failed to cancel operation: {e}"))?;
953 }
954
955 stream_session_request::Message::Subscribe(_subscribe_request) => {
956 debug!("Subscribe message received - stream already established");
957 }
959
960 stream_session_request::Message::UpdateConfig(_update_config) => {
961 debug!("UpdateConfig received but provider changes are no longer supported");
964 }
965
966 stream_session_request::Message::ExecuteCommand(execute_command) => {
967 use steer_core::app::conversation::AppCommandType;
968 let app_cmd_type = match AppCommandType::parse(&execute_command.command) {
969 Ok(cmd) => cmd,
970 Err(e) => {
971 return Err(format!("Failed to parse command: {e}").into());
972 }
973 };
974 let app_command = steer_core::app::AppCommand::ExecuteCommand(app_cmd_type);
975 session_manager
976 .send_command(&client_message.session_id, app_command)
977 .await
978 .map_err(|e| format!("Failed to execute command: {e}"))?;
979 }
980
981 stream_session_request::Message::ExecuteBashCommand(execute_bash_command) => {
982 let app_command = steer_core::app::AppCommand::ExecuteBashCommand {
983 command: execute_bash_command.command,
984 };
985 session_manager
986 .send_command(&client_message.session_id, app_command)
987 .await
988 .map_err(|e| format!("Failed to execute bash command: {e}"))?;
989 }
990
991 stream_session_request::Message::EditMessage(edit_message) => {
992 let app_command = steer_core::app::AppCommand::EditMessage {
993 message_id: edit_message.message_id,
994 new_content: edit_message.new_content,
995 };
996 session_manager
997 .send_command(&client_message.session_id, app_command)
998 .await
999 .map_err(|e| format!("Failed to edit message: {e}"))?;
1000 }
1001 }
1002 }
1003
1004 Ok(())
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010
1011 use std::collections::HashMap;
1012 use steer_core::session::state::WorkspaceConfig;
1013 use steer_core::session::stores::sqlite::SqliteSessionStore;
1014 use steer_core::session::{SessionConfig, SessionManagerConfig, SessionToolConfig};
1015 use steer_proto::agent::v1::agent_service_client::AgentServiceClient;
1016 use steer_proto::agent::v1::{SendMessageRequest, SubscribeRequest};
1017 use tempfile::TempDir;
1018 use tokio::sync::mpsc;
1019 use tokio_stream::StreamExt;
1020
1021 fn create_test_app_config() -> steer_core::app::AppConfig {
1022 steer_core::test_utils::test_app_config()
1023 }
1024
1025 async fn create_test_session_manager() -> (Arc<SessionManager>, TempDir) {
1026 let temp_dir = TempDir::new().unwrap();
1027 let db_path = temp_dir.path().join("test.db");
1028 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
1029
1030 let config = SessionManagerConfig {
1031 max_concurrent_sessions: 100,
1032 default_model: steer_core::config::model::builtin::claude_3_7_sonnet_20250219(),
1033 auto_persist: true,
1034 };
1035 let session_manager = Arc::new(SessionManager::new(store, config));
1036
1037 (session_manager, temp_dir)
1038 }
1039
1040 async fn create_test_server() -> (String, Arc<SessionManager>, TempDir) {
1041 let (session_manager, temp_dir) = create_test_session_manager().await;
1042
1043 let auth_storage = Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
1044 let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
1045 let model_registry =
1046 Arc::new(steer_core::model_registry::ModelRegistry::load(&[]).unwrap());
1047 let provider_registry = Arc::new(steer_core::auth::ProviderRegistry::load(&[]).unwrap());
1048 let service = AgentServiceImpl::new(
1049 session_manager.clone(),
1050 llm_config_provider,
1051 model_registry,
1052 provider_registry,
1053 );
1054
1055 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1057 let addr = listener.local_addr().unwrap();
1058
1059 let _server_task = tokio::spawn(async move {
1060 tonic::transport::Server::builder()
1061 .add_service(agent_service_server::AgentServiceServer::new(service))
1062 .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
1063 .await
1064 .unwrap();
1065 });
1066
1067 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1069
1070 let url = format!("http://{addr}");
1071 (url, session_manager, temp_dir)
1072 }
1073
1074 #[tokio::test]
1075 async fn test_session_cleanup_on_disconnect() {
1076 let (session_manager, _temp_dir) = create_test_session_manager().await;
1077
1078 let session_config = SessionConfig {
1080 workspace: WorkspaceConfig::Local {
1081 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1082 },
1083 tool_config: SessionToolConfig::default(),
1084 system_prompt: None,
1085 metadata: HashMap::new(),
1086 };
1087
1088 let app_config = create_test_app_config();
1089
1090 let (session_id, _command_tx) = session_manager
1091 .create_session(session_config, app_config)
1092 .await
1093 .unwrap();
1094
1095 assert!(session_manager.is_session_active(&session_id).await);
1097
1098 session_manager
1100 .increment_subscriber_count(&session_id)
1101 .await
1102 .unwrap();
1103
1104 assert!(session_manager.is_session_active(&session_id).await);
1106
1107 session_manager
1109 .decrement_subscriber_count(&session_id)
1110 .await
1111 .unwrap();
1112
1113 session_manager
1115 .maybe_suspend_idle_session(&session_id)
1116 .await
1117 .unwrap();
1118
1119 assert!(
1121 !session_manager.is_session_active(&session_id).await,
1122 "Session should be suspended after last client disconnects"
1123 );
1124
1125 let session_info = session_manager.get_session(&session_id).await.unwrap();
1127 assert!(
1128 session_info.is_some(),
1129 "Session should still exist in storage after suspension"
1130 );
1131 }
1132
1133 #[tokio::test]
1134 async fn test_session_with_multiple_subscribers() {
1135 let (session_manager, _temp_dir) = create_test_session_manager().await;
1136
1137 let session_config = SessionConfig {
1139 workspace: WorkspaceConfig::Local {
1140 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1141 },
1142 tool_config: SessionToolConfig::default(),
1143 system_prompt: None,
1144 metadata: HashMap::new(),
1145 };
1146
1147 let app_config = create_test_app_config();
1148
1149 let (session_id, _command_tx) = session_manager
1150 .create_session(session_config, app_config)
1151 .await
1152 .unwrap();
1153
1154 session_manager
1156 .increment_subscriber_count(&session_id)
1157 .await
1158 .unwrap();
1159 session_manager
1160 .increment_subscriber_count(&session_id)
1161 .await
1162 .unwrap();
1163
1164 session_manager
1166 .decrement_subscriber_count(&session_id)
1167 .await
1168 .unwrap();
1169 session_manager
1170 .maybe_suspend_idle_session(&session_id)
1171 .await
1172 .unwrap();
1173
1174 assert!(
1176 session_manager.is_session_active(&session_id).await,
1177 "Session should remain active with one subscriber"
1178 );
1179
1180 session_manager
1182 .decrement_subscriber_count(&session_id)
1183 .await
1184 .unwrap();
1185 session_manager
1186 .maybe_suspend_idle_session(&session_id)
1187 .await
1188 .unwrap();
1189
1190 assert!(
1192 !session_manager.is_session_active(&session_id).await,
1193 "Session should be suspended after all clients disconnect"
1194 );
1195 }
1196
1197 #[tokio::test]
1198 async fn test_grpc_client_connect_disconnect_cleanup() {
1199 let (server_url, session_manager, _temp_dir) = create_test_server().await;
1200
1201 let session_config = SessionConfig {
1203 workspace: WorkspaceConfig::Local {
1204 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1205 },
1206 tool_config: SessionToolConfig::default(),
1207 system_prompt: None,
1208 metadata: HashMap::new(),
1209 };
1210
1211 let app_config = create_test_app_config();
1212
1213 let (session_id, _command_tx) = session_manager
1214 .create_session(session_config, app_config)
1215 .await
1216 .unwrap();
1217
1218 assert!(session_manager.is_session_active(&session_id).await);
1220
1221 let mut client = AgentServiceClient::connect(server_url.clone())
1223 .await
1224 .unwrap();
1225
1226 let request_stream = tokio_stream::iter(vec![StreamSessionRequest {
1228 session_id: session_id.clone(),
1229 message: Some(stream_session_request::Message::Subscribe(
1230 SubscribeRequest {
1231 event_types: vec![],
1232 since_sequence: None,
1233 },
1234 )),
1235 }]);
1236
1237 let response = client.stream_session(request_stream).await.unwrap();
1238 let _stream = response.into_inner();
1239
1240 let (msg_tx, msg_rx) = mpsc::channel(10);
1242 msg_tx
1243 .send(StreamSessionRequest {
1244 session_id: session_id.clone(),
1245 message: Some(stream_session_request::Message::SendMessage(
1246 SendMessageRequest {
1247 session_id: session_id.clone(),
1248 message: "Hello, test!".to_string(),
1249 attachments: vec![],
1250 },
1251 )),
1252 })
1253 .await
1254 .unwrap();
1255
1256 let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1258 let response = client.stream_session(request_stream).await.unwrap();
1259 let mut stream = response.into_inner();
1260
1261 let timeout =
1263 tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()).await;
1264
1265 assert!(timeout.is_ok(), "Should receive at least one event");
1266
1267 drop(stream);
1269 drop(msg_tx);
1270
1271 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1273
1274 assert!(
1276 !session_manager.is_session_active(&session_id).await,
1277 "Session should be suspended after client disconnect"
1278 );
1279
1280 let session_info = session_manager.get_session(&session_id).await.unwrap();
1282 assert!(
1283 session_info.is_some(),
1284 "Session should still exist in storage"
1285 );
1286 }
1287
1288 #[tokio::test]
1289 async fn test_grpc_basic_session_resume() {
1290 let (server_url, session_manager, _temp_dir) = create_test_server().await;
1291
1292 let session_config = SessionConfig {
1294 workspace: WorkspaceConfig::Local {
1295 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1296 },
1297 tool_config: SessionToolConfig::default(),
1298 system_prompt: None,
1299 metadata: HashMap::new(),
1300 };
1301
1302 let app_config = create_test_app_config();
1303
1304 let (session_id, _command_tx) = session_manager
1305 .create_session(session_config, app_config)
1306 .await
1307 .unwrap();
1308
1309 session_manager.suspend_session(&session_id).await.unwrap();
1311 assert!(
1312 !session_manager.is_session_active(&session_id).await,
1313 "Session should be suspended"
1314 );
1315
1316 let mut client = AgentServiceClient::connect(server_url.clone())
1318 .await
1319 .unwrap();
1320
1321 let (msg_tx, msg_rx) = mpsc::channel(10);
1323
1324 msg_tx
1326 .send(StreamSessionRequest {
1327 session_id: session_id.clone(),
1328 message: Some(stream_session_request::Message::Subscribe(
1329 SubscribeRequest {
1330 event_types: vec![],
1331 since_sequence: None,
1332 },
1333 )),
1334 })
1335 .await
1336 .unwrap();
1337
1338 let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1339 let response = client.stream_session(request_stream).await;
1340
1341 assert!(
1343 response.is_ok(),
1344 "Should be able to connect to suspended session (auto-resume)"
1345 );
1346
1347 let stream = response.unwrap().into_inner();
1348
1349 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1351
1352 assert!(
1354 session_manager.is_session_active(&session_id).await,
1355 "Session should be active after auto-resume"
1356 );
1357
1358 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1360 assert!(
1361 session_manager.is_session_active(&session_id).await,
1362 "Session should remain active while client is connected"
1363 );
1364
1365 drop(stream);
1367 drop(msg_tx);
1368
1369 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1371
1372 assert!(
1374 !session_manager.is_session_active(&session_id).await,
1375 "Session should be suspended after client disconnects"
1376 );
1377 }
1378}