1#[cfg(feature = "gpu")]
3use crate::memory::PooledBuffer;
4use crate::memory::{MemoryAliasDetector, StridedView};
5use crate::tensor::TensorStorage;
6use crate::{Device, Result, Shape, Tensor, TensorError};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::sync::Arc;
9
10#[derive(Debug)]
12pub struct TensorView<T> {
13 pub storage: ViewStorage<T>,
15 pub view: StridedView,
17 device: Device,
19 requires_grad: bool,
21 grad: Option<Arc<TensorView<T>>>,
23 alias_detector: Arc<MemoryAliasDetector>,
25}
26
27#[derive(Debug, Clone)]
29pub enum ViewStorage<T> {
30 CpuRef(Arc<ArrayD<T>>),
32 #[cfg(feature = "gpu")]
34 GpuPooled(Arc<PooledBuffer<'static>>),
35 #[cfg(feature = "gpu")]
37 GpuRef(Arc<crate::gpu::buffer::GpuBuffer<T>>),
38}
39
40impl<T> TensorView<T> {
41 fn get_buffer_id(&self) -> usize {
43 match &self.storage {
44 ViewStorage::CpuRef(arr) => Arc::as_ptr(arr) as usize,
45 #[cfg(feature = "gpu")]
46 ViewStorage::GpuRef(gpu_buffer) => Arc::as_ptr(gpu_buffer) as usize,
47 #[cfg(feature = "gpu")]
48 ViewStorage::GpuPooled(pooled_buffer) => Arc::as_ptr(pooled_buffer) as usize,
49 }
50 }
51
52 pub fn numel(&self) -> usize {
54 self.view.shape.iter().product()
55 }
56
57 pub fn size_bytes(&self) -> usize {
59 self.view.size_bytes()
60 }
61
62 pub fn is_contiguous(&self) -> bool {
64 self.view.is_contiguous()
65 }
66}
67
68impl<T: Clone + Default> TensorView<T> {
69 pub fn from_tensor(tensor: &Tensor<T>) -> Result<Self>
71 where
72 T: Clone + Send + Sync + 'static,
73 {
74 let element_size = std::mem::size_of::<T>();
75 let shape = tensor.shape().dims().to_vec();
76 let strides = compute_default_strides(&shape, element_size);
77
78 let view = StridedView::new(0, shape, strides, element_size);
79 let alias_detector = Arc::new(MemoryAliasDetector::new());
80
81 let storage = match &tensor.storage {
82 TensorStorage::Cpu(arr) => ViewStorage::CpuRef(Arc::new(arr.clone())),
83 #[cfg(feature = "gpu")]
84 TensorStorage::Gpu(gpu_buffer) => ViewStorage::GpuRef(Arc::new(gpu_buffer.clone())),
85 };
86
87 Ok(Self {
88 storage,
89 view,
90 device: *tensor.device(),
91 requires_grad: tensor.requires_grad(),
92 grad: None,
93 alias_detector,
94 })
95 }
96
97 pub fn transpose(&self, axes: &[usize]) -> Result<TensorView<T>>
99 where
100 T: Clone + Send + Sync + 'static,
101 {
102 let new_view = self.view.transpose(axes)?;
103
104 let buffer_id = self.get_buffer_id();
106 if self
107 .alias_detector
108 .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
109 {
110 return Err(TensorError::invalid_argument(
111 "Transpose would create memory alias".to_string(),
112 ));
113 }
114
115 self.alias_detector
117 .register_view(buffer_id, new_view.offset, new_view.size_bytes());
118
119 Ok(Self {
120 storage: self.storage.clone(),
121 view: new_view,
122 device: self.device,
123 requires_grad: self.requires_grad,
124 grad: None,
125 alias_detector: Arc::clone(&self.alias_detector),
126 })
127 }
128
129 pub fn reshape(&self, new_shape: &[usize]) -> Result<TensorView<T>>
131 where
132 T: Clone + Send + Sync + 'static,
133 {
134 let new_view = self.view.reshape(new_shape)?;
135
136 let buffer_id = self.get_buffer_id();
138 if self
139 .alias_detector
140 .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
141 {
142 return Err(TensorError::invalid_argument(
143 "Reshape would create memory alias".to_string(),
144 ));
145 }
146
147 self.alias_detector
149 .register_view(buffer_id, new_view.offset, new_view.size_bytes());
150
151 Ok(Self {
152 storage: self.storage.clone(),
153 view: new_view,
154 device: self.device,
155 requires_grad: self.requires_grad,
156 grad: None,
157 alias_detector: Arc::clone(&self.alias_detector),
158 })
159 }
160
161 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<TensorView<T>>
163 where
164 T: Clone + Send + Sync + 'static,
165 {
166 let new_view = self.view.slice(ranges)?;
167
168 let buffer_id = self.get_buffer_id();
170 if self
171 .alias_detector
172 .check_alias(buffer_id, new_view.offset, new_view.size_bytes())
173 {
174 return Err(TensorError::invalid_argument(
175 "Slice would create memory alias".to_string(),
176 ));
177 }
178
179 self.alias_detector
181 .register_view(buffer_id, new_view.offset, new_view.size_bytes());
182
183 Ok(Self {
184 storage: self.storage.clone(),
185 view: new_view,
186 device: self.device,
187 requires_grad: self.requires_grad,
188 grad: None,
189 alias_detector: Arc::clone(&self.alias_detector),
190 })
191 }
192
193 pub fn to_tensor(&self) -> Result<Tensor<T>>
195 where
196 T: Clone + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
197 {
198 if self.view.is_contiguous() {
199 match &self.storage {
201 ViewStorage::CpuRef(arr) => {
202 let _shape = Shape::new(self.view.shape.clone());
203 Ok(Tensor::from_array((**arr).clone()))
204 }
205 #[cfg(feature = "gpu")]
206 ViewStorage::GpuRef(gpu_buffer) => {
207 let shape = Shape::new(self.view.shape.clone());
208 let mut result = Tensor::from_gpu_buffer((**gpu_buffer).clone(), shape);
209 result.set_requires_grad(self.requires_grad);
210 Ok(result)
211 }
212 #[cfg(feature = "gpu")]
213 ViewStorage::GpuPooled(pooled_buffer) => {
214 use wgpu::util::DeviceExt;
216
217 let pool_buffer = pooled_buffer.buffer();
218 let device = &crate::device::context::get_gpu_context(match self.device {
219 Device::Gpu(id) => id,
220 _ => {
221 return Err(TensorError::device_error_simple(
222 "Expected GPU device".to_string(),
223 ))
224 }
225 })?
226 .device;
227
228 let data_size =
230 self.view.shape.iter().product::<usize>() * std::mem::size_of::<T>();
231 let new_buffer = device.create_buffer(&wgpu::BufferDescriptor {
232 label: Some("pooled_to_tensor_buffer"),
233 size: data_size as u64,
234 usage: wgpu::BufferUsages::STORAGE
235 | wgpu::BufferUsages::COPY_SRC
236 | wgpu::BufferUsages::COPY_DST,
237 mapped_at_creation: false,
238 });
239
240 let mut encoder =
242 device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
243 label: Some("pooled_buffer_copy"),
244 });
245
246 encoder.copy_buffer_to_buffer(
247 pool_buffer,
248 pooled_buffer.offset() as u64,
249 &new_buffer,
250 0,
251 data_size as u64,
252 );
253
254 let queue = &crate::device::context::get_gpu_context(match self.device {
255 Device::Gpu(id) => id,
256 _ => {
257 return Err(TensorError::device_error_simple(
258 "Expected GPU device".to_string(),
259 ))
260 }
261 })?
262 .queue;
263
264 queue.submit(std::iter::once(encoder.finish()));
265 device.poll(wgpu::PollType::wait_indefinitely()).ok();
266
267 let device_id = match self.device {
269 Device::Gpu(id) => id,
270 _ => {
271 return Err(TensorError::device_error_simple(
272 "Expected GPU device".to_string(),
273 ))
274 }
275 };
276
277 let ctx = crate::device::context::get_gpu_context(device_id)?;
278 let gpu_buffer = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
279 new_buffer,
280 ctx.device.clone(),
281 ctx.queue.clone(),
282 Device::Gpu(device_id),
283 self.view.shape.iter().product::<usize>(),
284 );
285
286 let shape = Shape::new(self.view.shape.clone());
287 let mut result = Tensor::from_gpu_buffer(gpu_buffer, shape);
288 result.set_requires_grad(self.requires_grad);
289 Ok(result)
290 }
291 }
292 } else {
293 self.materialize()
295 }
296 }
297
298 fn materialize(&self) -> Result<Tensor<T>>
300 where
301 T: Clone + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
302 {
303 match &self.storage {
304 ViewStorage::CpuRef(arr) => {
305 let total_elements: usize = self.view.shape.iter().product();
307 let mut new_data = Vec::with_capacity(total_elements);
308
309 for flat_index in 0..total_elements {
311 let multi_index = flat_to_multi_index(flat_index, &self.view.shape);
312 let strided_index =
313 multi_to_strided_index(&multi_index, &self.view.strides, self.view.offset);
314
315 let byte_index = strided_index / self.view.element_size;
316 if let Some(slice) = arr.as_slice() {
317 if byte_index < slice.len() {
318 new_data.push(slice[byte_index]);
319 }
320 }
321 }
322
323 let _shape = Shape::new(self.view.shape.clone());
324 let new_array = ArrayD::from_shape_vec(IxDyn(&self.view.shape), new_data)
325 .map_err(|e| TensorError::invalid_argument(e.to_string()))?;
326
327 Ok(Tensor::from_array(new_array))
328 }
329 #[cfg(feature = "gpu")]
330 ViewStorage::GpuRef(gpu_buffer) => {
331 use wgpu::util::DeviceExt;
333
334 let device_id = match self.device {
335 Device::Gpu(id) => id,
336 _ => {
337 return Err(TensorError::device_error_simple(
338 "Expected GPU device".to_string(),
339 ))
340 }
341 };
342
343 let ctx = crate::device::context::get_gpu_context(device_id)?;
344 let device = &ctx.device;
345 let queue = &ctx.queue;
346
347 let total_elements: usize = self.view.shape.iter().product();
348 let output_size = total_elements * std::mem::size_of::<T>();
349
350 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
352 label: Some("strided_materialize_output"),
353 size: output_size as u64,
354 usage: wgpu::BufferUsages::STORAGE
355 | wgpu::BufferUsages::COPY_SRC
356 | wgpu::BufferUsages::COPY_DST,
357 mapped_at_creation: false,
358 });
359
360 #[repr(C)]
362 #[derive(bytemuck::Pod, bytemuck::Zeroable, Copy, Clone)]
363 struct StridedInfo {
364 ndim: u32,
365 total_elements: u32,
366 offset: u32,
367 element_size: u32,
368 }
369
370 let info = StridedInfo {
371 ndim: self.view.shape.len() as u32,
372 total_elements: total_elements as u32,
373 offset: self.view.offset as u32,
374 element_size: self.view.element_size as u32,
375 };
376
377 let info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
378 label: Some("strided_info"),
379 contents: bytemuck::cast_slice(&[info]),
380 usage: wgpu::BufferUsages::UNIFORM,
381 });
382
383 let shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
385 label: Some("strided_shape"),
386 contents: bytemuck::cast_slice(
387 &self
388 .view
389 .shape
390 .iter()
391 .map(|&x| x as u32)
392 .collect::<Vec<u32>>(),
393 ),
394 usage: wgpu::BufferUsages::STORAGE,
395 });
396
397 let strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
398 label: Some("strided_strides"),
399 contents: bytemuck::cast_slice(
400 &self
401 .view
402 .strides
403 .iter()
404 .map(|&x| x as u32)
405 .collect::<Vec<u32>>(),
406 ),
407 usage: wgpu::BufferUsages::STORAGE,
408 });
409
410 let shader_source = include_str!("gpu/shaders/strided_ops.wgsl");
412 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
413 label: Some("strided_materialize_shader"),
414 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
415 });
416
417 let bind_group_layout =
419 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
420 label: Some("strided_materialize_bind_group_layout"),
421 entries: &[
422 wgpu::BindGroupLayoutEntry {
423 binding: 0,
424 visibility: wgpu::ShaderStages::COMPUTE,
425 ty: wgpu::BindingType::Buffer {
426 ty: wgpu::BufferBindingType::Storage { read_only: true },
427 has_dynamic_offset: false,
428 min_binding_size: None,
429 },
430 count: None,
431 },
432 wgpu::BindGroupLayoutEntry {
433 binding: 1,
434 visibility: wgpu::ShaderStages::COMPUTE,
435 ty: wgpu::BindingType::Buffer {
436 ty: wgpu::BufferBindingType::Storage { read_only: false },
437 has_dynamic_offset: false,
438 min_binding_size: None,
439 },
440 count: None,
441 },
442 wgpu::BindGroupLayoutEntry {
443 binding: 2,
444 visibility: wgpu::ShaderStages::COMPUTE,
445 ty: wgpu::BindingType::Buffer {
446 ty: wgpu::BufferBindingType::Uniform,
447 has_dynamic_offset: false,
448 min_binding_size: None,
449 },
450 count: None,
451 },
452 wgpu::BindGroupLayoutEntry {
453 binding: 3,
454 visibility: wgpu::ShaderStages::COMPUTE,
455 ty: wgpu::BindingType::Buffer {
456 ty: wgpu::BufferBindingType::Storage { read_only: true },
457 has_dynamic_offset: false,
458 min_binding_size: None,
459 },
460 count: None,
461 },
462 wgpu::BindGroupLayoutEntry {
463 binding: 4,
464 visibility: wgpu::ShaderStages::COMPUTE,
465 ty: wgpu::BindingType::Buffer {
466 ty: wgpu::BufferBindingType::Storage { read_only: true },
467 has_dynamic_offset: false,
468 min_binding_size: None,
469 },
470 count: None,
471 },
472 ],
473 });
474
475 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
477 label: Some("strided_materialize_bind_group"),
478 layout: &bind_group_layout,
479 entries: &[
480 wgpu::BindGroupEntry {
481 binding: 0,
482 resource: gpu_buffer.buffer().as_entire_binding(),
483 },
484 wgpu::BindGroupEntry {
485 binding: 1,
486 resource: output_buffer.as_entire_binding(),
487 },
488 wgpu::BindGroupEntry {
489 binding: 2,
490 resource: info_buffer.as_entire_binding(),
491 },
492 wgpu::BindGroupEntry {
493 binding: 3,
494 resource: shape_buffer.as_entire_binding(),
495 },
496 wgpu::BindGroupEntry {
497 binding: 4,
498 resource: strides_buffer.as_entire_binding(),
499 },
500 ],
501 });
502
503 let pipeline_layout =
505 device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
506 label: Some("strided_materialize_pipeline_layout"),
507 bind_group_layouts: &[&bind_group_layout],
508 immediate_size: 0,
509 });
510
511 let compute_pipeline =
512 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
513 label: Some("strided_materialize_pipeline"),
514 layout: Some(&pipeline_layout),
515 module: &shader_module,
516 entry_point: Some("strided_materialize"),
517 cache: None,
518 compilation_options: Default::default(),
519 });
520
521 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
523 label: Some("strided_materialize_encoder"),
524 });
525
526 {
527 let mut compute_pass =
528 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
529 label: Some("strided_materialize_pass"),
530 timestamp_writes: None,
531 });
532
533 compute_pass.set_pipeline(&compute_pipeline);
534 compute_pass.set_bind_group(0, &bind_group, &[]);
535
536 let workgroup_size = 64;
537 let num_workgroups = (total_elements + workgroup_size - 1) / workgroup_size;
538 compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
539 }
540
541 queue.submit(std::iter::once(encoder.finish()));
542 device.poll(wgpu::PollType::wait_indefinitely()).ok();
543
544 let gpu_buffer_result = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
546 output_buffer,
547 ctx.device.clone(),
548 ctx.queue.clone(),
549 Device::Gpu(device_id),
550 total_elements,
551 );
552
553 let shape = Shape::new(self.view.shape.clone());
554 let mut result = Tensor::from_gpu_buffer(gpu_buffer_result, shape);
555 result.set_requires_grad(self.requires_grad);
556 Ok(result)
557 }
558 #[cfg(feature = "gpu")]
559 ViewStorage::GpuPooled(pooled_buffer) => {
560 use wgpu::util::DeviceExt;
562
563 let device_id = match self.device {
564 Device::Gpu(id) => id,
565 _ => {
566 return Err(TensorError::device_error_simple(
567 "Expected GPU device".to_string(),
568 ))
569 }
570 };
571
572 let ctx = crate::device::context::get_gpu_context(device_id)?;
573 let device = &ctx.device;
574 let queue = &ctx.queue;
575
576 let total_elements: usize = self.view.shape.iter().product();
577 let output_size = total_elements * std::mem::size_of::<T>();
578
579 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
581 label: Some("pooled_strided_materialize_output"),
582 size: output_size as u64,
583 usage: wgpu::BufferUsages::STORAGE
584 | wgpu::BufferUsages::COPY_SRC
585 | wgpu::BufferUsages::COPY_DST,
586 mapped_at_creation: false,
587 });
588
589 #[repr(C)]
591 #[derive(bytemuck::Pod, bytemuck::Zeroable, Copy, Clone)]
592 struct PooledStridedInfo {
593 ndim: u32,
594 total_elements: u32,
595 offset: u32,
596 element_size: u32,
597 pool_offset: u32, pad: [u32; 3], }
600
601 let info = PooledStridedInfo {
602 ndim: self.view.shape.len() as u32,
603 total_elements: total_elements as u32,
604 offset: self.view.offset as u32,
605 element_size: self.view.element_size as u32,
606 pool_offset: (pooled_buffer.offset() / std::mem::size_of::<T>()) as u32,
607 pad: [0; 3],
608 };
609
610 let info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
611 label: Some("pooled_strided_info"),
612 contents: bytemuck::cast_slice(&[info]),
613 usage: wgpu::BufferUsages::UNIFORM,
614 });
615
616 let shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
618 label: Some("pooled_strided_shape"),
619 contents: bytemuck::cast_slice(
620 &self
621 .view
622 .shape
623 .iter()
624 .map(|&x| x as u32)
625 .collect::<Vec<u32>>(),
626 ),
627 usage: wgpu::BufferUsages::STORAGE,
628 });
629
630 let strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
631 label: Some("pooled_strided_strides"),
632 contents: bytemuck::cast_slice(
633 &self
634 .view
635 .strides
636 .iter()
637 .map(|&x| x as u32)
638 .collect::<Vec<u32>>(),
639 ),
640 usage: wgpu::BufferUsages::STORAGE,
641 });
642
643 let shader_source = include_str!("gpu/shaders/strided_ops.wgsl");
645 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
646 label: Some("pooled_strided_materialize_shader"),
647 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
648 });
649
650 let bind_group_layout =
652 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
653 label: Some("pooled_strided_materialize_bind_group_layout"),
654 entries: &[
655 wgpu::BindGroupLayoutEntry {
656 binding: 0,
657 visibility: wgpu::ShaderStages::COMPUTE,
658 ty: wgpu::BindingType::Buffer {
659 ty: wgpu::BufferBindingType::Storage { read_only: true },
660 has_dynamic_offset: false,
661 min_binding_size: None,
662 },
663 count: None,
664 },
665 wgpu::BindGroupLayoutEntry {
666 binding: 1,
667 visibility: wgpu::ShaderStages::COMPUTE,
668 ty: wgpu::BindingType::Buffer {
669 ty: wgpu::BufferBindingType::Storage { read_only: false },
670 has_dynamic_offset: false,
671 min_binding_size: None,
672 },
673 count: None,
674 },
675 wgpu::BindGroupLayoutEntry {
676 binding: 2,
677 visibility: wgpu::ShaderStages::COMPUTE,
678 ty: wgpu::BindingType::Buffer {
679 ty: wgpu::BufferBindingType::Uniform,
680 has_dynamic_offset: false,
681 min_binding_size: None,
682 },
683 count: None,
684 },
685 wgpu::BindGroupLayoutEntry {
686 binding: 3,
687 visibility: wgpu::ShaderStages::COMPUTE,
688 ty: wgpu::BindingType::Buffer {
689 ty: wgpu::BufferBindingType::Storage { read_only: true },
690 has_dynamic_offset: false,
691 min_binding_size: None,
692 },
693 count: None,
694 },
695 wgpu::BindGroupLayoutEntry {
696 binding: 4,
697 visibility: wgpu::ShaderStages::COMPUTE,
698 ty: wgpu::BindingType::Buffer {
699 ty: wgpu::BufferBindingType::Storage { read_only: true },
700 has_dynamic_offset: false,
701 min_binding_size: None,
702 },
703 count: None,
704 },
705 ],
706 });
707
708 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
710 label: Some("pooled_strided_materialize_bind_group"),
711 layout: &bind_group_layout,
712 entries: &[
713 wgpu::BindGroupEntry {
714 binding: 0,
715 resource: pooled_buffer.buffer().as_entire_binding(),
716 },
717 wgpu::BindGroupEntry {
718 binding: 1,
719 resource: output_buffer.as_entire_binding(),
720 },
721 wgpu::BindGroupEntry {
722 binding: 2,
723 resource: info_buffer.as_entire_binding(),
724 },
725 wgpu::BindGroupEntry {
726 binding: 3,
727 resource: shape_buffer.as_entire_binding(),
728 },
729 wgpu::BindGroupEntry {
730 binding: 4,
731 resource: strides_buffer.as_entire_binding(),
732 },
733 ],
734 });
735
736 let pipeline_layout =
738 device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
739 label: Some("pooled_strided_materialize_pipeline_layout"),
740 bind_group_layouts: &[&bind_group_layout],
741 immediate_size: 0,
742 });
743
744 let compute_pipeline =
745 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
746 label: Some("pooled_strided_materialize_pipeline"),
747 layout: Some(&pipeline_layout),
748 module: &shader_module,
749 entry_point: Some("strided_materialize"),
750 cache: None,
751 compilation_options: Default::default(),
752 });
753
754 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
756 label: Some("pooled_strided_materialize_encoder"),
757 });
758
759 {
760 let mut compute_pass =
761 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
762 label: Some("pooled_strided_materialize_pass"),
763 timestamp_writes: None,
764 });
765
766 compute_pass.set_pipeline(&compute_pipeline);
767 compute_pass.set_bind_group(0, &bind_group, &[]);
768
769 let workgroup_size = 64;
770 let num_workgroups = (total_elements + workgroup_size - 1) / workgroup_size;
771 compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
772 }
773
774 queue.submit(std::iter::once(encoder.finish()));
775 device.poll(wgpu::PollType::wait_indefinitely()).ok();
776
777 let gpu_buffer_result = crate::gpu::buffer::GpuBuffer::from_wgpu_buffer(
779 output_buffer,
780 ctx.device.clone(),
781 ctx.queue.clone(),
782 Device::Gpu(device_id),
783 total_elements,
784 );
785
786 let shape = Shape::new(self.view.shape.clone());
787 let mut result = Tensor::from_gpu_buffer(gpu_buffer_result, shape);
788 result.set_requires_grad(self.requires_grad);
789 Ok(result)
790 }
791 }
792 }
793
794 pub fn shape(&self) -> &[usize] {
796 &self.view.shape
797 }
798
799 pub fn strides(&self) -> &[usize] {
801 &self.view.strides
802 }
803
804 pub fn device(&self) -> &Device {
806 &self.device
807 }
808
809 pub fn requires_grad(&self) -> bool {
811 self.requires_grad
812 }
813
814 pub fn set_requires_grad(&mut self, requires_grad: bool) {
816 self.requires_grad = requires_grad;
817 }
818}
819
820impl<T> Clone for TensorView<T>
821where
822 T: Clone + Default + Send + Sync + 'static,
823{
824 fn clone(&self) -> Self {
825 let buffer_id = self.get_buffer_id();
827 self.alias_detector
828 .register_view(buffer_id, self.view.offset, self.view.size_bytes());
829
830 Self {
831 storage: self.storage.clone(),
832 view: self.view.clone(),
833 device: self.device,
834 requires_grad: self.requires_grad,
835 grad: self.grad.clone(),
836 alias_detector: Arc::clone(&self.alias_detector),
837 }
838 }
839}
840
841impl<T> Drop for TensorView<T> {
842 fn drop(&mut self) {
843 let buffer_id = self.get_buffer_id();
845 self.alias_detector
846 .unregister_view(buffer_id, self.view.offset, self.view.size_bytes());
847 }
848}
849
850pub struct TensorViewOps;
852
853impl TensorViewOps {
854 pub fn transpose_zero_copy<T>(tensor: &TensorView<T>, axes: &[usize]) -> Result<TensorView<T>>
856 where
857 T: Clone + Default + Send + Sync + 'static,
858 {
859 tensor.transpose(axes)
860 }
861
862 pub fn reshape_zero_copy<T>(
864 tensor: &TensorView<T>,
865 new_shape: &[usize],
866 ) -> Result<TensorView<T>>
867 where
868 T: Clone + Default + Send + Sync + 'static,
869 {
870 tensor.reshape(new_shape)
871 }
872
873 pub fn slice_zero_copy<T>(
875 tensor: &TensorView<T>,
876 ranges: &[(usize, usize)],
877 ) -> Result<TensorView<T>>
878 where
879 T: Clone + Default + Send + Sync + 'static,
880 {
881 tensor.slice(ranges)
882 }
883
884 pub fn shares_memory<T>(tensor1: &TensorView<T>, tensor2: &TensorView<T>) -> bool
886 where
887 T: Clone + Default,
888 {
889 tensor1.get_buffer_id() == tensor2.get_buffer_id()
890 }
891
892 pub fn memory_stats<T>(tensor: &TensorView<T>) -> MemoryStats
894 where
895 T: Clone + Default,
896 {
897 MemoryStats {
898 total_elements: tensor.numel(),
899 size_bytes: tensor.size_bytes(),
900 is_contiguous: tensor.is_contiguous(),
901 has_aliases: Self::has_memory_aliases(tensor),
902 }
903 }
904
905 pub fn has_memory_aliases<T>(tensor: &TensorView<T>) -> bool
910 where
911 T: Clone + Default,
912 {
913 let buffer_id = tensor.get_buffer_id();
915 let start_offset = tensor.view.offset;
916 let size_bytes = tensor.size_bytes();
917
918 tensor
920 .alias_detector
921 .check_alias(buffer_id, start_offset, size_bytes)
922 }
923}
924
925#[derive(Debug, Clone)]
927pub struct MemoryStats {
928 pub total_elements: usize,
929 pub size_bytes: usize,
930 pub is_contiguous: bool,
931 pub has_aliases: bool,
932}
933
934fn flat_to_multi_index(flat_index: usize, shape: &[usize]) -> Vec<usize> {
936 let mut multi_index = Vec::with_capacity(shape.len());
937 let mut remaining = flat_index;
938
939 for &dim in shape.iter().rev() {
940 multi_index.push(remaining % dim);
941 remaining /= dim;
942 }
943
944 multi_index.reverse();
945 multi_index
946}
947
948fn multi_to_strided_index(multi_index: &[usize], strides: &[usize], offset: usize) -> usize {
949 let mut strided_index = offset;
950 for (idx, &stride) in multi_index.iter().zip(strides.iter()) {
951 strided_index += idx * stride;
952 }
953 strided_index
954}
955
956fn compute_default_strides(shape: &[usize], element_size: usize) -> Vec<usize> {
957 let mut strides = Vec::with_capacity(shape.len());
958 let mut stride = element_size;
959
960 for &dim in shape.iter().rev() {
961 strides.push(stride);
962 stride *= dim;
963 }
964
965 strides.reverse();
966 strides
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972 use crate::Tensor;
973
974 #[test]
975 fn test_tensor_view_creation() {
976 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
977 .expect("test: from_vec should succeed");
978 let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
979
980 assert_eq!(view.shape(), &[2, 2]);
981 assert_eq!(view.numel(), 4);
982 assert!(view.is_contiguous());
983 }
984
985 #[test]
986 fn test_zero_copy_transpose() {
987 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
988 .expect("test: from_vec should succeed");
989 let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
990
991 let transposed = view
992 .transpose(&[1, 0])
993 .expect("test: transpose should succeed");
994 assert_eq!(transposed.shape(), &[3, 2]);
995 assert_eq!(transposed.strides(), &[4, 12]); }
997
998 #[test]
999 fn test_zero_copy_reshape() {
1000 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1001 .expect("test: from_vec should succeed");
1002 let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1003
1004 let reshaped = view.reshape(&[3, 2]).expect("test: reshape should succeed");
1005 assert_eq!(reshaped.shape(), &[3, 2]);
1006 assert!(reshaped.is_contiguous());
1007 }
1008
1009 #[test]
1010 fn test_zero_copy_slice() {
1011 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1012 .expect("test: from_vec should succeed");
1013 let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1014
1015 let sliced = view
1016 .slice(&[(0, 1), (1, 3)])
1017 .expect("test: operation should succeed");
1018 assert_eq!(sliced.shape(), &[1, 2]);
1019 }
1020
1021 #[test]
1022 fn test_memory_stats() {
1023 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1024 .expect("test: from_vec should succeed");
1025 let view = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1026
1027 let stats = TensorViewOps::memory_stats(&view);
1028 assert_eq!(stats.total_elements, 4);
1029 assert_eq!(stats.size_bytes, 16); assert!(stats.is_contiguous);
1031 }
1032
1033 #[test]
1034 fn test_shares_memory() {
1035 let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1036 .expect("test: from_vec should succeed");
1037 let view1 = TensorView::from_tensor(&tensor).expect("test: from_tensor should succeed");
1038 let view2 = view1
1039 .transpose(&[1, 0])
1040 .expect("test: transpose should succeed");
1041
1042 assert!(TensorViewOps::shares_memory(&view1, &view2));
1043 }
1044
1045 #[test]
1046 fn test_flat_to_multi_index() {
1047 let shape = vec![2, 3];
1048 assert_eq!(flat_to_multi_index(0, &shape), vec![0, 0]);
1049 assert_eq!(flat_to_multi_index(1, &shape), vec![0, 1]);
1050 assert_eq!(flat_to_multi_index(3, &shape), vec![1, 0]);
1051 assert_eq!(flat_to_multi_index(5, &shape), vec![1, 2]);
1052 }
1053
1054 #[test]
1055 fn test_multi_to_strided_index() {
1056 let strides = vec![12, 4];
1057 let offset = 0;
1058 assert_eq!(multi_to_strided_index(&[0, 0], &strides, offset), 0);
1059 assert_eq!(multi_to_strided_index(&[0, 1], &strides, offset), 4);
1060 assert_eq!(multi_to_strided_index(&[1, 0], &strides, offset), 12);
1061 assert_eq!(multi_to_strided_index(&[1, 2], &strides, offset), 20);
1062 }
1063}