1use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17 AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
18 RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
19 ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
20 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
21};
22
23pub struct AgentClient {
25 id: String,
27 connection: AgentConnection,
29 timeout: Duration,
31 #[allow(dead_code)]
33 max_retries: u32,
34}
35
36enum AgentConnection {
38 UnixSocket(UnixStream),
39 Grpc(AgentProcessorClient<Channel>),
40}
41
42impl AgentClient {
43 pub async fn unix_socket(
45 id: impl Into<String>,
46 path: impl AsRef<std::path::Path>,
47 timeout: Duration,
48 ) -> Result<Self, AgentProtocolError> {
49 let id = id.into();
50 let path = path.as_ref();
51
52 trace!(
53 agent_id = %id,
54 socket_path = %path.display(),
55 timeout_ms = timeout.as_millis() as u64,
56 "Connecting to agent via Unix socket"
57 );
58
59 let stream = UnixStream::connect(path).await.map_err(|e| {
60 error!(
61 agent_id = %id,
62 socket_path = %path.display(),
63 error = %e,
64 "Failed to connect to agent via Unix socket"
65 );
66 AgentProtocolError::ConnectionFailed(e.to_string())
67 })?;
68
69 debug!(
70 agent_id = %id,
71 socket_path = %path.display(),
72 "Connected to agent via Unix socket"
73 );
74
75 Ok(Self {
76 id,
77 connection: AgentConnection::UnixSocket(stream),
78 timeout,
79 max_retries: 3,
80 })
81 }
82
83 pub async fn grpc(
90 id: impl Into<String>,
91 address: impl Into<String>,
92 timeout: Duration,
93 ) -> Result<Self, AgentProtocolError> {
94 let id = id.into();
95 let address = address.into();
96
97 trace!(
98 agent_id = %id,
99 address = %address,
100 timeout_ms = timeout.as_millis() as u64,
101 "Connecting to agent via gRPC"
102 );
103
104 let channel = Channel::from_shared(address.clone())
105 .map_err(|e| {
106 error!(
107 agent_id = %id,
108 address = %address,
109 error = %e,
110 "Invalid gRPC URI"
111 );
112 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
113 })?
114 .timeout(timeout)
115 .connect()
116 .await
117 .map_err(|e| {
118 error!(
119 agent_id = %id,
120 address = %address,
121 error = %e,
122 "Failed to connect to agent via gRPC"
123 );
124 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
125 })?;
126
127 let client = AgentProcessorClient::new(channel);
128
129 debug!(
130 agent_id = %id,
131 address = %address,
132 "Connected to agent via gRPC"
133 );
134
135 Ok(Self {
136 id,
137 connection: AgentConnection::Grpc(client),
138 timeout,
139 max_retries: 3,
140 })
141 }
142
143 #[allow(dead_code)]
145 pub fn id(&self) -> &str {
146 &self.id
147 }
148
149 pub async fn send_event(
151 &mut self,
152 event_type: EventType,
153 payload: impl Serialize,
154 ) -> Result<AgentResponse, AgentProtocolError> {
155 match &mut self.connection {
156 AgentConnection::UnixSocket(_) => {
157 self.send_event_unix_socket(event_type, payload).await
158 }
159 AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
160 }
161 }
162
163 async fn send_event_unix_socket(
165 &mut self,
166 event_type: EventType,
167 payload: impl Serialize,
168 ) -> Result<AgentResponse, AgentProtocolError> {
169 let request = AgentRequest {
170 version: PROTOCOL_VERSION,
171 event_type,
172 payload: serde_json::to_value(payload)
173 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
174 };
175
176 let request_bytes = serde_json::to_vec(&request)
178 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
179
180 if request_bytes.len() > MAX_MESSAGE_SIZE {
182 return Err(AgentProtocolError::MessageTooLarge {
183 size: request_bytes.len(),
184 max: MAX_MESSAGE_SIZE,
185 });
186 }
187
188 let response = tokio::time::timeout(self.timeout, async {
190 self.send_raw_unix(&request_bytes).await?;
191 self.receive_raw_unix().await
192 })
193 .await
194 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
195
196 let agent_response: AgentResponse = serde_json::from_slice(&response)
198 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
199
200 if agent_response.version != PROTOCOL_VERSION {
202 return Err(AgentProtocolError::VersionMismatch {
203 expected: PROTOCOL_VERSION,
204 actual: agent_response.version,
205 });
206 }
207
208 Ok(agent_response)
209 }
210
211 async fn send_event_grpc(
213 &mut self,
214 event_type: EventType,
215 payload: impl Serialize,
216 ) -> Result<AgentResponse, AgentProtocolError> {
217 let grpc_request = Self::build_grpc_request(event_type, payload)?;
219
220 let AgentConnection::Grpc(client) = &mut self.connection else {
221 unreachable!()
222 };
223
224 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
226 .await
227 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
228 .map_err(|e| {
229 AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
230 })?;
231
232 Self::convert_grpc_response(response.into_inner())
234 }
235
236 fn build_grpc_request(
238 event_type: EventType,
239 payload: impl Serialize,
240 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
241 let payload_json = serde_json::to_value(&payload)
242 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
243
244 let grpc_event_type = match event_type {
245 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
246 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
247 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
248 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
249 EventType::RequestComplete => grpc::EventType::RequestComplete,
250 EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
251 };
252
253 let event = match event_type {
254 EventType::RequestHeaders => {
255 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
256 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
258 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
259 method: event.method,
260 uri: event.uri,
261 headers: event
262 .headers
263 .into_iter()
264 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
265 .collect(),
266 })
267 }
268 EventType::RequestBodyChunk => {
269 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
270 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
271 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
272 correlation_id: event.correlation_id,
273 data: event.data.into_bytes(),
274 is_last: event.is_last,
275 total_size: event.total_size.map(|s| s as u64),
276 chunk_index: event.chunk_index,
277 bytes_received: event.bytes_received as u64,
278 })
279 }
280 EventType::ResponseHeaders => {
281 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
282 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
283 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
284 correlation_id: event.correlation_id,
285 status: event.status as u32,
286 headers: event
287 .headers
288 .into_iter()
289 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
290 .collect(),
291 })
292 }
293 EventType::ResponseBodyChunk => {
294 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
295 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
296 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
297 correlation_id: event.correlation_id,
298 data: event.data.into_bytes(),
299 is_last: event.is_last,
300 total_size: event.total_size.map(|s| s as u64),
301 chunk_index: event.chunk_index,
302 bytes_sent: event.bytes_sent as u64,
303 })
304 }
305 EventType::RequestComplete => {
306 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
307 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
308 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
309 correlation_id: event.correlation_id,
310 status: event.status as u32,
311 duration_ms: event.duration_ms,
312 request_body_size: event.request_body_size as u64,
313 response_body_size: event.response_body_size as u64,
314 upstream_attempts: event.upstream_attempts,
315 error: event.error,
316 })
317 }
318 EventType::WebSocketFrame => {
319 use base64::{engine::general_purpose::STANDARD, Engine as _};
320 let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
321 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
322 grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
323 correlation_id: event.correlation_id,
324 opcode: event.opcode,
325 data: STANDARD.decode(&event.data).unwrap_or_default(),
326 client_to_server: event.client_to_server,
327 frame_index: event.frame_index,
328 fin: event.fin,
329 route_id: event.route_id,
330 client_ip: event.client_ip,
331 })
332 }
333 };
334
335 Ok(grpc::AgentRequest {
336 version: PROTOCOL_VERSION,
337 event_type: grpc_event_type as i32,
338 event: Some(event),
339 })
340 }
341
342 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
344 grpc::RequestMetadata {
345 correlation_id: metadata.correlation_id.clone(),
346 request_id: metadata.request_id.clone(),
347 client_ip: metadata.client_ip.clone(),
348 client_port: metadata.client_port as u32,
349 server_name: metadata.server_name.clone(),
350 protocol: metadata.protocol.clone(),
351 tls_version: metadata.tls_version.clone(),
352 tls_cipher: metadata.tls_cipher.clone(),
353 route_id: metadata.route_id.clone(),
354 upstream_id: metadata.upstream_id.clone(),
355 timestamp: metadata.timestamp.clone(),
356 }
357 }
358
359 fn convert_grpc_response(
361 response: grpc::AgentResponse,
362 ) -> Result<AgentResponse, AgentProtocolError> {
363 let decision = match response.decision {
364 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
365 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
366 status: b.status as u16,
367 body: b.body,
368 headers: if b.headers.is_empty() {
369 None
370 } else {
371 Some(b.headers)
372 },
373 },
374 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
375 url: r.url,
376 status: r.status as u16,
377 },
378 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
379 challenge_type: c.challenge_type,
380 params: c.params,
381 },
382 None => Decision::Allow, };
384
385 let request_headers: Vec<HeaderOp> = response
386 .request_headers
387 .into_iter()
388 .filter_map(Self::convert_header_op_from_grpc)
389 .collect();
390
391 let response_headers: Vec<HeaderOp> = response
392 .response_headers
393 .into_iter()
394 .filter_map(Self::convert_header_op_from_grpc)
395 .collect();
396
397 let audit = response.audit.map(|a| AuditMetadata {
398 tags: a.tags,
399 rule_ids: a.rule_ids,
400 confidence: a.confidence,
401 reason_codes: a.reason_codes,
402 custom: a
403 .custom
404 .into_iter()
405 .map(|(k, v)| (k, serde_json::Value::String(v)))
406 .collect(),
407 });
408
409 let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
411 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
412 chunk_index: m.chunk_index,
413 });
414
415 let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
416 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
417 chunk_index: m.chunk_index,
418 });
419
420 let websocket_decision = response
422 .websocket_decision
423 .map(|ws_decision| match ws_decision {
424 grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
425 WebSocketDecision::Allow
426 }
427 grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
428 WebSocketDecision::Drop
429 }
430 grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
431 WebSocketDecision::Close {
432 code: c.code as u16,
433 reason: c.reason,
434 }
435 }
436 });
437
438 Ok(AgentResponse {
439 version: response.version,
440 decision,
441 request_headers,
442 response_headers,
443 routing_metadata: response.routing_metadata,
444 audit: audit.unwrap_or_default(),
445 needs_more: response.needs_more,
446 request_body_mutation,
447 response_body_mutation,
448 websocket_decision,
449 })
450 }
451
452 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
454 match op.operation? {
455 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
456 name: s.name,
457 value: s.value,
458 }),
459 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
460 name: a.name,
461 value: a.value,
462 }),
463 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
464 }
465 }
466
467 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
469 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
470 unreachable!()
471 };
472 let len_bytes = (data.len() as u32).to_be_bytes();
474 stream.write_all(&len_bytes).await?;
475 stream.write_all(data).await?;
477 stream.flush().await?;
478 Ok(())
479 }
480
481 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
483 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
484 unreachable!()
485 };
486 let mut len_bytes = [0u8; 4];
488 stream.read_exact(&mut len_bytes).await?;
489 let message_len = u32::from_be_bytes(len_bytes) as usize;
490
491 if message_len > MAX_MESSAGE_SIZE {
493 return Err(AgentProtocolError::MessageTooLarge {
494 size: message_len,
495 max: MAX_MESSAGE_SIZE,
496 });
497 }
498
499 let mut buffer = vec![0u8; message_len];
501 stream.read_exact(&mut buffer).await?;
502 Ok(buffer)
503 }
504
505 pub async fn close(self) -> Result<(), AgentProtocolError> {
507 match self.connection {
508 AgentConnection::UnixSocket(mut stream) => {
509 stream.shutdown().await?;
510 Ok(())
511 }
512 AgentConnection::Grpc(_) => Ok(()), }
514 }
515}