1use crate::storage::{StorageBackend, StorageObject};
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use std::sync::{Arc, Mutex};
49use std::time::SystemTime;
50use torsh_core::error::{Result, TorshError};
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct S3Config {
55 pub bucket: String,
57 pub region: String,
59 pub access_key_id: Option<String>,
61 pub secret_access_key: Option<String>,
63 pub endpoint: Option<String>,
65 pub server_side_encryption: bool,
67 pub storage_class: String,
69 pub multipart_threshold: usize,
71 pub multipart_chunk_size: usize,
73}
74
75impl Default for S3Config {
76 fn default() -> Self {
77 Self {
78 bucket: String::new(),
79 region: "us-east-1".to_string(),
80 access_key_id: None,
81 secret_access_key: None,
82 endpoint: None,
83 server_side_encryption: true,
84 storage_class: "STANDARD".to_string(),
85 multipart_threshold: 50 * 1024 * 1024, multipart_chunk_size: 10 * 1024 * 1024, }
88 }
89}
90
91pub struct MockS3Storage {
97 config: S3Config,
98 storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
100 metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
102}
103
104impl MockS3Storage {
105 pub fn new(bucket: String) -> Self {
107 Self {
108 config: S3Config {
109 bucket,
110 ..Default::default()
111 },
112 storage: Arc::new(Mutex::new(HashMap::new())),
113 metadata: Arc::new(Mutex::new(HashMap::new())),
114 }
115 }
116
117 pub fn with_config(config: S3Config) -> Self {
119 Self {
120 config,
121 storage: Arc::new(Mutex::new(HashMap::new())),
122 metadata: Arc::new(Mutex::new(HashMap::new())),
123 }
124 }
125
126 pub fn bucket(&self) -> &str {
128 &self.config.bucket
129 }
130
131 pub fn region(&self) -> &str {
133 &self.config.region
134 }
135
136 fn multipart_upload(&mut self, key: &str, data: &[u8]) -> Result<()> {
138 let chunk_size = self.config.multipart_chunk_size;
140 let num_parts = (data.len() + chunk_size - 1) / chunk_size;
141
142 for i in 0..num_parts {
144 let start = i * chunk_size;
145 let end = std::cmp::min(start + chunk_size, data.len());
146 let _part_data = &data[start..end];
147 }
149
150 self.storage
152 .lock()
153 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
154 .insert(key.to_string(), data.to_vec());
155
156 Ok(())
157 }
158}
159
160impl StorageBackend for MockS3Storage {
161 fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
162 if data.len() > self.config.multipart_threshold {
164 return self.multipart_upload(key, data);
165 }
166
167 self.storage
169 .lock()
170 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
171 .insert(key.to_string(), data.to_vec());
172
173 let metadata = StorageObject {
175 key: key.to_string(),
176 size: data.len() as u64,
177 last_modified: SystemTime::now(),
178 content_type: Some("application/octet-stream".to_string()),
179 etag: Some(format!("{:x}", md5::compute(data))),
180 metadata: HashMap::new(),
181 };
182
183 self.metadata
184 .lock()
185 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
186 .insert(key.to_string(), metadata);
187
188 Ok(())
189 }
190
191 fn get(&self, key: &str) -> Result<Vec<u8>> {
192 self.storage
193 .lock()
194 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
195 .get(key)
196 .cloned()
197 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
198 }
199
200 fn delete(&mut self, key: &str) -> Result<()> {
201 self.storage
202 .lock()
203 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
204 .remove(key);
205
206 self.metadata
207 .lock()
208 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
209 .remove(key);
210
211 Ok(())
212 }
213
214 fn exists(&self, key: &str) -> Result<bool> {
215 Ok(self
216 .storage
217 .lock()
218 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
219 .contains_key(key))
220 }
221
222 fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
223 let metadata = self
224 .metadata
225 .lock()
226 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
227
228 Ok(metadata
229 .values()
230 .filter(|obj| obj.key.starts_with(prefix))
231 .cloned()
232 .collect())
233 }
234
235 fn get_metadata(&self, key: &str) -> Result<StorageObject> {
236 self.metadata
237 .lock()
238 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
239 .get(key)
240 .cloned()
241 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
242 }
243
244 fn backend_type(&self) -> &str {
245 "s3"
246 }
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct GcsConfig {
252 pub bucket: String,
254 pub project_id: String,
256 pub service_account_key: Option<String>,
258 pub storage_class: String,
260}
261
262impl Default for GcsConfig {
263 fn default() -> Self {
264 Self {
265 bucket: String::new(),
266 project_id: String::new(),
267 service_account_key: None,
268 storage_class: "STANDARD".to_string(),
269 }
270 }
271}
272
273pub struct MockGcsStorage {
275 config: GcsConfig,
276 storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
277 metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
278}
279
280impl MockGcsStorage {
281 pub fn new(bucket: String, project_id: String) -> Self {
283 Self {
284 config: GcsConfig {
285 bucket,
286 project_id,
287 ..Default::default()
288 },
289 storage: Arc::new(Mutex::new(HashMap::new())),
290 metadata: Arc::new(Mutex::new(HashMap::new())),
291 }
292 }
293
294 pub fn with_config(config: GcsConfig) -> Self {
296 Self {
297 config,
298 storage: Arc::new(Mutex::new(HashMap::new())),
299 metadata: Arc::new(Mutex::new(HashMap::new())),
300 }
301 }
302
303 pub fn bucket(&self) -> &str {
305 &self.config.bucket
306 }
307
308 pub fn project_id(&self) -> &str {
310 &self.config.project_id
311 }
312}
313
314impl StorageBackend for MockGcsStorage {
315 fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
316 self.storage
317 .lock()
318 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
319 .insert(key.to_string(), data.to_vec());
320
321 let metadata = StorageObject {
322 key: key.to_string(),
323 size: data.len() as u64,
324 last_modified: SystemTime::now(),
325 content_type: Some("application/octet-stream".to_string()),
326 etag: Some(format!("{:x}", md5::compute(data))),
327 metadata: HashMap::new(),
328 };
329
330 self.metadata
331 .lock()
332 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
333 .insert(key.to_string(), metadata);
334
335 Ok(())
336 }
337
338 fn get(&self, key: &str) -> Result<Vec<u8>> {
339 self.storage
340 .lock()
341 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
342 .get(key)
343 .cloned()
344 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
345 }
346
347 fn delete(&mut self, key: &str) -> Result<()> {
348 self.storage
349 .lock()
350 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
351 .remove(key);
352
353 self.metadata
354 .lock()
355 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
356 .remove(key);
357
358 Ok(())
359 }
360
361 fn exists(&self, key: &str) -> Result<bool> {
362 Ok(self
363 .storage
364 .lock()
365 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
366 .contains_key(key))
367 }
368
369 fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
370 let metadata = self
371 .metadata
372 .lock()
373 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
374
375 Ok(metadata
376 .values()
377 .filter(|obj| obj.key.starts_with(prefix))
378 .cloned()
379 .collect())
380 }
381
382 fn get_metadata(&self, key: &str) -> Result<StorageObject> {
383 self.metadata
384 .lock()
385 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
386 .get(key)
387 .cloned()
388 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
389 }
390
391 fn backend_type(&self) -> &str {
392 "gcs"
393 }
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct AzureConfig {
399 pub account_name: String,
401 pub container: String,
403 pub access_key: Option<String>,
405 pub sas_token: Option<String>,
407 pub access_tier: String,
409}
410
411impl Default for AzureConfig {
412 fn default() -> Self {
413 Self {
414 account_name: String::new(),
415 container: String::new(),
416 access_key: None,
417 sas_token: None,
418 access_tier: "Hot".to_string(),
419 }
420 }
421}
422
423pub struct MockAzureStorage {
425 config: AzureConfig,
426 storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
427 metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
428}
429
430impl MockAzureStorage {
431 pub fn new(account_name: String, container: String) -> Self {
433 Self {
434 config: AzureConfig {
435 account_name,
436 container,
437 ..Default::default()
438 },
439 storage: Arc::new(Mutex::new(HashMap::new())),
440 metadata: Arc::new(Mutex::new(HashMap::new())),
441 }
442 }
443
444 pub fn with_config(config: AzureConfig) -> Self {
446 Self {
447 config,
448 storage: Arc::new(Mutex::new(HashMap::new())),
449 metadata: Arc::new(Mutex::new(HashMap::new())),
450 }
451 }
452
453 pub fn account_name(&self) -> &str {
455 &self.config.account_name
456 }
457
458 pub fn container(&self) -> &str {
460 &self.config.container
461 }
462}
463
464impl StorageBackend for MockAzureStorage {
465 fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
466 self.storage
467 .lock()
468 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
469 .insert(key.to_string(), data.to_vec());
470
471 let metadata = StorageObject {
472 key: key.to_string(),
473 size: data.len() as u64,
474 last_modified: SystemTime::now(),
475 content_type: Some("application/octet-stream".to_string()),
476 etag: Some(format!("{:x}", md5::compute(data))),
477 metadata: HashMap::new(),
478 };
479
480 self.metadata
481 .lock()
482 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
483 .insert(key.to_string(), metadata);
484
485 Ok(())
486 }
487
488 fn get(&self, key: &str) -> Result<Vec<u8>> {
489 self.storage
490 .lock()
491 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
492 .get(key)
493 .cloned()
494 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
495 }
496
497 fn delete(&mut self, key: &str) -> Result<()> {
498 self.storage
499 .lock()
500 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
501 .remove(key);
502
503 self.metadata
504 .lock()
505 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
506 .remove(key);
507
508 Ok(())
509 }
510
511 fn exists(&self, key: &str) -> Result<bool> {
512 Ok(self
513 .storage
514 .lock()
515 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
516 .contains_key(key))
517 }
518
519 fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
520 let metadata = self
521 .metadata
522 .lock()
523 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
524
525 Ok(metadata
526 .values()
527 .filter(|obj| obj.key.starts_with(prefix))
528 .cloned()
529 .collect())
530 }
531
532 fn get_metadata(&self, key: &str) -> Result<StorageObject> {
533 self.metadata
534 .lock()
535 .map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
536 .get(key)
537 .cloned()
538 .ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
539 }
540
541 fn backend_type(&self) -> &str {
542 "azure"
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 #[test]
551 fn test_mock_s3_storage() {
552 let mut storage = MockS3Storage::new("test-bucket".to_string());
553 assert_eq!(storage.backend_type(), "s3");
554 assert_eq!(storage.bucket(), "test-bucket");
555
556 let data = b"test data";
558 storage.put("test/key", data).unwrap();
559 let retrieved = storage.get("test/key").unwrap();
560 assert_eq!(retrieved, data);
561
562 assert!(storage.exists("test/key").unwrap());
564 assert!(!storage.exists("nonexistent").unwrap());
565
566 let metadata = storage.get_metadata("test/key").unwrap();
568 assert_eq!(metadata.size, data.len() as u64);
569 assert!(metadata.etag.is_some());
570
571 storage.delete("test/key").unwrap();
573 assert!(!storage.exists("test/key").unwrap());
574 }
575
576 #[test]
577 fn test_mock_s3_multipart_upload() {
578 let mut storage = MockS3Storage::new("test-bucket".to_string());
579
580 let large_data = vec![0u8; 60 * 1024 * 1024]; storage.put("test/large", &large_data).unwrap();
584 let retrieved = storage.get("test/large").unwrap();
585 assert_eq!(retrieved.len(), large_data.len());
586 }
587
588 #[test]
589 fn test_mock_s3_list() {
590 let mut storage = MockS3Storage::new("test-bucket".to_string());
591
592 storage.put("models/bert/v1.bin", b"data1").unwrap();
593 storage.put("models/bert/v2.bin", b"data2").unwrap();
594 storage.put("models/gpt/v1.bin", b"data3").unwrap();
595
596 let bert_models = storage.list("models/bert/").unwrap();
597 assert_eq!(bert_models.len(), 2);
598
599 let all_models = storage.list("models/").unwrap();
600 assert_eq!(all_models.len(), 3);
601 }
602
603 #[test]
604 fn test_mock_gcs_storage() {
605 let mut storage =
606 MockGcsStorage::new("test-bucket".to_string(), "test-project".to_string());
607 assert_eq!(storage.backend_type(), "gcs");
608 assert_eq!(storage.bucket(), "test-bucket");
609 assert_eq!(storage.project_id(), "test-project");
610
611 let data = b"test data";
612 storage.put("test/key", data).unwrap();
613 let retrieved = storage.get("test/key").unwrap();
614 assert_eq!(retrieved, data);
615 }
616
617 #[test]
618 fn test_mock_azure_storage() {
619 let mut storage =
620 MockAzureStorage::new("testaccount".to_string(), "testcontainer".to_string());
621 assert_eq!(storage.backend_type(), "azure");
622 assert_eq!(storage.account_name(), "testaccount");
623 assert_eq!(storage.container(), "testcontainer");
624
625 let data = b"test data";
626 storage.put("test/key", data).unwrap();
627 let retrieved = storage.get("test/key").unwrap();
628 assert_eq!(retrieved, data);
629 }
630
631 #[test]
632 fn test_s3_config_defaults() {
633 let config = S3Config::default();
634 assert_eq!(config.region, "us-east-1");
635 assert_eq!(config.storage_class, "STANDARD");
636 assert!(config.server_side_encryption);
637 assert_eq!(config.multipart_threshold, 50 * 1024 * 1024);
638 }
639}