Skip to main content

synaptic_models/
backend.rs

1use std::{collections::VecDeque, pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::Value;
6use synaptic_core::SynapticError;
7use tokio::sync::Mutex;
8
9#[derive(Debug, Clone)]
10pub struct ProviderRequest {
11    pub url: String,
12    pub headers: Vec<(String, String)>,
13    pub body: Value,
14}
15
16#[derive(Debug, Clone)]
17pub struct ProviderResponse {
18    pub status: u16,
19    pub body: Value,
20}
21
22pub type ByteStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, SynapticError>> + Send>>;
23
24#[async_trait]
25pub trait ProviderBackend: Send + Sync {
26    async fn send(&self, request: ProviderRequest) -> Result<ProviderResponse, SynapticError>;
27    async fn send_stream(&self, request: ProviderRequest) -> Result<ByteStream, SynapticError>;
28}
29
30/// Production backend using reqwest.
31pub struct HttpBackend {
32    client: reqwest::Client,
33}
34
35impl HttpBackend {
36    pub fn new() -> Self {
37        Self {
38            client: reqwest::Client::new(),
39        }
40    }
41}
42
43impl Default for HttpBackend {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49#[async_trait]
50impl ProviderBackend for HttpBackend {
51    async fn send(&self, request: ProviderRequest) -> Result<ProviderResponse, SynapticError> {
52        let mut builder = self.client.post(&request.url);
53        for (key, value) in &request.headers {
54            builder = builder.header(key, value);
55        }
56        builder = builder.json(&request.body);
57
58        let response = builder
59            .send()
60            .await
61            .map_err(|e| SynapticError::Model(format!("HTTP request failed: {e}")))?;
62
63        let status = response.status().as_u16();
64        let body: Value = response
65            .json()
66            .await
67            .map_err(|e| SynapticError::Parsing(format!("failed to parse response JSON: {e}")))?;
68
69        Ok(ProviderResponse { status, body })
70    }
71
72    async fn send_stream(&self, request: ProviderRequest) -> Result<ByteStream, SynapticError> {
73        use futures::StreamExt;
74
75        let mut builder = self.client.post(&request.url);
76        for (key, value) in &request.headers {
77            builder = builder.header(key, value);
78        }
79        builder = builder.json(&request.body);
80
81        let response = builder
82            .send()
83            .await
84            .map_err(|e| SynapticError::Model(format!("HTTP stream request failed: {e}")))?;
85
86        let stream = response
87            .bytes_stream()
88            .map(|result| result.map_err(|e| SynapticError::Model(format!("stream error: {e}"))));
89
90        Ok(Box::pin(stream))
91    }
92}
93
94/// Test backend with queued responses and stream chunks.
95pub struct FakeBackend {
96    responses: Arc<Mutex<VecDeque<Result<ProviderResponse, SynapticError>>>>,
97    stream_chunks: Arc<Mutex<VecDeque<Vec<bytes::Bytes>>>>,
98}
99
100impl FakeBackend {
101    pub fn new() -> Self {
102        Self {
103            responses: Arc::new(Mutex::new(VecDeque::new())),
104            stream_chunks: Arc::new(Mutex::new(VecDeque::new())),
105        }
106    }
107
108    pub fn push_response(&self, response: ProviderResponse) -> &Self {
109        self.responses
110            .try_lock()
111            .expect("not concurrent during setup")
112            .push_back(Ok(response));
113        self
114    }
115
116    pub fn push_error(&self, error: SynapticError) -> &Self {
117        self.responses
118            .try_lock()
119            .expect("not concurrent during setup")
120            .push_back(Err(error));
121        self
122    }
123
124    pub fn push_stream_chunks(&self, chunks: Vec<bytes::Bytes>) -> &Self {
125        self.stream_chunks
126            .try_lock()
127            .expect("not concurrent during setup")
128            .push_back(chunks);
129        self
130    }
131}
132
133impl Default for FakeBackend {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139#[async_trait]
140impl ProviderBackend for FakeBackend {
141    async fn send(&self, _request: ProviderRequest) -> Result<ProviderResponse, SynapticError> {
142        let mut responses = self.responses.lock().await;
143        responses
144            .pop_front()
145            .unwrap_or_else(|| Err(SynapticError::Model("FakeBackend exhausted".to_string())))
146    }
147
148    async fn send_stream(&self, _request: ProviderRequest) -> Result<ByteStream, SynapticError> {
149        let mut stream_chunks = self.stream_chunks.lock().await;
150        let chunks = stream_chunks.pop_front().unwrap_or_default();
151
152        let stream = futures::stream::iter(chunks.into_iter().map(Ok));
153        Ok(Box::pin(stream))
154    }
155}