steer_grpc/grpc/
server.rs

1use crate::grpc::conversions::message_to_proto;
2use crate::grpc::session_manager_ext::SessionManagerExt;
3use std::sync::Arc;
4use steer_core::session::manager::SessionManager;
5use steer_proto::agent::v1::{self as proto, *};
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tonic::{Request, Response, Status, Streaming};
9use tracing::{debug, error, info, warn};
10
11pub struct AgentServiceImpl {
12    session_manager: Arc<SessionManager>,
13    llm_config_provider: steer_core::config::LlmConfigProvider,
14    model_registry: Arc<steer_core::model_registry::ModelRegistry>,
15    provider_registry: Arc<steer_core::auth::ProviderRegistry>,
16}
17
18impl AgentServiceImpl {
19    pub fn new(
20        session_manager: Arc<SessionManager>,
21        llm_config_provider: steer_core::config::LlmConfigProvider,
22        model_registry: Arc<steer_core::model_registry::ModelRegistry>,
23        provider_registry: Arc<steer_core::auth::ProviderRegistry>,
24    ) -> Self {
25        Self {
26            session_manager,
27            llm_config_provider,
28            model_registry,
29            provider_registry,
30        }
31    }
32}
33
34#[tonic::async_trait]
35impl agent_service_server::AgentService for AgentServiceImpl {
36    type StreamSessionStream = ReceiverStream<Result<StreamSessionResponse, Status>>;
37    type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
38    type GetSessionStream =
39        std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
40    type GetConversationStream = std::pin::Pin<
41        Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
42    >;
43    type ActivateSessionStream = std::pin::Pin<
44        Box<dyn futures::Stream<Item = Result<ActivateSessionResponse, Status>> + Send>,
45    >;
46
47    async fn stream_session(
48        &self,
49        request: Request<Streaming<StreamSessionRequest>>,
50    ) -> Result<Response<Self::StreamSessionStream>, Status> {
51        let mut client_stream = request.into_inner();
52        let (tx, rx) = mpsc::channel(100);
53
54        // Clone session manager, llm_config_provider, model_registry, and provider_registry for the stream handler task
55        let session_manager = self.session_manager.clone();
56        let llm_config_provider = self.llm_config_provider.clone();
57        let model_registry = self.model_registry.clone();
58        let provider_registry = self.provider_registry.clone();
59
60        let _stream_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
61            // Handle the first message to establish the session connection
62            let (session_id, mut event_rx) = if let Some(client_message_result) =
63                client_stream.message().await.transpose()
64            {
65                match client_message_result {
66                    Ok(client_message) => {
67                        let session_id = client_message.session_id.clone();
68
69                        // Try to take the event receiver for this session
70                        let receiver = match session_manager
71                            .take_event_receiver(&client_message.session_id)
72                            .await
73                        {
74                            Ok(receiver) => {
75                                // Session is already active - TUI will call GetConversation RPC to get history
76                                debug!("Session {} is already active, TUI should call GetConversation to retrieve history", session_id);
77                                receiver
78                            },
79                            Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionNotActive { session_id })) => {
80                                info!("Session {} not active, attempting to resume", session_id);
81
82                                // Try to resume the session
83                                match try_resume_session(&session_manager, &session_id, &llm_config_provider, &model_registry, &provider_registry).await {
84                                    Ok(()) => {
85                                        // Session resumed, try to take receiver again
86                                        match session_manager.take_event_receiver(&session_id).await {
87                                            Ok(receiver) => receiver,
88                                            Err(e) => {
89                                                error!("Failed to get event receiver after resuming session {}: {}", session_id, e);
90                                                let _ = tx
91                                                    .send(Err(Status::internal(format!(
92                                                        "Failed to establish stream after resuming session: {e}"
93                                                    ))))
94                                                    .await;
95                                                return;
96                                            }
97                                        }
98                                    }
99                                    Err(e) => {
100                                        error!("Failed to resume session {}: {}", session_id, e);
101                                        let _ = tx
102                                            .send(Err(e))
103                                            .await;
104                                        return;
105                                    }
106                                }
107                            }
108                            Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionAlreadyHasListener { session_id })) => {
109                                error!("Session already has an active stream: {}", session_id);
110                                let _ = tx
111                                    .send(Err(Status::already_exists(format!(
112                                        "Session {session_id} already has an active stream"
113                                    ))))
114                                    .await;
115                                return;
116                            }
117                            Err(e) => {
118                                error!("Error taking event receiver: {}", e);
119                                let _ = tx
120                                    .send(Err(Status::internal(format!(
121                                        "Error establishing stream: {e}"
122                                    ))))
123                                    .await;
124                                return;
125                            }
126                        };
127
128                        // Process the first message
129                        if let Err(e) =
130                            handle_client_message(&session_manager, client_message).await
131                        {
132                            error!("Error handling first client message: {}", e);
133                            let _ = tx
134                                .send(Err(Status::internal(format!(
135                                    "Error processing message: {e}"
136                                ))))
137                                .await;
138                            return;
139                        }
140
141                        (session_id, receiver)
142                    }
143                    Err(e) => {
144                        error!("Error receiving first client message: {}", e);
145                        let _ = tx.send(Err(Status::internal("Stream error"))).await;
146                        return;
147                    }
148                }
149            } else {
150                error!("No initial client message received");
151                let _ = tx.send(Err(Status::internal("No initial message"))).await;
152                return;
153            };
154
155            let mut event_sequence = 0u64;
156
157            // Mark session as having an active subscriber
158            if let Err(e) = session_manager
159                .increment_subscriber_count(&session_id)
160                .await
161            {
162                warn!(
163                    "Failed to increment subscriber count for session {}: {}",
164                    session_id, e
165                );
166            }
167
168            // Spawn task to handle outgoing events (App -> Client)
169            let tx_clone = tx.clone();
170            let session_id_clone = session_id.clone();
171            let event_task = tokio::spawn(async move {
172                while let Some(app_event) = event_rx.recv().await {
173                    event_sequence += 1;
174                    let server_event = match crate::grpc::conversions::app_event_to_server_event(
175                        app_event,
176                        event_sequence,
177                    ) {
178                        Ok(event) => event,
179                        Err(e) => {
180                            warn!("Failed to convert app event to server event: {}", e);
181                            continue;
182                        }
183                    };
184
185                    if let Err(e) = tx_clone.send(Ok(server_event)).await {
186                        warn!("Failed to send event to client: {}", e);
187                        break;
188                    }
189                }
190                debug!(
191                    "Event forwarding task ended for session: {}",
192                    session_id_clone
193                );
194            });
195
196            // Handle incoming messages (Client -> App)
197            while let Some(client_message_result) = client_stream.message().await.transpose() {
198                match client_message_result {
199                    Ok(client_message) => {
200                        // Touch the session to update last activity
201                        if let Err(e) = session_manager.touch_session(&session_id).await {
202                            warn!("Failed to touch session {}: {}", session_id, e);
203                        }
204
205                        if let Err(e) =
206                            handle_client_message(&session_manager, client_message).await
207                        {
208                            error!("Error handling client message: {}", e);
209                            let _ = tx
210                                .send(Err(Status::internal(format!(
211                                    "Error processing message: {e}"
212                                ))))
213                                .await;
214                            break;
215                        }
216                    }
217                    Err(e) => {
218                        error!("Error receiving client message: {}", e);
219                        let _ = tx.send(Err(Status::internal("Stream error"))).await;
220                        break;
221                    }
222                }
223            }
224
225            // Clean up
226            event_task.abort();
227
228            // Decrement subscriber count
229            if let Err(e) = session_manager
230                .decrement_subscriber_count(&session_id)
231                .await
232            {
233                warn!(
234                    "Failed to decrement subscriber count for session {}: {}",
235                    session_id, e
236                );
237            }
238
239            info!("Client stream ended for session: {}", session_id);
240
241            // Check if we should suspend the session (no more subscribers)
242            if let Err(e) = session_manager
243                .maybe_suspend_idle_session(&session_id)
244                .await
245            {
246                warn!("Failed to check/suspend idle session {}: {}", session_id, e);
247            }
248        });
249
250        Ok(Response::new(ReceiverStream::new(rx)))
251    }
252
253    async fn create_session(
254        &self,
255        request: Request<CreateSessionRequest>,
256    ) -> Result<Response<CreateSessionResponse>, Status> {
257        let req = request.into_inner();
258
259        let app_config = steer_core::app::AppConfig {
260            llm_config_provider: self.llm_config_provider.clone(),
261            model_registry: self.model_registry.clone(),
262            provider_registry: self.provider_registry.clone(),
263        };
264
265        match self
266            .session_manager
267            .create_session_grpc(req, app_config)
268            .await
269        {
270            Ok((_session_id, session_info)) => Ok(Response::new(CreateSessionResponse {
271                session: Some(session_info),
272            })),
273            Err(e) => {
274                error!("Failed to create session: {}", e);
275                Err(e.into())
276            }
277        }
278    }
279
280    async fn list_sessions(
281        &self,
282        request: Request<ListSessionsRequest>,
283    ) -> Result<Response<ListSessionsResponse>, Status> {
284        let _req = request.into_inner();
285
286        // Create filter - for now just list all sessions
287        let filter = steer_core::session::SessionFilter::default();
288
289        match self.session_manager.list_sessions(filter).await {
290            Ok(sessions) => {
291                let proto_sessions = sessions
292                    .into_iter()
293                    .map(|session| SessionInfo {
294                        id: session.id,
295                        created_at: Some(prost_types::Timestamp::from(
296                            std::time::SystemTime::from(session.created_at),
297                        )),
298                        updated_at: Some(prost_types::Timestamp::from(
299                            std::time::SystemTime::from(session.updated_at),
300                        )),
301                        status: proto::SessionStatus::Active as i32,
302                        metadata: Some(proto::SessionMetadata {
303                            labels: session.metadata,
304                            annotations: std::collections::HashMap::new(),
305                        }),
306                    })
307                    .collect();
308
309                Ok(Response::new(ListSessionsResponse {
310                    sessions: proto_sessions,
311                    next_page_token: None,
312                }))
313            }
314            Err(e) => {
315                error!("Failed to list sessions: {}", e);
316                Err(Status::internal(format!("Failed to list sessions: {e}")))
317            }
318        }
319    }
320
321    async fn get_session(
322        &self,
323        request: Request<GetSessionRequest>,
324    ) -> Result<Response<Self::GetSessionStream>, Status> {
325        let req = request.into_inner();
326        let session_manager = self.session_manager.clone();
327
328        let stream = async_stream::try_stream! {
329            match session_manager.get_session_proto(&req.session_id).await {
330                Ok(Some(session_state)) => {
331                    // Send header
332                    yield GetSessionResponse {
333                        chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
334                            id: session_state.id,
335                            created_at: session_state.created_at,
336                            updated_at: session_state.updated_at,
337                            config: session_state.config,
338                        })),
339                    };
340
341                    // Stream messages one by one
342                    for message in session_state.messages {
343                        yield GetSessionResponse {
344                            chunk: Some(get_session_response::Chunk::Message(message)),
345                        };
346                    }
347
348                    // Stream tool calls
349                    for (key, value) in session_state.tool_calls {
350                        yield GetSessionResponse {
351                            chunk: Some(get_session_response::Chunk::ToolCall(ToolCallStateEntry {
352                                key,
353                                value: Some(value),
354                            })),
355                        };
356                    }
357
358                    // Send footer
359                    yield GetSessionResponse {
360                        chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
361                            approved_tools: session_state.approved_tools,
362                            last_event_sequence: session_state.last_event_sequence,
363                            metadata: session_state.metadata,
364                        })),
365                    };
366                }
367                Ok(None) => {
368                    Err(Status::not_found(format!(
369                        "Session not found: {}",
370                        req.session_id
371                    )))?;
372                }
373                Err(e) => {
374                    error!("Failed to get session: {}", e);
375                    Err(Status::internal(format!("Failed to get session: {e}")))?;
376                }
377            }
378        };
379
380        Ok(Response::new(Box::pin(stream)))
381    }
382
383    async fn delete_session(
384        &self,
385        request: Request<DeleteSessionRequest>,
386    ) -> Result<Response<DeleteSessionResponse>, Status> {
387        let req = request.into_inner();
388
389        match self.session_manager.delete_session(&req.session_id).await {
390            Ok(true) => Ok(Response::new(DeleteSessionResponse {})),
391            Ok(false) => Err(Status::not_found(format!(
392                "Session not found: {}",
393                req.session_id
394            ))),
395            Err(e) => {
396                error!("Failed to delete session: {}", e);
397                Err(Status::internal(format!("Failed to delete session: {e}")))
398            }
399        }
400    }
401
402    async fn get_conversation(
403        &self,
404        request: Request<GetConversationRequest>,
405    ) -> Result<Response<Self::GetConversationStream>, Status> {
406        let req = request.into_inner();
407        let session_manager = self.session_manager.clone();
408
409        info!("GetConversation called for session: {}", req.session_id);
410
411        let stream = async_stream::try_stream! {
412            match session_manager.get_session_state(&req.session_id).await {
413                Ok(Some(session_state)) => {
414                    info!(
415                        "Found session state with {} messages and {} approved tools",
416                        session_state.messages.len(),
417                        session_state.approved_tools.len()
418                    );
419
420                    // Stream messages one by one
421                    for msg in session_state.messages {
422                        let proto_msg = message_to_proto(msg.clone())
423                            .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
424                        yield GetConversationResponse {
425                            chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
426                        };
427                    }
428
429                    // Send footer with approved tools
430                    yield GetConversationResponse {
431                        chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
432                            approved_tools: session_state.approved_tools.into_iter().collect(),
433                        })),
434                    };
435                }
436                Ok(None) => {
437                    Err(Status::not_found(format!(
438                        "Session not found: {}",
439                        req.session_id
440                    )))?;
441                }
442                Err(e) => {
443                    error!("Failed to get session state: {}", e);
444                    Err(Status::internal(format!("Failed to get session state: {e}")))?;
445                }
446            }
447        };
448
449        Ok(Response::new(Box::pin(stream)))
450    }
451
452    async fn send_message(
453        &self,
454        request: Request<SendMessageRequest>,
455    ) -> Result<Response<SendMessageResponse>, Status> {
456        let req = request.into_inner();
457
458        let app_command = steer_core::app::AppCommand::ProcessUserInput(req.message);
459
460        match self
461            .session_manager
462            .send_command(&req.session_id, app_command)
463            .await
464        {
465            Ok(()) => {
466                // Generate operation ID for tracking
467                let operation_id = format!("op_{}", uuid::Uuid::new_v4());
468                Ok(Response::new(SendMessageResponse {
469                    operation: Some(Operation {
470                        id: operation_id,
471                        session_id: req.session_id,
472                        r#type: OperationType::SendMessage as i32,
473                        status: OperationStatus::Running as i32,
474                        created_at: Some(
475                            prost_types::Timestamp::from(std::time::SystemTime::now()),
476                        ),
477                        completed_at: None,
478                        metadata: std::collections::HashMap::new(),
479                    }),
480                }))
481            }
482            Err(e) => {
483                error!("Failed to send message: {}", e);
484                Err(Status::internal(format!("Failed to send message: {e}")))
485            }
486        }
487    }
488
489    async fn approve_tool(
490        &self,
491        request: Request<ApproveToolRequest>,
492    ) -> Result<Response<ApproveToolResponse>, Status> {
493        let req = request.into_inner();
494
495        let approval = match req.decision {
496            Some(decision) => match decision {
497                proto::ApprovalDecision {
498                    decision_type: Some(proto::approval_decision::DecisionType::Deny(true)),
499                } => steer_core::app::command::ApprovalType::Denied,
500                proto::ApprovalDecision {
501                    decision_type: Some(proto::approval_decision::DecisionType::Once(true)),
502                } => steer_core::app::command::ApprovalType::Once,
503                proto::ApprovalDecision {
504                    decision_type: Some(proto::approval_decision::DecisionType::AlwaysTool(true)),
505                } => steer_core::app::command::ApprovalType::AlwaysTool,
506                proto::ApprovalDecision {
507                    decision_type:
508                        Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)),
509                } => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
510                _ => {
511                    return Err(Status::invalid_argument(
512                        "Invalid approval decision enum value",
513                    ));
514                }
515            },
516            None => {
517                return Err(Status::invalid_argument("Missing approval decision"));
518            }
519        };
520
521        let app_command = steer_core::app::AppCommand::HandleToolResponse {
522            id: req.tool_call_id,
523            approval,
524        };
525
526        match self
527            .session_manager
528            .send_command(&req.session_id, app_command)
529            .await
530        {
531            Ok(()) => Ok(Response::new(ApproveToolResponse {})),
532            Err(e) => {
533                error!("Failed to approve tool: {}", e);
534                Err(Status::internal(format!("Failed to approve tool: {e}")))
535            }
536        }
537    }
538
539    async fn activate_session(
540        &self,
541        request: Request<ActivateSessionRequest>,
542    ) -> Result<Response<Self::ActivateSessionStream>, Status> {
543        let req = request.into_inner();
544        let session_manager = self.session_manager.clone();
545        let llm_config_provider = self.llm_config_provider.clone();
546        let model_registry = self.model_registry.clone();
547        let provider_registry = self.provider_registry.clone();
548
549        info!("ActivateSession called for {}", req.session_id);
550
551        let stream = async_stream::try_stream! {
552            // Check if already active or activate it
553            let state = if let Ok(Some(state)) = session_manager
554                .get_session_state(&req.session_id)
555                .await
556            {
557                state
558            } else {
559                // Not active, so activate it
560                let app_config = steer_core::app::AppConfig {
561                    llm_config_provider: llm_config_provider.clone(),
562                    model_registry: model_registry.clone(),
563                    provider_registry: provider_registry.clone(),
564                };
565
566                session_manager
567                    .resume_session(&req.session_id, app_config)
568                    .await
569                    .map_err(|e| Status::internal(format!("Failed to resume session: {e}")))?;
570
571                // Fetch state now that it's active
572                session_manager
573                    .get_session_state(&req.session_id)
574                    .await
575                    .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?
576                    .ok_or_else(|| Status::not_found(format!("Session not found: {}", req.session_id)))?
577            };
578
579            // Stream messages one by one
580            for msg in state.messages {
581                let proto_msg = message_to_proto(msg)
582                    .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
583                yield ActivateSessionResponse {
584                    chunk: Some(activate_session_response::Chunk::Message(proto_msg)),
585                };
586            }
587
588            // Send footer with approved tools
589            yield ActivateSessionResponse {
590                chunk: Some(activate_session_response::Chunk::Footer(ActivateSessionFooter {
591                    approved_tools: state.approved_tools.into_iter().collect(),
592                })),
593            };
594        };
595
596        Ok(Response::new(Box::pin(stream)))
597    }
598
599    async fn cancel_operation(
600        &self,
601        request: Request<CancelOperationRequest>,
602    ) -> Result<Response<CancelOperationResponse>, Status> {
603        let req = request.into_inner();
604
605        let app_command = steer_core::app::AppCommand::CancelProcessing;
606
607        match self
608            .session_manager
609            .send_command(&req.session_id, app_command)
610            .await
611        {
612            Ok(()) => Ok(Response::new(CancelOperationResponse {})),
613            Err(e) => {
614                error!("Failed to cancel operation: {}", e);
615                Err(Status::internal(format!("Failed to cancel operation: {e}")))
616            }
617        }
618    }
619
620    async fn list_files(
621        &self,
622        request: Request<ListFilesRequest>,
623    ) -> Result<Response<Self::ListFilesStream>, Status> {
624        let req = request.into_inner();
625
626        debug!("ListFiles called for session: {}", req.session_id);
627
628        // Get the session's workspace
629        let workspace = match self
630            .session_manager
631            .get_session_workspace(&req.session_id)
632            .await
633        {
634            Ok(Some(workspace)) => workspace,
635            Ok(None) => {
636                return Err(Status::not_found(format!(
637                    "Session not found: {}",
638                    req.session_id
639                )));
640            }
641            Err(e) => {
642                error!("Failed to get session workspace: {}", e);
643                return Err(Status::internal(format!(
644                    "Failed to get session workspace: {e}"
645                )));
646            }
647        };
648
649        // Create the response stream
650        let (tx, rx) = mpsc::channel(100);
651
652        // Spawn task to stream the files
653        let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
654            // Get the file list from the workspace
655            let query = if req.query.is_empty() {
656                None
657            } else {
658                Some(req.query.as_str())
659            };
660
661            let max_results = if req.max_results == 0 {
662                None
663            } else {
664                Some(req.max_results as usize)
665            };
666
667            match workspace.list_files(query, max_results).await {
668                Ok(files) => {
669                    // Stream files in chunks of 1000
670                    for chunk in files.chunks(1000) {
671                        let response = ListFilesResponse {
672                            paths: chunk.to_vec(),
673                        };
674
675                        if let Err(e) = tx.send(Ok(response)).await {
676                            warn!("Failed to send file list chunk: {}", e);
677                            break;
678                        }
679                    }
680                }
681                Err(e) => {
682                    error!("Failed to list files: {}", e);
683                    let _ = tx
684                        .send(Err(Status::internal(format!("Failed to list files: {e}"))))
685                        .await;
686                }
687            }
688        });
689
690        Ok(Response::new(ReceiverStream::new(rx)))
691    }
692
693    async fn get_mcp_servers(
694        &self,
695        request: Request<GetMcpServersRequest>,
696    ) -> Result<Response<GetMcpServersResponse>, Status> {
697        let req = request.into_inner();
698
699        debug!("GetMcpServers called for session: {}", req.session_id);
700
701        // Get MCP server statuses from session manager
702        match self.session_manager.get_mcp_statuses(&req.session_id).await {
703            Ok(servers) => {
704                use crate::grpc::conversions::mcp_server_info_to_proto;
705
706                let proto_servers = servers.into_iter().map(mcp_server_info_to_proto).collect();
707
708                Ok(Response::new(GetMcpServersResponse {
709                    servers: proto_servers,
710                }))
711            }
712            Err(e) => {
713                error!("Failed to get MCP server statuses: {}", e);
714                Err(Status::internal(format!(
715                    "Failed to get MCP server statuses: {e}"
716                )))
717            }
718        }
719    }
720    async fn list_providers(
721        &self,
722        _request: Request<ListProvidersRequest>,
723    ) -> Result<Response<ListProvidersResponse>, Status> {
724        // Use the injected provider registry instance
725        let providers = self
726            .provider_registry
727            .all()
728            .map(|p| proto::ProviderInfo {
729                id: p.id.storage_key(),
730                name: p.name.clone(),
731                auth_schemes: p
732                    .auth_schemes
733                    .iter()
734                    .map(|s| match s {
735                        steer_core::config::toml_types::AuthScheme::ApiKey => {
736                            proto::ProviderAuthScheme::AuthSchemeApiKey as i32
737                        }
738                        steer_core::config::toml_types::AuthScheme::Oauth2 => {
739                            proto::ProviderAuthScheme::AuthSchemeOauth2 as i32
740                        }
741                    })
742                    .collect(),
743            })
744            .collect();
745
746        Ok(Response::new(ListProvidersResponse { providers }))
747    }
748
749    async fn list_models(
750        &self,
751        request: Request<ListModelsRequest>,
752    ) -> Result<Response<ListModelsResponse>, Status> {
753        let req = request.into_inner();
754
755        // Use the injected model registry
756        let model_registry = &self.model_registry;
757
758        // Get only recommended models from the registry
759        let all_models: Vec<proto::ProviderModel> = model_registry
760            .recommended() // Only recommended models
761            .filter(|m| {
762                if let Some(ref provider_id) = req.provider_id {
763                    m.provider.storage_key() == *provider_id
764                } else {
765                    true
766                }
767            })
768            .map(|m| proto::ProviderModel {
769                provider_id: m.provider.storage_key(),
770                model_id: m.id.clone(),
771                display_name: m.display_name.clone().unwrap_or_else(|| m.id.clone()),
772                supports_thinking: m
773                    .parameters
774                    .as_ref()
775                    .and_then(|p| p.thinking_config.as_ref())
776                    .map(|tc| tc.enabled)
777                    .unwrap_or(false),
778                aliases: m.aliases.clone(),
779            })
780            .collect();
781
782        Ok(Response::new(ListModelsResponse { models: all_models }))
783    }
784
785    async fn get_provider_auth_status(
786        &self,
787        request: Request<proto::GetProviderAuthStatusRequest>,
788    ) -> Result<Response<proto::GetProviderAuthStatusResponse>, Status> {
789        let req = request.into_inner();
790
791        let mut statuses = Vec::new();
792        for p in self.provider_registry.all() {
793            if let Some(ref filter) = req.provider_id {
794                if &p.id.storage_key() != filter {
795                    continue;
796                }
797            }
798            let status = match self
799                .llm_config_provider
800                .get_auth_for_provider(&p.id)
801                .await
802                .map_err(|e| Status::internal(format!("auth lookup failed: {e}")))?
803            {
804                Some(steer_core::config::ApiAuth::OAuth) => {
805                    proto::provider_auth_status::Status::AuthStatusOauth as i32
806                }
807                Some(steer_core::config::ApiAuth::Key(_)) => {
808                    proto::provider_auth_status::Status::AuthStatusApiKey as i32
809                }
810                None => proto::provider_auth_status::Status::AuthStatusNone as i32,
811            };
812            statuses.push(proto::ProviderAuthStatus {
813                provider_id: p.id.storage_key(),
814                status,
815            });
816        }
817
818        Ok(Response::new(proto::GetProviderAuthStatusResponse {
819            statuses,
820        }))
821    }
822
823    async fn resolve_model(
824        &self,
825        request: Request<proto::ResolveModelRequest>,
826    ) -> Result<Response<proto::ResolveModelResponse>, Status> {
827        let req = request.into_inner();
828
829        // Use the injected model registry to resolve the input
830        match self.model_registry.resolve(&req.input) {
831            Ok(model_id) => {
832                let model_spec = proto::ModelSpec {
833                    provider_id: model_id.0.storage_key(),
834                    model_id: model_id.1,
835                };
836                Ok(Response::new(proto::ResolveModelResponse {
837                    model: Some(model_spec),
838                }))
839            }
840            Err(e) => Err(Status::not_found(format!(
841                "Failed to resolve model '{}': {}",
842                req.input, e
843            ))),
844        }
845    }
846}
847
848async fn try_resume_session(
849    session_manager: &SessionManager,
850    session_id: &str,
851    llm_config_provider: &steer_core::config::LlmConfigProvider,
852    model_registry: &Arc<steer_core::model_registry::ModelRegistry>,
853    provider_registry: &Arc<steer_core::auth::ProviderRegistry>,
854) -> Result<(), Status> {
855    let app_config = steer_core::app::AppConfig {
856        llm_config_provider: llm_config_provider.clone(),
857        model_registry: model_registry.clone(),
858        provider_registry: provider_registry.clone(),
859    };
860
861    // Attempt to resume the session
862    match session_manager.resume_session(session_id, app_config).await {
863        Ok(_command_tx) => {
864            info!("Successfully resumed session: {}", session_id);
865            // TUI will call GetCurrentConversation when it connects
866            Ok(())
867        }
868        Err(steer_core::error::Error::SessionManager(
869            steer_core::session::manager::SessionManagerError::CapacityExceeded { current, max },
870        )) => {
871            warn!(
872                "Cannot resume session {}: server at capacity ({}/{})",
873                session_id, current, max
874            );
875            Err(Status::resource_exhausted(format!(
876                "Server at maximum capacity ({current}/{max}). Cannot resume session."
877            )))
878        }
879        Err(e) => {
880            error!("Failed to resume session {}: {}", session_id, e);
881            Err(Status::internal(format!("Failed to resume session: {e}")))
882        }
883    }
884}
885
886async fn handle_client_message(
887    session_manager: &SessionManager,
888    client_message: StreamSessionRequest,
889) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
890    debug!(
891        "Handling client message for session: {}",
892        client_message.session_id
893    );
894
895    if let Some(message) = client_message.message {
896        match message {
897            stream_session_request::Message::SendMessage(send_msg) => {
898                // Convert to AppCommand - just process user input since that's what exists
899                let app_command = steer_core::app::AppCommand::ProcessUserInput(send_msg.message);
900
901                session_manager
902                    .send_command(&client_message.session_id, app_command)
903                    .await
904                    .map_err(|e| format!("Failed to send message: {e}"))?;
905            }
906
907            stream_session_request::Message::ToolApproval(approval) => {
908                // Convert approval decision using existing HandleToolResponse
909                let approval_type = match approval.decision {
910                    Some(decision) => match decision.decision_type {
911                        Some(proto::approval_decision::DecisionType::Deny(_)) => {
912                            steer_core::app::command::ApprovalType::Denied
913                        }
914                        Some(proto::approval_decision::DecisionType::Once(_)) => {
915                            steer_core::app::command::ApprovalType::Once
916                        }
917                        Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => {
918                            steer_core::app::command::ApprovalType::AlwaysTool
919                        }
920                        Some(proto::approval_decision::DecisionType::AlwaysBashPattern(
921                            pattern,
922                        )) => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
923                        None => {
924                            return Err(
925                                "Invalid approval decision: no decision type specified".into()
926                            );
927                        }
928                    },
929                    None => {
930                        return Err("Invalid approval decision: no decision provided".into());
931                    }
932                };
933
934                let app_command = steer_core::app::AppCommand::HandleToolResponse {
935                    id: approval.tool_call_id,
936                    approval: approval_type,
937                };
938
939                session_manager
940                    .send_command(&client_message.session_id, app_command)
941                    .await
942                    .map_err(|e| format!("Failed to approve tool: {e}"))?;
943            }
944
945            stream_session_request::Message::Cancel(_cancel) => {
946                // Use existing CancelProcessing command
947                let app_command = steer_core::app::AppCommand::CancelProcessing;
948
949                session_manager
950                    .send_command(&client_message.session_id, app_command)
951                    .await
952                    .map_err(|e| format!("Failed to cancel operation: {e}"))?;
953            }
954
955            stream_session_request::Message::Subscribe(_subscribe_request) => {
956                debug!("Subscribe message received - stream already established");
957                // No action needed - stream is already active
958            }
959
960            stream_session_request::Message::UpdateConfig(_update_config) => {
961                // UpdateConfig no longer supports changing the LLM provider
962                // Tool config updates are handled separately
963                debug!("UpdateConfig received but provider changes are no longer supported");
964            }
965
966            stream_session_request::Message::ExecuteCommand(execute_command) => {
967                use steer_core::app::conversation::AppCommandType;
968                let app_cmd_type = match AppCommandType::parse(&execute_command.command) {
969                    Ok(cmd) => cmd,
970                    Err(e) => {
971                        return Err(format!("Failed to parse command: {e}").into());
972                    }
973                };
974                let app_command = steer_core::app::AppCommand::ExecuteCommand(app_cmd_type);
975                session_manager
976                    .send_command(&client_message.session_id, app_command)
977                    .await
978                    .map_err(|e| format!("Failed to execute command: {e}"))?;
979            }
980
981            stream_session_request::Message::ExecuteBashCommand(execute_bash_command) => {
982                let app_command = steer_core::app::AppCommand::ExecuteBashCommand {
983                    command: execute_bash_command.command,
984                };
985                session_manager
986                    .send_command(&client_message.session_id, app_command)
987                    .await
988                    .map_err(|e| format!("Failed to execute bash command: {e}"))?;
989            }
990
991            stream_session_request::Message::EditMessage(edit_message) => {
992                let app_command = steer_core::app::AppCommand::EditMessage {
993                    message_id: edit_message.message_id,
994                    new_content: edit_message.new_content,
995                };
996                session_manager
997                    .send_command(&client_message.session_id, app_command)
998                    .await
999                    .map_err(|e| format!("Failed to edit message: {e}"))?;
1000            }
1001        }
1002    }
1003
1004    Ok(())
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010
1011    use std::collections::HashMap;
1012    use steer_core::session::state::WorkspaceConfig;
1013    use steer_core::session::stores::sqlite::SqliteSessionStore;
1014    use steer_core::session::{SessionConfig, SessionManagerConfig, SessionToolConfig};
1015    use steer_proto::agent::v1::agent_service_client::AgentServiceClient;
1016    use steer_proto::agent::v1::{SendMessageRequest, SubscribeRequest};
1017    use tempfile::TempDir;
1018    use tokio::sync::mpsc;
1019    use tokio_stream::StreamExt;
1020
1021    fn create_test_app_config() -> steer_core::app::AppConfig {
1022        steer_core::test_utils::test_app_config()
1023    }
1024
1025    async fn create_test_session_manager() -> (Arc<SessionManager>, TempDir) {
1026        let temp_dir = TempDir::new().unwrap();
1027        let db_path = temp_dir.path().join("test.db");
1028        let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
1029
1030        let config = SessionManagerConfig {
1031            max_concurrent_sessions: 100,
1032            default_model: steer_core::config::model::builtin::claude_3_7_sonnet_20250219(),
1033            auto_persist: true,
1034        };
1035        let session_manager = Arc::new(SessionManager::new(store, config));
1036
1037        (session_manager, temp_dir)
1038    }
1039
1040    async fn create_test_server() -> (String, Arc<SessionManager>, TempDir) {
1041        let (session_manager, temp_dir) = create_test_session_manager().await;
1042
1043        let auth_storage = Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
1044        let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
1045        let model_registry =
1046            Arc::new(steer_core::model_registry::ModelRegistry::load(&[]).unwrap());
1047        let provider_registry = Arc::new(steer_core::auth::ProviderRegistry::load(&[]).unwrap());
1048        let service = AgentServiceImpl::new(
1049            session_manager.clone(),
1050            llm_config_provider,
1051            model_registry,
1052            provider_registry,
1053        );
1054
1055        // Start server on random port
1056        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1057        let addr = listener.local_addr().unwrap();
1058
1059        let _server_task = tokio::spawn(async move {
1060            tonic::transport::Server::builder()
1061                .add_service(agent_service_server::AgentServiceServer::new(service))
1062                .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
1063                .await
1064                .unwrap();
1065        });
1066
1067        // Give server time to start
1068        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1069
1070        let url = format!("http://{addr}");
1071        (url, session_manager, temp_dir)
1072    }
1073
1074    #[tokio::test]
1075    async fn test_session_cleanup_on_disconnect() {
1076        let (session_manager, _temp_dir) = create_test_session_manager().await;
1077
1078        // Create a session
1079        let session_config = SessionConfig {
1080            workspace: WorkspaceConfig::Local {
1081                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1082            },
1083            tool_config: SessionToolConfig::default(),
1084            system_prompt: None,
1085            metadata: HashMap::new(),
1086        };
1087
1088        let app_config = create_test_app_config();
1089
1090        let (session_id, _command_tx) = session_manager
1091            .create_session(session_config, app_config)
1092            .await
1093            .unwrap();
1094
1095        // Verify session is active
1096        assert!(session_manager.is_session_active(&session_id).await);
1097
1098        // Simulate a client connection by incrementing subscriber count
1099        session_manager
1100            .increment_subscriber_count(&session_id)
1101            .await
1102            .unwrap();
1103
1104        // Verify session is still active
1105        assert!(session_manager.is_session_active(&session_id).await);
1106
1107        // Simulate client disconnect by decrementing subscriber count
1108        session_manager
1109            .decrement_subscriber_count(&session_id)
1110            .await
1111            .unwrap();
1112
1113        // Check if session should be suspended
1114        session_manager
1115            .maybe_suspend_idle_session(&session_id)
1116            .await
1117            .unwrap();
1118
1119        // Verify session was suspended (not active in memory)
1120        assert!(
1121            !session_manager.is_session_active(&session_id).await,
1122            "Session should be suspended after last client disconnects"
1123        );
1124
1125        // Verify session still exists in storage
1126        let session_info = session_manager.get_session(&session_id).await.unwrap();
1127        assert!(
1128            session_info.is_some(),
1129            "Session should still exist in storage after suspension"
1130        );
1131    }
1132
1133    #[tokio::test]
1134    async fn test_session_with_multiple_subscribers() {
1135        let (session_manager, _temp_dir) = create_test_session_manager().await;
1136
1137        // Create a session
1138        let session_config = SessionConfig {
1139            workspace: WorkspaceConfig::Local {
1140                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1141            },
1142            tool_config: SessionToolConfig::default(),
1143            system_prompt: None,
1144            metadata: HashMap::new(),
1145        };
1146
1147        let app_config = create_test_app_config();
1148
1149        let (session_id, _command_tx) = session_manager
1150            .create_session(session_config, app_config)
1151            .await
1152            .unwrap();
1153
1154        // Simulate two clients connecting
1155        session_manager
1156            .increment_subscriber_count(&session_id)
1157            .await
1158            .unwrap();
1159        session_manager
1160            .increment_subscriber_count(&session_id)
1161            .await
1162            .unwrap();
1163
1164        // First client disconnects
1165        session_manager
1166            .decrement_subscriber_count(&session_id)
1167            .await
1168            .unwrap();
1169        session_manager
1170            .maybe_suspend_idle_session(&session_id)
1171            .await
1172            .unwrap();
1173
1174        // Session should still be active (one subscriber remaining)
1175        assert!(
1176            session_manager.is_session_active(&session_id).await,
1177            "Session should remain active with one subscriber"
1178        );
1179
1180        // Second client disconnects
1181        session_manager
1182            .decrement_subscriber_count(&session_id)
1183            .await
1184            .unwrap();
1185        session_manager
1186            .maybe_suspend_idle_session(&session_id)
1187            .await
1188            .unwrap();
1189
1190        // Now session should be suspended
1191        assert!(
1192            !session_manager.is_session_active(&session_id).await,
1193            "Session should be suspended after all clients disconnect"
1194        );
1195    }
1196
1197    #[tokio::test]
1198    async fn test_grpc_client_connect_disconnect_cleanup() {
1199        let (server_url, session_manager, _temp_dir) = create_test_server().await;
1200
1201        // Create a session first
1202        let session_config = SessionConfig {
1203            workspace: WorkspaceConfig::Local {
1204                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1205            },
1206            tool_config: SessionToolConfig::default(),
1207            system_prompt: None,
1208            metadata: HashMap::new(),
1209        };
1210
1211        let app_config = create_test_app_config();
1212
1213        let (session_id, _command_tx) = session_manager
1214            .create_session(session_config, app_config)
1215            .await
1216            .unwrap();
1217
1218        // Verify session is active
1219        assert!(session_manager.is_session_active(&session_id).await);
1220
1221        // Connect client
1222        let mut client = AgentServiceClient::connect(server_url.clone())
1223            .await
1224            .unwrap();
1225
1226        // Start streaming with subscribe message
1227        let request_stream = tokio_stream::iter(vec![StreamSessionRequest {
1228            session_id: session_id.clone(),
1229            message: Some(stream_session_request::Message::Subscribe(
1230                SubscribeRequest {
1231                    event_types: vec![],
1232                    since_sequence: None,
1233                },
1234            )),
1235        }]);
1236
1237        let response = client.stream_session(request_stream).await.unwrap();
1238        let _stream = response.into_inner();
1239
1240        // Send a test message to verify session is working
1241        let (msg_tx, msg_rx) = mpsc::channel(10);
1242        msg_tx
1243            .send(StreamSessionRequest {
1244                session_id: session_id.clone(),
1245                message: Some(stream_session_request::Message::SendMessage(
1246                    SendMessageRequest {
1247                        session_id: session_id.clone(),
1248                        message: "Hello, test!".to_string(),
1249                        attachments: vec![],
1250                    },
1251                )),
1252            })
1253            .await
1254            .unwrap();
1255
1256        // Create new request stream with the message channel
1257        let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1258        let response = client.stream_session(request_stream).await.unwrap();
1259        let mut stream = response.into_inner();
1260
1261        // Wait for some response to verify session is working
1262        let timeout =
1263            tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()).await;
1264
1265        assert!(timeout.is_ok(), "Should receive at least one event");
1266
1267        // Drop the stream to simulate client disconnect
1268        drop(stream);
1269        drop(msg_tx);
1270
1271        // Give the server time to process the disconnect
1272        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1273
1274        // Verify session was suspended (not active in memory)
1275        assert!(
1276            !session_manager.is_session_active(&session_id).await,
1277            "Session should be suspended after client disconnect"
1278        );
1279
1280        // Verify session still exists in storage
1281        let session_info = session_manager.get_session(&session_id).await.unwrap();
1282        assert!(
1283            session_info.is_some(),
1284            "Session should still exist in storage"
1285        );
1286    }
1287
1288    #[tokio::test]
1289    async fn test_grpc_basic_session_resume() {
1290        let (server_url, session_manager, _temp_dir) = create_test_server().await;
1291
1292        // Create a session
1293        let session_config = SessionConfig {
1294            workspace: WorkspaceConfig::Local {
1295                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1296            },
1297            tool_config: SessionToolConfig::default(),
1298            system_prompt: None,
1299            metadata: HashMap::new(),
1300        };
1301
1302        let app_config = create_test_app_config();
1303
1304        let (session_id, _command_tx) = session_manager
1305            .create_session(session_config, app_config)
1306            .await
1307            .unwrap();
1308
1309        // Suspend the session manually to simulate a disconnected state
1310        session_manager.suspend_session(&session_id).await.unwrap();
1311        assert!(
1312            !session_manager.is_session_active(&session_id).await,
1313            "Session should be suspended"
1314        );
1315
1316        // Try to reconnect - this should auto-resume the session
1317        let mut client = AgentServiceClient::connect(server_url.clone())
1318            .await
1319            .unwrap();
1320
1321        // Use a channel to keep the stream alive
1322        let (msg_tx, msg_rx) = mpsc::channel(10);
1323
1324        // Send initial subscribe message
1325        msg_tx
1326            .send(StreamSessionRequest {
1327                session_id: session_id.clone(),
1328                message: Some(stream_session_request::Message::Subscribe(
1329                    SubscribeRequest {
1330                        event_types: vec![],
1331                        since_sequence: None,
1332                    },
1333                )),
1334            })
1335            .await
1336            .unwrap();
1337
1338        let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1339        let response = client.stream_session(request_stream).await;
1340
1341        // The connection should succeed (auto-resume should work)
1342        assert!(
1343            response.is_ok(),
1344            "Should be able to connect to suspended session (auto-resume)"
1345        );
1346
1347        let stream = response.unwrap().into_inner();
1348
1349        // Give time for auto-resume to complete
1350        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1351
1352        // Session should be active again after auto-resume
1353        assert!(
1354            session_manager.is_session_active(&session_id).await,
1355            "Session should be active after auto-resume"
1356        );
1357
1358        // Keep the stream alive a bit longer to ensure it stays active
1359        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1360        assert!(
1361            session_manager.is_session_active(&session_id).await,
1362            "Session should remain active while client is connected"
1363        );
1364
1365        // Clean up - drop the stream to disconnect
1366        drop(stream);
1367        drop(msg_tx);
1368
1369        // Give time for cleanup to run
1370        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1371
1372        // Now session should be suspended again after disconnect
1373        assert!(
1374            !session_manager.is_session_active(&session_id).await,
1375            "Session should be suspended after client disconnects"
1376        );
1377    }
1378}