1use scirs2_core::ndarray::{Array1, Array2};
77use std::collections::HashMap;
78use std::marker::PhantomData;
79use std::ptr::NonNull;
80use std::sync::{Arc, Mutex, RwLock};
81
82pub struct MemorySafety;
84
85impl MemorySafety {
86 pub fn document_safety(operation: &str) -> MemorySafetyGuarantee {
88 match operation {
89 "array_indexing" => MemorySafetyGuarantee {
90 operation: operation.to_string(),
91 guarantees: vec![
92 "Bounds checking prevents buffer overflows".to_string(),
93 "Panic on out-of-bounds access in debug mode".to_string(),
94 "Optional bounds checking in release mode for performance".to_string(),
95 ],
96 unsafe_blocks: vec![],
97 mitigation_strategies: vec![
98 "Use checked indexing methods when bounds are uncertain".to_string(),
99 "Validate input dimensions before processing".to_string(),
100 ],
101 },
102 "parallel_processing" => MemorySafetyGuarantee {
103 operation: operation.to_string(),
104 guarantees: vec![
105 "Send and Sync traits prevent data races".to_string(),
106 "Rayon provides work-stealing without data races".to_string(),
107 "Immutable borrows allow safe parallel reading".to_string(),
108 ],
109 unsafe_blocks: vec![],
110 mitigation_strategies: vec![
111 "Use Arc<T> for shared ownership across threads".to_string(),
112 "Use Mutex<T> or RwLock<T> for shared mutable access".to_string(),
113 ],
114 },
115 "gpu_operations" => MemorySafetyGuarantee {
116 operation: operation.to_string(),
117 guarantees: vec![
118 "CUDA memory is managed through RAII wrappers".to_string(),
119 "GPU pointers are opaque and cannot be dereferenced on CPU".to_string(),
120 "Automatic cleanup of GPU resources on drop".to_string(),
121 ],
122 unsafe_blocks: vec![
123 "CUDA FFI calls require unsafe blocks".to_string(),
124 "Memory transfers between CPU and GPU use unsafe operations".to_string(),
125 ],
126 mitigation_strategies: vec![
127 "Wrap all CUDA operations in safe abstractions".to_string(),
128 "Validate GPU memory allocation success".to_string(),
129 "Use typed GPU pointers to prevent type confusion".to_string(),
130 ],
131 },
132 _ => MemorySafetyGuarantee {
133 operation: operation.to_string(),
134 guarantees: vec!["General Rust memory safety guarantees apply".to_string()],
135 unsafe_blocks: vec![],
136 mitigation_strategies: vec![],
137 },
138 }
139 }
140
141 pub fn validate_unsafe_usage(code_block: &str) -> UnsafeValidationResult {
143 let mut issues = Vec::new();
144 let mut recommendations = Vec::new();
145
146 if code_block.contains("transmute") {
148 issues.push("transmute operations can break type safety".to_string());
149 recommendations.push("Consider using safe casting alternatives".to_string());
150 }
151
152 if code_block.contains("from_raw_parts") {
153 issues.push("Raw pointer operations require careful validation".to_string());
154 recommendations.push("Ensure pointer validity and proper alignment".to_string());
155 }
156
157 if code_block.contains("assume_init") {
158 issues.push("Uninitialized memory access detected".to_string());
159 recommendations
160 .push("Use MaybeUninit for safer uninitialized memory handling".to_string());
161 }
162
163 let safety_score = if issues.is_empty() {
164 100
165 } else {
166 std::cmp::max(0, 100 - (issues.len() * 20)) as u8
167 };
168
169 UnsafeValidationResult {
170 safety_score,
171 issues,
172 recommendations,
173 requires_review: safety_score < 80,
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct MemorySafetyGuarantee {
181 pub operation: String,
182 pub guarantees: Vec<String>,
183 pub unsafe_blocks: Vec<String>,
184 pub mitigation_strategies: Vec<String>,
185}
186
187#[derive(Debug, Clone)]
189pub struct UnsafeValidationResult {
190 pub safety_score: u8, pub issues: Vec<String>,
192 pub recommendations: Vec<String>,
193 pub requires_review: bool,
194}
195
196pub trait SafeArrayOps<T> {
198 fn safe_get(&self, index: &[usize]) -> Option<&T>;
200
201 fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T>;
203
204 fn validate_dimensions(&self) -> Result<(), String>;
206
207 fn is_valid_index(&self, index: &[usize]) -> bool;
209}
210
211impl<T> SafeArrayOps<T> for Array2<T> {
212 fn safe_get(&self, index: &[usize]) -> Option<&T> {
213 if index.len() != 2 {
214 return None;
215 }
216 self.get((index[0], index[1]))
217 }
218
219 fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
220 if index.len() != 2 {
221 return None;
222 }
223 self.get_mut((index[0], index[1]))
224 }
225
226 fn validate_dimensions(&self) -> Result<(), String> {
227 if self.nrows() == 0 || self.ncols() == 0 {
228 Err("Array has zero-sized dimension".to_string())
229 } else if self.nrows() > isize::MAX as usize || self.ncols() > isize::MAX as usize {
230 Err("Array dimension exceeds maximum safe size".to_string())
231 } else {
232 Ok(())
233 }
234 }
235
236 fn is_valid_index(&self, index: &[usize]) -> bool {
237 index.len() == 2 && index[0] < self.nrows() && index[1] < self.ncols()
238 }
239}
240
241impl<T> SafeArrayOps<T> for Array1<T> {
242 fn safe_get(&self, index: &[usize]) -> Option<&T> {
243 if index.len() != 1 {
244 return None;
245 }
246 self.get(index[0])
247 }
248
249 fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
250 if index.len() != 1 {
251 return None;
252 }
253 self.get_mut(index[0])
254 }
255
256 fn validate_dimensions(&self) -> Result<(), String> {
257 if self.is_empty() {
258 Err("Array is empty".to_string())
259 } else if self.len() > isize::MAX as usize {
260 Err("Array length exceeds maximum safe size".to_string())
261 } else {
262 Ok(())
263 }
264 }
265
266 fn is_valid_index(&self, index: &[usize]) -> bool {
267 index.len() == 1 && index[0] < self.len()
268 }
269}
270
271pub struct SafeMemoryPool<T> {
273 pools: Arc<Mutex<HashMap<usize, Vec<Vec<T>>>>>,
274 allocated_count: Arc<Mutex<usize>>,
275 max_pool_size: usize,
276}
277
278impl<T> SafeMemoryPool<T> {
279 pub fn new() -> Self {
281 Self {
282 pools: Arc::new(Mutex::new(HashMap::new())),
283 allocated_count: Arc::new(Mutex::new(0)),
284 max_pool_size: 1000, }
286 }
287
288 pub fn with_limits(max_pool_size: usize) -> Self {
290 Self {
291 pools: Arc::new(Mutex::new(HashMap::new())),
292 allocated_count: Arc::new(Mutex::new(0)),
293 max_pool_size,
294 }
295 }
296
297 pub fn allocate(&self, capacity: usize) -> SafePooledBuffer<T> {
299 let buffer = {
300 let mut pools = self.pools.lock().unwrap();
301 if let Some(pool) = pools.get_mut(&capacity) {
302 if let Some(mut buffer) = pool.pop() {
303 buffer.clear();
304 buffer
305 } else {
306 Vec::with_capacity(capacity)
307 }
308 } else {
309 Vec::with_capacity(capacity)
310 }
311 };
312
313 {
314 let mut count = self.allocated_count.lock().unwrap();
315 *count += 1;
316 }
317
318 SafePooledBuffer {
319 buffer: Some(buffer),
320 capacity,
321 pool: self.pools.clone(),
322 allocated_count: self.allocated_count.clone(),
323 max_pool_size: self.max_pool_size,
324 }
325 }
326
327 pub fn stats(&self) -> MemoryPoolStats {
329 let allocated_count = *self.allocated_count.lock().unwrap();
330 let pools = self.pools.lock().unwrap();
331 let pooled_count: usize = pools.values().map(|v| v.len()).sum();
332
333 MemoryPoolStats {
334 allocated_count,
335 pooled_count,
336 pool_sizes: pools.iter().map(|(&k, v)| (k, v.len())).collect(),
337 }
338 }
339}
340
341impl<T> Default for SafeMemoryPool<T> {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct MemoryPoolStats {
350 pub allocated_count: usize,
351 pub pooled_count: usize,
352 pub pool_sizes: Vec<(usize, usize)>, }
354
355pub struct SafePooledBuffer<T> {
357 buffer: Option<Vec<T>>,
358 capacity: usize,
359 pool: Arc<Mutex<HashMap<usize, Vec<Vec<T>>>>>,
360 allocated_count: Arc<Mutex<usize>>,
361 max_pool_size: usize,
362}
363
364impl<T> SafePooledBuffer<T> {
365 pub fn as_mut_vec(&mut self) -> &mut Vec<T> {
367 self.buffer.as_mut().expect("Buffer has been consumed")
368 }
369
370 pub fn as_ref_vec(&self) -> &Vec<T> {
372 self.buffer.as_ref().expect("Buffer has been consumed")
373 }
374
375 pub fn into_inner(mut self) -> Vec<T> {
377 self.buffer.take().expect("Buffer has been consumed")
378 }
379}
380
381impl<T> Drop for SafePooledBuffer<T> {
382 fn drop(&mut self) {
383 if let Some(buffer) = self.buffer.take() {
384 let mut pools = self.pool.lock().unwrap();
386 let pool = pools.entry(self.capacity).or_default();
387
388 if pool.len() < self.max_pool_size {
389 pool.push(buffer);
390 }
391 let mut count = self.allocated_count.lock().unwrap();
395 *count = count.saturating_sub(1);
396 }
397 }
398}
399
400impl<T> std::ops::Deref for SafePooledBuffer<T> {
401 type Target = Vec<T>;
402
403 fn deref(&self) -> &Self::Target {
404 self.as_ref_vec()
405 }
406}
407
408impl<T> std::ops::DerefMut for SafePooledBuffer<T> {
409 fn deref_mut(&mut self) -> &mut Self::Target {
410 self.as_mut_vec()
411 }
412}
413
414#[derive(Debug)]
416pub struct SafePtr<T> {
417 ptr: NonNull<T>,
418 _marker: PhantomData<T>,
419}
420
421impl<T> SafePtr<T> {
422 pub unsafe fn new(ptr: NonNull<T>) -> Self {
431 Self {
432 ptr,
433 _marker: PhantomData,
434 }
435 }
436
437 pub unsafe fn as_ptr(&self) -> *const T {
443 self.ptr.as_ptr()
444 }
445
446 pub unsafe fn as_mut_ptr(&self) -> *mut T {
452 self.ptr.as_ptr()
453 }
454}
455
456unsafe impl<T: Send> Send for SafePtr<T> {}
458unsafe impl<T: Sync> Sync for SafePtr<T> {}
459
460pub struct SafeSharedModel<T> {
462 inner: Arc<RwLock<T>>,
463 id: String,
464}
465
466impl<T> SafeSharedModel<T> {
467 pub fn new(model: T, id: String) -> Self {
469 Self {
470 inner: Arc::new(RwLock::new(model)),
471 id,
472 }
473 }
474
475 pub fn read(&self) -> std::sync::RwLockReadGuard<'_, T> {
477 self.inner
478 .read()
479 .unwrap_or_else(|e| panic!("RwLock poisoned for model {}: {}", self.id, e))
480 }
481
482 pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, T> {
484 self.inner
485 .write()
486 .unwrap_or_else(|e| panic!("RwLock poisoned for model {}: {}", self.id, e))
487 }
488
489 pub fn try_read(&self) -> Option<std::sync::RwLockReadGuard<'_, T>> {
491 self.inner.try_read().ok()
492 }
493
494 pub fn try_write(&self) -> Option<std::sync::RwLockWriteGuard<'_, T>> {
496 self.inner.try_write().ok()
497 }
498
499 pub fn clone_ref(&self) -> Self {
501 Self {
502 inner: Arc::clone(&self.inner),
503 id: self.id.clone(),
504 }
505 }
506}
507
508impl<T: Clone> SafeSharedModel<T> {
509 pub fn clone_model(&self) -> T {
511 self.read().clone()
512 }
513}
514
515#[allow(non_snake_case)]
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use scirs2_core::ndarray::Array2;
520
521 #[test]
522 fn test_memory_safety_documentation() {
523 let guarantee = MemorySafety::document_safety("array_indexing");
524 assert_eq!(guarantee.operation, "array_indexing");
525 assert!(!guarantee.guarantees.is_empty());
526 }
527
528 #[test]
529 fn test_unsafe_validation() {
530 let safe_code = "let x = vec![1, 2, 3]; let y = &x[0];";
531 let result = MemorySafety::validate_unsafe_usage(safe_code);
532 assert_eq!(result.safety_score, 100);
533 assert!(result.issues.is_empty());
534
535 let unsafe_code = "let x = transmute::<i32, f32>(42);";
536 let result = MemorySafety::validate_unsafe_usage(unsafe_code);
537 assert!(result.safety_score < 100);
538 assert!(!result.issues.is_empty());
539 }
540
541 #[test]
542 fn test_safe_array_operations() {
543 let array = Array2::<f64>::zeros((10, 10));
544
545 assert!(array.safe_get(&[0, 0]).is_some());
547 assert!(array.safe_get(&[10, 10]).is_none());
548 assert!(array.safe_get(&[5]).is_none()); assert!(array.validate_dimensions().is_ok());
552
553 assert!(array.is_valid_index(&[5, 5]));
555 assert!(!array.is_valid_index(&[10, 5]));
556 }
557
558 #[test]
559 fn test_memory_pool() {
560 let pool = SafeMemoryPool::<i32>::new();
561
562 let buffer = pool.allocate(100);
564 assert_eq!(buffer.capacity(), 100);
565
566 let stats = pool.stats();
567 assert_eq!(stats.allocated_count, 1);
568
569 drop(buffer);
571
572 let stats = pool.stats();
573 assert_eq!(stats.allocated_count, 0);
574 assert_eq!(stats.pooled_count, 1);
575 }
576
577 #[test]
578 fn test_shared_model() {
579 let model = vec![1, 2, 3, 4, 5];
580 let shared = SafeSharedModel::new(model, "test_model".to_string());
581
582 {
584 let reader = shared.read();
585 assert_eq!(reader.len(), 5);
586 }
587
588 {
590 let mut writer = shared.write();
591 writer.push(6);
592 assert_eq!(writer.len(), 6);
593 }
594
595 let shared2 = shared.clone_ref();
597 let reader = shared2.read();
598 assert_eq!(reader.len(), 6);
599 }
600
601 #[test]
602 fn test_pooled_buffer_deref() {
603 let pool = SafeMemoryPool::<i32>::new();
604 let mut buffer = pool.allocate(10);
605
606 buffer.push(42);
608 assert_eq!(buffer.len(), 1);
609 assert_eq!(buffer[0], 42);
610
611 let inner = buffer.into_inner();
613 assert_eq!(inner, vec![42]);
614 }
615}