1use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17 AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
18 RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
19 ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
20 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
21};
22
23pub struct AgentClient {
25 id: String,
27 connection: AgentConnection,
29 timeout: Duration,
31 #[allow(dead_code)]
33 max_retries: u32,
34}
35
36enum AgentConnection {
38 UnixSocket(UnixStream),
39 Grpc(AgentProcessorClient<Channel>),
40}
41
42impl AgentClient {
43 pub async fn unix_socket(
45 id: impl Into<String>,
46 path: impl AsRef<std::path::Path>,
47 timeout: Duration,
48 ) -> Result<Self, AgentProtocolError> {
49 let id = id.into();
50 let path = path.as_ref();
51
52 trace!(
53 agent_id = %id,
54 socket_path = %path.display(),
55 timeout_ms = timeout.as_millis() as u64,
56 "Connecting to agent via Unix socket"
57 );
58
59 let stream = UnixStream::connect(path).await.map_err(|e| {
60 error!(
61 agent_id = %id,
62 socket_path = %path.display(),
63 error = %e,
64 "Failed to connect to agent via Unix socket"
65 );
66 AgentProtocolError::ConnectionFailed(e.to_string())
67 })?;
68
69 debug!(
70 agent_id = %id,
71 socket_path = %path.display(),
72 "Connected to agent via Unix socket"
73 );
74
75 Ok(Self {
76 id,
77 connection: AgentConnection::UnixSocket(stream),
78 timeout,
79 max_retries: 3,
80 })
81 }
82
83 pub async fn grpc(
90 id: impl Into<String>,
91 address: impl Into<String>,
92 timeout: Duration,
93 ) -> Result<Self, AgentProtocolError> {
94 let id = id.into();
95 let address = address.into();
96
97 trace!(
98 agent_id = %id,
99 address = %address,
100 timeout_ms = timeout.as_millis() as u64,
101 "Connecting to agent via gRPC"
102 );
103
104 let channel = Channel::from_shared(address.clone())
105 .map_err(|e| {
106 error!(
107 agent_id = %id,
108 address = %address,
109 error = %e,
110 "Invalid gRPC URI"
111 );
112 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
113 })?
114 .timeout(timeout)
115 .connect()
116 .await
117 .map_err(|e| {
118 error!(
119 agent_id = %id,
120 address = %address,
121 error = %e,
122 "Failed to connect to agent via gRPC"
123 );
124 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
125 })?;
126
127 let client = AgentProcessorClient::new(channel);
128
129 debug!(
130 agent_id = %id,
131 address = %address,
132 "Connected to agent via gRPC"
133 );
134
135 Ok(Self {
136 id,
137 connection: AgentConnection::Grpc(client),
138 timeout,
139 max_retries: 3,
140 })
141 }
142
143 #[allow(dead_code)]
145 pub fn id(&self) -> &str {
146 &self.id
147 }
148
149 pub async fn send_event(
151 &mut self,
152 event_type: EventType,
153 payload: impl Serialize,
154 ) -> Result<AgentResponse, AgentProtocolError> {
155 match &mut self.connection {
156 AgentConnection::UnixSocket(_) => {
157 self.send_event_unix_socket(event_type, payload).await
158 }
159 AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
160 }
161 }
162
163 async fn send_event_unix_socket(
165 &mut self,
166 event_type: EventType,
167 payload: impl Serialize,
168 ) -> Result<AgentResponse, AgentProtocolError> {
169 let request = AgentRequest {
170 version: PROTOCOL_VERSION,
171 event_type,
172 payload: serde_json::to_value(payload)
173 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
174 };
175
176 let request_bytes = serde_json::to_vec(&request)
178 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
179
180 if request_bytes.len() > MAX_MESSAGE_SIZE {
182 return Err(AgentProtocolError::MessageTooLarge {
183 size: request_bytes.len(),
184 max: MAX_MESSAGE_SIZE,
185 });
186 }
187
188 let response = tokio::time::timeout(self.timeout, async {
190 self.send_raw_unix(&request_bytes).await?;
191 self.receive_raw_unix().await
192 })
193 .await
194 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
195
196 let agent_response: AgentResponse = serde_json::from_slice(&response)
198 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
199
200 if agent_response.version != PROTOCOL_VERSION {
202 return Err(AgentProtocolError::VersionMismatch {
203 expected: PROTOCOL_VERSION,
204 actual: agent_response.version,
205 });
206 }
207
208 Ok(agent_response)
209 }
210
211 async fn send_event_grpc(
213 &mut self,
214 event_type: EventType,
215 payload: impl Serialize,
216 ) -> Result<AgentResponse, AgentProtocolError> {
217 let grpc_request = Self::build_grpc_request(event_type, payload)?;
219
220 let AgentConnection::Grpc(client) = &mut self.connection else {
221 unreachable!()
222 };
223
224 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
226 .await
227 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
228 .map_err(|e| {
229 AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
230 })?;
231
232 Self::convert_grpc_response(response.into_inner())
234 }
235
236 fn build_grpc_request(
238 event_type: EventType,
239 payload: impl Serialize,
240 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
241 let payload_json = serde_json::to_value(&payload)
242 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
243
244 let grpc_event_type = match event_type {
245 EventType::Configure => {
246 return Err(AgentProtocolError::Serialization(
247 "Configure events are not supported via gRPC".to_string(),
248 ))
249 }
250 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
251 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
252 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
253 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
254 EventType::RequestComplete => grpc::EventType::RequestComplete,
255 EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
256 };
257
258 let event = match event_type {
259 EventType::Configure => unreachable!("Configure handled above"),
260 EventType::RequestHeaders => {
261 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
262 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
263 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
264 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
265 method: event.method,
266 uri: event.uri,
267 headers: event
268 .headers
269 .into_iter()
270 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
271 .collect(),
272 })
273 }
274 EventType::RequestBodyChunk => {
275 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
276 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
277 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
278 correlation_id: event.correlation_id,
279 data: event.data.into_bytes(),
280 is_last: event.is_last,
281 total_size: event.total_size.map(|s| s as u64),
282 chunk_index: event.chunk_index,
283 bytes_received: event.bytes_received as u64,
284 })
285 }
286 EventType::ResponseHeaders => {
287 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
288 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
289 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
290 correlation_id: event.correlation_id,
291 status: event.status as u32,
292 headers: event
293 .headers
294 .into_iter()
295 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
296 .collect(),
297 })
298 }
299 EventType::ResponseBodyChunk => {
300 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
301 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
302 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
303 correlation_id: event.correlation_id,
304 data: event.data.into_bytes(),
305 is_last: event.is_last,
306 total_size: event.total_size.map(|s| s as u64),
307 chunk_index: event.chunk_index,
308 bytes_sent: event.bytes_sent as u64,
309 })
310 }
311 EventType::RequestComplete => {
312 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
313 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
314 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
315 correlation_id: event.correlation_id,
316 status: event.status as u32,
317 duration_ms: event.duration_ms,
318 request_body_size: event.request_body_size as u64,
319 response_body_size: event.response_body_size as u64,
320 upstream_attempts: event.upstream_attempts,
321 error: event.error,
322 })
323 }
324 EventType::WebSocketFrame => {
325 use base64::{engine::general_purpose::STANDARD, Engine as _};
326 let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
327 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
328 grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
329 correlation_id: event.correlation_id,
330 opcode: event.opcode,
331 data: STANDARD.decode(&event.data).unwrap_or_default(),
332 client_to_server: event.client_to_server,
333 frame_index: event.frame_index,
334 fin: event.fin,
335 route_id: event.route_id,
336 client_ip: event.client_ip,
337 })
338 }
339 };
340
341 Ok(grpc::AgentRequest {
342 version: PROTOCOL_VERSION,
343 event_type: grpc_event_type as i32,
344 event: Some(event),
345 })
346 }
347
348 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
350 grpc::RequestMetadata {
351 correlation_id: metadata.correlation_id.clone(),
352 request_id: metadata.request_id.clone(),
353 client_ip: metadata.client_ip.clone(),
354 client_port: metadata.client_port as u32,
355 server_name: metadata.server_name.clone(),
356 protocol: metadata.protocol.clone(),
357 tls_version: metadata.tls_version.clone(),
358 tls_cipher: metadata.tls_cipher.clone(),
359 route_id: metadata.route_id.clone(),
360 upstream_id: metadata.upstream_id.clone(),
361 timestamp: metadata.timestamp.clone(),
362 traceparent: metadata.traceparent.clone(),
363 }
364 }
365
366 fn convert_grpc_response(
368 response: grpc::AgentResponse,
369 ) -> Result<AgentResponse, AgentProtocolError> {
370 let decision = match response.decision {
371 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
372 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
373 status: b.status as u16,
374 body: b.body,
375 headers: if b.headers.is_empty() {
376 None
377 } else {
378 Some(b.headers)
379 },
380 },
381 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
382 url: r.url,
383 status: r.status as u16,
384 },
385 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
386 challenge_type: c.challenge_type,
387 params: c.params,
388 },
389 None => Decision::Allow, };
391
392 let request_headers: Vec<HeaderOp> = response
393 .request_headers
394 .into_iter()
395 .filter_map(Self::convert_header_op_from_grpc)
396 .collect();
397
398 let response_headers: Vec<HeaderOp> = response
399 .response_headers
400 .into_iter()
401 .filter_map(Self::convert_header_op_from_grpc)
402 .collect();
403
404 let audit = response.audit.map(|a| AuditMetadata {
405 tags: a.tags,
406 rule_ids: a.rule_ids,
407 confidence: a.confidence,
408 reason_codes: a.reason_codes,
409 custom: a
410 .custom
411 .into_iter()
412 .map(|(k, v)| (k, serde_json::Value::String(v)))
413 .collect(),
414 });
415
416 let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
418 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
419 chunk_index: m.chunk_index,
420 });
421
422 let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
423 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
424 chunk_index: m.chunk_index,
425 });
426
427 let websocket_decision = response
429 .websocket_decision
430 .map(|ws_decision| match ws_decision {
431 grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
432 WebSocketDecision::Allow
433 }
434 grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
435 WebSocketDecision::Drop
436 }
437 grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
438 WebSocketDecision::Close {
439 code: c.code as u16,
440 reason: c.reason,
441 }
442 }
443 });
444
445 Ok(AgentResponse {
446 version: response.version,
447 decision,
448 request_headers,
449 response_headers,
450 routing_metadata: response.routing_metadata,
451 audit: audit.unwrap_or_default(),
452 needs_more: response.needs_more,
453 request_body_mutation,
454 response_body_mutation,
455 websocket_decision,
456 })
457 }
458
459 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
461 match op.operation? {
462 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
463 name: s.name,
464 value: s.value,
465 }),
466 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
467 name: a.name,
468 value: a.value,
469 }),
470 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
471 }
472 }
473
474 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
476 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
477 unreachable!()
478 };
479 let len_bytes = (data.len() as u32).to_be_bytes();
481 stream.write_all(&len_bytes).await?;
482 stream.write_all(data).await?;
484 stream.flush().await?;
485 Ok(())
486 }
487
488 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
490 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
491 unreachable!()
492 };
493 let mut len_bytes = [0u8; 4];
495 stream.read_exact(&mut len_bytes).await?;
496 let message_len = u32::from_be_bytes(len_bytes) as usize;
497
498 if message_len > MAX_MESSAGE_SIZE {
500 return Err(AgentProtocolError::MessageTooLarge {
501 size: message_len,
502 max: MAX_MESSAGE_SIZE,
503 });
504 }
505
506 let mut buffer = vec![0u8; message_len];
508 stream.read_exact(&mut buffer).await?;
509 Ok(buffer)
510 }
511
512 pub async fn close(self) -> Result<(), AgentProtocolError> {
514 match self.connection {
515 AgentConnection::UnixSocket(mut stream) => {
516 stream.shutdown().await?;
517 Ok(())
518 }
519 AgentConnection::Grpc(_) => Ok(()), }
521 }
522}