walker_common/scoop/
source.rs

1use crate::USER_AGENT;
2use anyhow::bail;
3use aws_config::{BehaviorVersion, Region, meta::region::RegionProviderChain};
4use aws_sdk_s3::{
5    Client,
6    config::{AppName, Credentials},
7};
8use bytes::Bytes;
9use std::{
10    borrow::Cow,
11    path::{Path, PathBuf},
12};
13use url::Url;
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub enum Source {
17    Path(PathBuf),
18    Http(Url),
19    S3(S3),
20}
21
22impl TryFrom<&str> for Source {
23    type Error = anyhow::Error;
24
25    fn try_from(value: &str) -> Result<Self, Self::Error> {
26        Ok(
27            if value.starts_with("http://") || value.starts_with("https://") {
28                Self::Http(Url::parse(value)?)
29            } else if value.starts_with("s3://") {
30                Self::S3(S3::try_from(value)?)
31            } else {
32                Self::Path(value.into())
33            },
34        )
35    }
36}
37
38impl Source {
39    pub async fn discover(self) -> anyhow::Result<Vec<Self>> {
40        match self {
41            Self::Path(path) => Ok(Self::discover_path(path)?
42                .into_iter()
43                .map(Self::Path)
44                .collect()),
45            Self::S3(s3) if s3.key.is_none() => Ok(Self::discover_s3(s3)
46                .await?
47                .into_iter()
48                .map(Self::S3)
49                .collect()),
50            value => Ok(vec![value]),
51        }
52    }
53
54    fn discover_path(path: PathBuf) -> anyhow::Result<Vec<PathBuf>> {
55        log::debug!("Discovering: {}", path.display());
56
57        if !path.exists() {
58            bail!("{} does not exist", path.display());
59        } else if path.is_file() {
60            log::debug!("Is a file");
61            Ok(vec![path])
62        } else if path.is_dir() {
63            log::debug!("Is a directory");
64            let mut result = Vec::new();
65
66            for path in walkdir::WalkDir::new(path).into_iter() {
67                let path = path?;
68                if path.file_type().is_file() {
69                    result.push(path.path().to_path_buf());
70                }
71            }
72
73            Ok(result)
74        } else {
75            log::warn!("Is something unknown: {}", path.display());
76            Ok(vec![])
77        }
78    }
79
80    async fn discover_s3(s3: S3) -> anyhow::Result<Vec<S3>> {
81        let client = s3.client().await?;
82
83        let mut response = client
84            .list_objects_v2()
85            .bucket(s3.bucket.clone())
86            .max_keys(100)
87            .into_paginator()
88            .send();
89
90        let mut result = vec![];
91        while let Some(next) = response.next().await {
92            let next = next?;
93            for object in next.contents() {
94                if let Some(key) = object.key.clone() {
95                    result.push(key);
96                }
97            }
98        }
99
100        Ok(result
101            .into_iter()
102            .map(|key| S3 {
103                key: Some(key),
104                ..(s3.clone())
105            })
106            .collect())
107    }
108
109    pub fn name(&self) -> Cow<'_, str> {
110        match self {
111            Self::Path(path) => path.to_string_lossy(),
112            Self::Http(url) => url.as_str().into(),
113            Self::S3(s3) => format!(
114                "s3://{}/{}/{}",
115                s3.region,
116                s3.bucket,
117                s3.key.as_deref().unwrap_or_default()
118            )
119            .into(),
120        }
121    }
122
123    /// Load the content of the source
124    pub async fn load(&self) -> Result<Bytes, anyhow::Error> {
125        Ok(match self {
126            Self::Path(path) => tokio::fs::read(path).await?.into(),
127            Self::Http(url) => {
128                reqwest::get(url.clone())
129                    .await?
130                    .error_for_status()?
131                    .bytes()
132                    .await?
133            }
134            Self::S3(s3) => {
135                let client = s3.client();
136                client
137                    .await?
138                    .get_object()
139                    .key(s3.key.clone().unwrap_or_default())
140                    .bucket(s3.bucket.clone())
141                    .send()
142                    .await?
143                    .body
144                    .collect()
145                    .await?
146                    .into_bytes()
147            }
148        })
149    }
150
151    /// Delete the source
152    pub async fn delete(&self) -> anyhow::Result<()> {
153        match self {
154            Self::Path(file) => {
155                // just delete the file
156                tokio::fs::remove_file(&file).await?;
157            }
158            Self::Http(url) => {
159                // issue a DELETE request
160                reqwest::Client::builder()
161                    .build()?
162                    .delete(url.clone())
163                    .send()
164                    .await?;
165            }
166            Self::S3(s3) => {
167                // delete the object from the bucket
168                let client = s3.client();
169                client
170                    .await?
171                    .delete_object()
172                    .key(s3.key.clone().unwrap_or_default())
173                    .bucket(s3.bucket.clone())
174                    .send()
175                    .await?;
176            }
177        }
178
179        Ok(())
180    }
181
182    /// move the source
183    ///
184    /// NOTE: This is a no-op for HTTP sources.
185    pub async fn r#move(&self, path: &str) -> anyhow::Result<()> {
186        match self {
187            Self::Path(file) => {
188                let path = Path::new(&path);
189                tokio::fs::create_dir_all(path).await?;
190                tokio::fs::copy(&file, path.join(file)).await?;
191                tokio::fs::remove_file(&file).await?;
192            }
193            Self::Http(url) => {
194                // no-op, but warn
195                log::warn!("Unable to move HTTP source ({url}), skipping!");
196            }
197            Self::S3(s3) => {
198                let client = s3.client();
199                client
200                    .await?
201                    .copy_object()
202                    .copy_source(s3.key.clone().unwrap_or_default())
203                    .key(format!("{path}/{}", s3.key.as_deref().unwrap_or_default()))
204                    .bucket(s3.bucket.clone())
205                    .send()
206                    .await?;
207            }
208        }
209
210        Ok(())
211    }
212}
213
214#[derive(Clone, Debug, PartialEq, Eq)]
215pub struct S3 {
216    region: String,
217    credentials: Option<(String, String)>,
218    bucket: String,
219    key: Option<String>,
220}
221
222impl TryFrom<&str> for S3 {
223    type Error = anyhow::Error;
224
225    fn try_from(value: &str) -> Result<Self, Self::Error> {
226        let uri = fluent_uri::Uri::try_from(value)?;
227
228        let Some(auth) = uri.authority() else {
229            bail!("Missing authority");
230        };
231
232        let path = uri.path().to_string();
233        let path = path.trim_start_matches('/');
234        if path.is_empty() {
235            bail!("Missing bucket");
236        }
237
238        let (bucket, key) = match path.split_once('/') {
239            Some((bucket, key)) => (bucket.to_string(), Some(key.to_string())),
240            None => (path.to_string(), None),
241        };
242
243        let region = auth.host().to_string();
244
245        let credentials = auth.userinfo().and_then(|userinfo| {
246            userinfo
247                .split_once(':')
248                .map(|(username, password)| (username.to_string(), password.to_string()))
249        });
250
251        Ok(S3 {
252            region,
253            credentials,
254            bucket,
255            key,
256        })
257    }
258}
259
260impl S3 {
261    pub async fn client(&self) -> anyhow::Result<Client> {
262        let region_provider = RegionProviderChain::first_try(Region::new(self.region.clone()));
263
264        let mut shared_config = aws_config::defaults(BehaviorVersion::latest())
265            .region(region_provider)
266            .app_name(AppName::new(USER_AGENT)?);
267
268        if let Some((key_id, access_key)) = &self.credentials {
269            let credentials = Credentials::new(key_id, access_key, None, None, "config");
270            shared_config = shared_config.credentials_provider(credentials);
271        }
272
273        let shared_config = shared_config.load().await;
274
275        Ok(Client::new(&shared_config))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn parse_s3() {
285        assert_eq!(
286            S3 {
287                region: "us-east-1".to_string(),
288                credentials: None,
289                bucket: "b1".to_string(),
290                key: None,
291            },
292            S3::try_from("s3://us-east-1/b1").unwrap()
293        );
294        assert_eq!(
295            S3 {
296                region: "us-east-1".to_string(),
297                credentials: Some(("foo".to_string(), "bar".to_string())),
298                bucket: "b1".to_string(),
299                key: None,
300            },
301            S3::try_from("s3://foo:bar@us-east-1/b1").unwrap()
302        );
303        assert_eq!(
304            S3 {
305                region: "us-east-1".to_string(),
306                credentials: Some(("foo".to_string(), "bar".to_string())),
307                bucket: "b1".to_string(),
308                key: Some("path/to/file".to_string()),
309            },
310            S3::try_from("s3://foo:bar@us-east-1/b1/path/to/file").unwrap()
311        );
312    }
313
314    #[test]
315    fn parse_s3_custom_region() {
316        assert_eq!(
317            S3 {
318                region: "my.own.endpoint".to_string(),
319                credentials: None,
320                bucket: "b1".to_string(),
321                key: None,
322            },
323            S3::try_from("s3://my.own.endpoint/b1").unwrap()
324        );
325        assert_eq!(
326            S3 {
327                region: "my.own.endpoint".to_string(),
328                credentials: Some(("foo".to_string(), "bar".to_string())),
329                bucket: "b1".to_string(),
330                key: None,
331            },
332            S3::try_from("s3://foo:bar@my.own.endpoint/b1").unwrap()
333        );
334        assert_eq!(
335            S3 {
336                region: "my.own.endpoint".to_string(),
337                credentials: Some(("foo".to_string(), "bar".to_string())),
338                bucket: "b1".to_string(),
339                key: Some("path/to/file".to_string()),
340            },
341            S3::try_from("s3://foo:bar@my.own.endpoint/b1/path/to/file").unwrap()
342        );
343    }
344}