Skip to main content

tenflowers_core/
tensor_view.rs

1// Enhanced tensor with zero-copy view support
2#[cfg(feature = "gpu")]
3use crate::memory::PooledBuffer;
4use crate::memory::{MemoryAliasDetector, StridedView};
5use crate::tensor::TensorStorage;
6use crate::{Device, Result, Shape, Tensor, TensorError};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::sync::Arc;
9
10/// Enhanced tensor that supports zero-copy views and strided operations
11#[derive(Debug)]
12pub struct TensorView<T> {
13    /// Reference to the underlying tensor data
14    pub storage: ViewStorage<T>,
15    /// Strided view information
16    pub view: StridedView,
17    /// Device where the tensor is stored
18    device: Device,
19    /// Whether this tensor requires gradient computation
20    requires_grad: bool,
21    /// Gradient tensor (if any)
22    grad: Option<Arc<TensorView<T>>>,
23    /// Memory alias detector for safety
24    alias_detector: Arc<MemoryAliasDetector>,
25}
26
27/// Storage for tensor views
28#[derive(Debug, Clone)]
29pub enum ViewStorage<T> {
30    /// Reference to CPU tensor storage
31    CpuRef(Arc<ArrayD<T>>),
32    /// Reference to GPU pooled buffer
33    #[cfg(feature = "gpu")]
34    GpuPooled(Arc<PooledBuffer<'static>>),
35    /// Reference to regular GPU buffer
36    #[cfg(feature = "gpu")]
37    GpuRef(Arc<crate::gpu::buffer::GpuBuffer<T>>),
38}
39
40impl<T> TensorView<T> {
41    /// Get a unique buffer ID for alias detection
42    fn get_buffer_id(&self) -> usize {
43        match &self.storage {
44            ViewStorage::CpuRef(arr) => Arc::as_ptr(arr) as usize,
45            #[cfg(feature = "gpu")]
46            ViewStorage::GpuRef(gpu_buffer) => Arc::as_ptr(gpu_buffer) as usize,
47            #[cfg(feature = "gpu")]
48            ViewStorage::GpuPooled(pooled_buffer) => Arc::as_ptr(pooled_buffer) as usize,
49        }
50    }
51
52    /// Get the number of elements in the tensor
53    pub fn numel(&self) -> usize {
54        self.view.shape.iter().product()
55    }
56
57    /// Get the size in bytes
58    pub fn size_bytes(&self) -> usize {
59        self.view.size_bytes()
60    }
61
62    /// Check if the tensor is contiguous in memory
63    pub fn is_contiguous(&self) -> bool {
64        self.view.is_contiguous()
65    }
66}
67
68impl<T: Clone + Default> TensorView<T> {
69    /// Create a new tensor view from an existing tensor
70    pub fn from_tensor(tensor: &Tensor<T>) -> Result<Self>
71    where
72        T: Clone + Send + Sync + 'static,
73    {
74        let element_size = std::mem::size_of::<T>();
75        let shape = tensor.shape().dims().to_vec();
76        let strides = compute_default_strides(&shape, element_size);
77
78        let view = StridedView::new(0, shape, strides, element_size);
79        let alias_detector = Arc::new(MemoryAliasDetector::new());
80
81        let storage = match &tensor.storage {
82            TensorStorage::Cpu(arr) => ViewStorage::CpuRef(Arc::new(arr.clone())),
83            #[cfg(feature = "gpu")]
84            TensorStorage::Gpu(gpu_buffer) => ViewStorage::GpuRef(Arc::new(gpu_buffer.clone())),
85        };
86
87        Ok(Self {
88            storage,
89            view,
90            device: *tensor.device(),
91            requires_grad: tensor.requires_grad(),
92            grad: None,
93            alias_detector,
94        })
95    }
96
97    /// Create a zero-copy transpose view
98    pub fn transpose(&self, axes: &[usize]) -> Result<TensorView<T>>
99    where
100        T: Clone + Send + Sync + 'static,
101    {
102        let new_view = self.view.transpose(axes)?;
103
104        // Check for memory aliasing
105        let buffer_id = self.get_buffer_id();
106        if self
107            .alias_detector
108            .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
109        {
110            return Err(TensorError::invalid_argument(
111                "Transpose would create memory alias".to_string(),
112            ));
113        }
114
115        // Register the new view
116        self.alias_detector
117            .register_view(buffer_id, new_view.offset, new_view.size_bytes());
118
119        Ok(Self {
120            storage: self.storage.clone(),
121            view: new_view,
122            device: self.device,
123            requires_grad: self.requires_grad,
124            grad: None,
125            alias_detector: Arc::clone(&self.alias_detector),
126        })
127    }
128
129    /// Create a zero-copy reshape view (when possible)
130    pub fn reshape(&self, new_shape: &[usize]) -> Result<TensorView<T>>
131    where
132        T: Clone + Send + Sync + 'static,
133    {
134        let new_view = self.view.reshape(new_shape)?;
135
136        // Check for memory aliasing
137        let buffer_id = self.get_buffer_id();
138        if self
139            .alias_detector
140            .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
141        {
142            return Err(TensorError::invalid_argument(
143                "Reshape would create memory alias".to_string(),
144            ));
145        }
146
147        // Register the new view
148        self.alias_detector
149            .register_view(buffer_id, new_view.offset, new_view.size_bytes());
150
151        Ok(Self {
152            storage: self.storage.clone(),
153            view: new_view,
154            device: self.device,
155            requires_grad: self.requires_grad,
156            grad: None,
157            alias_detector: Arc::clone(&self.alias_detector),
158        })
159    }
160
161    /// Create a zero-copy slice view
162    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<TensorView<T>>
163    where
164        T: Clone + Send + Sync + 'static,
165    {
166        let new_view = self.view.slice(ranges)?;
167
168        // Check for memory aliasing
169        let buffer_id = self.get_buffer_id();
170        if self
171            .alias_detector
172            .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
173        {
174            return Err(TensorError::invalid_argument(
175                "Slice would create memory alias".to_string(),
176            ));
177        }
178
179        // Register the new view
180        self.alias_detector
181            .register_view(buffer_id, new_view.offset, new_view.size_bytes());
182
183        Ok(Self {
184            storage: self.storage.clone(),
185            view: new_view,
186            device: self.device,
187            requires_grad: self.requires_grad,
188            grad: None,
189            alias_detector: Arc::clone(&self.alias_detector),
190        })
191    }
192
193    /// Convert back to a regular tensor (may require data copy)
194    pub fn to_tensor(&self) -> Result<Tensor<T>>
195    where
196        T: Clone + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
197    {
198        if self.view.is_contiguous() {
199            // Zero-copy conversion for contiguous views
200            match &self.storage {
201                ViewStorage::CpuRef(arr) => {
202                    let _shape = Shape::new(self.view.shape.clone());
203                    Ok(Tensor::from_array((**arr).clone()))
204                }
205                #[cfg(feature = "gpu")]
206                ViewStorage::GpuRef(gpu_buffer) => {
207                    let shape = Shape::new(self.view.shape.clone());
208                    let mut result = Tensor::from_gpu_buffer((**gpu_buffer).clone(), shape);
209                    result.set_requires_grad(self.requires_grad);
210                    Ok(result)
211                }
212                #[cfg(feature = "gpu")]
213                ViewStorage::GpuPooled(pooled_buffer) => {
214                    // For pooled buffers, we need to extract the data from the pool
215                    use wgpu::util::DeviceExt;
216
217                    let pool_buffer = pooled_buffer.buffer();
218                    let device = &crate::device::context::get_gpu_context(match self.device {
219                        Device::Gpu(id) => id,
220                        _ => {
221                            return Err(TensorError::device_error_simple(
222                                "Expected GPU device".to_string(),
223                            ))
224                        }
225                    })?
226                    .device;
227
228                    // Create a new buffer with the exact size needed
229                    let data_size =
230                        self.view.shape.iter().product::<usize>() * std::mem::size_of::<T>();
231                    let new_buffer = device.create_buffer(&wgpu::BufferDescriptor {
232                        label: Some("pooled_to_tensor_buffer"),
233                        size: data_size as u64,
234                        usage: wgpu::BufferUsages::STORAGE
235                            | wgpu::BufferUsages::COPY_SRC
236                            | wgpu::BufferUsages::COPY_DST,
237                        mapped_at_creation: false,
238                    });
239
240                    // Copy data from pool buffer to new buffer
241                    let mut encoder =
242                        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
243                            label: Some("pooled_buffer_copy"),
244                        });
245
246                    encoder.copy_buffer_to_buffer(
247                        pool_buffer,
248                        pooled_buffer.offset() as u64,
249                        &new_buffer,
250                        0,
251                        data_size as u64,
252                    );
253
254                    let queue = &crate::device::context::get_gpu_context(match self.device {
255                        Device::Gpu(id) => id,
256                        _ => {
257                            return Err(TensorError::device_error_simple(
258                                "Expected GPU device".to_string(),
259                            ))
260                        }
261                    })?
262                    .queue;
263
264                    queue.submit(std::iter::once(encoder.finish()));
265                    device.poll(wgpu::PollType::wait_indefinitely()).ok();
266
267                    // Create GPU buffer wrapper
268                    let device_id = match self.device {
269                        Device::Gpu(id) => id,
270                        _ => {
271                            return Err(TensorError::device_error_simple(
272                                "Expected GPU device".to_string(),
273                            ))
274                        }
275                    };
276
277                    let ctx = crate::device::context::get_gpu_context(device_id)?;
278                    let gpu_buffer = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
279                        new_buffer,
280                        ctx.device.clone(),
281                        ctx.queue.clone(),
282                        Device::Gpu(device_id),
283                        self.view.shape.iter().product::<usize>(),
284                    );
285
286                    let shape = Shape::new(self.view.shape.clone());
287                    let mut result = Tensor::from_gpu_buffer(gpu_buffer, shape);
288                    result.set_requires_grad(self.requires_grad);
289                    Ok(result)
290                }
291            }
292        } else {
293            // Non-contiguous view requires data copy
294            self.materialize()
295        }
296    }
297
298    /// Materialize the view into a new contiguous tensor
299    fn materialize(&self) -> Result<Tensor<T>>
300    where
301        T: Clone + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
302    {
303        match &self.storage {
304            ViewStorage::CpuRef(arr) => {
305                // Create new contiguous array
306                let total_elements: usize = self.view.shape.iter().product();
307                let mut new_data = Vec::with_capacity(total_elements);
308
309                // Copy data using strided indexing
310                for flat_index in 0..total_elements {
311                    let multi_index = flat_to_multi_index(flat_index, &self.view.shape);
312                    let strided_index =
313                        multi_to_strided_index(&multi_index, &self.view.strides, self.view.offset);
314
315                    let byte_index = strided_index / self.view.element_size;
316                    if let Some(slice) = arr.as_slice() {
317                        if byte_index < slice.len() {
318                            new_data.push(slice[byte_index]);
319                        }
320                    }
321                }
322
323                let _shape = Shape::new(self.view.shape.clone());
324                let new_array = ArrayD::from_shape_vec(IxDyn(&self.view.shape), new_data)
325                    .map_err(|e| TensorError::invalid_argument(e.to_string()))?;
326
327                Ok(Tensor::from_array(new_array))
328            }
329            #[cfg(feature = "gpu")]
330            ViewStorage::GpuRef(gpu_buffer) => {
331                // For GPU, we need to implement a kernel to gather strided data
332                use wgpu::util::DeviceExt;
333
334                let device_id = match self.device {
335                    Device::Gpu(id) => id,
336                    _ => {
337                        return Err(TensorError::device_error_simple(
338                            "Expected GPU device".to_string(),
339                        ))
340                    }
341                };
342
343                let ctx = crate::device::context::get_gpu_context(device_id)?;
344                let device = &ctx.device;
345                let queue = &ctx.queue;
346
347                let total_elements: usize = self.view.shape.iter().product();
348                let output_size = total_elements * std::mem::size_of::<T>();
349
350                // Create output buffer
351                let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
352                    label: Some("strided_materialize_output"),
353                    size: output_size as u64,
354                    usage: wgpu::BufferUsages::STORAGE
355                        | wgpu::BufferUsages::COPY_SRC
356                        | wgpu::BufferUsages::COPY_DST,
357                    mapped_at_creation: false,
358                });
359
360                // Create info buffer with view parameters
361                #[repr(C)]
362                #[derive(bytemuck::Pod, bytemuck::Zeroable, Copy, Clone)]
363                struct StridedInfo {
364                    ndim: u32,
365                    total_elements: u32,
366                    offset: u32,
367                    element_size: u32,
368                }
369
370                let info = StridedInfo {
371                    ndim: self.view.shape.len() as u32,
372                    total_elements: total_elements as u32,
373                    offset: self.view.offset as u32,
374                    element_size: self.view.element_size as u32,
375                };
376
377                let info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
378                    label: Some("strided_info"),
379                    contents: bytemuck::cast_slice(&[info]),
380                    usage: wgpu::BufferUsages::UNIFORM,
381                });
382
383                // Create shape and strides buffers
384                let shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
385                    label: Some("strided_shape"),
386                    contents: bytemuck::cast_slice(
387                        &self
388                            .view
389                            .shape
390                            .iter()
391                            .map(|&x| x as u32)
392                            .collect::<Vec<u32>>(),
393                    ),
394                    usage: wgpu::BufferUsages::STORAGE,
395                });
396
397                let strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
398                    label: Some("strided_strides"),
399                    contents: bytemuck::cast_slice(
400                        &self
401                            .view
402                            .strides
403                            .iter()
404                            .map(|&x| x as u32)
405                            .collect::<Vec<u32>>(),
406                    ),
407                    usage: wgpu::BufferUsages::STORAGE,
408                });
409
410                // Create compute shader for strided materialization
411                let shader_source = include_str!("gpu/shaders/strided_ops.wgsl");
412                let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
413                    label: Some("strided_materialize_shader"),
414                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
415                });
416
417                // Create bind group layout
418                let bind_group_layout =
419                    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
420                        label: Some("strided_materialize_bind_group_layout"),
421                        entries: &[
422                            wgpu::BindGroupLayoutEntry {
423                                binding: 0,
424                                visibility: wgpu::ShaderStages::COMPUTE,
425                                ty: wgpu::BindingType::Buffer {
426                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
427                                    has_dynamic_offset: false,
428                                    min_binding_size: None,
429                                },
430                                count: None,
431                            },
432                            wgpu::BindGroupLayoutEntry {
433                                binding: 1,
434                                visibility: wgpu::ShaderStages::COMPUTE,
435                                ty: wgpu::BindingType::Buffer {
436                                    ty: wgpu::BufferBindingType::Storage { read_only: false },
437                                    has_dynamic_offset: false,
438                                    min_binding_size: None,
439                                },
440                                count: None,
441                            },
442                            wgpu::BindGroupLayoutEntry {
443                                binding: 2,
444                                visibility: wgpu::ShaderStages::COMPUTE,
445                                ty: wgpu::BindingType::Buffer {
446                                    ty: wgpu::BufferBindingType::Uniform,
447                                    has_dynamic_offset: false,
448                                    min_binding_size: None,
449                                },
450                                count: None,
451                            },
452                            wgpu::BindGroupLayoutEntry {
453                                binding: 3,
454                                visibility: wgpu::ShaderStages::COMPUTE,
455                                ty: wgpu::BindingType::Buffer {
456                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
457                                    has_dynamic_offset: false,
458                                    min_binding_size: None,
459                                },
460                                count: None,
461                            },
462                            wgpu::BindGroupLayoutEntry {
463                                binding: 4,
464                                visibility: wgpu::ShaderStages::COMPUTE,
465                                ty: wgpu::BindingType::Buffer {
466                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
467                                    has_dynamic_offset: false,
468                                    min_binding_size: None,
469                                },
470                                count: None,
471                            },
472                        ],
473                    });
474
475                // Create bind group
476                let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
477                    label: Some("strided_materialize_bind_group"),
478                    layout: &bind_group_layout,
479                    entries: &[
480                        wgpu::BindGroupEntry {
481                            binding: 0,
482                            resource: gpu_buffer.buffer().as_entire_binding(),
483                        },
484                        wgpu::BindGroupEntry {
485                            binding: 1,
486                            resource: output_buffer.as_entire_binding(),
487                        },
488                        wgpu::BindGroupEntry {
489                            binding: 2,
490                            resource: info_buffer.as_entire_binding(),
491                        },
492                        wgpu::BindGroupEntry {
493                            binding: 3,
494                            resource: shape_buffer.as_entire_binding(),
495                        },
496                        wgpu::BindGroupEntry {
497                            binding: 4,
498                            resource: strides_buffer.as_entire_binding(),
499                        },
500                    ],
501                });
502
503                // Create pipeline
504                let pipeline_layout =
505                    device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
506                        label: Some("strided_materialize_pipeline_layout"),
507                        bind_group_layouts: &[&bind_group_layout],
508                        immediate_size: 0,
509                    });
510
511                let compute_pipeline =
512                    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
513                        label: Some("strided_materialize_pipeline"),
514                        layout: Some(&pipeline_layout),
515                        module: &shader_module,
516                        entry_point: Some("strided_materialize"),
517                        cache: None,
518                        compilation_options: Default::default(),
519                    });
520
521                // Execute compute shader
522                let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
523                    label: Some("strided_materialize_encoder"),
524                });
525
526                {
527                    let mut compute_pass =
528                        encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
529                            label: Some("strided_materialize_pass"),
530                            timestamp_writes: None,
531                        });
532
533                    compute_pass.set_pipeline(&compute_pipeline);
534                    compute_pass.set_bind_group(0, &bind_group, &[]);
535
536                    let workgroup_size = 64;
537                    let num_workgroups = (total_elements + workgroup_size - 1) / workgroup_size;
538                    compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
539                }
540
541                queue.submit(std::iter::once(encoder.finish()));
542                device.poll(wgpu::PollType::wait_indefinitely()).ok();
543
544                // Create result GPU buffer
545                let gpu_buffer_result = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
546                    output_buffer,
547                    ctx.device.clone(),
548                    ctx.queue.clone(),
549                    Device::Gpu(device_id),
550                    total_elements,
551                );
552
553                let shape = Shape::new(self.view.shape.clone());
554                let mut result = Tensor::from_gpu_buffer(gpu_buffer_result, shape);
555                result.set_requires_grad(self.requires_grad);
556                Ok(result)
557            }
558            #[cfg(feature = "gpu")]
559            ViewStorage::GpuPooled(pooled_buffer) => {
560                // For pooled buffers, we need to implement strided materialization from pool
561                use wgpu::util::DeviceExt;
562
563                let device_id = match self.device {
564                    Device::Gpu(id) => id,
565                    _ => {
566                        return Err(TensorError::device_error_simple(
567                            "Expected GPU device".to_string(),
568                        ))
569                    }
570                };
571
572                let ctx = crate::device::context::get_gpu_context(device_id)?;
573                let device = &ctx.device;
574                let queue = &ctx.queue;
575
576                let total_elements: usize = self.view.shape.iter().product();
577                let output_size = total_elements * std::mem::size_of::<T>();
578
579                // Create output buffer
580                let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
581                    label: Some("pooled_strided_materialize_output"),
582                    size: output_size as u64,
583                    usage: wgpu::BufferUsages::STORAGE
584                        | wgpu::BufferUsages::COPY_SRC
585                        | wgpu::BufferUsages::COPY_DST,
586                    mapped_at_creation: false,
587                });
588
589                // Create info buffer with view parameters including pool offset
590                #[repr(C)]
591                #[derive(bytemuck::Pod, bytemuck::Zeroable, Copy, Clone)]
592                struct PooledStridedInfo {
593                    ndim: u32,
594                    total_elements: u32,
595                    offset: u32,
596                    element_size: u32,
597                    pool_offset: u32, // Additional offset from pooled buffer
598                    pad: [u32; 3],    // Padding for alignment
599                }
600
601                let info = PooledStridedInfo {
602                    ndim: self.view.shape.len() as u32,
603                    total_elements: total_elements as u32,
604                    offset: self.view.offset as u32,
605                    element_size: self.view.element_size as u32,
606                    pool_offset: (pooled_buffer.offset() / std::mem::size_of::<T>()) as u32,
607                    pad: [0; 3],
608                };
609
610                let info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
611                    label: Some("pooled_strided_info"),
612                    contents: bytemuck::cast_slice(&[info]),
613                    usage: wgpu::BufferUsages::UNIFORM,
614                });
615
616                // Create shape and strides buffers
617                let shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
618                    label: Some("pooled_strided_shape"),
619                    contents: bytemuck::cast_slice(
620                        &self
621                            .view
622                            .shape
623                            .iter()
624                            .map(|&x| x as u32)
625                            .collect::<Vec<u32>>(),
626                    ),
627                    usage: wgpu::BufferUsages::STORAGE,
628                });
629
630                let strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
631                    label: Some("pooled_strided_strides"),
632                    contents: bytemuck::cast_slice(
633                        &self
634                            .view
635                            .strides
636                            .iter()
637                            .map(|&x| x as u32)
638                            .collect::<Vec<u32>>(),
639                    ),
640                    usage: wgpu::BufferUsages::STORAGE,
641                });
642
643                // Create compute shader for pooled strided materialization
644                let shader_source = include_str!("gpu/shaders/strided_ops.wgsl");
645                let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
646                    label: Some("pooled_strided_materialize_shader"),
647                    source: wgpu::ShaderSource::Wgsl(shader_source.into()),
648                });
649
650                // Create bind group layout (same as regular strided)
651                let bind_group_layout =
652                    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
653                        label: Some("pooled_strided_materialize_bind_group_layout"),
654                        entries: &[
655                            wgpu::BindGroupLayoutEntry {
656                                binding: 0,
657                                visibility: wgpu::ShaderStages::COMPUTE,
658                                ty: wgpu::BindingType::Buffer {
659                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
660                                    has_dynamic_offset: false,
661                                    min_binding_size: None,
662                                },
663                                count: None,
664                            },
665                            wgpu::BindGroupLayoutEntry {
666                                binding: 1,
667                                visibility: wgpu::ShaderStages::COMPUTE,
668                                ty: wgpu::BindingType::Buffer {
669                                    ty: wgpu::BufferBindingType::Storage { read_only: false },
670                                    has_dynamic_offset: false,
671                                    min_binding_size: None,
672                                },
673                                count: None,
674                            },
675                            wgpu::BindGroupLayoutEntry {
676                                binding: 2,
677                                visibility: wgpu::ShaderStages::COMPUTE,
678                                ty: wgpu::BindingType::Buffer {
679                                    ty: wgpu::BufferBindingType::Uniform,
680                                    has_dynamic_offset: false,
681                                    min_binding_size: None,
682                                },
683                                count: None,
684                            },
685                            wgpu::BindGroupLayoutEntry {
686                                binding: 3,
687                                visibility: wgpu::ShaderStages::COMPUTE,
688                                ty: wgpu::BindingType::Buffer {
689                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
690                                    has_dynamic_offset: false,
691                                    min_binding_size: None,
692                                },
693                                count: None,
694                            },
695                            wgpu::BindGroupLayoutEntry {
696                                binding: 4,
697                                visibility: wgpu::ShaderStages::COMPUTE,
698                                ty: wgpu::BindingType::Buffer {
699                                    ty: wgpu::BufferBindingType::Storage { read_only: true },
700                                    has_dynamic_offset: false,
701                                    min_binding_size: None,
702                                },
703                                count: None,
704                            },
705                        ],
706                    });
707
708                // Create bind group using pool buffer
709                let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
710                    label: Some("pooled_strided_materialize_bind_group"),
711                    layout: &bind_group_layout,
712                    entries: &[
713                        wgpu::BindGroupEntry {
714                            binding: 0,
715                            resource: pooled_buffer.buffer().as_entire_binding(),
716                        },
717                        wgpu::BindGroupEntry {
718                            binding: 1,
719                            resource: output_buffer.as_entire_binding(),
720                        },
721                        wgpu::BindGroupEntry {
722                            binding: 2,
723                            resource: info_buffer.as_entire_binding(),
724                        },
725                        wgpu::BindGroupEntry {
726                            binding: 3,
727                            resource: shape_buffer.as_entire_binding(),
728                        },
729                        wgpu::BindGroupEntry {
730                            binding: 4,
731                            resource: strides_buffer.as_entire_binding(),
732                        },
733                    ],
734                });
735
736                // Create pipeline
737                let pipeline_layout =
738                    device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
739                        label: Some("pooled_strided_materialize_pipeline_layout"),
740                        bind_group_layouts: &[&bind_group_layout],
741                        immediate_size: 0,
742                    });
743
744                let compute_pipeline =
745                    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
746                        label: Some("pooled_strided_materialize_pipeline"),
747                        layout: Some(&pipeline_layout),
748                        module: &shader_module,
749                        entry_point: Some("strided_materialize"),
750                        cache: None,
751                        compilation_options: Default::default(),
752                    });
753
754                // Execute compute shader
755                let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
756                    label: Some("pooled_strided_materialize_encoder"),
757                });
758
759                {
760                    let mut compute_pass =
761                        encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
762                            label: Some("pooled_strided_materialize_pass"),
763                            timestamp_writes: None,
764                        });
765
766                    compute_pass.set_pipeline(&compute_pipeline);
767                    compute_pass.set_bind_group(0, &bind_group, &[]);
768
769                    let workgroup_size = 64;
770                    let num_workgroups = (total_elements + workgroup_size - 1) / workgroup_size;
771                    compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
772                }
773
774                queue.submit(std::iter::once(encoder.finish()));
775                device.poll(wgpu::PollType::wait_indefinitely()).ok();
776
777                // Create result GPU buffer
778                let gpu_buffer_result = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
779                    output_buffer,
780                    ctx.device.clone(),
781                    ctx.queue.clone(),
782                    Device::Gpu(device_id),
783                    total_elements,
784                );
785
786                let shape = Shape::new(self.view.shape.clone());
787                let mut result = Tensor::from_gpu_buffer(gpu_buffer_result, shape);
788                result.set_requires_grad(self.requires_grad);
789                Ok(result)
790            }
791        }
792    }
793
794    /// Get the shape of the tensor view
795    pub fn shape(&self) -> &[usize] {
796        &self.view.shape
797    }
798
799    /// Get the strides of the tensor view
800    pub fn strides(&self) -> &[usize] {
801        &self.view.strides
802    }
803
804    /// Get the device where the tensor is stored
805    pub fn device(&self) -> &Device {
806        &self.device
807    }
808
809    /// Check if the tensor requires gradient computation
810    pub fn requires_grad(&self) -> bool {
811        self.requires_grad
812    }
813
814    /// Set whether the tensor requires gradient computation
815    pub fn set_requires_grad(&mut self, requires_grad: bool) {
816        self.requires_grad = requires_grad;
817    }
818}
819
820impl<T> Clone for TensorView<T>
821where
822    T: Clone + Default + Send + Sync + 'static,
823{
824    fn clone(&self) -> Self {
825        // Register the cloned view in the alias detector
826        let buffer_id = self.get_buffer_id();
827        self.alias_detector
828            .register_view(buffer_id, self.view.offset, self.view.size_bytes());
829
830        Self {
831            storage: self.storage.clone(),
832            view: self.view.clone(),
833            device: self.device,
834            requires_grad: self.requires_grad,
835            grad: self.grad.clone(),
836            alias_detector: Arc::clone(&self.alias_detector),
837        }
838    }
839}
840
841impl<T> Drop for TensorView<T> {
842    fn drop(&mut self) {
843        // Unregister the view from the alias detector
844        let buffer_id = self.get_buffer_id();
845        self.alias_detector
846            .unregister_view(buffer_id, self.view.offset, self.view.size_bytes());
847    }
848}
849
850/// Memory-efficient tensor operations using views
851pub struct TensorViewOps;
852
853impl TensorViewOps {
854    /// Perform zero-copy transpose if possible
855    pub fn transpose_zero_copy<T>(tensor: &TensorView<T>, axes: &[usize]) -> Result<TensorView<T>>
856    where
857        T: Clone + Default + Send + Sync + 'static,
858    {
859        tensor.transpose(axes)
860    }
861
862    /// Perform zero-copy reshape if possible
863    pub fn reshape_zero_copy<T>(
864        tensor: &TensorView<T>,
865        new_shape: &[usize],
866    ) -> Result<TensorView<T>>
867    where
868        T: Clone + Default + Send + Sync + 'static,
869    {
870        tensor.reshape(new_shape)
871    }
872
873    /// Create a zero-copy slice view
874    pub fn slice_zero_copy<T>(
875        tensor: &TensorView<T>,
876        ranges: &[(usize, usize)],
877    ) -> Result<TensorView<T>>
878    where
879        T: Clone + Default + Send + Sync + 'static,
880    {
881        tensor.slice(ranges)
882    }
883
884    /// Check if two tensor views share memory
885    pub fn shares_memory<T>(tensor1: &TensorView<T>, tensor2: &TensorView<T>) -> bool
886    where
887        T: Clone + Default,
888    {
889        tensor1.get_buffer_id() == tensor2.get_buffer_id()
890    }
891
892    /// Get memory usage statistics for a tensor view
893    pub fn memory_stats<T>(tensor: &TensorView<T>) -> MemoryStats
894    where
895        T: Clone + Default,
896    {
897        MemoryStats {
898            total_elements: tensor.numel(),
899            size_bytes: tensor.size_bytes(),
900            is_contiguous: tensor.is_contiguous(),
901            has_aliases: Self::has_memory_aliases(tensor),
902        }
903    }
904
905    /// Check if a tensor view has memory aliases
906    ///
907    /// Returns true if the tensor's memory region overlaps with any other tracked tensors.
908    /// This helps detect potentially unsafe operations on aliased memory.
909    pub fn has_memory_aliases<T>(tensor: &TensorView<T>) -> bool
910    where
911        T: Clone + Default,
912    {
913        // Use the alias detector to check for overlapping memory regions
914        let buffer_id = tensor.get_buffer_id();
915        let start_offset = tensor.view.offset;
916        let size_bytes = tensor.size_bytes();
917
918        // Check with the memory alias detector
919        tensor
920            .alias_detector
921            .check_alias(buffer_id, start_offset, size_bytes)
922    }
923}
924
925/// Memory usage statistics for tensor views
926#[derive(Debug, Clone)]
927pub struct MemoryStats {
928    pub total_elements: usize,
929    pub size_bytes: usize,
930    pub is_contiguous: bool,
931    pub has_aliases: bool,
932}
933
934/// Utility functions for index calculations
935fn flat_to_multi_index(flat_index: usize, shape: &[usize]) -> Vec<usize> {
936    let mut multi_index = Vec::with_capacity(shape.len());
937    let mut remaining = flat_index;
938
939    for &dim in shape.iter().rev() {
940        multi_index.push(remaining % dim);
941        remaining /= dim;
942    }
943
944    multi_index.reverse();
945    multi_index
946}
947
948fn multi_to_strided_index(multi_index: &[usize], strides: &[usize], offset: usize) -> usize {
949    let mut strided_index = offset;
950    for (idx, &stride) in multi_index.iter().zip(strides.iter()) {
951        strided_index += idx * stride;
952    }
953    strided_index
954}
955
956fn compute_default_strides(shape: &[usize], element_size: usize) -> Vec<usize> {
957    let mut strides = Vec::with_capacity(shape.len());
958    let mut stride = element_size;
959
960    for &dim in shape.iter().rev() {
961        strides.push(stride);
962        stride *= dim;
963    }
964
965    strides.reverse();
966    strides
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972    use crate::Tensor;
973
974    #[test]
975    fn test_tensor_view_creation() {
976        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
977            .expect("test: from_vec should succeed");
978        let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
979
980        assert_eq!(view.shape(), &[2, 2]);
981        assert_eq!(view.numel(), 4);
982        assert!(view.is_contiguous());
983    }
984
985    #[test]
986    fn test_zero_copy_transpose() {
987        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
988            .expect("test: from_vec should succeed");
989        let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
990
991        let transposed = view
992            .transpose(&[1, 0])
993            .expect("test: transpose should succeed");
994        assert_eq!(transposed.shape(), &[3, 2]);
995        assert_eq!(transposed.strides(), &[4, 12]); // Strides change for transpose
996    }
997
998    #[test]
999    fn test_zero_copy_reshape() {
1000        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1001            .expect("test: from_vec should succeed");
1002        let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1003
1004        let reshaped = view.reshape(&[3, 2]).expect("test: reshape should succeed");
1005        assert_eq!(reshaped.shape(), &[3, 2]);
1006        assert!(reshaped.is_contiguous());
1007    }
1008
1009    #[test]
1010    fn test_zero_copy_slice() {
1011        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1012            .expect("test: from_vec should succeed");
1013        let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1014
1015        let sliced = view
1016            .slice(&[(0, 1), (1, 3)])
1017            .expect("test: operation should succeed");
1018        assert_eq!(sliced.shape(), &[1, 2]);
1019    }
1020
1021    #[test]
1022    fn test_memory_stats() {
1023        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1024            .expect("test: from_vec should succeed");
1025        let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1026
1027        let stats = TensorViewOps::memory_stats(&view);
1028        assert_eq!(stats.total_elements, 4);
1029        assert_eq!(stats.size_bytes, 16); // 4 * 4 bytes
1030        assert!(stats.is_contiguous);
1031    }
1032
1033    #[test]
1034    fn test_shares_memory() {
1035        let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1036            .expect("test: from_vec should succeed");
1037        let view1 = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1038        let view2 = view1
1039            .transpose(&[1, 0])
1040            .expect("test: transpose should succeed");
1041
1042        assert!(TensorViewOps::shares_memory(&view1, &view2));
1043    }
1044
1045    #[test]
1046    fn test_flat_to_multi_index() {
1047        let shape = vec![2, 3];
1048        assert_eq!(flat_to_multi_index(0, &shape), vec![0, 0]);
1049        assert_eq!(flat_to_multi_index(1, &shape), vec![0, 1]);
1050        assert_eq!(flat_to_multi_index(3, &shape), vec![1, 0]);
1051        assert_eq!(flat_to_multi_index(5, &shape), vec![1, 2]);
1052    }
1053
1054    #[test]
1055    fn test_multi_to_strided_index() {
1056        let strides = vec![12, 4];
1057        let offset = 0;
1058        assert_eq!(multi_to_strided_index(&[0, 0], &strides, offset), 0);
1059        assert_eq!(multi_to_strided_index(&[0, 1], &strides, offset), 4);
1060        assert_eq!(multi_to_strided_index(&[1, 0], &strides, offset), 12);
1061        assert_eq!(multi_to_strided_index(&[1, 2], &strides, offset), 20);
1062    }
1063}