treasury_api/
lib.rs

1//! Defines API to communicate between client and server.
2
3use std::{
4    mem::{size_of, size_of_val},
5    sync::atomic::{AtomicU32, Ordering},
6};
7
8use bincode::Options;
9use eyre::WrapErr;
10use serde::{de::DeserializeOwned, Serialize};
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12use treasury_id::AssetId;
13
14/// First message that must be sent by client after connection to treasury server.
15#[repr(C)]
16pub struct Handshake {
17    /// Magic value that must be equal to [`MAGIC`]. Otherwise server SHOULD drop the connection.
18    pub magic: u32,
19
20    /// Major version of the crate used by client. If versions used by client and server mismatch, then server SHOULD drop the connection.
21    pub version: u32,
22}
23
24/// First request that must follow handshake.
25/// Opens particular treasury the client is going to work with.
26#[derive(Debug, serde::Serialize, serde::Deserialize)]
27pub struct OpenRequest {
28    /// Path to directory that contains Treasury.toml or any in descendants.
29    pub path: Box<str>,
30
31    /// Specifies that new treasury must be init. Fails if treasury directory already contains `Treasury.toml`
32    /// But succeeds in descendant directories.
33    pub init: bool,
34}
35
36/// Response to the `OpenRequest`
37#[derive(Debug, serde::Serialize, serde::Deserialize)]
38pub enum OpenResponse {
39    Success,
40
41    /// Failure.
42    /// Payload contains description.
43    Failure {
44        description: Box<str>,
45    },
46}
47
48/// Requests to Treasury instance.
49#[derive(Debug, serde::Serialize, serde::Deserialize)]
50pub enum Request {
51    /// Stores new asset into treasury.
52    Store {
53        /// Url for source file.
54        source: Box<str>,
55
56        /// Source format.
57        format: Option<Box<str>>,
58
59        /// Targe format.
60        target: Box<str>,
61    },
62
63    /// Fetches url of the artifact for the specified asset.
64    FetchUrl { id: AssetId },
65
66    /// Fetches url of the artifact for the specified asset.
67    FindAsset { source: Box<str>, target: Box<str> },
68}
69
70/// Response to store request.
71#[derive(Debug, serde::Serialize, serde::Deserialize)]
72pub enum StoreResponse {
73    /// Success.
74    /// Payload contains asset id.
75    Success { id: AssetId, path: Box<str> },
76
77    /// Storing process requires to read data from URL, but can't access it from treasury host.
78    NeedData { url: Box<str> },
79
80    /// Failure.
81    /// Payload contains description.
82    Failure { description: Box<str> },
83}
84
85#[derive(Debug, serde::Serialize, serde::Deserialize)]
86pub enum FetchUrlResponse {
87    /// Success.
88    /// Payload contains URL of the artifact.
89    Success { artifact: Box<str> },
90
91    /// Asset not found
92    NotFound,
93
94    /// Failure response to any store request.
95    Failure { description: Box<str> },
96}
97
98#[derive(Debug, serde::Serialize, serde::Deserialize)]
99pub enum FindResponse {
100    /// Success.
101    /// Payload contains URL of the artifact.
102    Success { id: AssetId, path: Box<str> },
103
104    /// Asset not found
105    NotFound,
106
107    /// Failure response to any store request.
108    Failure { description: Box<str> },
109}
110
111pub const MAGIC: u32 = u32::from_be_bytes(*b"TRES");
112
113pub fn version() -> u32 {
114    static VERSION: AtomicU32 = AtomicU32::new(u32::MAX);
115
116    #[cold]
117    fn init_version() -> u32 {
118        // Initialize
119        env!("CARGO_PKG_VERSION_MAJOR")
120            .parse()
121            .expect("Bad major version")
122    }
123
124    let mut version = VERSION.load(Ordering::Relaxed);
125    if version == u32::MAX {
126        version = init_version();
127        VERSION.store(version, Ordering::Relaxed);
128    }
129    version
130}
131
132#[derive(Debug)]
133#[repr(C)]
134pub struct MessageHeader {
135    pub size: u32,
136}
137
138pub const DEFAULT_PORT: u16 = 12345;
139
140pub fn get_port() -> u16 {
141    match std::env::var("TREASURY_SERVICE_PORT") {
142        Ok(port_string) => match port_string.parse() {
143            Ok(port) => port,
144            Err(_) => {
145                tracing::error!(
146                    "Failed to parse desired treasury port from env '{}'. Using default {}",
147                    port_string,
148                    DEFAULT_PORT
149                );
150                DEFAULT_PORT
151            }
152        },
153        Err(_) => DEFAULT_PORT,
154    }
155}
156
157const INLINE_MESSAGE_LIMIT: usize = 1 << 12; // 4 KiB
158const MESSAGE_LIMIT: usize = 1 << 28; // 256 MiB
159
160pub async fn send_message<T: Serialize>(
161    stream: &mut (impl AsyncWrite + Unpin),
162    message: T,
163) -> eyre::Result<()> {
164    let size = bincode_options()
165        .serialized_size(&message)
166        .wrap_err("Failed to determine serialized size of the message")?;
167
168    eyre::ensure!(size <= MESSAGE_LIMIT as u64, "Message is too large");
169
170    let size = size as u32;
171    let header = MessageHeader { size };
172    tracing::debug!("Sending message header {:?}", header);
173
174    let mut buffer = [0; INLINE_MESSAGE_LIMIT];
175    if size > INLINE_MESSAGE_LIMIT as u32 {
176        let mut buffer = vec![0; size_of::<MessageHeader>() + size as usize];
177
178        buffer[..size_of::<MessageHeader>()].copy_from_slice(&header.size.to_le_bytes());
179
180        bincode_options()
181            .serialize_into(&mut buffer[size_of::<MessageHeader>()..], &message)
182            .wrap_err("Failed to serialize message")?;
183
184        stream
185            .write_all(&buffer)
186            .await
187            .wrap_err("Failed to send message")?;
188
189        tracing::debug!("{} bytes sent", buffer.len());
190    } else {
191        let buffer = &mut buffer[..size_of::<MessageHeader>() + size as usize];
192
193        buffer[..size_of::<MessageHeader>()].copy_from_slice(&header.size.to_le_bytes());
194
195        bincode_options()
196            .serialize_into(&mut buffer[size_of::<MessageHeader>()..], &message)
197            .wrap_err("Failed to serialize message")?;
198
199        stream
200            .write_all(buffer)
201            .await
202            .wrap_err("Failed to send message")?;
203
204        tracing::debug!("{} bytes sent", buffer.len());
205    }
206
207    Ok(())
208}
209
210async fn next_message_header(
211    stream: &mut (impl AsyncRead + Unpin),
212) -> std::io::Result<Option<MessageHeader>> {
213    let mut buffer = [0; size_of::<MessageHeader>()];
214    match stream.read_exact(&mut buffer).await {
215        Ok(_) => Ok(Some(MessageHeader {
216            size: u32::from_le_bytes(buffer),
217        })),
218        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
219        Err(err) => Err(err),
220    }
221}
222
223pub async fn recv_message<T: DeserializeOwned>(
224    stream: &mut (impl AsyncRead + Unpin),
225) -> eyre::Result<Option<T>> {
226    let header = match next_message_header(stream).await? {
227        None => {
228            tracing::debug!("Connection closed");
229            return Ok(None);
230        }
231        Some(header) => header,
232    };
233
234    tracing::debug!("Next message header {:?}", header);
235
236    eyre::ensure!(header.size <= MESSAGE_LIMIT as u32, "Message is too large");
237
238    let mut buffer = [0; INLINE_MESSAGE_LIMIT];
239
240    if header.size > INLINE_MESSAGE_LIMIT as u32 {
241        let mut buffer = vec![0; header.size as usize];
242        stream.read_exact(&mut buffer).await?;
243
244        tracing::debug!(
245            "{} bytes received",
246            size_of::<MessageHeader>() + header.size as usize
247        );
248
249        let message = bincode_options()
250            .deserialize(&buffer)
251            .wrap_err("Failed to parse request")?;
252
253        Ok(Some(message))
254    } else {
255        let buffer = &mut buffer[..header.size as usize];
256        stream.read_exact(buffer).await?;
257
258        tracing::debug!(
259            "{} bytes received",
260            size_of::<MessageHeader>() + header.size as usize
261        );
262
263        let message = bincode_options()
264            .deserialize(buffer)
265            .wrap_err("Failed to parse request")?;
266
267        Ok(Some(message))
268    }
269}
270
271pub async fn recv_handshake(stream: &mut (impl AsyncRead + Unpin)) -> eyre::Result<()> {
272    let mut buffer = [0; size_of::<Handshake>()];
273
274    stream
275        .read_exact(&mut buffer)
276        .await
277        .wrap_err("Handshake failed")?;
278
279    let handshake = Handshake {
280        magic: u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]),
281        version: u32::from_le_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]),
282    };
283
284    tracing::debug!(
285        "Handshake received {}:{}",
286        handshake.magic,
287        handshake.version
288    );
289
290    eyre::ensure!(
291        handshake.magic == MAGIC,
292        "Wrong MAGIC number. Expected '{}', found '{}'",
293        MAGIC,
294        handshake.magic
295    );
296
297    let version = version();
298
299    eyre::ensure!(
300        handshake.version == version,
301        "Treasury API version mismatch. Expected '{}', found '{}'",
302        version,
303        handshake.version,
304    );
305
306    tracing::info!("Handshake valid");
307
308    Ok(())
309}
310
311pub async fn send_handshake(stream: &mut (impl AsyncWrite + Unpin)) -> eyre::Result<()> {
312    let mut buffer = [0; size_of::<Handshake>()];
313
314    buffer[..size_of_val(&MAGIC)].copy_from_slice(&MAGIC.to_le_bytes());
315    buffer[size_of_val(&MAGIC)..].copy_from_slice(&version().to_le_bytes());
316
317    stream
318        .write_all(&buffer)
319        .await
320        .wrap_err("Handshake failed")?;
321
322    tracing::debug!("Handshake sent {}:{}", MAGIC, version());
323
324    Ok(())
325}
326
327fn bincode_options() -> impl Options {
328    bincode::options()
329        .with_big_endian()
330        .with_fixint_encoding()
331        .allow_trailing_bytes()
332}