valve/adapter/
pipeline.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::SystemTime;
4
5use async_trait::async_trait;
6use serde_json::Value;
7
8use super::{AdapterContext, NoOpRequestAdapter, NoOpResponseAdapter};
9use crate::error::ValveError;
10use crate::observer::{ObservedStream, StreamEvent, StreamPosition, telemetry::TelemetryToken};
11
12#[derive(Debug, Clone)]
13pub struct RequestToken {
14    pub payload: Value,
15    pub metadata: HashMap<String, String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct ResponseToken {
20    pub payload: Value,
21    pub metadata: HashMap<String, String>,
22}
23
24impl TelemetryToken for RequestToken {
25    fn metadata(&self) -> &HashMap<String, String> {
26        &self.metadata
27    }
28
29    fn payload(&self) -> &Value {
30        &self.payload
31    }
32}
33
34impl TelemetryToken for ResponseToken {
35    fn metadata(&self) -> &HashMap<String, String> {
36        &self.metadata
37    }
38
39    fn payload(&self) -> &Value {
40        &self.payload
41    }
42}
43
44#[async_trait]
45pub trait RequestAdapter: Send + Sync {
46    async fn adapt(
47        &self,
48        payload: Value,
49        context: &AdapterContext,
50        stream: &ObservedStream<RequestToken>,
51    ) -> Result<Value, ValveError>;
52}
53
54#[async_trait]
55pub trait ResponseAdapter: Send + Sync {
56    async fn adapt(
57        &self,
58        payload: Value,
59        context: &AdapterContext,
60        stream: &ObservedStream<ResponseToken>,
61    ) -> Result<Value, ValveError>;
62}
63
64#[derive(Clone)]
65pub struct AdapterPipeline {
66    request_adapter: Arc<dyn RequestAdapter>,
67    response_adapter: Arc<dyn ResponseAdapter>,
68    request_stream: ObservedStream<RequestToken>,
69    response_stream: ObservedStream<ResponseToken>,
70    context: AdapterContext,
71}
72
73impl AdapterPipeline {
74    pub fn new(
75        request_adapter: Arc<dyn RequestAdapter>,
76        response_adapter: Arc<dyn ResponseAdapter>,
77    ) -> Self {
78        Self {
79            request_adapter,
80            response_adapter,
81            request_stream: ObservedStream::default(),
82            response_stream: ObservedStream::default(),
83            context: AdapterContext::default(),
84        }
85    }
86
87    pub fn noop() -> Self {
88        Self::new(
89            Arc::new(NoOpRequestAdapter::default()),
90            Arc::new(NoOpResponseAdapter::default()),
91        )
92    }
93
94    pub fn request_stream(&self) -> ObservedStream<RequestToken> {
95        self.request_stream.clone()
96    }
97
98    pub fn response_stream(&self) -> ObservedStream<ResponseToken> {
99        self.response_stream.clone()
100    }
101
102    pub fn context(&self) -> AdapterContext {
103        self.context.clone()
104    }
105
106    pub async fn adapt_request(&self, payload: Value) -> Result<Value, ValveError> {
107        let received_event = StreamEvent {
108            token: RequestToken {
109                payload: payload.clone(),
110                metadata: self.context.metadata(),
111            },
112            position: StreamPosition::BeforeAdapter,
113            timestamp: SystemTime::now(),
114        };
115        self.request_stream.broadcast(received_event);
116
117        let adapted = self
118            .request_adapter
119            .adapt(payload, &self.context, &self.request_stream)
120            .await?;
121
122        let adapted_event = StreamEvent {
123            token: RequestToken {
124                payload: adapted.clone(),
125                metadata: self.context.metadata(),
126            },
127            position: StreamPosition::AfterAdapter,
128            timestamp: SystemTime::now(),
129        };
130        self.request_stream.broadcast(adapted_event);
131
132        Ok(adapted)
133    }
134
135    pub async fn adapt_response(&self, payload: Value) -> Result<Value, ValveError> {
136        let received_event = StreamEvent {
137            token: ResponseToken {
138                payload: payload.clone(),
139                metadata: self.context.metadata(),
140            },
141            position: StreamPosition::BeforeAdapter,
142            timestamp: SystemTime::now(),
143        };
144        self.response_stream.broadcast(received_event);
145
146        let adapted = self
147            .response_adapter
148            .adapt(payload, &self.context, &self.response_stream)
149            .await?;
150
151        let adapted_event = StreamEvent {
152            token: ResponseToken {
153                payload: adapted.clone(),
154                metadata: self.context.metadata(),
155            },
156            position: StreamPosition::AfterAdapter,
157            timestamp: SystemTime::now(),
158        };
159        self.response_stream.broadcast(adapted_event);
160
161        Ok(adapted)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use std::sync::{Arc, Mutex};
168
169    use super::*;
170    use crate::observer::{StreamObserver, StreamPosition};
171
172    struct RecordingObserver {
173        events: Arc<Mutex<Vec<StreamPosition>>>,
174    }
175
176    impl StreamObserver<RequestToken> for RecordingObserver {
177        fn on_event(&self, event: &StreamEvent<RequestToken>) {
178            self.events.lock().unwrap().push(event.position);
179        }
180    }
181
182    impl StreamObserver<ResponseToken> for RecordingObserver {
183        fn on_event(&self, event: &StreamEvent<ResponseToken>) {
184            self.events.lock().unwrap().push(event.position);
185        }
186    }
187
188    #[tokio::test]
189    async fn broadcasts_pipeline_events_during_adaptation() {
190        let pipeline = AdapterPipeline::noop();
191        let events = Arc::new(Mutex::new(Vec::new()));
192        let observer = Arc::new(RecordingObserver {
193            events: events.clone(),
194        });
195
196        pipeline.request_stream().register(observer.clone());
197        pipeline.response_stream().register(observer);
198
199        let request = serde_json::json!({ "foo": "bar" });
200        let response = serde_json::json!({ "baz": "qux" });
201
202        let adapted_request = pipeline.adapt_request(request.clone()).await.unwrap();
203        let adapted_response = pipeline.adapt_response(response.clone()).await.unwrap();
204
205        assert_eq!(adapted_request, request);
206        assert_eq!(adapted_response, response);
207
208        let captured = events.lock().unwrap();
209        assert_eq!(captured.len(), 4);
210        assert_eq!(captured[0], StreamPosition::BeforeAdapter);
211        assert_eq!(captured[1], StreamPosition::AfterAdapter);
212        assert_eq!(captured[2], StreamPosition::BeforeAdapter);
213        assert_eq!(captured[3], StreamPosition::AfterAdapter);
214    }
215}