sozu_client/
lib.rs

1//! # Sōzu client
2//!
3//! This library provides a client to interact with Sōzu.
4//! The client is able to do one-time request or send batches.
5
6use bb8::Pool;
7use sozu_command_lib::{
8    channel::ChannelError,
9    proto::command::{Request, Response, ResponseStatus, WorkerRequest, request::RequestType},
10};
11
12use tokio::{
13    fs::File,
14    io::{AsyncWriteExt, BufWriter},
15    task::{JoinError, spawn_blocking as blocking},
16};
17use tracing::trace;
18
19use crate::channel::{ConnectionManager, ConnectionProperties};
20
21pub mod channel;
22pub mod config;
23pub mod socket;
24#[cfg(feature = "unpooled")]
25pub mod unpooled;
26
27// -----------------------------------------------------------------------------
28// Error
29
30#[derive(thiserror::Error, Debug)]
31pub enum Error {
32    #[error("failed to create connection pool over unix socket, {0}")]
33    CreatePool(channel::Error),
34    #[error("failed to execute blocking task, {0}")]
35    Join(JoinError),
36    #[error("failed to get connection to socket, {0}")]
37    GetConnection(bb8::RunError<channel::Error>),
38    #[error("failed to send request, {0}")]
39    Send(ChannelError),
40    #[error("failed to read response, {0}")]
41    Receive(ChannelError),
42    #[error("got an invalid status code, {0}")]
43    InvalidStatusCode(i32),
44    #[error("failed to execute request, got status '{0}', {1}")]
45    Failure(String, String, Response),
46    #[error("failed to create temporary directory, {0}")]
47    CreateTempDir(std::io::Error),
48    #[error("failed to create temporary file, {0}")]
49    CreateTempFile(std::io::Error),
50    #[error("failed to serialize worker request, {0}")]
51    Serialize(serde_json::Error),
52    #[error("failed to write worker request, {0}")]
53    Write(std::io::Error),
54    #[error("failed to flush worker request buffer, {0}")]
55    Flush(std::io::Error),
56}
57
58impl From<JoinError> for Error {
59    #[tracing::instrument]
60    fn from(err: JoinError) -> Self {
61        Self::Join(err)
62    }
63}
64
65impl Error {
66    #[tracing::instrument]
67    pub fn is_recoverable(&self) -> bool {
68        !matches!(
69            self,
70            Self::Send(_) | Self::Receive(_) | Self::CreatePool(_) | Self::GetConnection(_)
71        )
72    }
73}
74
75// -----------------------------------------------------------------------------
76// Sender
77
78#[async_trait::async_trait]
79pub trait Sender {
80    type Error;
81
82    async fn send(&self, request: RequestType) -> Result<Response, Self::Error>;
83
84    async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error>;
85}
86
87// -----------------------------------------------------------------------------
88// Client
89
90#[derive(Clone, Debug)]
91pub struct Client {
92    pool: Pool<ConnectionManager>,
93}
94
95#[async_trait::async_trait]
96impl Sender for Client {
97    type Error = Error;
98
99    #[tracing::instrument(skip_all)]
100    async fn send(&self, request: RequestType) -> Result<Response, Self::Error> {
101        trace!("Retrieve a connection to Sōzu's socket");
102        let mut conn = self.pool.get().await.map_err(Error::GetConnection)?;
103
104        trace!("Send request to Sōzu");
105        conn.write_message(&Request {
106            request_type: Some(request),
107        })
108        .map_err(Error::Send)?;
109
110        loop {
111            trace!("Read request to Sōzu");
112            let response = conn.read_message().map_err(Error::Receive)?;
113
114            let status = ResponseStatus::try_from(response.status)
115                .map_err(|_| Error::InvalidStatusCode(response.status))?;
116
117            match status {
118                ResponseStatus::Processing => continue,
119                ResponseStatus::Failure => {
120                    return Err(Error::Failure(
121                        status.as_str_name().to_string(),
122                        response.message.to_string().to_lowercase(),
123                        response,
124                    ));
125                }
126                ResponseStatus::Ok => {
127                    return Ok(response);
128                }
129            }
130        }
131    }
132
133    #[tracing::instrument(skip_all)]
134    async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error> {
135        // -------------------------------------------------------------------------
136        // Create temporary folder and writer to batch requests
137        let tmpdir = blocking(|| tempfile::tempdir().map_err(Error::CreateTempDir)).await??;
138
139        let path = tmpdir.path().join("requests.json");
140        let mut writer = BufWriter::new(File::create(&path).await.map_err(Error::CreateTempFile)?);
141
142        for (idx, request) in requests.iter().cloned().enumerate() {
143            let worker_request = WorkerRequest {
144                id: format!("{}-{idx}", env!("CARGO_PKG_NAME")).to_uppercase(),
145                content: Request::from(request),
146            };
147
148            let payload =
149                blocking(move || serde_json::to_string(&worker_request).map_err(Error::Serialize))
150                    .await??;
151
152            writer
153                .write_all(format!("{payload}\n\0").as_bytes())
154                .await
155                .map_err(Error::Write)?;
156        }
157
158        writer.flush().await.map_err(Error::Flush)?;
159
160        // -------------------------------------------------------------------------
161        // Send a LoadState request with the file that we have created.
162        self.send(RequestType::LoadState(path.to_string_lossy().to_string()))
163            .await
164    }
165}
166
167impl From<Pool<ConnectionManager>> for Client {
168    #[tracing::instrument(skip_all)]
169    fn from(pool: Pool<ConnectionManager>) -> Self {
170        Self { pool }
171    }
172}
173
174impl Client {
175    #[tracing::instrument]
176    pub async fn try_new(opts: ConnectionProperties) -> Result<Self, Error> {
177        let pool = Pool::builder()
178            .build(ConnectionManager::new(opts))
179            .await
180            .map_err(Error::CreatePool)?;
181
182        Ok(Self::from(pool))
183    }
184}