1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11 convert_app_command_to_client_message, proto_to_mcp_server_info, proto_to_message,
12 server_event_to_app_event, session_tool_config_to_proto, tool_approval_policy_to_proto,
13 workspace_config_to_proto,
14};
15use crate::grpc::error::GrpcError;
16
17type GrpcResult<T> = std::result::Result<T, GrpcError>;
18
19use steer_core::app::conversation::Message;
20use steer_core::app::io::{AppCommandSink, AppEventSource};
21use steer_core::app::{AppCommand, AppEvent};
22use steer_core::session::{McpServerInfo, SessionConfig};
23use steer_proto::agent::v1::{
24 self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
25 GetMcpServersRequest, GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState,
26 StreamSessionRequest, SubscribeRequest, agent_service_client::AgentServiceClient,
27 stream_session_request::Message as StreamSessionRequestType,
28};
29
30pub struct AgentClient {
32 client: Mutex<AgentServiceClient<Channel>>,
33 session_id: Mutex<Option<String>>,
34 command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
35 event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
36 stream_handle: Mutex<Option<JoinHandle<()>>>,
37}
38
39impl AgentClient {
40 pub async fn connect(addr: &str) -> GrpcResult<Self> {
42 info!("Connecting to gRPC server at {}", addr);
43
44 let client = AgentServiceClient::connect(addr.to_string()).await?;
45
46 info!("Successfully connected to gRPC server");
47
48 Ok(Self {
49 client: Mutex::new(client),
50 session_id: Mutex::new(None),
51 command_tx: Mutex::new(None),
52 stream_handle: Mutex::new(None),
53 event_rx: Mutex::new(None),
54 })
55 }
56
57 pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
59 info!("Creating gRPC client from provided channel");
60
61 let client = AgentServiceClient::new(channel);
62
63 Ok(Self {
64 client: Mutex::new(client),
65 session_id: Mutex::new(None),
66 command_tx: Mutex::new(None),
67 stream_handle: Mutex::new(None),
68 event_rx: Mutex::new(None),
69 })
70 }
71
72 pub async fn local(default_model: steer_core::config::model::ModelId) -> GrpcResult<Self> {
74 use crate::local_server::setup_local_grpc;
75 let (channel, _server_handle) = setup_local_grpc(default_model, None).await?;
76 Self::from_channel(channel).await
77 }
78
79 pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
81 debug!("Creating new session with gRPC server");
82
83 let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
84 let workspace_config = workspace_config_to_proto(&config.workspace);
85 let tool_config = session_tool_config_to_proto(&config.tool_config);
86
87 let request = Request::new(CreateSessionRequest {
88 tool_policy: Some(tool_policy),
89 metadata: config.metadata,
90 tool_config: Some(tool_config),
91 workspace_config: Some(workspace_config),
92 system_prompt: config.system_prompt,
93 });
94
95 let response = self
96 .client
97 .lock()
98 .await
99 .create_session(request)
100 .await
101 .map_err(Box::new)?;
102 let response = response.into_inner();
103 let session = response
104 .session
105 .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
106
107 *self.session_id.lock().await = Some(session.id.clone());
108
109 info!("Created session: {}", session.id);
110 Ok(session.id)
111 }
112
113 pub async fn activate_session(
115 &self,
116 session_id: String,
117 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
118 info!("Activating remote session: {}", session_id);
119
120 let mut stream = self
121 .client
122 .lock()
123 .await
124 .activate_session(proto::ActivateSessionRequest {
125 session_id: session_id.clone(),
126 })
127 .await
128 .map_err(Box::new)?
129 .into_inner();
130
131 let mut messages = Vec::new();
132 let mut approved_tools = Vec::new();
133
134 while let Some(response) = stream
135 .message()
136 .await
137 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
138 {
139 match response.chunk {
140 Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
141 match proto_to_message(proto_msg) {
142 Ok(msg) => messages.push(msg),
143 Err(e) => return Err(GrpcError::ConversionError(e)),
144 }
145 }
146 Some(proto::activate_session_response::Chunk::Footer(footer)) => {
147 approved_tools = footer.approved_tools;
148 }
149 None => {}
150 }
151 }
152
153 *self.session_id.lock().await = Some(session_id);
154 Ok((messages, approved_tools))
155 }
156
157 pub async fn start_streaming(&self) -> GrpcResult<()> {
159 let session_id = self
160 .session_id
161 .lock()
162 .await
163 .as_ref()
164 .cloned()
165 .ok_or_else(|| GrpcError::InvalidSessionState {
166 reason: "No session ID - call create_session or activate_session first".to_string(),
167 })?;
168
169 debug!("Starting bidirectional stream for session: {}", session_id);
170
171 let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
173 let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
174
175 let outbound_stream = ReceiverStream::new(cmd_rx);
177 let request = Request::new(outbound_stream);
178
179 let response = self
180 .client
181 .lock()
182 .await
183 .stream_session(request)
184 .await
185 .map_err(Box::new)?;
186 let mut inbound_stream = response.into_inner();
187
188 let subscribe_msg = StreamSessionRequest {
190 session_id: session_id.clone(),
191 message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
192 event_types: vec![], since_sequence: None,
194 })),
195 };
196
197 cmd_tx
198 .send(subscribe_msg)
199 .await
200 .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
201
202 let session_id_clone = session_id.clone();
204 let stream_handle = tokio::spawn(async move {
205 info!(
206 "Started event stream handler for session: {}",
207 session_id_clone
208 );
209
210 while let Some(result) = inbound_stream.message().await.transpose() {
211 match result {
212 Ok(server_event) => {
213 debug!(
214 "Received server event: sequence {}",
215 server_event.sequence_num
216 );
217
218 match server_event_to_app_event(server_event) {
219 Ok(app_event) => {
220 if let Err(e) = evt_tx.send(app_event).await {
221 warn!("Failed to forward event to TUI: {}", e);
222 break;
223 }
224 }
225 Err(e) => {
226 error!("Failed to convert server event: {}", e);
227 }
229 }
230 }
231 Err(e) => {
232 error!("gRPC stream error: {}", e);
233 break;
234 }
235 }
236 }
237
238 info!(
239 "Event stream handler ended for session: {}",
240 session_id_clone
241 );
242 });
243
244 *self.command_tx.lock().await = Some(cmd_tx);
246 *self.stream_handle.lock().await = Some(stream_handle);
247 *self.event_rx.lock().await = Some(evt_rx);
249
250 info!(
251 "Bidirectional streaming started for session: {}",
252 session_id
253 );
254 Ok(())
255 }
256
257 pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
259 let session_id = self
260 .session_id
261 .lock()
262 .await
263 .as_ref()
264 .cloned()
265 .ok_or_else(|| GrpcError::InvalidSessionState {
266 reason: "No active session".to_string(),
267 })?;
268
269 let command_tx = self
270 .command_tx
271 .lock()
272 .await
273 .as_ref()
274 .cloned()
275 .ok_or_else(|| GrpcError::InvalidSessionState {
276 reason: "Streaming not started - call start_streaming first".to_string(),
277 })?;
278
279 let message = convert_app_command_to_client_message(command, &session_id)?;
280
281 if let Some(message) = message {
282 command_tx.send(message).await.map_err(|_| {
283 GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
284 })?;
285 }
286
287 Ok(())
288 }
289
290 pub async fn session_id(&self) -> Option<String> {
292 self.session_id.lock().await.clone()
293 }
294
295 pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
297 debug!("Listing sessions from gRPC server");
298
299 let request = Request::new(ListSessionsRequest {
300 filter: None,
301 page_size: None,
302 page_token: None,
303 });
304
305 let response = self
306 .client
307 .lock()
308 .await
309 .list_sessions(request)
310 .await
311 .map_err(Box::new)?;
312 let sessions_response = response.into_inner();
313
314 Ok(sessions_response.sessions)
315 }
316
317 pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
319 debug!("Getting session {} from gRPC server", session_id);
320
321 let request = Request::new(GetSessionRequest {
322 session_id: session_id.to_string(),
323 });
324
325 let mut stream = self
326 .client
327 .lock()
328 .await
329 .get_session(request)
330 .await
331 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
332 .into_inner();
333
334 let mut header = None;
335 let mut messages = Vec::new();
336 let mut tool_calls = std::collections::HashMap::new();
337 let mut footer = None;
338
339 while let Some(response) = stream
340 .message()
341 .await
342 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
343 {
344 match response.chunk {
345 Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
346 Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
347 Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
348 if let Some(value) = tc.value {
349 tool_calls.insert(tc.key, value);
350 }
351 }
352 Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
353 None => {}
354 }
355 }
356
357 match (header, footer) {
358 (Some(h), Some(f)) => Ok(Some(SessionState {
359 id: h.id,
360 created_at: h.created_at,
361 updated_at: h.updated_at,
362 config: h.config,
363 messages,
364 tool_calls,
365 approved_tools: f.approved_tools,
366 last_event_sequence: f.last_event_sequence,
367 metadata: f.metadata,
368 })),
369 _ => Ok(None),
370 }
371 }
372
373 pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
375 debug!("Deleting session {} from gRPC server", session_id);
376
377 let request = Request::new(DeleteSessionRequest {
378 session_id: session_id.to_string(),
379 });
380
381 match self.client.lock().await.delete_session(request).await {
382 Ok(_) => {
383 info!("Successfully deleted session: {}", session_id);
384 Ok(true)
385 }
386 Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
387 Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
388 }
389 }
390
391 pub async fn get_conversation(
393 &self,
394 session_id: &str,
395 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
396 info!(
397 "Client adapter getting conversation for session: {}",
398 session_id
399 );
400
401 let mut stream = self
402 .client
403 .lock()
404 .await
405 .get_conversation(GetConversationRequest {
406 session_id: session_id.to_string(),
407 })
408 .await
409 .map_err(Box::new)?
410 .into_inner();
411
412 let mut messages = Vec::new();
413 let mut approved_tools = Vec::new();
414
415 while let Some(response) = stream
416 .message()
417 .await
418 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
419 {
420 match response.chunk {
421 Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
422 match proto_to_message(proto_msg) {
423 Ok(msg) => messages.push(msg),
424 Err(e) => {
425 warn!("Failed to convert message: {}", e);
426 return Err(GrpcError::ConversionError(e));
427 }
428 }
429 }
430 Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
431 approved_tools = footer.approved_tools;
432 }
433 None => {}
434 }
435 }
436
437 info!(
438 "Successfully converted {} messages from GetConversation response",
439 messages.len()
440 );
441
442 Ok((messages, approved_tools))
443 }
444
445 pub async fn shutdown(self) {
447 if let Some(handle) = self.stream_handle.lock().await.take() {
448 handle.abort();
449 let _ = handle.await;
450 }
451
452 if let Some(session_id) = &*self.session_id.lock().await {
453 info!("GrpcClientAdapter shut down for session: {}", session_id);
454 }
455 }
456
457 pub async fn get_mcp_servers(&self) -> GrpcResult<Vec<McpServerInfo>> {
458 let session_id = self
459 .session_id
460 .lock()
461 .await
462 .as_ref()
463 .cloned()
464 .ok_or_else(|| GrpcError::InvalidSessionState {
465 reason: "No active session".to_string(),
466 })?;
467
468 let request = Request::new(GetMcpServersRequest {
469 session_id: session_id.clone(),
470 });
471
472 let response = self
473 .client
474 .lock()
475 .await
476 .get_mcp_servers(request)
477 .await
478 .map_err(Box::new)?;
479
480 let servers = response
481 .into_inner()
482 .servers
483 .into_iter()
484 .filter_map(|s| proto_to_mcp_server_info(s).ok())
485 .collect();
486
487 Ok(servers)
488 }
489
490 pub async fn resolve_model(
492 &self,
493 input: &str,
494 ) -> GrpcResult<steer_core::config::model::ModelId> {
495 let request = Request::new(proto::ResolveModelRequest {
496 input: input.to_string(),
497 });
498
499 let response = self
500 .client
501 .lock()
502 .await
503 .resolve_model(request)
504 .await
505 .map_err(Box::new)?;
506
507 let inner = response.into_inner();
508 let model_spec = inner.model.ok_or_else(|| GrpcError::InvalidSessionState {
509 reason: format!("Server returned no model for input '{input}'"),
510 })?;
511
512 let provider_id: steer_core::config::provider::ProviderId =
515 serde_json::from_value(serde_json::Value::String(model_spec.provider_id.clone()))
516 .map_err(|_| GrpcError::InvalidSessionState {
517 reason: format!(
518 "Invalid provider ID from server: {}",
519 model_spec.provider_id
520 ),
521 })?;
522
523 Ok((provider_id, model_spec.model_id))
524 }
525
526 pub async fn list_providers(&self) -> GrpcResult<Vec<proto::ProviderInfo>> {
528 let request = Request::new(proto::ListProvidersRequest {});
529 let response = self
530 .client
531 .lock()
532 .await
533 .list_providers(request)
534 .await
535 .map_err(Box::new)?;
536 Ok(response.into_inner().providers)
537 }
538
539 pub async fn get_provider_auth_status(
541 &self,
542 provider_id: Option<String>,
543 ) -> GrpcResult<Vec<proto::ProviderAuthStatus>> {
544 let request = Request::new(proto::GetProviderAuthStatusRequest { provider_id });
545 let response = self
546 .client
547 .lock()
548 .await
549 .get_provider_auth_status(request)
550 .await
551 .map_err(Box::new)?;
552 Ok(response.into_inner().statuses)
553 }
554
555 pub async fn list_models(
557 &self,
558 provider_id: Option<String>,
559 ) -> GrpcResult<Vec<proto::ProviderModel>> {
560 let request = Request::new(proto::ListModelsRequest { provider_id });
561
562 let response = self
563 .client
564 .lock()
565 .await
566 .list_models(request)
567 .await
568 .map_err(Box::new)?;
569
570 Ok(response.into_inner().models)
571 }
572}
573
574#[async_trait]
575impl AppCommandSink for AgentClient {
576 async fn send_command(&self, command: AppCommand) -> Result<()> {
577 self.send_command(command)
578 .await
579 .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
580 }
581}
582
583#[async_trait]
584impl AppEventSource for AgentClient {
585 async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
586 self.event_rx.lock().await.take().expect(
589 "Event receiver already taken - GrpcClientAdapter only supports single subscription",
590 )
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use crate::grpc::conversions::tool_approval_policy_to_proto;
598 use steer_core::session::ToolApprovalPolicy;
599 use steer_proto::agent::v1::tool_approval_policy::Policy;
600
601 #[test]
602 fn test_convert_tool_approval_policy() {
603 let policy = ToolApprovalPolicy::AlwaysAsk;
604 let proto_policy = tool_approval_policy_to_proto(&policy);
605 assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
606
607 let mut tools = std::collections::HashSet::new();
608 tools.insert("bash".to_string());
609 let policy = ToolApprovalPolicy::PreApproved { tools };
610 let proto_policy = tool_approval_policy_to_proto(&policy);
611 assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
612 }
613
614 #[test]
615 fn test_convert_app_command_to_client_message() {
616 let session_id = "test-session";
617
618 let command = AppCommand::ProcessUserInput("Hello".to_string());
619 let result = convert_app_command_to_client_message(command, session_id).unwrap();
620 assert!(result.is_some());
621
622 let command = AppCommand::Shutdown;
623 let result = convert_app_command_to_client_message(command, session_id).unwrap();
624 assert!(result.is_none());
625 }
626}