rust_actions_cache_api/
lib.rs1use bytes::Bytes;
10use reqwest::{Client, RequestBuilder, Response};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
16#[non_exhaustive]
17pub enum Error {
18 #[error(transparent)]
20 Reqwest(#[from] reqwest::Error),
21 #[error("server rate limited the request, asking to wait {retry_after} seconds")]
23 RateLimit {
24 retry_after: u64,
26 #[source]
28 source: reqwest::Error,
29 },
30 #[error("did not find a runtime token in the ACTIONS_RUNTIME_TOKEN environment variable")]
32 NoRuntimeToken,
33 #[error("did not find the endpoint URL in the ACTIONS_CACHE_URL environment variable")]
35 NoEndpointUrl,
36}
37
38impl Error {
39 pub fn retry_after(&self) -> Option<u64> {
43 if let Self::RateLimit { retry_after, .. } = *self {
44 Some(retry_after)
45 } else {
46 None
47 }
48 }
49}
50
51pub type Result<T, E = Error> = std::result::Result<T, E>;
53
54#[derive(Deserialize, Debug)]
56pub struct CacheHit {
57 #[serde(rename = "cacheKey")]
59 pub key: String,
60 pub scope: String,
62}
63
64pub struct Cache {
69 client: Client,
70 token: String,
71 endpoint: String,
72}
73
74impl Cache {
75 pub fn new(user_agent: &str) -> Result<Self> {
79 let token = std::env::var("ACTIONS_RUNTIME_TOKEN").map_err(|_| Error::NoRuntimeToken)?;
80
81 let endpoint = format!(
82 "{}/_apis/artifactcache",
83 std::env::var("ACTIONS_CACHE_URL")
84 .map_err(|_| Error::NoEndpointUrl)?
85 .trim_end_matches('/')
86 );
87
88 let client = Client::builder().user_agent(user_agent).build()?;
89
90 Ok(Self {
91 client,
92 token,
93 endpoint,
94 })
95 }
96
97 fn api_request(&self, builder: RequestBuilder) -> RequestBuilder {
99 builder.bearer_auth(&self.token).header(
100 reqwest::header::ACCEPT,
101 "application/json;api-version=6.0-preview.1",
102 )
103 }
104
105 pub async fn get_url(
115 &self,
116 key_space: &str,
117 key_prefixes: &[&str],
118 ) -> Result<Option<(CacheHit, String)>> {
119 #[derive(Deserialize)]
120 pub struct GetResponse {
121 #[serde(flatten)]
122 hit: CacheHit,
123 #[serde(rename = "archiveLocation")]
124 location: String,
125 }
126
127 let response = self
128 .api_request(self.client.get(format!("{}/cache", self.endpoint)))
129 .query(&[("keys", &*key_prefixes.join(",")), ("version", key_space)])
130 .send()
131 .await?;
132
133 tracing::debug!(response_headers = ?response.headers());
134
135 if response.status() == reqwest::StatusCode::NO_CONTENT {
136 Ok(None)
137 } else {
138 let response: GetResponse = error_for_response(response)?.json().await?;
139 Ok(Some((response.hit, response.location)))
140 }
141 }
142
143 pub async fn get_bytes(
147 &self,
148 key_space: &str,
149 keys: &[&str],
150 ) -> Result<Option<(CacheHit, Bytes)>> {
151 if let Some((hit, location)) = self.get_url(key_space, keys).await? {
152 let response = self.client.get(location).send().await?;
153
154 tracing::debug!(response_headers = ?response.headers());
155
156 Ok(Some((hit, response.bytes().await?)))
157 } else {
158 Ok(None)
159 }
160 }
161
162 pub async fn put_bytes(&self, key_space: &str, key: &str, data: Bytes) -> Result<()> {
164 #[derive(Serialize)]
165 struct ReserveRequest<'a> {
166 key: &'a str,
167 version: &'a str,
168 }
169 #[derive(Deserialize)]
170 struct ReserveResponse {
171 #[serde(rename = "cacheId")]
172 cache_id: i64,
173 }
174
175 let response = self
176 .api_request(self.client.post(format!("{}/caches", self.endpoint)))
177 .json(&ReserveRequest {
178 key,
179 version: key_space,
180 })
181 .send()
182 .await?;
183
184 tracing::debug!(response_headers = ?response.headers());
185
186 let ReserveResponse { cache_id } = error_for_response(response)?.json().await?;
187
188 if !data.is_empty() {
189 let response = self
190 .api_request(
191 self.client
192 .patch(format!("{}/caches/{}", self.endpoint, cache_id)),
193 )
194 .header(
195 reqwest::header::CONTENT_RANGE,
196 format!("bytes {}-{}/*", 0, data.len() - 1),
197 )
198 .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
199 .body(data.clone())
200 .send()
201 .await?;
202
203 tracing::debug!(response_headers = ?response.headers());
204
205 error_for_response(response)?;
206 }
207
208 #[derive(Serialize)]
209 struct RequestBody<'a> {
210 key: &'a str,
211 version: &'a str,
212 }
213
214 #[derive(Serialize)]
215 struct FinalizeRequest {
216 size: usize,
217 }
218
219 let response = self
220 .api_request(
221 self.client
222 .post(format!("{}/caches/{}", self.endpoint, cache_id)),
223 )
224 .json(&FinalizeRequest { size: data.len() })
225 .send()
226 .await?;
227
228 tracing::debug!(response_headers = ?response.headers());
229
230 error_for_response(response)?;
231 Ok(())
232 }
233}
234
235fn error_for_response(response: Response) -> Result<Response> {
236 if response.status().is_client_error() || response.status().is_server_error() {
237 if let Some(retry_after) = response
238 .headers()
239 .get(reqwest::header::RETRY_AFTER)
240 .and_then(|v| v.to_str().ok()?.parse().ok())
241 {
242 return Err(Error::RateLimit {
243 retry_after,
244 source: response.error_for_status().unwrap_err(),
245 });
246 }
247 }
248 response.error_for_status().map_err(Into::into)
249}