1use bb8::Pool;
7use sozu_command_lib::{
8 channel::ChannelError,
9 proto::command::{request::RequestType, Request, Response, ResponseStatus, WorkerRequest},
10};
11use tempdir::TempDir;
12use tokio::{
13 fs::File,
14 io::{AsyncWriteExt, BufWriter},
15 task::{spawn_blocking as blocking, JoinError},
16};
17use tracing::trace;
18
19use crate::channel::{ConnectionManager, ConnectionProperties};
20
21pub mod channel;
22pub mod config;
23pub mod socket;
24#[cfg(feature = "unpooled")]
25pub mod unpooled;
26
27#[derive(thiserror::Error, Debug)]
31pub enum Error {
32 #[error("failed to create connection pool over unix socket, {0}")]
33 CreatePool(channel::Error),
34 #[error("failed to execute blocking task, {0}")]
35 Join(JoinError),
36 #[error("failed to get connection to socket, {0}")]
37 GetConnection(bb8::RunError<channel::Error>),
38 #[error("failed to send request, {0}")]
39 Send(ChannelError),
40 #[error("failed to read response, {0}")]
41 Receive(ChannelError),
42 #[error("got an invalid status code, {0}")]
43 InvalidStatusCode(i32),
44 #[error("failed to execute request, got status '{0}', {1}")]
45 Failure(String, String, Response),
46 #[error("failed to create temporary directory, {0}")]
47 CreateTempDir(std::io::Error),
48 #[error("failed to create temporary file, {0}")]
49 CreateTempFile(std::io::Error),
50 #[error("failed to serialize worker request, {0}")]
51 Serialize(serde_json::Error),
52 #[error("failed to write worker request, {0}")]
53 Write(std::io::Error),
54 #[error("failed to flush worker request buffer, {0}")]
55 Flush(std::io::Error),
56}
57
58impl From<JoinError> for Error {
59 #[tracing::instrument]
60 fn from(err: JoinError) -> Self {
61 Self::Join(err)
62 }
63}
64
65impl Error {
66 #[tracing::instrument]
67 pub fn is_recoverable(&self) -> bool {
68 !matches!(self, Self::Send(_) | Self::Receive(_) | Self::CreatePool(_) | Self::GetConnection(_))
69 }
70}
71
72#[async_trait::async_trait]
76pub trait Sender {
77 type Error;
78
79 async fn send(&self, request: RequestType) -> Result<Response, Self::Error>;
80
81 async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error>;
82}
83
84#[derive(Clone, Debug)]
88pub struct Client {
89 pool: Pool<ConnectionManager>,
90}
91
92#[async_trait::async_trait]
93impl Sender for Client {
94 type Error = Error;
95
96 #[tracing::instrument(skip_all)]
97 async fn send(&self, request: RequestType) -> Result<Response, Self::Error> {
98 trace!("Retrieve a connection to Sōzu's socket");
99 let mut conn = self.pool.get().await.map_err(Error::GetConnection)?;
100
101 trace!("Send request to Sōzu");
102 conn.write_message(&Request {
103 request_type: Some(request),
104 })
105 .map_err(Error::Send)?;
106
107 loop {
108 trace!("Read request to Sōzu");
109 let response = conn.read_message().map_err(Error::Receive)?;
110
111 let status = ResponseStatus::try_from(response.status)
112 .map_err(|_| Error::InvalidStatusCode(response.status))?;
113
114 match status {
115 ResponseStatus::Processing => continue,
116 ResponseStatus::Failure => {
117 return Err(Error::Failure(status.as_str_name().to_string(), response.message.to_string().to_lowercase(), response));
118 }
119 ResponseStatus::Ok => {
120 return Ok(response);
121 }
122 }
123 }
124 }
125
126 #[tracing::instrument(skip_all)]
127 async fn send_all(&self, requests: &[RequestType]) -> Result<Response, Self::Error> {
128 let tmpdir =
131 blocking(|| TempDir::new(env!("CARGO_PKG_NAME")).map_err(Error::CreateTempDir))
132 .await??;
133
134 let path = tmpdir.path().join("requests.json");
135 let mut writer = BufWriter::new(File::create(&path).await.map_err(Error::CreateTempFile)?);
136
137 for (idx, request) in requests.iter().cloned().enumerate() {
138 let worker_request = WorkerRequest {
139 id: format!("{}-{idx}", env!("CARGO_PKG_NAME")).to_uppercase(),
140 content: Request::from(request),
141 };
142
143 let payload =
144 blocking(move || serde_json::to_string(&worker_request).map_err(Error::Serialize))
145 .await??;
146
147 writer
148 .write_all(format!("{payload}\n\0").as_bytes())
149 .await
150 .map_err(Error::Write)?;
151 }
152
153 writer.flush().await.map_err(Error::Flush)?;
154
155 self.send(RequestType::LoadState(path.to_string_lossy().to_string()))
158 .await
159 }
160}
161
162impl From<Pool<ConnectionManager>> for Client {
163 #[tracing::instrument(skip_all)]
164 fn from(pool: Pool<ConnectionManager>) -> Self {
165 Self { pool }
166 }
167}
168
169impl Client {
170 #[tracing::instrument]
171 pub async fn try_new(opts: ConnectionProperties) -> Result<Self, Error> {
172 let pool = Pool::builder()
173 .build(ConnectionManager::new(opts))
174 .await
175 .map_err(Error::CreatePool)?;
176
177 Ok(Self::from(pool))
178 }
179}