1use async_trait::async_trait;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::grpc_v2::{
16 self, agent_service_v2_server::AgentServiceV2, agent_service_v2_server::AgentServiceV2Server,
17 AgentToProxy, ProxyToAgent,
18};
19use crate::v2::pool::CHANNEL_BUFFER_SIZE;
20use crate::v2::{AgentCapabilities, HandshakeRequest, HandshakeResponse, HealthStatus};
21use crate::{
22 AgentResponse, Decision, EventType, HeaderOp, RequestBodyChunkEvent, RequestCompleteEvent,
23 RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
24 WebSocketFrameEvent,
25};
26
27#[async_trait]
85pub trait AgentHandlerV2: Send + Sync {
86 fn capabilities(&self) -> AgentCapabilities;
88
89 async fn on_handshake(&self, _request: HandshakeRequest) -> HandshakeResponse {
91 HandshakeResponse::success(self.capabilities())
93 }
94
95 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
97 AgentResponse::default_allow()
98 }
99
100 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
102 AgentResponse::default_allow()
103 }
104
105 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
107 AgentResponse::default_allow()
108 }
109
110 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
112 AgentResponse::default_allow()
113 }
114
115 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
117 AgentResponse::default_allow()
118 }
119
120 async fn on_websocket_frame(&self, _event: WebSocketFrameEvent) -> AgentResponse {
122 AgentResponse::websocket_allow()
123 }
124
125 fn health_status(&self) -> HealthStatus {
127 HealthStatus::healthy(self.capabilities().agent_id.clone())
128 }
129
130 fn metrics_report(&self) -> Option<crate::v2::MetricsReport> {
132 None
133 }
134
135 async fn on_configure(&self, _config: serde_json::Value, _version: Option<String>) -> bool {
137 true
139 }
140
141 async fn on_shutdown(&self, _reason: ShutdownReason, _grace_period_ms: u64) {
143 }
145
146 async fn on_drain(&self, _duration_ms: u64, _reason: DrainReason) {
148 }
150
151 async fn on_stream_closed(&self) {}
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum ShutdownReason {
158 Graceful,
159 Immediate,
160 ConfigReload,
161 Upgrade,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum DrainReason {
167 ConfigReload,
168 Maintenance,
169 HealthCheckFailed,
170 Manual,
171}
172
173pub struct GrpcAgentServerV2 {
209 id: String,
210 handler: Arc<dyn AgentHandlerV2>,
211}
212
213impl GrpcAgentServerV2 {
214 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandlerV2>) -> Self {
216 let id = id.into();
217 debug!(agent_id = %id, "Creating gRPC agent server v2");
218 Self {
219 id,
220 handler: Arc::from(handler),
221 }
222 }
223
224 pub fn into_service(self) -> AgentServiceV2Server<GrpcAgentHandlerV2> {
226 trace!(agent_id = %self.id, "Converting to tonic v2 service");
227 AgentServiceV2Server::new(GrpcAgentHandlerV2 {
228 id: self.id,
229 handler: self.handler,
230 })
231 }
232
233 pub async fn run(self, addr: std::net::SocketAddr) -> Result<(), crate::AgentProtocolError> {
235 info!(
236 agent_id = %self.id,
237 address = %addr,
238 "gRPC agent server v2 listening"
239 );
240
241 tonic::transport::Server::builder()
242 .add_service(self.into_service())
243 .serve(addr)
244 .await
245 .map_err(|e| {
246 error!(error = %e, "gRPC v2 server error");
247 crate::AgentProtocolError::ConnectionFailed(format!("gRPC v2 server error: {}", e))
248 })
249 }
250}
251
252pub struct GrpcAgentHandlerV2 {
254 id: String,
255 handler: Arc<dyn AgentHandlerV2>,
256}
257
258type ProcessResponseStream = Pin<Box<dyn Stream<Item = Result<AgentToProxy, Status>> + Send>>;
259type ControlResponseStream =
260 Pin<Box<dyn Stream<Item = Result<grpc_v2::ProxyControl, Status>> + Send>>;
261
262#[tonic::async_trait]
263impl AgentServiceV2 for GrpcAgentHandlerV2 {
264 type ProcessStreamStream = ProcessResponseStream;
265 type ControlStreamStream = ControlResponseStream;
266
267 async fn process_stream(
269 &self,
270 request: Request<Streaming<ProxyToAgent>>,
271 ) -> Result<Response<Self::ProcessStreamStream>, Status> {
272 let mut inbound = request.into_inner();
273 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
274 let handler = Arc::clone(&self.handler);
275 let agent_id = self.id.clone();
276
277 debug!(agent_id = %agent_id, "Starting v2 process stream");
278
279 tokio::spawn(async move {
280 let mut handshake_done = false;
281
282 while let Some(result) = inbound.next().await {
283 let msg = match result {
284 Ok(m) => m,
285 Err(e) => {
286 error!(agent_id = %agent_id, error = %e, "Stream error");
287 break;
288 }
289 };
290
291 let response = match msg.message {
292 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
293 trace!(agent_id = %agent_id, "Processing handshake");
294 let handshake_req = convert_handshake_request(req);
295 let resp = handler.on_handshake(handshake_req).await;
296 handshake_done = resp.success;
297 Some(AgentToProxy {
298 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
299 convert_handshake_response(resp),
300 )),
301 })
302 }
303 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
304 if !handshake_done {
305 warn!(agent_id = %agent_id, "Received event before handshake");
306 continue;
307 }
308 let event = convert_request_headers_from_grpc(e);
309 let correlation_id = event.metadata.correlation_id.clone();
310 let start = Instant::now();
311 let resp = handler.on_request_headers(event).await;
312 let processing_time_ms = start.elapsed().as_millis() as u64;
313 Some(create_agent_response(
314 correlation_id,
315 resp,
316 processing_time_ms,
317 ))
318 }
319 Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(e)) => {
320 if !handshake_done {
321 continue;
322 }
323 let event = convert_body_chunk_to_request(e);
324 let correlation_id = event.correlation_id.clone();
325 let start = Instant::now();
326 let resp = handler.on_request_body_chunk(event).await;
327 let processing_time_ms = start.elapsed().as_millis() as u64;
328 Some(create_agent_response(
329 correlation_id,
330 resp,
331 processing_time_ms,
332 ))
333 }
334 Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(e)) => {
335 if !handshake_done {
336 continue;
337 }
338 let event = convert_response_headers_from_grpc(e);
339 let correlation_id = event.correlation_id.clone();
340 let start = Instant::now();
341 let resp = handler.on_response_headers(event).await;
342 let processing_time_ms = start.elapsed().as_millis() as u64;
343 Some(create_agent_response(
344 correlation_id,
345 resp,
346 processing_time_ms,
347 ))
348 }
349 Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(e)) => {
350 if !handshake_done {
351 continue;
352 }
353 let event = convert_body_chunk_to_response(e);
354 let correlation_id = event.correlation_id.clone();
355 let start = Instant::now();
356 let resp = handler.on_response_body_chunk(event).await;
357 let processing_time_ms = start.elapsed().as_millis() as u64;
358 Some(create_agent_response(
359 correlation_id,
360 resp,
361 processing_time_ms,
362 ))
363 }
364 Some(grpc_v2::proxy_to_agent::Message::RequestComplete(e)) => {
365 if !handshake_done {
366 continue;
367 }
368 let event = convert_request_complete_from_grpc(e);
369 let correlation_id = event.correlation_id.clone();
370 let start = Instant::now();
371 let resp = handler.on_request_complete(event).await;
372 let processing_time_ms = start.elapsed().as_millis() as u64;
373 Some(create_agent_response(
374 correlation_id,
375 resp,
376 processing_time_ms,
377 ))
378 }
379 Some(grpc_v2::proxy_to_agent::Message::WebsocketFrame(e)) => {
380 if !handshake_done {
381 continue;
382 }
383 let event = convert_websocket_frame_from_grpc(e);
384 let correlation_id = event.correlation_id.clone();
385 let start = Instant::now();
386 let resp = handler.on_websocket_frame(event).await;
387 let processing_time_ms = start.elapsed().as_millis() as u64;
388 Some(create_agent_response(
389 correlation_id,
390 resp,
391 processing_time_ms,
392 ))
393 }
394 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => {
395 trace!(agent_id = %agent_id, sequence = ping.sequence, "Received ping");
396 Some(AgentToProxy {
397 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
398 sequence: ping.sequence,
399 ping_timestamp_ms: ping.timestamp_ms,
400 timestamp_ms: now_ms(),
401 })),
402 })
403 }
404 Some(grpc_v2::proxy_to_agent::Message::Cancel(cancel)) => {
405 debug!(
406 agent_id = %agent_id,
407 correlation_id = %cancel.correlation_id,
408 "Request cancelled"
409 );
410 None
411 }
412 Some(grpc_v2::proxy_to_agent::Message::Configure(_)) => {
413 None
415 }
416 Some(grpc_v2::proxy_to_agent::Message::Guardrail(_)) => {
417 None
419 }
420 None => {
421 warn!(agent_id = %agent_id, "Empty message received");
422 None
423 }
424 };
425
426 if let Some(resp) = response {
427 if tx.send(Ok(resp)).await.is_err() {
428 debug!(agent_id = %agent_id, "Stream closed by receiver");
429 break;
430 }
431 }
432 }
433
434 handler.on_stream_closed().await;
435 debug!(agent_id = %agent_id, "Process stream ended");
436 });
437
438 let output_stream = ReceiverStream::new(rx);
439 Ok(Response::new(
440 Box::pin(output_stream) as Self::ProcessStreamStream
441 ))
442 }
443
444 async fn control_stream(
450 &self,
451 request: Request<Streaming<grpc_v2::AgentControl>>,
452 ) -> Result<Response<Self::ControlStreamStream>, Status> {
453 let mut inbound = request.into_inner();
454 let (tx, rx) = mpsc::channel::<Result<grpc_v2::ProxyControl, Status>>(16);
455 let handler = Arc::clone(&self.handler);
456 let agent_id = self.id.clone();
457
458 debug!(agent_id = %agent_id, "Starting v2 control stream");
459
460 let _handler_clone = Arc::clone(&handler);
462 let tx_clone = tx.clone();
463 let agent_id_clone = agent_id.clone();
464 tokio::spawn(async move {
465 while let Some(result) = inbound.next().await {
466 let msg = match result {
467 Ok(m) => m,
468 Err(e) => {
469 error!(agent_id = %agent_id_clone, error = %e, "Control stream error");
470 break;
471 }
472 };
473
474 match msg.message {
478 Some(grpc_v2::agent_control::Message::Health(health)) => {
479 trace!(
480 agent_id = %agent_id_clone,
481 state = health.state,
482 "Received health status from agent"
483 );
484 }
486 Some(grpc_v2::agent_control::Message::Metrics(metrics)) => {
487 trace!(
488 agent_id = %agent_id_clone,
489 counters = metrics.counters.len(),
490 gauges = metrics.gauges.len(),
491 "Received metrics report from agent"
492 );
493 }
495 Some(grpc_v2::agent_control::Message::ConfigUpdate(update)) => {
496 debug!(
497 agent_id = %agent_id_clone,
498 request_id = %update.request_id,
499 "Received config update request from agent"
500 );
501 let response = grpc_v2::ProxyControl {
503 message: Some(grpc_v2::proxy_control::Message::ConfigResponse(
504 grpc_v2::ConfigUpdateResponse {
505 request_id: update.request_id,
506 accepted: true,
507 error: None,
508 timestamp_ms: now_ms(),
509 },
510 )),
511 };
512 if tx_clone.send(Ok(response)).await.is_err() {
513 break;
514 }
515 }
516 Some(grpc_v2::agent_control::Message::Log(log)) => {
517 match log.level {
519 1 => {
520 trace!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
521 }
522 2 => {
523 debug!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
524 }
525 3 => warn!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
526 4 => {
527 error!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
528 }
529 _ => info!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
530 }
531 }
532 None => {
533 warn!(agent_id = %agent_id_clone, "Empty control message received");
534 }
535 }
536 }
537
538 debug!(agent_id = %agent_id_clone, "Control stream inbound handler ended");
539 });
540
541 let capabilities = handler.capabilities();
543 let health_interval_ms = capabilities.health.report_interval_ms;
544 let metrics_enabled = capabilities.features.metrics_export;
545
546 if health_interval_ms > 0 || metrics_enabled {
547 let handler_for_health = Arc::clone(&handler);
548 let tx_for_health = tx;
549 let agent_id_for_health = agent_id.clone();
550
551 tokio::spawn(async move {
552 let health_interval = std::time::Duration::from_millis(health_interval_ms as u64);
553 let mut interval = tokio::time::interval(health_interval);
554
555 loop {
556 interval.tick().await;
557
558 let health = handler_for_health.health_status();
560 trace!(
561 agent_id = %agent_id_for_health,
562 state = ?health.state,
563 message = ?health.message,
564 "Agent health status collected"
565 );
566
567 let heartbeat = grpc_v2::ProxyControl {
570 message: Some(grpc_v2::proxy_control::Message::Configure(
571 grpc_v2::ConfigureEvent {
572 config_json: "{}".to_string(),
573 config_version: None,
574 is_initial: false,
575 timestamp_ms: now_ms(),
576 },
577 )),
578 };
579
580 if tx_for_health.send(Ok(heartbeat)).await.is_err() {
581 debug!(
582 agent_id = %agent_id_for_health,
583 "Control stream closed, stopping health reporter"
584 );
585 break;
586 }
587 }
588 });
589 }
590
591 let output_stream = ReceiverStream::new(rx);
592 Ok(Response::new(
593 Box::pin(output_stream) as Self::ControlStreamStream
594 ))
595 }
596
597 async fn process_event(
599 &self,
600 request: Request<ProxyToAgent>,
601 ) -> Result<Response<AgentToProxy>, Status> {
602 let msg = request.into_inner();
603
604 trace!(agent_id = %self.id, "Processing single event (v1 compat)");
605
606 let response = match msg.message {
607 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
608 let handshake_req = convert_handshake_request(req);
609 let resp = self.handler.on_handshake(handshake_req).await;
610 AgentToProxy {
611 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
612 convert_handshake_response(resp),
613 )),
614 }
615 }
616 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
617 let event = convert_request_headers_from_grpc(e);
618 let correlation_id = event.metadata.correlation_id.clone();
619 let start = Instant::now();
620 let resp = self.handler.on_request_headers(event).await;
621 let processing_time_ms = start.elapsed().as_millis() as u64;
622 create_agent_response(correlation_id, resp, processing_time_ms)
623 }
624 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => AgentToProxy {
625 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
626 sequence: ping.sequence,
627 ping_timestamp_ms: ping.timestamp_ms,
628 timestamp_ms: now_ms(),
629 })),
630 },
631 _ => {
632 return Err(Status::invalid_argument("Unsupported event type"));
633 }
634 };
635
636 Ok(Response::new(response))
637 }
638}
639
640fn convert_handshake_request(req: grpc_v2::HandshakeRequest) -> HandshakeRequest {
645 HandshakeRequest {
646 supported_versions: req.supported_versions,
647 proxy_id: req.proxy_id,
648 proxy_version: req.proxy_version,
649 config: serde_json::from_str(&req.config_json).unwrap_or(serde_json::Value::Null),
650 }
651}
652
653fn convert_handshake_response(resp: HandshakeResponse) -> grpc_v2::HandshakeResponse {
654 grpc_v2::HandshakeResponse {
655 protocol_version: resp.protocol_version,
656 capabilities: Some(convert_capabilities_to_grpc(&resp.capabilities)),
657 success: resp.success,
658 error: resp.error,
659 }
660}
661
662fn convert_capabilities_to_grpc(caps: &AgentCapabilities) -> grpc_v2::AgentCapabilities {
663 grpc_v2::AgentCapabilities {
664 protocol_version: caps.protocol_version,
665 agent_id: caps.agent_id.clone(),
666 name: caps.name.clone(),
667 version: caps.version.clone(),
668 supported_events: caps
669 .supported_events
670 .iter()
671 .map(|e| event_type_to_i32(*e))
672 .collect(),
673 features: Some(grpc_v2::AgentFeatures {
674 streaming_body: caps.features.streaming_body,
675 websocket: caps.features.websocket,
676 guardrails: caps.features.guardrails,
677 config_push: caps.features.config_push,
678 metrics_export: caps.features.metrics_export,
679 concurrent_requests: caps.features.concurrent_requests,
680 cancellation: caps.features.cancellation,
681 flow_control: caps.features.flow_control,
682 health_reporting: caps.features.health_reporting,
683 }),
684 limits: Some(grpc_v2::AgentLimits {
685 max_body_size: caps.limits.max_body_size as u64,
686 max_concurrency: caps.limits.max_concurrency,
687 preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
688 max_memory: caps.limits.max_memory.map(|m| m as u64),
689 max_processing_time_ms: caps.limits.max_processing_time_ms,
690 }),
691 health_config: Some(grpc_v2::HealthConfig {
692 report_interval_ms: caps.health.report_interval_ms,
693 include_load_metrics: caps.health.include_load_metrics,
694 include_resource_metrics: caps.health.include_resource_metrics,
695 }),
696 }
697}
698
699pub(crate) fn event_type_to_i32(event_type: EventType) -> i32 {
700 match event_type {
701 EventType::Configure => 8,
702 EventType::RequestHeaders => 1,
703 EventType::RequestBodyChunk => 2,
704 EventType::ResponseHeaders => 3,
705 EventType::ResponseBodyChunk => 4,
706 EventType::RequestComplete => 5,
707 EventType::WebSocketFrame => 6,
708 EventType::GuardrailInspect => 7,
709 }
710}
711
712fn convert_request_headers_from_grpc(e: grpc_v2::RequestHeadersEvent) -> RequestHeadersEvent {
713 let metadata = match e.metadata {
714 Some(m) => RequestMetadata {
715 correlation_id: m.correlation_id,
716 request_id: m.request_id,
717 client_ip: m.client_ip,
718 client_port: m.client_port as u16,
719 server_name: m.server_name,
720 protocol: m.protocol,
721 tls_version: m.tls_version,
722 tls_cipher: None,
723 route_id: m.route_id,
724 upstream_id: m.upstream_id,
725 timestamp: format!("{}", m.timestamp_ms),
726 traceparent: m.traceparent,
727 },
728 None => RequestMetadata {
729 correlation_id: String::new(),
730 request_id: String::new(),
731 client_ip: String::new(),
732 client_port: 0,
733 server_name: None,
734 protocol: String::new(),
735 tls_version: None,
736 tls_cipher: None,
737 route_id: None,
738 upstream_id: None,
739 timestamp: String::new(),
740 traceparent: None,
741 },
742 };
743
744 let headers = e
745 .headers
746 .into_iter()
747 .fold(std::collections::HashMap::new(), |mut map, h| {
748 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
749 map
750 });
751
752 RequestHeadersEvent {
753 metadata,
754 method: e.method,
755 uri: e.uri,
756 headers,
757 }
758}
759
760fn convert_body_chunk_to_request(e: grpc_v2::BodyChunkEvent) -> RequestBodyChunkEvent {
761 use base64::{engine::general_purpose::STANDARD, Engine as _};
762 RequestBodyChunkEvent {
763 correlation_id: e.correlation_id,
764 data: STANDARD.encode(&e.data),
765 is_last: e.is_last,
766 total_size: e.total_size.map(|s| s as usize),
767 chunk_index: e.chunk_index,
768 bytes_received: e.bytes_transferred as usize,
769 }
770}
771
772fn convert_body_chunk_to_response(e: grpc_v2::BodyChunkEvent) -> ResponseBodyChunkEvent {
773 use base64::{engine::general_purpose::STANDARD, Engine as _};
774 ResponseBodyChunkEvent {
775 correlation_id: e.correlation_id,
776 data: STANDARD.encode(&e.data),
777 is_last: e.is_last,
778 total_size: e.total_size.map(|s| s as usize),
779 chunk_index: e.chunk_index,
780 bytes_sent: e.bytes_transferred as usize,
781 }
782}
783
784fn convert_response_headers_from_grpc(e: grpc_v2::ResponseHeadersEvent) -> ResponseHeadersEvent {
785 let headers = e
786 .headers
787 .into_iter()
788 .fold(std::collections::HashMap::new(), |mut map, h| {
789 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
790 map
791 });
792
793 ResponseHeadersEvent {
794 correlation_id: e.correlation_id,
795 status: e.status_code as u16,
796 headers,
797 }
798}
799
800fn convert_request_complete_from_grpc(e: grpc_v2::RequestCompleteEvent) -> RequestCompleteEvent {
801 RequestCompleteEvent {
802 correlation_id: e.correlation_id,
803 status: e.status_code as u16,
804 duration_ms: e.duration_ms,
805 request_body_size: e.bytes_received as usize,
806 response_body_size: e.bytes_sent as usize,
807 upstream_attempts: 1,
808 error: e.error,
809 }
810}
811
812fn convert_websocket_frame_from_grpc(e: grpc_v2::WebSocketFrameEvent) -> WebSocketFrameEvent {
813 use base64::{engine::general_purpose::STANDARD, Engine as _};
814 WebSocketFrameEvent {
815 correlation_id: e.correlation_id,
816 opcode: format!("{}", e.frame_type),
817 data: STANDARD.encode(&e.payload),
818 client_to_server: e.client_to_server,
819 frame_index: 0,
820 fin: true,
821 route_id: None,
822 client_ip: String::new(),
823 }
824}
825
826fn create_agent_response(
827 correlation_id: String,
828 resp: AgentResponse,
829 processing_time_ms: u64,
830) -> AgentToProxy {
831 let decision = match resp.decision {
832 Decision::Allow => Some(grpc_v2::agent_response::Decision::Allow(
833 grpc_v2::AllowDecision {},
834 )),
835 Decision::Block {
836 status,
837 body,
838 headers,
839 } => Some(grpc_v2::agent_response::Decision::Block(
840 grpc_v2::BlockDecision {
841 status: status as u32,
842 body,
843 headers: headers
844 .unwrap_or_default()
845 .into_iter()
846 .map(|(k, v)| grpc_v2::Header { name: k, value: v })
847 .collect(),
848 },
849 )),
850 Decision::Redirect { url, status } => Some(grpc_v2::agent_response::Decision::Redirect(
851 grpc_v2::RedirectDecision {
852 url,
853 status: status as u32,
854 },
855 )),
856 Decision::Challenge {
857 challenge_type,
858 params,
859 } => Some(grpc_v2::agent_response::Decision::Challenge(
860 grpc_v2::ChallengeDecision {
861 challenge_type,
862 params,
863 },
864 )),
865 };
866
867 let request_headers: Vec<grpc_v2::HeaderOp> = resp
868 .request_headers
869 .into_iter()
870 .map(convert_header_op_to_grpc)
871 .collect();
872
873 let response_headers: Vec<grpc_v2::HeaderOp> = resp
874 .response_headers
875 .into_iter()
876 .map(convert_header_op_to_grpc)
877 .collect();
878
879 let audit = Some(grpc_v2::AuditMetadata {
880 tags: resp.audit.tags,
881 rule_ids: resp.audit.rule_ids,
882 confidence: resp.audit.confidence,
883 reason_codes: resp.audit.reason_codes,
884 custom: resp
885 .audit
886 .custom
887 .into_iter()
888 .map(|(k, v)| (k, v.to_string()))
889 .collect(),
890 });
891
892 AgentToProxy {
893 message: Some(grpc_v2::agent_to_proxy::Message::Response(
894 grpc_v2::AgentResponse {
895 correlation_id,
896 decision,
897 request_headers,
898 response_headers,
899 audit,
900 processing_time_ms: Some(processing_time_ms),
901 needs_more: resp.needs_more,
902 },
903 )),
904 }
905}
906
907fn convert_header_op_to_grpc(op: HeaderOp) -> grpc_v2::HeaderOp {
908 let operation = match op {
909 HeaderOp::Set { name, value } => {
910 Some(grpc_v2::header_op::Operation::Set(grpc_v2::Header {
911 name,
912 value,
913 }))
914 }
915 HeaderOp::Add { name, value } => {
916 Some(grpc_v2::header_op::Operation::Add(grpc_v2::Header {
917 name,
918 value,
919 }))
920 }
921 HeaderOp::Remove { name } => Some(grpc_v2::header_op::Operation::Remove(name)),
922 };
923 grpc_v2::HeaderOp { operation }
924}
925
926fn now_ms() -> u64 {
927 std::time::SystemTime::now()
928 .duration_since(std::time::UNIX_EPOCH)
929 .map(|d| d.as_millis() as u64)
930 .unwrap_or(0)
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936
937 struct TestHandlerV2;
938
939 #[async_trait]
940 impl AgentHandlerV2 for TestHandlerV2 {
941 fn capabilities(&self) -> AgentCapabilities {
942 AgentCapabilities::new("test-v2", "Test Agent V2", "1.0.0")
943 }
944 }
945
946 #[test]
947 fn test_create_server() {
948 let server = GrpcAgentServerV2::new("test", Box::new(TestHandlerV2));
949 assert_eq!(server.id, "test");
950 }
951}