1use async_trait::async_trait;
2use steer_core::error::Result;
3use tokio::sync::{Mutex, mpsc};
4use tokio::task::JoinHandle;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::Request;
7use tonic::transport::Channel;
8use tracing::{debug, error, info, warn};
9
10use crate::grpc::conversions::{
11 convert_app_command_to_client_message, proto_to_mcp_server_info, proto_to_message,
12 server_event_to_app_event, session_tool_config_to_proto, tool_approval_policy_to_proto,
13 workspace_config_to_proto,
14};
15use crate::grpc::error::GrpcError;
16
17type GrpcResult<T> = std::result::Result<T, GrpcError>;
18
19use steer_core::app::conversation::Message;
20use steer_core::app::io::{AppCommandSink, AppEventSource};
21use steer_core::app::{AppCommand, AppEvent};
22use steer_core::session::{McpServerInfo, SessionConfig};
23use steer_proto::agent::v1::{
24 self as proto, CreateSessionRequest, DeleteSessionRequest, GetConversationRequest,
25 GetMcpServersRequest, GetSessionRequest, ListSessionsRequest, SessionInfo, SessionState,
26 StreamSessionRequest, SubscribeRequest, agent_service_client::AgentServiceClient,
27 stream_session_request::Message as StreamSessionRequestType,
28};
29
30pub struct AgentClient {
32 client: Mutex<AgentServiceClient<Channel>>,
33 session_id: Mutex<Option<String>>,
34 command_tx: Mutex<Option<mpsc::Sender<StreamSessionRequest>>>,
35 event_rx: Mutex<Option<mpsc::Receiver<AppEvent>>>,
36 stream_handle: Mutex<Option<JoinHandle<()>>>,
37}
38
39impl AgentClient {
40 pub async fn connect(addr: &str) -> GrpcResult<Self> {
42 info!("Connecting to gRPC server at {}", addr);
43
44 let client = AgentServiceClient::connect(addr.to_string()).await?;
45
46 info!("Successfully connected to gRPC server");
47
48 Ok(Self {
49 client: Mutex::new(client),
50 session_id: Mutex::new(None),
51 command_tx: Mutex::new(None),
52 stream_handle: Mutex::new(None),
53 event_rx: Mutex::new(None),
54 })
55 }
56
57 pub async fn from_channel(channel: Channel) -> GrpcResult<Self> {
59 info!("Creating gRPC client from provided channel");
60
61 let client = AgentServiceClient::new(channel);
62
63 Ok(Self {
64 client: Mutex::new(client),
65 session_id: Mutex::new(None),
66 command_tx: Mutex::new(None),
67 stream_handle: Mutex::new(None),
68 event_rx: Mutex::new(None),
69 })
70 }
71
72 pub async fn create_session(&self, config: SessionConfig) -> GrpcResult<String> {
74 debug!("Creating new session with gRPC server");
75
76 let tool_policy = tool_approval_policy_to_proto(&config.tool_config.approval_policy);
77 let workspace_config = workspace_config_to_proto(&config.workspace);
78 let tool_config = session_tool_config_to_proto(&config.tool_config);
79
80 let request = Request::new(CreateSessionRequest {
81 tool_policy: Some(tool_policy),
82 metadata: config.metadata,
83 tool_config: Some(tool_config),
84 workspace_config: Some(workspace_config),
85 system_prompt: config.system_prompt,
86 });
87
88 let response = self
89 .client
90 .lock()
91 .await
92 .create_session(request)
93 .await
94 .map_err(Box::new)?;
95 let response = response.into_inner();
96 let session = response
97 .session
98 .ok_or_else(|| Box::new(tonic::Status::internal("No session info in response")))?;
99
100 *self.session_id.lock().await = Some(session.id.clone());
101
102 info!("Created session: {}", session.id);
103 Ok(session.id)
104 }
105
106 pub async fn activate_session(
108 &self,
109 session_id: String,
110 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
111 info!("Activating remote session: {}", session_id);
112
113 let mut stream = self
114 .client
115 .lock()
116 .await
117 .activate_session(proto::ActivateSessionRequest {
118 session_id: session_id.clone(),
119 })
120 .await
121 .map_err(Box::new)?
122 .into_inner();
123
124 let mut messages = Vec::new();
125 let mut approved_tools = Vec::new();
126
127 while let Some(response) = stream
128 .message()
129 .await
130 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
131 {
132 match response.chunk {
133 Some(proto::activate_session_response::Chunk::Message(proto_msg)) => {
134 match proto_to_message(proto_msg) {
135 Ok(msg) => messages.push(msg),
136 Err(e) => return Err(GrpcError::ConversionError(e)),
137 }
138 }
139 Some(proto::activate_session_response::Chunk::Footer(footer)) => {
140 approved_tools = footer.approved_tools;
141 }
142 None => {}
143 }
144 }
145
146 *self.session_id.lock().await = Some(session_id);
147 Ok((messages, approved_tools))
148 }
149
150 pub async fn start_streaming(&self) -> GrpcResult<()> {
152 let session_id = self
153 .session_id
154 .lock()
155 .await
156 .as_ref()
157 .cloned()
158 .ok_or_else(|| GrpcError::InvalidSessionState {
159 reason: "No session ID - call create_session or activate_session first".to_string(),
160 })?;
161
162 debug!("Starting bidirectional stream for session: {}", session_id);
163
164 let (cmd_tx, cmd_rx) = mpsc::channel::<StreamSessionRequest>(32);
166 let (evt_tx, evt_rx) = mpsc::channel::<AppEvent>(100);
167
168 let outbound_stream = ReceiverStream::new(cmd_rx);
170 let request = Request::new(outbound_stream);
171
172 let response = self
173 .client
174 .lock()
175 .await
176 .stream_session(request)
177 .await
178 .map_err(Box::new)?;
179 let mut inbound_stream = response.into_inner();
180
181 let subscribe_msg = StreamSessionRequest {
183 session_id: session_id.clone(),
184 message: Some(StreamSessionRequestType::Subscribe(SubscribeRequest {
185 event_types: vec![], since_sequence: None,
187 })),
188 };
189
190 cmd_tx
191 .send(subscribe_msg)
192 .await
193 .map_err(|_| GrpcError::StreamError("Failed to send subscribe message".to_string()))?;
194
195 let session_id_clone = session_id.clone();
197 let stream_handle = tokio::spawn(async move {
198 info!(
199 "Started event stream handler for session: {}",
200 session_id_clone
201 );
202
203 while let Some(result) = inbound_stream.message().await.transpose() {
204 match result {
205 Ok(server_event) => {
206 debug!(
207 "Received server event: sequence {}",
208 server_event.sequence_num
209 );
210
211 match server_event_to_app_event(server_event) {
212 Ok(app_event) => {
213 if let Err(e) = evt_tx.send(app_event).await {
214 warn!("Failed to forward event to TUI: {}", e);
215 break;
216 }
217 }
218 Err(e) => {
219 error!("Failed to convert server event: {}", e);
220 }
222 }
223 }
224 Err(e) => {
225 error!("gRPC stream error: {}", e);
226 break;
227 }
228 }
229 }
230
231 info!(
232 "Event stream handler ended for session: {}",
233 session_id_clone
234 );
235 });
236
237 *self.command_tx.lock().await = Some(cmd_tx);
239 *self.stream_handle.lock().await = Some(stream_handle);
240 *self.event_rx.lock().await = Some(evt_rx);
242
243 info!(
244 "Bidirectional streaming started for session: {}",
245 session_id
246 );
247 Ok(())
248 }
249
250 pub async fn send_command(&self, command: AppCommand) -> GrpcResult<()> {
252 let session_id = self
253 .session_id
254 .lock()
255 .await
256 .as_ref()
257 .cloned()
258 .ok_or_else(|| GrpcError::InvalidSessionState {
259 reason: "No active session".to_string(),
260 })?;
261
262 let command_tx = self
263 .command_tx
264 .lock()
265 .await
266 .as_ref()
267 .cloned()
268 .ok_or_else(|| GrpcError::InvalidSessionState {
269 reason: "Streaming not started - call start_streaming first".to_string(),
270 })?;
271
272 let message = convert_app_command_to_client_message(command, &session_id)?;
273
274 if let Some(message) = message {
275 command_tx.send(message).await.map_err(|_| {
276 GrpcError::StreamError("Failed to send command - stream may be closed".to_string())
277 })?;
278 }
279
280 Ok(())
281 }
282
283 pub async fn session_id(&self) -> Option<String> {
285 self.session_id.lock().await.clone()
286 }
287
288 pub async fn list_sessions(&self) -> GrpcResult<Vec<SessionInfo>> {
290 debug!("Listing sessions from gRPC server");
291
292 let request = Request::new(ListSessionsRequest {
293 filter: None,
294 page_size: None,
295 page_token: None,
296 });
297
298 let response = self
299 .client
300 .lock()
301 .await
302 .list_sessions(request)
303 .await
304 .map_err(Box::new)?;
305 let sessions_response = response.into_inner();
306
307 Ok(sessions_response.sessions)
308 }
309
310 pub async fn get_session(&self, session_id: &str) -> GrpcResult<Option<SessionState>> {
312 debug!("Getting session {} from gRPC server", session_id);
313
314 let request = Request::new(GetSessionRequest {
315 session_id: session_id.to_string(),
316 });
317
318 let mut stream = self
319 .client
320 .lock()
321 .await
322 .get_session(request)
323 .await
324 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
325 .into_inner();
326
327 let mut header = None;
328 let mut messages = Vec::new();
329 let mut tool_calls = std::collections::HashMap::new();
330 let mut footer = None;
331
332 while let Some(response) = stream
333 .message()
334 .await
335 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
336 {
337 match response.chunk {
338 Some(proto::get_session_response::Chunk::Header(h)) => header = Some(h),
339 Some(proto::get_session_response::Chunk::Message(m)) => messages.push(m),
340 Some(proto::get_session_response::Chunk::ToolCall(tc)) => {
341 if let Some(value) = tc.value {
342 tool_calls.insert(tc.key, value);
343 }
344 }
345 Some(proto::get_session_response::Chunk::Footer(f)) => footer = Some(f),
346 None => {}
347 }
348 }
349
350 match (header, footer) {
351 (Some(h), Some(f)) => Ok(Some(SessionState {
352 id: h.id,
353 created_at: h.created_at,
354 updated_at: h.updated_at,
355 config: h.config,
356 messages,
357 tool_calls,
358 approved_tools: f.approved_tools,
359 last_event_sequence: f.last_event_sequence,
360 metadata: f.metadata,
361 })),
362 _ => Ok(None),
363 }
364 }
365
366 pub async fn delete_session(&self, session_id: &str) -> GrpcResult<bool> {
368 debug!("Deleting session {} from gRPC server", session_id);
369
370 let request = Request::new(DeleteSessionRequest {
371 session_id: session_id.to_string(),
372 });
373
374 match self.client.lock().await.delete_session(request).await {
375 Ok(_) => {
376 info!("Successfully deleted session: {}", session_id);
377 Ok(true)
378 }
379 Err(status) if status.code() == tonic::Code::NotFound => Ok(false),
380 Err(e) => Err(GrpcError::CallFailed(Box::new(e))),
381 }
382 }
383
384 pub async fn get_conversation(
386 &self,
387 session_id: &str,
388 ) -> GrpcResult<(Vec<Message>, Vec<String>)> {
389 info!(
390 "Client adapter getting conversation for session: {}",
391 session_id
392 );
393
394 let mut stream = self
395 .client
396 .lock()
397 .await
398 .get_conversation(GetConversationRequest {
399 session_id: session_id.to_string(),
400 })
401 .await
402 .map_err(Box::new)?
403 .into_inner();
404
405 let mut messages = Vec::new();
406 let mut approved_tools = Vec::new();
407
408 while let Some(response) = stream
409 .message()
410 .await
411 .map_err(|e| GrpcError::CallFailed(Box::new(e)))?
412 {
413 match response.chunk {
414 Some(proto::get_conversation_response::Chunk::Message(proto_msg)) => {
415 match proto_to_message(proto_msg) {
416 Ok(msg) => messages.push(msg),
417 Err(e) => {
418 warn!("Failed to convert message: {}", e);
419 return Err(GrpcError::ConversionError(e));
420 }
421 }
422 }
423 Some(proto::get_conversation_response::Chunk::Footer(footer)) => {
424 approved_tools = footer.approved_tools;
425 }
426 None => {}
427 }
428 }
429
430 info!(
431 "Successfully converted {} messages from GetConversation response",
432 messages.len()
433 );
434
435 Ok((messages, approved_tools))
436 }
437
438 pub async fn shutdown(self) {
440 if let Some(handle) = self.stream_handle.lock().await.take() {
441 handle.abort();
442 let _ = handle.await;
443 }
444
445 if let Some(session_id) = &*self.session_id.lock().await {
446 info!("GrpcClientAdapter shut down for session: {}", session_id);
447 }
448 }
449
450 pub async fn get_mcp_servers(&self) -> GrpcResult<Vec<McpServerInfo>> {
451 let session_id = self
452 .session_id
453 .lock()
454 .await
455 .as_ref()
456 .cloned()
457 .ok_or_else(|| GrpcError::InvalidSessionState {
458 reason: "No active session".to_string(),
459 })?;
460
461 let request = Request::new(GetMcpServersRequest {
462 session_id: session_id.clone(),
463 });
464
465 let response = self
466 .client
467 .lock()
468 .await
469 .get_mcp_servers(request)
470 .await
471 .map_err(Box::new)?;
472
473 let servers = response
474 .into_inner()
475 .servers
476 .into_iter()
477 .filter_map(|s| proto_to_mcp_server_info(s).ok())
478 .collect();
479
480 Ok(servers)
481 }
482}
483
484#[async_trait]
485impl AppCommandSink for AgentClient {
486 async fn send_command(&self, command: AppCommand) -> Result<()> {
487 self.send_command(command)
488 .await
489 .map_err(|e| steer_core::error::Error::InvalidOperation(e.to_string()))
490 }
491}
492
493#[async_trait]
494impl AppEventSource for AgentClient {
495 async fn subscribe(&self) -> mpsc::Receiver<AppEvent> {
496 self.event_rx.lock().await.take().expect(
499 "Event receiver already taken - GrpcClientAdapter only supports single subscription",
500 )
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::grpc::conversions::tool_approval_policy_to_proto;
508 use steer_core::session::ToolApprovalPolicy;
509 use steer_proto::agent::v1::tool_approval_policy::Policy;
510
511 #[test]
512 fn test_convert_tool_approval_policy() {
513 let policy = ToolApprovalPolicy::AlwaysAsk;
514 let proto_policy = tool_approval_policy_to_proto(&policy);
515 assert!(matches!(proto_policy.policy, Some(Policy::AlwaysAsk(_))));
516
517 let mut tools = std::collections::HashSet::new();
518 tools.insert("bash".to_string());
519 let policy = ToolApprovalPolicy::PreApproved { tools };
520 let proto_policy = tool_approval_policy_to_proto(&policy);
521 assert!(matches!(proto_policy.policy, Some(Policy::PreApproved(_))));
522 }
523
524 #[test]
525 fn test_convert_app_command_to_client_message() {
526 let session_id = "test-session";
527
528 let command = AppCommand::ProcessUserInput("Hello".to_string());
529 let result = convert_app_command_to_client_message(command, session_id).unwrap();
530 assert!(result.is_some());
531
532 let command = AppCommand::Shutdown;
533 let result = convert_app_command_to_client_message(command, session_id).unwrap();
534 assert!(result.is_none());
535 }
536}