wsi_streamer/io/
s3_reader.rs

1use async_trait::async_trait;
2use aws_sdk_s3::Client;
3use bytes::Bytes;
4
5use super::RangeReader;
6use crate::error::IoError;
7
8/// S3-backed implementation of RangeReader.
9///
10/// Reads byte ranges from objects in S3 or S3-compatible storage (MinIO, GCS, etc.)
11/// using HTTP range requests. The object size is fetched once on creation via HEAD.
12#[derive(Clone)]
13pub struct S3RangeReader {
14    client: Client,
15    bucket: String,
16    key: String,
17    size: u64,
18    identifier: String,
19}
20
21impl S3RangeReader {
22    /// Create a new S3RangeReader for the given bucket and key.
23    ///
24    /// This performs a HEAD request to determine the object size.
25    /// Returns an error if the object does not exist or is inaccessible.
26    pub async fn new(client: Client, bucket: String, key: String) -> Result<Self, IoError> {
27        let head = client
28            .head_object()
29            .bucket(&bucket)
30            .key(&key)
31            .send()
32            .await
33            .map_err(|e| {
34                // Check if this is a 404 Not Found error
35                // The HeadObjectError has an is_not_found() method that we can use
36                let is_not_found = e
37                    .as_service_error()
38                    .map(|se| se.is_not_found())
39                    .unwrap_or(false);
40
41                if is_not_found {
42                    return IoError::NotFound(format!("s3://{}/{}", bucket, key));
43                }
44
45                // Also check for 404 status code in the raw response
46                let status_is_404 = e
47                    .raw_response()
48                    .map(|r| r.status().as_u16() == 404)
49                    .unwrap_or(false);
50
51                if status_is_404 {
52                    return IoError::NotFound(format!("s3://{}/{}", bucket, key));
53                }
54
55                // Fallback: check the error string for common patterns
56                let err_str = e.to_string();
57                if err_str.contains("NotFound")
58                    || err_str.contains("NoSuchKey")
59                    || err_str.contains("404")
60                {
61                    return IoError::NotFound(format!("s3://{}/{}", bucket, key));
62                }
63
64                IoError::S3(err_str)
65            })?;
66
67        let size = head.content_length().unwrap_or(0) as u64;
68        let identifier = format!("s3://{}/{}", bucket, key);
69
70        Ok(Self {
71            client,
72            bucket,
73            key,
74            size,
75            identifier,
76        })
77    }
78
79    /// Get the bucket name.
80    pub fn bucket(&self) -> &str {
81        &self.bucket
82    }
83
84    /// Get the object key.
85    pub fn key(&self) -> &str {
86        &self.key
87    }
88}
89
90#[async_trait]
91impl RangeReader for S3RangeReader {
92    async fn read_exact_at(&self, offset: u64, len: usize) -> Result<Bytes, IoError> {
93        // Validate range bounds
94        if offset + len as u64 > self.size {
95            return Err(IoError::RangeOutOfBounds {
96                offset,
97                requested: len as u64,
98                size: self.size,
99            });
100        }
101
102        // Handle zero-length reads
103        if len == 0 {
104            return Ok(Bytes::new());
105        }
106
107        // Build range header: "bytes=start-end" (inclusive on both ends)
108        let range = format!("bytes={}-{}", offset, offset + len as u64 - 1);
109
110        let resp = self
111            .client
112            .get_object()
113            .bucket(&self.bucket)
114            .key(&self.key)
115            .range(range)
116            .send()
117            .await
118            .map_err(|e| IoError::S3(e.to_string()))?;
119
120        let data = resp
121            .body
122            .collect()
123            .await
124            .map_err(|e| IoError::Connection(e.to_string()))?
125            .into_bytes();
126
127        Ok(data)
128    }
129
130    fn size(&self) -> u64 {
131        self.size
132    }
133
134    fn identifier(&self) -> &str {
135        &self.identifier
136    }
137}
138
139/// Create an S3 client with optional custom endpoint.
140///
141/// Use a custom endpoint for S3-compatible services like MinIO:
142/// ```ignore
143/// let client = create_s3_client(Some("http://localhost:9000")).await;
144/// ```
145///
146/// For AWS S3, pass `None` to use the default endpoint:
147/// ```ignore
148/// let client = create_s3_client(None).await;
149/// ```
150pub async fn create_s3_client(endpoint_url: Option<&str>) -> Client {
151    let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
152
153    if let Some(endpoint) = endpoint_url {
154        config_loader = config_loader.endpoint_url(endpoint);
155    }
156
157    let sdk_config = config_loader.load().await;
158
159    // For S3-compatible services, we often need to use path-style addressing
160    let s3_config = if endpoint_url.is_some() {
161        aws_sdk_s3::config::Builder::from(&sdk_config)
162            .force_path_style(true)
163            .build()
164    } else {
165        aws_sdk_s3::config::Builder::from(&sdk_config).build()
166    };
167
168    Client::from_conf(s3_config)
169}
170
171#[cfg(test)]
172mod tests {
173    // Integration tests require a running S3-compatible service (e.g., MinIO)
174    // and are not included in unit tests. See tests/integration/ for E2E tests.
175}