1use async_trait::async_trait;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::grpc_v2::{
16 self, agent_service_v2_server::AgentServiceV2, agent_service_v2_server::AgentServiceV2Server,
17 AgentToProxy, ProxyToAgent,
18};
19use crate::v2::pool::CHANNEL_BUFFER_SIZE;
20use crate::v2::{AgentCapabilities, HandshakeRequest, HandshakeResponse, HealthStatus};
21use crate::{
22 AgentResponse as V1Response, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
23 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
24 ResponseHeadersEvent, WebSocketFrameEvent,
25};
26
27#[async_trait]
36pub trait AgentHandlerV2: Send + Sync {
37 fn capabilities(&self) -> AgentCapabilities;
39
40 async fn on_handshake(&self, _request: HandshakeRequest) -> HandshakeResponse {
42 HandshakeResponse::success(self.capabilities())
44 }
45
46 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> V1Response {
48 V1Response::default_allow()
49 }
50
51 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> V1Response {
53 V1Response::default_allow()
54 }
55
56 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> V1Response {
58 V1Response::default_allow()
59 }
60
61 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> V1Response {
63 V1Response::default_allow()
64 }
65
66 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> V1Response {
68 V1Response::default_allow()
69 }
70
71 async fn on_websocket_frame(&self, _event: WebSocketFrameEvent) -> V1Response {
73 V1Response::websocket_allow()
74 }
75
76 fn health_status(&self) -> HealthStatus {
78 HealthStatus::healthy(self.capabilities().agent_id.clone())
79 }
80
81 fn metrics_report(&self) -> Option<crate::v2::MetricsReport> {
83 None
84 }
85
86 async fn on_configure(&self, _config: serde_json::Value, _version: Option<String>) -> bool {
88 true
90 }
91
92 async fn on_shutdown(&self, _reason: ShutdownReason, _grace_period_ms: u64) {
94 }
96
97 async fn on_drain(&self, _duration_ms: u64, _reason: DrainReason) {
99 }
101
102 async fn on_stream_closed(&self) {}
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum ShutdownReason {
109 Graceful,
110 Immediate,
111 ConfigReload,
112 Upgrade,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum DrainReason {
118 ConfigReload,
119 Maintenance,
120 HealthCheckFailed,
121 Manual,
122}
123
124pub struct GrpcAgentServerV2 {
126 id: String,
127 handler: Arc<dyn AgentHandlerV2>,
128}
129
130impl GrpcAgentServerV2 {
131 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandlerV2>) -> Self {
133 let id = id.into();
134 debug!(agent_id = %id, "Creating gRPC agent server v2");
135 Self {
136 id,
137 handler: Arc::from(handler),
138 }
139 }
140
141 pub fn into_service(self) -> AgentServiceV2Server<GrpcAgentHandlerV2> {
143 trace!(agent_id = %self.id, "Converting to tonic v2 service");
144 AgentServiceV2Server::new(GrpcAgentHandlerV2 {
145 id: self.id,
146 handler: self.handler,
147 })
148 }
149
150 pub async fn run(self, addr: std::net::SocketAddr) -> Result<(), crate::AgentProtocolError> {
152 info!(
153 agent_id = %self.id,
154 address = %addr,
155 "gRPC agent server v2 listening"
156 );
157
158 tonic::transport::Server::builder()
159 .add_service(self.into_service())
160 .serve(addr)
161 .await
162 .map_err(|e| {
163 error!(error = %e, "gRPC v2 server error");
164 crate::AgentProtocolError::ConnectionFailed(format!("gRPC v2 server error: {}", e))
165 })
166 }
167}
168
169pub struct GrpcAgentHandlerV2 {
171 id: String,
172 handler: Arc<dyn AgentHandlerV2>,
173}
174
175type ProcessResponseStream = Pin<Box<dyn Stream<Item = Result<AgentToProxy, Status>> + Send>>;
176type ControlResponseStream =
177 Pin<Box<dyn Stream<Item = Result<grpc_v2::ProxyControl, Status>> + Send>>;
178
179#[tonic::async_trait]
180impl AgentServiceV2 for GrpcAgentHandlerV2 {
181 type ProcessStreamStream = ProcessResponseStream;
182 type ControlStreamStream = ControlResponseStream;
183
184 async fn process_stream(
186 &self,
187 request: Request<Streaming<ProxyToAgent>>,
188 ) -> Result<Response<Self::ProcessStreamStream>, Status> {
189 let mut inbound = request.into_inner();
190 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
191 let handler = Arc::clone(&self.handler);
192 let agent_id = self.id.clone();
193
194 debug!(agent_id = %agent_id, "Starting v2 process stream");
195
196 tokio::spawn(async move {
197 let mut handshake_done = false;
198
199 while let Some(result) = inbound.next().await {
200 let msg = match result {
201 Ok(m) => m,
202 Err(e) => {
203 error!(agent_id = %agent_id, error = %e, "Stream error");
204 break;
205 }
206 };
207
208 let response = match msg.message {
209 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
210 trace!(agent_id = %agent_id, "Processing handshake");
211 let handshake_req = convert_handshake_request(req);
212 let resp = handler.on_handshake(handshake_req).await;
213 handshake_done = resp.success;
214 Some(AgentToProxy {
215 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
216 convert_handshake_response(resp),
217 )),
218 })
219 }
220 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
221 if !handshake_done {
222 warn!(agent_id = %agent_id, "Received event before handshake");
223 continue;
224 }
225 let event = convert_request_headers_from_grpc(e);
226 let correlation_id = event.metadata.correlation_id.clone();
227 let start = Instant::now();
228 let resp = handler.on_request_headers(event).await;
229 let processing_time_ms = start.elapsed().as_millis() as u64;
230 Some(create_agent_response(
231 correlation_id,
232 resp,
233 processing_time_ms,
234 ))
235 }
236 Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(e)) => {
237 if !handshake_done {
238 continue;
239 }
240 let event = convert_body_chunk_to_request(e);
241 let correlation_id = event.correlation_id.clone();
242 let start = Instant::now();
243 let resp = handler.on_request_body_chunk(event).await;
244 let processing_time_ms = start.elapsed().as_millis() as u64;
245 Some(create_agent_response(
246 correlation_id,
247 resp,
248 processing_time_ms,
249 ))
250 }
251 Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(e)) => {
252 if !handshake_done {
253 continue;
254 }
255 let event = convert_response_headers_from_grpc(e);
256 let correlation_id = event.correlation_id.clone();
257 let start = Instant::now();
258 let resp = handler.on_response_headers(event).await;
259 let processing_time_ms = start.elapsed().as_millis() as u64;
260 Some(create_agent_response(
261 correlation_id,
262 resp,
263 processing_time_ms,
264 ))
265 }
266 Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(e)) => {
267 if !handshake_done {
268 continue;
269 }
270 let event = convert_body_chunk_to_response(e);
271 let correlation_id = event.correlation_id.clone();
272 let start = Instant::now();
273 let resp = handler.on_response_body_chunk(event).await;
274 let processing_time_ms = start.elapsed().as_millis() as u64;
275 Some(create_agent_response(
276 correlation_id,
277 resp,
278 processing_time_ms,
279 ))
280 }
281 Some(grpc_v2::proxy_to_agent::Message::RequestComplete(e)) => {
282 if !handshake_done {
283 continue;
284 }
285 let event = convert_request_complete_from_grpc(e);
286 let correlation_id = event.correlation_id.clone();
287 let start = Instant::now();
288 let resp = handler.on_request_complete(event).await;
289 let processing_time_ms = start.elapsed().as_millis() as u64;
290 Some(create_agent_response(
291 correlation_id,
292 resp,
293 processing_time_ms,
294 ))
295 }
296 Some(grpc_v2::proxy_to_agent::Message::WebsocketFrame(e)) => {
297 if !handshake_done {
298 continue;
299 }
300 let event = convert_websocket_frame_from_grpc(e);
301 let correlation_id = event.correlation_id.clone();
302 let start = Instant::now();
303 let resp = handler.on_websocket_frame(event).await;
304 let processing_time_ms = start.elapsed().as_millis() as u64;
305 Some(create_agent_response(
306 correlation_id,
307 resp,
308 processing_time_ms,
309 ))
310 }
311 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => {
312 trace!(agent_id = %agent_id, sequence = ping.sequence, "Received ping");
313 Some(AgentToProxy {
314 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
315 sequence: ping.sequence,
316 ping_timestamp_ms: ping.timestamp_ms,
317 timestamp_ms: now_ms(),
318 })),
319 })
320 }
321 Some(grpc_v2::proxy_to_agent::Message::Cancel(cancel)) => {
322 debug!(
323 agent_id = %agent_id,
324 correlation_id = %cancel.correlation_id,
325 "Request cancelled"
326 );
327 None
328 }
329 Some(grpc_v2::proxy_to_agent::Message::Configure(_)) => {
330 None
332 }
333 Some(grpc_v2::proxy_to_agent::Message::Guardrail(_)) => {
334 None
336 }
337 None => {
338 warn!(agent_id = %agent_id, "Empty message received");
339 None
340 }
341 };
342
343 if let Some(resp) = response {
344 if tx.send(Ok(resp)).await.is_err() {
345 debug!(agent_id = %agent_id, "Stream closed by receiver");
346 break;
347 }
348 }
349 }
350
351 handler.on_stream_closed().await;
352 debug!(agent_id = %agent_id, "Process stream ended");
353 });
354
355 let output_stream = ReceiverStream::new(rx);
356 Ok(Response::new(
357 Box::pin(output_stream) as Self::ProcessStreamStream
358 ))
359 }
360
361 async fn control_stream(
367 &self,
368 request: Request<Streaming<grpc_v2::AgentControl>>,
369 ) -> Result<Response<Self::ControlStreamStream>, Status> {
370 let mut inbound = request.into_inner();
371 let (tx, rx) = mpsc::channel::<Result<grpc_v2::ProxyControl, Status>>(16);
372 let handler = Arc::clone(&self.handler);
373 let agent_id = self.id.clone();
374
375 debug!(agent_id = %agent_id, "Starting v2 control stream");
376
377 let _handler_clone = Arc::clone(&handler);
379 let tx_clone = tx.clone();
380 let agent_id_clone = agent_id.clone();
381 tokio::spawn(async move {
382 while let Some(result) = inbound.next().await {
383 let msg = match result {
384 Ok(m) => m,
385 Err(e) => {
386 error!(agent_id = %agent_id_clone, error = %e, "Control stream error");
387 break;
388 }
389 };
390
391 match msg.message {
395 Some(grpc_v2::agent_control::Message::Health(health)) => {
396 trace!(
397 agent_id = %agent_id_clone,
398 state = health.state,
399 "Received health status from agent"
400 );
401 }
403 Some(grpc_v2::agent_control::Message::Metrics(metrics)) => {
404 trace!(
405 agent_id = %agent_id_clone,
406 counters = metrics.counters.len(),
407 gauges = metrics.gauges.len(),
408 "Received metrics report from agent"
409 );
410 }
412 Some(grpc_v2::agent_control::Message::ConfigUpdate(update)) => {
413 debug!(
414 agent_id = %agent_id_clone,
415 request_id = %update.request_id,
416 "Received config update request from agent"
417 );
418 let response = grpc_v2::ProxyControl {
420 message: Some(grpc_v2::proxy_control::Message::ConfigResponse(
421 grpc_v2::ConfigUpdateResponse {
422 request_id: update.request_id,
423 accepted: true,
424 error: None,
425 timestamp_ms: now_ms(),
426 },
427 )),
428 };
429 if tx_clone.send(Ok(response)).await.is_err() {
430 break;
431 }
432 }
433 Some(grpc_v2::agent_control::Message::Log(log)) => {
434 match log.level {
436 1 => {
437 trace!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
438 }
439 2 => {
440 debug!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
441 }
442 3 => warn!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
443 4 => {
444 error!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
445 }
446 _ => info!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
447 }
448 }
449 None => {
450 warn!(agent_id = %agent_id_clone, "Empty control message received");
451 }
452 }
453 }
454
455 debug!(agent_id = %agent_id_clone, "Control stream inbound handler ended");
456 });
457
458 let capabilities = handler.capabilities();
460 let health_interval_ms = capabilities.health.report_interval_ms;
461 let metrics_enabled = capabilities.features.metrics_export;
462
463 if health_interval_ms > 0 || metrics_enabled {
464 let handler_for_health = Arc::clone(&handler);
465 let tx_for_health = tx;
466 let agent_id_for_health = agent_id.clone();
467
468 tokio::spawn(async move {
469 let health_interval = std::time::Duration::from_millis(health_interval_ms as u64);
470 let mut interval = tokio::time::interval(health_interval);
471
472 loop {
473 interval.tick().await;
474
475 let health = handler_for_health.health_status();
477 let health_msg = grpc_v2::ProxyControl {
478 message: Some(grpc_v2::proxy_control::Message::Configure(
479 grpc_v2::ConfigureEvent {
480 config_json: "{}".to_string(), config_version: None,
482 is_initial: false,
483 timestamp_ms: now_ms(),
484 },
485 )),
486 };
487
488 let _ = health_msg; let _ = health;
493
494 if tx_for_health.is_closed() {
496 debug!(agent_id = %agent_id_for_health, "Control stream closed, stopping health reporter");
497 break;
498 }
499 }
500 });
501 }
502
503 let output_stream = ReceiverStream::new(rx);
504 Ok(Response::new(
505 Box::pin(output_stream) as Self::ControlStreamStream
506 ))
507 }
508
509 async fn process_event(
511 &self,
512 request: Request<ProxyToAgent>,
513 ) -> Result<Response<AgentToProxy>, Status> {
514 let msg = request.into_inner();
515
516 trace!(agent_id = %self.id, "Processing single event (v1 compat)");
517
518 let response = match msg.message {
519 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
520 let handshake_req = convert_handshake_request(req);
521 let resp = self.handler.on_handshake(handshake_req).await;
522 AgentToProxy {
523 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
524 convert_handshake_response(resp),
525 )),
526 }
527 }
528 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
529 let event = convert_request_headers_from_grpc(e);
530 let correlation_id = event.metadata.correlation_id.clone();
531 let start = Instant::now();
532 let resp = self.handler.on_request_headers(event).await;
533 let processing_time_ms = start.elapsed().as_millis() as u64;
534 create_agent_response(correlation_id, resp, processing_time_ms)
535 }
536 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => AgentToProxy {
537 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
538 sequence: ping.sequence,
539 ping_timestamp_ms: ping.timestamp_ms,
540 timestamp_ms: now_ms(),
541 })),
542 },
543 _ => {
544 return Err(Status::invalid_argument("Unsupported event type"));
545 }
546 };
547
548 Ok(Response::new(response))
549 }
550}
551
552fn convert_handshake_request(req: grpc_v2::HandshakeRequest) -> HandshakeRequest {
557 HandshakeRequest {
558 supported_versions: req.supported_versions,
559 proxy_id: req.proxy_id,
560 proxy_version: req.proxy_version,
561 config: serde_json::from_str(&req.config_json).unwrap_or(serde_json::Value::Null),
562 }
563}
564
565fn convert_handshake_response(resp: HandshakeResponse) -> grpc_v2::HandshakeResponse {
566 grpc_v2::HandshakeResponse {
567 protocol_version: resp.protocol_version,
568 capabilities: Some(convert_capabilities_to_grpc(&resp.capabilities)),
569 success: resp.success,
570 error: resp.error,
571 }
572}
573
574fn convert_capabilities_to_grpc(caps: &AgentCapabilities) -> grpc_v2::AgentCapabilities {
575 grpc_v2::AgentCapabilities {
576 protocol_version: caps.protocol_version,
577 agent_id: caps.agent_id.clone(),
578 name: caps.name.clone(),
579 version: caps.version.clone(),
580 supported_events: caps
581 .supported_events
582 .iter()
583 .map(|e| event_type_to_i32(*e))
584 .collect(),
585 features: Some(grpc_v2::AgentFeatures {
586 streaming_body: caps.features.streaming_body,
587 websocket: caps.features.websocket,
588 guardrails: caps.features.guardrails,
589 config_push: caps.features.config_push,
590 metrics_export: caps.features.metrics_export,
591 concurrent_requests: caps.features.concurrent_requests,
592 cancellation: caps.features.cancellation,
593 flow_control: caps.features.flow_control,
594 health_reporting: caps.features.health_reporting,
595 }),
596 limits: Some(grpc_v2::AgentLimits {
597 max_body_size: caps.limits.max_body_size as u64,
598 max_concurrency: caps.limits.max_concurrency,
599 preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
600 max_memory: caps.limits.max_memory.map(|m| m as u64),
601 max_processing_time_ms: caps.limits.max_processing_time_ms,
602 }),
603 health_config: Some(grpc_v2::HealthConfig {
604 report_interval_ms: caps.health.report_interval_ms,
605 include_load_metrics: caps.health.include_load_metrics,
606 include_resource_metrics: caps.health.include_resource_metrics,
607 }),
608 }
609}
610
611fn event_type_to_i32(event_type: EventType) -> i32 {
612 match event_type {
613 EventType::Configure => 8,
614 EventType::RequestHeaders => 1,
615 EventType::RequestBodyChunk => 2,
616 EventType::ResponseHeaders => 3,
617 EventType::ResponseBodyChunk => 4,
618 EventType::RequestComplete => 5,
619 EventType::WebSocketFrame => 6,
620 EventType::GuardrailInspect => 7,
621 }
622}
623
624fn convert_request_headers_from_grpc(e: grpc_v2::RequestHeadersEvent) -> RequestHeadersEvent {
625 let metadata = match e.metadata {
626 Some(m) => RequestMetadata {
627 correlation_id: m.correlation_id,
628 request_id: m.request_id,
629 client_ip: m.client_ip,
630 client_port: m.client_port as u16,
631 server_name: m.server_name,
632 protocol: m.protocol,
633 tls_version: m.tls_version,
634 tls_cipher: None,
635 route_id: m.route_id,
636 upstream_id: m.upstream_id,
637 timestamp: format!("{}", m.timestamp_ms),
638 traceparent: m.traceparent,
639 },
640 None => RequestMetadata {
641 correlation_id: String::new(),
642 request_id: String::new(),
643 client_ip: String::new(),
644 client_port: 0,
645 server_name: None,
646 protocol: String::new(),
647 tls_version: None,
648 tls_cipher: None,
649 route_id: None,
650 upstream_id: None,
651 timestamp: String::new(),
652 traceparent: None,
653 },
654 };
655
656 let headers = e
657 .headers
658 .into_iter()
659 .fold(std::collections::HashMap::new(), |mut map, h| {
660 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
661 map
662 });
663
664 RequestHeadersEvent {
665 metadata,
666 method: e.method,
667 uri: e.uri,
668 headers,
669 }
670}
671
672fn convert_body_chunk_to_request(e: grpc_v2::BodyChunkEvent) -> RequestBodyChunkEvent {
673 use base64::{engine::general_purpose::STANDARD, Engine as _};
674 RequestBodyChunkEvent {
675 correlation_id: e.correlation_id,
676 data: STANDARD.encode(&e.data),
677 is_last: e.is_last,
678 total_size: e.total_size.map(|s| s as usize),
679 chunk_index: e.chunk_index,
680 bytes_received: e.bytes_transferred as usize,
681 }
682}
683
684fn convert_body_chunk_to_response(e: grpc_v2::BodyChunkEvent) -> ResponseBodyChunkEvent {
685 use base64::{engine::general_purpose::STANDARD, Engine as _};
686 ResponseBodyChunkEvent {
687 correlation_id: e.correlation_id,
688 data: STANDARD.encode(&e.data),
689 is_last: e.is_last,
690 total_size: e.total_size.map(|s| s as usize),
691 chunk_index: e.chunk_index,
692 bytes_sent: e.bytes_transferred as usize,
693 }
694}
695
696fn convert_response_headers_from_grpc(e: grpc_v2::ResponseHeadersEvent) -> ResponseHeadersEvent {
697 let headers = e
698 .headers
699 .into_iter()
700 .fold(std::collections::HashMap::new(), |mut map, h| {
701 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
702 map
703 });
704
705 ResponseHeadersEvent {
706 correlation_id: e.correlation_id,
707 status: e.status_code as u16,
708 headers,
709 }
710}
711
712fn convert_request_complete_from_grpc(e: grpc_v2::RequestCompleteEvent) -> RequestCompleteEvent {
713 RequestCompleteEvent {
714 correlation_id: e.correlation_id,
715 status: e.status_code as u16,
716 duration_ms: e.duration_ms,
717 request_body_size: e.bytes_received as usize,
718 response_body_size: e.bytes_sent as usize,
719 upstream_attempts: 1,
720 error: e.error,
721 }
722}
723
724fn convert_websocket_frame_from_grpc(e: grpc_v2::WebSocketFrameEvent) -> WebSocketFrameEvent {
725 use base64::{engine::general_purpose::STANDARD, Engine as _};
726 WebSocketFrameEvent {
727 correlation_id: e.correlation_id,
728 opcode: format!("{}", e.frame_type),
729 data: STANDARD.encode(&e.payload),
730 client_to_server: e.client_to_server,
731 frame_index: 0,
732 fin: true,
733 route_id: None,
734 client_ip: String::new(),
735 }
736}
737
738fn create_agent_response(
739 correlation_id: String,
740 resp: V1Response,
741 processing_time_ms: u64,
742) -> AgentToProxy {
743 let decision = match resp.decision {
744 Decision::Allow => Some(grpc_v2::agent_response::Decision::Allow(
745 grpc_v2::AllowDecision {},
746 )),
747 Decision::Block {
748 status,
749 body,
750 headers,
751 } => Some(grpc_v2::agent_response::Decision::Block(
752 grpc_v2::BlockDecision {
753 status: status as u32,
754 body,
755 headers: headers
756 .unwrap_or_default()
757 .into_iter()
758 .map(|(k, v)| grpc_v2::Header { name: k, value: v })
759 .collect(),
760 },
761 )),
762 Decision::Redirect { url, status } => Some(grpc_v2::agent_response::Decision::Redirect(
763 grpc_v2::RedirectDecision {
764 url,
765 status: status as u32,
766 },
767 )),
768 Decision::Challenge {
769 challenge_type,
770 params,
771 } => Some(grpc_v2::agent_response::Decision::Challenge(
772 grpc_v2::ChallengeDecision {
773 challenge_type,
774 params,
775 },
776 )),
777 };
778
779 let request_headers: Vec<grpc_v2::HeaderOp> = resp
780 .request_headers
781 .into_iter()
782 .map(convert_header_op_to_grpc)
783 .collect();
784
785 let response_headers: Vec<grpc_v2::HeaderOp> = resp
786 .response_headers
787 .into_iter()
788 .map(convert_header_op_to_grpc)
789 .collect();
790
791 let audit = Some(grpc_v2::AuditMetadata {
792 tags: resp.audit.tags,
793 rule_ids: resp.audit.rule_ids,
794 confidence: resp.audit.confidence,
795 reason_codes: resp.audit.reason_codes,
796 custom: resp
797 .audit
798 .custom
799 .into_iter()
800 .map(|(k, v)| (k, v.to_string()))
801 .collect(),
802 });
803
804 AgentToProxy {
805 message: Some(grpc_v2::agent_to_proxy::Message::Response(
806 grpc_v2::AgentResponse {
807 correlation_id,
808 decision,
809 request_headers,
810 response_headers,
811 audit,
812 processing_time_ms: Some(processing_time_ms),
813 needs_more: resp.needs_more,
814 },
815 )),
816 }
817}
818
819fn convert_header_op_to_grpc(op: HeaderOp) -> grpc_v2::HeaderOp {
820 let operation = match op {
821 HeaderOp::Set { name, value } => {
822 Some(grpc_v2::header_op::Operation::Set(grpc_v2::Header {
823 name,
824 value,
825 }))
826 }
827 HeaderOp::Add { name, value } => {
828 Some(grpc_v2::header_op::Operation::Add(grpc_v2::Header {
829 name,
830 value,
831 }))
832 }
833 HeaderOp::Remove { name } => Some(grpc_v2::header_op::Operation::Remove(name)),
834 };
835 grpc_v2::HeaderOp { operation }
836}
837
838fn now_ms() -> u64 {
839 std::time::SystemTime::now()
840 .duration_since(std::time::UNIX_EPOCH)
841 .map(|d| d.as_millis() as u64)
842 .unwrap_or(0)
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848
849 struct TestHandlerV2;
850
851 #[async_trait]
852 impl AgentHandlerV2 for TestHandlerV2 {
853 fn capabilities(&self) -> AgentCapabilities {
854 AgentCapabilities::new("test-v2", "Test Agent V2", "1.0.0")
855 }
856 }
857
858 #[test]
859 fn test_create_server() {
860 let server = GrpcAgentServerV2::new("test", Box::new(TestHandlerV2));
861 assert_eq!(server.id, "test");
862 }
863}