1use async_trait::async_trait;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::grpc_v2::{
16 self, agent_service_v2_server::AgentServiceV2, agent_service_v2_server::AgentServiceV2Server,
17 AgentToProxy, ProxyToAgent,
18};
19use crate::v2::pool::CHANNEL_BUFFER_SIZE;
20use crate::v2::{
21 AgentCapabilities, HandshakeRequest, HandshakeResponse, HealthStatus,
22};
23use crate::{
24 AgentResponse as V1Response, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
25 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
26 ResponseHeadersEvent, WebSocketFrameEvent,
27};
28
29#[async_trait]
38pub trait AgentHandlerV2: Send + Sync {
39 fn capabilities(&self) -> AgentCapabilities;
41
42 async fn on_handshake(&self, _request: HandshakeRequest) -> HandshakeResponse {
44 HandshakeResponse::success(self.capabilities())
46 }
47
48 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> V1Response {
50 V1Response::default_allow()
51 }
52
53 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> V1Response {
55 V1Response::default_allow()
56 }
57
58 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> V1Response {
60 V1Response::default_allow()
61 }
62
63 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> V1Response {
65 V1Response::default_allow()
66 }
67
68 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> V1Response {
70 V1Response::default_allow()
71 }
72
73 async fn on_websocket_frame(&self, _event: WebSocketFrameEvent) -> V1Response {
75 V1Response::websocket_allow()
76 }
77
78 fn health_status(&self) -> HealthStatus {
80 HealthStatus::healthy(self.capabilities().agent_id.clone())
81 }
82
83 fn metrics_report(&self) -> Option<crate::v2::MetricsReport> {
85 None
86 }
87
88 async fn on_configure(&self, _config: serde_json::Value, _version: Option<String>) -> bool {
90 true
92 }
93
94 async fn on_shutdown(&self, _reason: ShutdownReason, _grace_period_ms: u64) {
96 }
98
99 async fn on_drain(&self, _duration_ms: u64, _reason: DrainReason) {
101 }
103
104 async fn on_stream_closed(&self) {}
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum ShutdownReason {
111 Graceful,
112 Immediate,
113 ConfigReload,
114 Upgrade,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum DrainReason {
120 ConfigReload,
121 Maintenance,
122 HealthCheckFailed,
123 Manual,
124}
125
126pub struct GrpcAgentServerV2 {
128 id: String,
129 handler: Arc<dyn AgentHandlerV2>,
130}
131
132impl GrpcAgentServerV2 {
133 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandlerV2>) -> Self {
135 let id = id.into();
136 debug!(agent_id = %id, "Creating gRPC agent server v2");
137 Self {
138 id,
139 handler: Arc::from(handler),
140 }
141 }
142
143 pub fn into_service(self) -> AgentServiceV2Server<GrpcAgentHandlerV2> {
145 trace!(agent_id = %self.id, "Converting to tonic v2 service");
146 AgentServiceV2Server::new(GrpcAgentHandlerV2 {
147 id: self.id,
148 handler: self.handler,
149 })
150 }
151
152 pub async fn run(self, addr: std::net::SocketAddr) -> Result<(), crate::AgentProtocolError> {
154 info!(
155 agent_id = %self.id,
156 address = %addr,
157 "gRPC agent server v2 listening"
158 );
159
160 tonic::transport::Server::builder()
161 .add_service(self.into_service())
162 .serve(addr)
163 .await
164 .map_err(|e| {
165 error!(error = %e, "gRPC v2 server error");
166 crate::AgentProtocolError::ConnectionFailed(format!("gRPC v2 server error: {}", e))
167 })
168 }
169}
170
171pub struct GrpcAgentHandlerV2 {
173 id: String,
174 handler: Arc<dyn AgentHandlerV2>,
175}
176
177type ProcessResponseStream = Pin<Box<dyn Stream<Item = Result<AgentToProxy, Status>> + Send>>;
178type ControlResponseStream = Pin<Box<dyn Stream<Item = Result<grpc_v2::ProxyControl, Status>> + Send>>;
179
180#[tonic::async_trait]
181impl AgentServiceV2 for GrpcAgentHandlerV2 {
182 type ProcessStreamStream = ProcessResponseStream;
183 type ControlStreamStream = ControlResponseStream;
184
185 async fn process_stream(
187 &self,
188 request: Request<Streaming<ProxyToAgent>>,
189 ) -> Result<Response<Self::ProcessStreamStream>, Status> {
190 let mut inbound = request.into_inner();
191 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
192 let handler = Arc::clone(&self.handler);
193 let agent_id = self.id.clone();
194
195 debug!(agent_id = %agent_id, "Starting v2 process stream");
196
197 tokio::spawn(async move {
198 let mut handshake_done = false;
199
200 while let Some(result) = inbound.next().await {
201 let msg = match result {
202 Ok(m) => m,
203 Err(e) => {
204 error!(agent_id = %agent_id, error = %e, "Stream error");
205 break;
206 }
207 };
208
209 let response = match msg.message {
210 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
211 trace!(agent_id = %agent_id, "Processing handshake");
212 let handshake_req = convert_handshake_request(req);
213 let resp = handler.on_handshake(handshake_req).await;
214 handshake_done = resp.success;
215 Some(AgentToProxy {
216 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
217 convert_handshake_response(resp),
218 )),
219 })
220 }
221 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
222 if !handshake_done {
223 warn!(agent_id = %agent_id, "Received event before handshake");
224 continue;
225 }
226 let event = convert_request_headers_from_grpc(e);
227 let correlation_id = event.metadata.correlation_id.clone();
228 let start = Instant::now();
229 let resp = handler.on_request_headers(event).await;
230 let processing_time_ms = start.elapsed().as_millis() as u64;
231 Some(create_agent_response(correlation_id, resp, processing_time_ms))
232 }
233 Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(e)) => {
234 if !handshake_done {
235 continue;
236 }
237 let event = convert_body_chunk_to_request(e);
238 let correlation_id = event.correlation_id.clone();
239 let start = Instant::now();
240 let resp = handler.on_request_body_chunk(event).await;
241 let processing_time_ms = start.elapsed().as_millis() as u64;
242 Some(create_agent_response(correlation_id, resp, processing_time_ms))
243 }
244 Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(e)) => {
245 if !handshake_done {
246 continue;
247 }
248 let event = convert_response_headers_from_grpc(e);
249 let correlation_id = event.correlation_id.clone();
250 let start = Instant::now();
251 let resp = handler.on_response_headers(event).await;
252 let processing_time_ms = start.elapsed().as_millis() as u64;
253 Some(create_agent_response(correlation_id, resp, processing_time_ms))
254 }
255 Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(e)) => {
256 if !handshake_done {
257 continue;
258 }
259 let event = convert_body_chunk_to_response(e);
260 let correlation_id = event.correlation_id.clone();
261 let start = Instant::now();
262 let resp = handler.on_response_body_chunk(event).await;
263 let processing_time_ms = start.elapsed().as_millis() as u64;
264 Some(create_agent_response(correlation_id, resp, processing_time_ms))
265 }
266 Some(grpc_v2::proxy_to_agent::Message::RequestComplete(e)) => {
267 if !handshake_done {
268 continue;
269 }
270 let event = convert_request_complete_from_grpc(e);
271 let correlation_id = event.correlation_id.clone();
272 let start = Instant::now();
273 let resp = handler.on_request_complete(event).await;
274 let processing_time_ms = start.elapsed().as_millis() as u64;
275 Some(create_agent_response(correlation_id, resp, processing_time_ms))
276 }
277 Some(grpc_v2::proxy_to_agent::Message::WebsocketFrame(e)) => {
278 if !handshake_done {
279 continue;
280 }
281 let event = convert_websocket_frame_from_grpc(e);
282 let correlation_id = event.correlation_id.clone();
283 let start = Instant::now();
284 let resp = handler.on_websocket_frame(event).await;
285 let processing_time_ms = start.elapsed().as_millis() as u64;
286 Some(create_agent_response(correlation_id, resp, processing_time_ms))
287 }
288 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => {
289 trace!(agent_id = %agent_id, sequence = ping.sequence, "Received ping");
290 Some(AgentToProxy {
291 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
292 sequence: ping.sequence,
293 ping_timestamp_ms: ping.timestamp_ms,
294 timestamp_ms: now_ms(),
295 })),
296 })
297 }
298 Some(grpc_v2::proxy_to_agent::Message::Cancel(cancel)) => {
299 debug!(
300 agent_id = %agent_id,
301 correlation_id = %cancel.correlation_id,
302 "Request cancelled"
303 );
304 None
305 }
306 Some(grpc_v2::proxy_to_agent::Message::Configure(_)) => {
307 None
309 }
310 Some(grpc_v2::proxy_to_agent::Message::Guardrail(_)) => {
311 None
313 }
314 None => {
315 warn!(agent_id = %agent_id, "Empty message received");
316 None
317 }
318 };
319
320 if let Some(resp) = response {
321 if tx.send(Ok(resp)).await.is_err() {
322 debug!(agent_id = %agent_id, "Stream closed by receiver");
323 break;
324 }
325 }
326 }
327
328 handler.on_stream_closed().await;
329 debug!(agent_id = %agent_id, "Process stream ended");
330 });
331
332 let output_stream = ReceiverStream::new(rx);
333 Ok(Response::new(Box::pin(output_stream) as Self::ProcessStreamStream))
334 }
335
336 async fn control_stream(
342 &self,
343 request: Request<Streaming<grpc_v2::AgentControl>>,
344 ) -> Result<Response<Self::ControlStreamStream>, Status> {
345 let mut inbound = request.into_inner();
346 let (tx, rx) = mpsc::channel::<Result<grpc_v2::ProxyControl, Status>>(16);
347 let handler = Arc::clone(&self.handler);
348 let agent_id = self.id.clone();
349
350 debug!(agent_id = %agent_id, "Starting v2 control stream");
351
352 let _handler_clone = Arc::clone(&handler);
354 let tx_clone = tx.clone();
355 let agent_id_clone = agent_id.clone();
356 tokio::spawn(async move {
357 while let Some(result) = inbound.next().await {
358 let msg = match result {
359 Ok(m) => m,
360 Err(e) => {
361 error!(agent_id = %agent_id_clone, error = %e, "Control stream error");
362 break;
363 }
364 };
365
366 match msg.message {
370 Some(grpc_v2::agent_control::Message::Health(health)) => {
371 trace!(
372 agent_id = %agent_id_clone,
373 state = health.state,
374 "Received health status from agent"
375 );
376 }
378 Some(grpc_v2::agent_control::Message::Metrics(metrics)) => {
379 trace!(
380 agent_id = %agent_id_clone,
381 counters = metrics.counters.len(),
382 gauges = metrics.gauges.len(),
383 "Received metrics report from agent"
384 );
385 }
387 Some(grpc_v2::agent_control::Message::ConfigUpdate(update)) => {
388 debug!(
389 agent_id = %agent_id_clone,
390 request_id = %update.request_id,
391 "Received config update request from agent"
392 );
393 let response = grpc_v2::ProxyControl {
395 message: Some(grpc_v2::proxy_control::Message::ConfigResponse(
396 grpc_v2::ConfigUpdateResponse {
397 request_id: update.request_id,
398 accepted: true,
399 error: None,
400 timestamp_ms: now_ms(),
401 },
402 )),
403 };
404 if tx_clone.send(Ok(response)).await.is_err() {
405 break;
406 }
407 }
408 Some(grpc_v2::agent_control::Message::Log(log)) => {
409 match log.level {
411 1 => trace!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
412 2 => debug!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
413 3 => warn!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
414 4 => error!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
415 _ => info!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
416 }
417 }
418 None => {
419 warn!(agent_id = %agent_id_clone, "Empty control message received");
420 }
421 }
422 }
423
424 debug!(agent_id = %agent_id_clone, "Control stream inbound handler ended");
425 });
426
427 let capabilities = handler.capabilities();
429 let health_interval_ms = capabilities.health.report_interval_ms;
430 let metrics_enabled = capabilities.features.metrics_export;
431
432 if health_interval_ms > 0 || metrics_enabled {
433 let handler_for_health = Arc::clone(&handler);
434 let tx_for_health = tx;
435 let agent_id_for_health = agent_id.clone();
436
437 tokio::spawn(async move {
438 let health_interval = std::time::Duration::from_millis(health_interval_ms as u64);
439 let mut interval = tokio::time::interval(health_interval);
440
441 loop {
442 interval.tick().await;
443
444 let health = handler_for_health.health_status();
446 let health_msg = grpc_v2::ProxyControl {
447 message: Some(grpc_v2::proxy_control::Message::Configure(
448 grpc_v2::ConfigureEvent {
449 config_json: "{}".to_string(), config_version: None,
451 is_initial: false,
452 timestamp_ms: now_ms(),
453 },
454 )),
455 };
456
457 let _ = health_msg; let _ = health;
462
463 if tx_for_health.is_closed() {
465 debug!(agent_id = %agent_id_for_health, "Control stream closed, stopping health reporter");
466 break;
467 }
468 }
469 });
470 }
471
472 let output_stream = ReceiverStream::new(rx);
473 Ok(Response::new(Box::pin(output_stream) as Self::ControlStreamStream))
474 }
475
476 async fn process_event(
478 &self,
479 request: Request<ProxyToAgent>,
480 ) -> Result<Response<AgentToProxy>, Status> {
481 let msg = request.into_inner();
482
483 trace!(agent_id = %self.id, "Processing single event (v1 compat)");
484
485 let response = match msg.message {
486 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
487 let handshake_req = convert_handshake_request(req);
488 let resp = self.handler.on_handshake(handshake_req).await;
489 AgentToProxy {
490 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
491 convert_handshake_response(resp),
492 )),
493 }
494 }
495 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
496 let event = convert_request_headers_from_grpc(e);
497 let correlation_id = event.metadata.correlation_id.clone();
498 let start = Instant::now();
499 let resp = self.handler.on_request_headers(event).await;
500 let processing_time_ms = start.elapsed().as_millis() as u64;
501 create_agent_response(correlation_id, resp, processing_time_ms)
502 }
503 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => AgentToProxy {
504 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
505 sequence: ping.sequence,
506 ping_timestamp_ms: ping.timestamp_ms,
507 timestamp_ms: now_ms(),
508 })),
509 },
510 _ => {
511 return Err(Status::invalid_argument("Unsupported event type"));
512 }
513 };
514
515 Ok(Response::new(response))
516 }
517}
518
519fn convert_handshake_request(req: grpc_v2::HandshakeRequest) -> HandshakeRequest {
524 HandshakeRequest {
525 supported_versions: req.supported_versions,
526 proxy_id: req.proxy_id,
527 proxy_version: req.proxy_version,
528 config: serde_json::from_str(&req.config_json).unwrap_or(serde_json::Value::Null),
529 }
530}
531
532fn convert_handshake_response(resp: HandshakeResponse) -> grpc_v2::HandshakeResponse {
533 grpc_v2::HandshakeResponse {
534 protocol_version: resp.protocol_version,
535 capabilities: Some(convert_capabilities_to_grpc(&resp.capabilities)),
536 success: resp.success,
537 error: resp.error,
538 }
539}
540
541fn convert_capabilities_to_grpc(caps: &AgentCapabilities) -> grpc_v2::AgentCapabilities {
542 grpc_v2::AgentCapabilities {
543 protocol_version: caps.protocol_version,
544 agent_id: caps.agent_id.clone(),
545 name: caps.name.clone(),
546 version: caps.version.clone(),
547 supported_events: caps
548 .supported_events
549 .iter()
550 .map(|e| event_type_to_i32(*e))
551 .collect(),
552 features: Some(grpc_v2::AgentFeatures {
553 streaming_body: caps.features.streaming_body,
554 websocket: caps.features.websocket,
555 guardrails: caps.features.guardrails,
556 config_push: caps.features.config_push,
557 metrics_export: caps.features.metrics_export,
558 concurrent_requests: caps.features.concurrent_requests,
559 cancellation: caps.features.cancellation,
560 flow_control: caps.features.flow_control,
561 health_reporting: caps.features.health_reporting,
562 }),
563 limits: Some(grpc_v2::AgentLimits {
564 max_body_size: caps.limits.max_body_size as u64,
565 max_concurrency: caps.limits.max_concurrency,
566 preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
567 max_memory: caps.limits.max_memory.map(|m| m as u64),
568 max_processing_time_ms: caps.limits.max_processing_time_ms,
569 }),
570 health_config: Some(grpc_v2::HealthConfig {
571 report_interval_ms: caps.health.report_interval_ms,
572 include_load_metrics: caps.health.include_load_metrics,
573 include_resource_metrics: caps.health.include_resource_metrics,
574 }),
575 }
576}
577
578fn event_type_to_i32(event_type: EventType) -> i32 {
579 match event_type {
580 EventType::Configure => 8,
581 EventType::RequestHeaders => 1,
582 EventType::RequestBodyChunk => 2,
583 EventType::ResponseHeaders => 3,
584 EventType::ResponseBodyChunk => 4,
585 EventType::RequestComplete => 5,
586 EventType::WebSocketFrame => 6,
587 EventType::GuardrailInspect => 7,
588 }
589}
590
591fn convert_request_headers_from_grpc(e: grpc_v2::RequestHeadersEvent) -> RequestHeadersEvent {
592 let metadata = match e.metadata {
593 Some(m) => RequestMetadata {
594 correlation_id: m.correlation_id,
595 request_id: m.request_id,
596 client_ip: m.client_ip,
597 client_port: m.client_port as u16,
598 server_name: m.server_name,
599 protocol: m.protocol,
600 tls_version: m.tls_version,
601 tls_cipher: None,
602 route_id: m.route_id,
603 upstream_id: m.upstream_id,
604 timestamp: format!("{}", m.timestamp_ms),
605 traceparent: m.traceparent,
606 },
607 None => RequestMetadata {
608 correlation_id: String::new(),
609 request_id: String::new(),
610 client_ip: String::new(),
611 client_port: 0,
612 server_name: None,
613 protocol: String::new(),
614 tls_version: None,
615 tls_cipher: None,
616 route_id: None,
617 upstream_id: None,
618 timestamp: String::new(),
619 traceparent: None,
620 },
621 };
622
623 let headers = e
624 .headers
625 .into_iter()
626 .fold(std::collections::HashMap::new(), |mut map, h| {
627 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
628 map
629 });
630
631 RequestHeadersEvent {
632 metadata,
633 method: e.method,
634 uri: e.uri,
635 headers,
636 }
637}
638
639fn convert_body_chunk_to_request(e: grpc_v2::BodyChunkEvent) -> RequestBodyChunkEvent {
640 use base64::{engine::general_purpose::STANDARD, Engine as _};
641 RequestBodyChunkEvent {
642 correlation_id: e.correlation_id,
643 data: STANDARD.encode(&e.data),
644 is_last: e.is_last,
645 total_size: e.total_size.map(|s| s as usize),
646 chunk_index: e.chunk_index,
647 bytes_received: e.bytes_transferred as usize,
648 }
649}
650
651fn convert_body_chunk_to_response(e: grpc_v2::BodyChunkEvent) -> ResponseBodyChunkEvent {
652 use base64::{engine::general_purpose::STANDARD, Engine as _};
653 ResponseBodyChunkEvent {
654 correlation_id: e.correlation_id,
655 data: STANDARD.encode(&e.data),
656 is_last: e.is_last,
657 total_size: e.total_size.map(|s| s as usize),
658 chunk_index: e.chunk_index,
659 bytes_sent: e.bytes_transferred as usize,
660 }
661}
662
663fn convert_response_headers_from_grpc(e: grpc_v2::ResponseHeadersEvent) -> ResponseHeadersEvent {
664 let headers = e
665 .headers
666 .into_iter()
667 .fold(std::collections::HashMap::new(), |mut map, h| {
668 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
669 map
670 });
671
672 ResponseHeadersEvent {
673 correlation_id: e.correlation_id,
674 status: e.status_code as u16,
675 headers,
676 }
677}
678
679fn convert_request_complete_from_grpc(e: grpc_v2::RequestCompleteEvent) -> RequestCompleteEvent {
680 RequestCompleteEvent {
681 correlation_id: e.correlation_id,
682 status: e.status_code as u16,
683 duration_ms: e.duration_ms,
684 request_body_size: e.bytes_received as usize,
685 response_body_size: e.bytes_sent as usize,
686 upstream_attempts: 1,
687 error: e.error,
688 }
689}
690
691fn convert_websocket_frame_from_grpc(e: grpc_v2::WebSocketFrameEvent) -> WebSocketFrameEvent {
692 use base64::{engine::general_purpose::STANDARD, Engine as _};
693 WebSocketFrameEvent {
694 correlation_id: e.correlation_id,
695 opcode: format!("{}", e.frame_type),
696 data: STANDARD.encode(&e.payload),
697 client_to_server: e.client_to_server,
698 frame_index: 0,
699 fin: true,
700 route_id: None,
701 client_ip: String::new(),
702 }
703}
704
705fn create_agent_response(
706 correlation_id: String,
707 resp: V1Response,
708 processing_time_ms: u64,
709) -> AgentToProxy {
710 let decision = match resp.decision {
711 Decision::Allow => Some(grpc_v2::agent_response::Decision::Allow(
712 grpc_v2::AllowDecision {},
713 )),
714 Decision::Block { status, body, headers } => {
715 Some(grpc_v2::agent_response::Decision::Block(
716 grpc_v2::BlockDecision {
717 status: status as u32,
718 body,
719 headers: headers
720 .unwrap_or_default()
721 .into_iter()
722 .map(|(k, v)| grpc_v2::Header { name: k, value: v })
723 .collect(),
724 },
725 ))
726 }
727 Decision::Redirect { url, status } => Some(grpc_v2::agent_response::Decision::Redirect(
728 grpc_v2::RedirectDecision {
729 url,
730 status: status as u32,
731 },
732 )),
733 Decision::Challenge { challenge_type, params } => {
734 Some(grpc_v2::agent_response::Decision::Challenge(
735 grpc_v2::ChallengeDecision {
736 challenge_type,
737 params,
738 },
739 ))
740 }
741 };
742
743 let request_headers: Vec<grpc_v2::HeaderOp> = resp
744 .request_headers
745 .into_iter()
746 .map(convert_header_op_to_grpc)
747 .collect();
748
749 let response_headers: Vec<grpc_v2::HeaderOp> = resp
750 .response_headers
751 .into_iter()
752 .map(convert_header_op_to_grpc)
753 .collect();
754
755 let audit = Some(grpc_v2::AuditMetadata {
756 tags: resp.audit.tags,
757 rule_ids: resp.audit.rule_ids,
758 confidence: resp.audit.confidence,
759 reason_codes: resp.audit.reason_codes,
760 custom: resp
761 .audit
762 .custom
763 .into_iter()
764 .map(|(k, v)| (k, v.to_string()))
765 .collect(),
766 });
767
768 AgentToProxy {
769 message: Some(grpc_v2::agent_to_proxy::Message::Response(
770 grpc_v2::AgentResponse {
771 correlation_id,
772 decision,
773 request_headers,
774 response_headers,
775 audit,
776 processing_time_ms: Some(processing_time_ms),
777 needs_more: resp.needs_more,
778 },
779 )),
780 }
781}
782
783fn convert_header_op_to_grpc(op: HeaderOp) -> grpc_v2::HeaderOp {
784 let operation = match op {
785 HeaderOp::Set { name, value } => {
786 Some(grpc_v2::header_op::Operation::Set(grpc_v2::Header {
787 name,
788 value,
789 }))
790 }
791 HeaderOp::Add { name, value } => {
792 Some(grpc_v2::header_op::Operation::Add(grpc_v2::Header {
793 name,
794 value,
795 }))
796 }
797 HeaderOp::Remove { name } => Some(grpc_v2::header_op::Operation::Remove(name)),
798 };
799 grpc_v2::HeaderOp { operation }
800}
801
802fn now_ms() -> u64 {
803 std::time::SystemTime::now()
804 .duration_since(std::time::UNIX_EPOCH)
805 .map(|d| d.as_millis() as u64)
806 .unwrap_or(0)
807}
808
809#[cfg(test)]
810mod tests {
811 use super::*;
812
813 struct TestHandlerV2;
814
815 #[async_trait]
816 impl AgentHandlerV2 for TestHandlerV2 {
817 fn capabilities(&self) -> AgentCapabilities {
818 AgentCapabilities::new("test-v2", "Test Agent V2", "1.0.0")
819 }
820 }
821
822 #[test]
823 fn test_create_server() {
824 let server = GrpcAgentServerV2::new("test", Box::new(TestHandlerV2));
825 assert_eq!(server.id, "test");
826 }
827}