1use bb8::Pool;
7use sozu_command_lib::{
8 channel::ChannelError,
9 proto::command::{Request, Response, ResponseStatus, WorkerRequest, request::RequestType},
10};
11
12use tokio::{
13 fs::File,
14 io::{AsyncWriteExt, BufWriter},
15 task::{JoinError, spawn_blocking as blocking},
16};
17use tracing::trace;
18
19use crate::channel::{ConnectionManager, ConnectionProperties};
20
21pub mod channel;
22pub mod config;
23pub mod socket;
24#[cfg(feature = "unpooled")]
25pub mod unpooled;
26
27#[derive(thiserror::Error, Debug)]
31pub enum Error {
32 #[error("failed to create connection pool over unix socket, {0}")]
33 CreatePool(channel::Error),
34 #[error("failed to execute blocking task, {0}")]
35 Join(JoinError),
36 #[error("failed to get connection to socket, {0}")]
37 GetConnection(bb8::RunError<channel::Error>),
38 #[error("failed to send request, {0}")]
39 Send(ChannelError),
40 #[error("failed to read response, {0}")]
41 Receive(ChannelError),
42 #[error("got an invalid status code, {0}")]
43 InvalidStatusCode(i32),
44 #[error("failed to execute request, got status '{0}', {1}")]
45 Failure(String, String, Response),
46 #[error("failed to create temporary directory, {0}")]
47 CreateTempDir(std::io::Error),
48 #[error("failed to create temporary file, {0}")]
49 CreateTempFile(std::io::Error),
50 #[error("failed to serialize worker request, {0}")]
51 Serialize(serde_json::Error),
52 #[error("failed to write worker request, {0}")]
53 Write(std::io::Error),
54 #[error("failed to flush worker request buffer, {0}")]
55 Flush(std::io::Error),
56}
57
58impl From<JoinError> for Error {
59 #[tracing::instrument]
60 fn from(err: JoinError) -> Self {
61 Self::Join(err)
62 }
63}
64
65impl Error {
66 #[tracing::instrument]
67 pub fn is_recoverable(&self) -> bool {
68 !matches!(
69 self,
70 Self::Send(_) | Self::Receive(_) | Self::CreatePool(_) | Self::GetConnection(_)
71 )
72 }
73}
74
75#[async_trait::async_trait]
79pub trait Sender {
80 type Error;
81
82 async fn send(&self, request: RequestType) -> Result<Response, Self::Error>;
83
84 async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error>;
85}
86
87#[derive(Clone, Debug)]
91pub struct Client {
92 pool: Pool<ConnectionManager>,
93}
94
95#[async_trait::async_trait]
96impl Sender for Client {
97 type Error = Error;
98
99 #[tracing::instrument(skip_all)]
100 async fn send(&self, request: RequestType) -> Result<Response, Self::Error> {
101 trace!("Retrieve a connection to Sōzu's socket");
102 let mut conn = self.pool.get().await.map_err(Error::GetConnection)?;
103
104 trace!("Send request to Sōzu");
105 conn.write_message(&Request {
106 request_type: Some(request),
107 })
108 .map_err(Error::Send)?;
109
110 loop {
111 trace!("Read request to Sōzu");
112 let response = conn.read_message().map_err(Error::Receive)?;
113
114 let status = ResponseStatus::try_from(response.status)
115 .map_err(|_| Error::InvalidStatusCode(response.status))?;
116
117 match status {
118 ResponseStatus::Processing => continue,
119 ResponseStatus::Failure => {
120 return Err(Error::Failure(
121 status.as_str_name().to_string(),
122 response.message.to_string().to_lowercase(),
123 response,
124 ));
125 }
126 ResponseStatus::Ok => {
127 return Ok(response);
128 }
129 }
130 }
131 }
132
133 #[tracing::instrument(skip_all)]
134 async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error> {
135 let tmpdir = blocking(|| tempfile::tempdir().map_err(Error::CreateTempDir)).await??;
138
139 let path = tmpdir.path().join("requests.json");
140 let mut writer = BufWriter::new(File::create(&path).await.map_err(Error::CreateTempFile)?);
141
142 for (idx, request) in requests.iter().cloned().enumerate() {
143 let worker_request = WorkerRequest {
144 id: format!("{}-{idx}", env!("CARGO_PKG_NAME")).to_uppercase(),
145 content: Request::from(request),
146 };
147
148 let payload =
149 blocking(move || serde_json::to_string(&worker_request).map_err(Error::Serialize))
150 .await??;
151
152 writer
153 .write_all(format!("{payload}\n\0").as_bytes())
154 .await
155 .map_err(Error::Write)?;
156 }
157
158 writer.flush().await.map_err(Error::Flush)?;
159
160 self.send(RequestType::LoadState(path.to_string_lossy().to_string()))
163 .await
164 }
165}
166
167impl From<Pool<ConnectionManager>> for Client {
168 #[tracing::instrument(skip_all)]
169 fn from(pool: Pool<ConnectionManager>) -> Self {
170 Self { pool }
171 }
172}
173
174impl Client {
175 #[tracing::instrument]
176 pub async fn try_new(opts: ConnectionProperties) -> Result<Self, Error> {
177 let pool = Pool::builder()
178 .build(ConnectionManager::new(opts))
179 .await
180 .map_err(Error::CreatePool)?;
181
182 Ok(Self::from(pool))
183 }
184}