1use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12
13use crate::errors::AgentProtocolError;
14use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
15use crate::protocol::{
16 AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
17 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
18 ResponseHeadersEvent, MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
19};
20
21pub struct AgentClient {
23 id: String,
25 connection: AgentConnection,
27 timeout: Duration,
29 #[allow(dead_code)]
31 max_retries: u32,
32}
33
34enum AgentConnection {
36 UnixSocket(UnixStream),
37 Grpc(AgentProcessorClient<Channel>),
38}
39
40impl AgentClient {
41 pub async fn unix_socket(
43 id: impl Into<String>,
44 path: impl AsRef<std::path::Path>,
45 timeout: Duration,
46 ) -> Result<Self, AgentProtocolError> {
47 let stream = UnixStream::connect(path.as_ref())
48 .await
49 .map_err(|e| AgentProtocolError::ConnectionFailed(e.to_string()))?;
50
51 Ok(Self {
52 id: id.into(),
53 connection: AgentConnection::UnixSocket(stream),
54 timeout,
55 max_retries: 3,
56 })
57 }
58
59 pub async fn grpc(
66 id: impl Into<String>,
67 address: impl Into<String>,
68 timeout: Duration,
69 ) -> Result<Self, AgentProtocolError> {
70 let address = address.into();
71 let channel = Channel::from_shared(address.clone())
72 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e)))?
73 .timeout(timeout)
74 .connect()
75 .await
76 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e)))?;
77
78 let client = AgentProcessorClient::new(channel);
79
80 Ok(Self {
81 id: id.into(),
82 connection: AgentConnection::Grpc(client),
83 timeout,
84 max_retries: 3,
85 })
86 }
87
88 #[allow(dead_code)]
90 pub fn id(&self) -> &str {
91 &self.id
92 }
93
94 pub async fn send_event(
96 &mut self,
97 event_type: EventType,
98 payload: impl Serialize,
99 ) -> Result<AgentResponse, AgentProtocolError> {
100 match &mut self.connection {
101 AgentConnection::UnixSocket(_) => {
102 self.send_event_unix_socket(event_type, payload).await
103 }
104 AgentConnection::Grpc(_) => {
105 self.send_event_grpc(event_type, payload).await
106 }
107 }
108 }
109
110 async fn send_event_unix_socket(
112 &mut self,
113 event_type: EventType,
114 payload: impl Serialize,
115 ) -> Result<AgentResponse, AgentProtocolError> {
116 let request = AgentRequest {
117 version: PROTOCOL_VERSION,
118 event_type,
119 payload: serde_json::to_value(payload)
120 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
121 };
122
123 let request_bytes = serde_json::to_vec(&request)
125 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
126
127 if request_bytes.len() > MAX_MESSAGE_SIZE {
129 return Err(AgentProtocolError::MessageTooLarge {
130 size: request_bytes.len(),
131 max: MAX_MESSAGE_SIZE,
132 });
133 }
134
135 let response = tokio::time::timeout(self.timeout, async {
137 self.send_raw_unix(&request_bytes).await?;
138 self.receive_raw_unix().await
139 })
140 .await
141 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
142
143 let agent_response: AgentResponse = serde_json::from_slice(&response)
145 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
146
147 if agent_response.version != PROTOCOL_VERSION {
149 return Err(AgentProtocolError::VersionMismatch {
150 expected: PROTOCOL_VERSION,
151 actual: agent_response.version,
152 });
153 }
154
155 Ok(agent_response)
156 }
157
158 async fn send_event_grpc(
160 &mut self,
161 event_type: EventType,
162 payload: impl Serialize,
163 ) -> Result<AgentResponse, AgentProtocolError> {
164 let grpc_request = Self::build_grpc_request(event_type, payload)?;
166
167 let AgentConnection::Grpc(client) = &mut self.connection else {
168 unreachable!()
169 };
170
171 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
173 .await
174 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
175 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e)))?;
176
177 Self::convert_grpc_response(response.into_inner())
179 }
180
181 fn build_grpc_request(
183 event_type: EventType,
184 payload: impl Serialize,
185 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
186 let payload_json = serde_json::to_value(&payload)
187 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
188
189 let grpc_event_type = match event_type {
190 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
191 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
192 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
193 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
194 EventType::RequestComplete => grpc::EventType::RequestComplete,
195 };
196
197 let event = match event_type {
198 EventType::RequestHeaders => {
199 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
200 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
201 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
202 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
203 method: event.method,
204 uri: event.uri,
205 headers: event.headers.into_iter().map(|(k, v)| {
206 (k, grpc::HeaderValues { values: v })
207 }).collect(),
208 })
209 }
210 EventType::RequestBodyChunk => {
211 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
212 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
213 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
214 correlation_id: event.correlation_id,
215 data: event.data.into_bytes(),
216 is_last: event.is_last,
217 total_size: event.total_size.map(|s| s as u64),
218 })
219 }
220 EventType::ResponseHeaders => {
221 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
222 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
223 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
224 correlation_id: event.correlation_id,
225 status: event.status as u32,
226 headers: event.headers.into_iter().map(|(k, v)| {
227 (k, grpc::HeaderValues { values: v })
228 }).collect(),
229 })
230 }
231 EventType::ResponseBodyChunk => {
232 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
233 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
234 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
235 correlation_id: event.correlation_id,
236 data: event.data.into_bytes(),
237 is_last: event.is_last,
238 total_size: event.total_size.map(|s| s as u64),
239 })
240 }
241 EventType::RequestComplete => {
242 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
243 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
244 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
245 correlation_id: event.correlation_id,
246 status: event.status as u32,
247 duration_ms: event.duration_ms,
248 request_body_size: event.request_body_size as u64,
249 response_body_size: event.response_body_size as u64,
250 upstream_attempts: event.upstream_attempts,
251 error: event.error,
252 })
253 }
254 };
255
256 Ok(grpc::AgentRequest {
257 version: PROTOCOL_VERSION,
258 event_type: grpc_event_type as i32,
259 event: Some(event),
260 })
261 }
262
263 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
265 grpc::RequestMetadata {
266 correlation_id: metadata.correlation_id.clone(),
267 request_id: metadata.request_id.clone(),
268 client_ip: metadata.client_ip.clone(),
269 client_port: metadata.client_port as u32,
270 server_name: metadata.server_name.clone(),
271 protocol: metadata.protocol.clone(),
272 tls_version: metadata.tls_version.clone(),
273 tls_cipher: metadata.tls_cipher.clone(),
274 route_id: metadata.route_id.clone(),
275 upstream_id: metadata.upstream_id.clone(),
276 timestamp: metadata.timestamp.clone(),
277 }
278 }
279
280 fn convert_grpc_response(
282 response: grpc::AgentResponse,
283 ) -> Result<AgentResponse, AgentProtocolError> {
284 let decision = match response.decision {
285 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
286 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
287 status: b.status as u16,
288 body: b.body,
289 headers: if b.headers.is_empty() { None } else { Some(b.headers) },
290 },
291 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
292 url: r.url,
293 status: r.status as u16,
294 },
295 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
296 challenge_type: c.challenge_type,
297 params: c.params,
298 },
299 None => Decision::Allow, };
301
302 let request_headers: Vec<HeaderOp> = response.request_headers
303 .into_iter()
304 .filter_map(Self::convert_header_op_from_grpc)
305 .collect();
306
307 let response_headers: Vec<HeaderOp> = response.response_headers
308 .into_iter()
309 .filter_map(Self::convert_header_op_from_grpc)
310 .collect();
311
312 let audit = response.audit.map(|a| AuditMetadata {
313 tags: a.tags,
314 rule_ids: a.rule_ids,
315 confidence: a.confidence,
316 reason_codes: a.reason_codes,
317 custom: a.custom.into_iter().map(|(k, v)| {
318 (k, serde_json::Value::String(v))
319 }).collect(),
320 });
321
322 Ok(AgentResponse {
323 version: response.version,
324 decision,
325 request_headers,
326 response_headers,
327 routing_metadata: response.routing_metadata,
328 audit: audit.unwrap_or_default(),
329 })
330 }
331
332 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
334 match op.operation? {
335 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
336 name: s.name,
337 value: s.value,
338 }),
339 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
340 name: a.name,
341 value: a.value,
342 }),
343 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove {
344 name: r.name,
345 }),
346 }
347 }
348
349 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
351 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
352 unreachable!()
353 };
354 let len_bytes = (data.len() as u32).to_be_bytes();
356 stream.write_all(&len_bytes).await?;
357 stream.write_all(data).await?;
359 stream.flush().await?;
360 Ok(())
361 }
362
363 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
365 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
366 unreachable!()
367 };
368 let mut len_bytes = [0u8; 4];
370 stream.read_exact(&mut len_bytes).await?;
371 let message_len = u32::from_be_bytes(len_bytes) as usize;
372
373 if message_len > MAX_MESSAGE_SIZE {
375 return Err(AgentProtocolError::MessageTooLarge {
376 size: message_len,
377 max: MAX_MESSAGE_SIZE,
378 });
379 }
380
381 let mut buffer = vec![0u8; message_len];
383 stream.read_exact(&mut buffer).await?;
384 Ok(buffer)
385 }
386
387 pub async fn close(self) -> Result<(), AgentProtocolError> {
389 match self.connection {
390 AgentConnection::UnixSocket(mut stream) => {
391 stream.shutdown().await?;
392 Ok(())
393 }
394 AgentConnection::Grpc(_) => Ok(()), }
396 }
397}