1use crate::{
2 config::{TypeNameFormatLevel, type_name_format},
3 kernel::KernelMetadata,
4 logging::ProfileLevel,
5 memory_management::{MemoryAllocationMode, MemoryUsage},
6 runtime::Runtime,
7 server::{
8 CommunicationId, ComputeServer, CopyDescriptor, CubeCount, ExecutionMode, Handle, IoError,
9 KernelArguments, MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutPolicy,
10 MemoryLayoutStrategy, ProfileError, ReduceOperation, ServerCommunication, ServerError,
11 ServerUtilities,
12 },
13 storage::{ComputeStorage, ManagedResource},
14};
15use alloc::{format, sync::Arc, vec, vec::Vec};
16use cubecl_common::{
17 backtrace::BackTrace,
18 bytes::{AllocationProperty, Bytes},
19 device::{Device, DeviceId},
20 device_handle::DeviceHandle,
21 future::DynFut,
22 profile::ProfileDuration,
23};
24use cubecl_ir::{DeviceProperties, ElemType, VectorSize, features::Features};
25use cubecl_zspace::Shape;
26
27#[allow(unused)]
28use cubecl_common::profile::TimingMethod;
29use cubecl_common::stream_id::StreamId;
30
31pub struct ComputeClient<R: Runtime> {
34 device: DeviceHandle<R::Server>,
35 utilities: Arc<ServerUtilities<R::Server>>,
36 stream_id: Option<StreamId>,
37}
38
39impl<R: Runtime> Clone for ComputeClient<R> {
40 fn clone(&self) -> Self {
41 Self {
42 device: self.device.clone(),
43 utilities: self.utilities.clone(),
44 stream_id: self.stream_id,
45 }
46 }
47}
48
49impl<R: Runtime> ComputeClient<R> {
50 pub fn info(&self) -> &<R::Server as ComputeServer>::Info {
52 &self.utilities.info
53 }
54
55 pub fn init<D: Device>(device: &D, server: R::Server) -> Self {
57 let utilities = server.utilities();
58 let context = DeviceHandle::<R::Server>::insert(device.to_id(), server)
59 .expect("Can't create a new client on an already registered server");
60
61 Self {
62 device: context,
63 utilities,
64 stream_id: None,
65 }
66 }
67
68 pub fn load<D: Device>(device: &D) -> Self {
70 let context = DeviceHandle::<R::Server>::new(device.to_id());
71
72 let utilities = context
74 .utilities()
75 .downcast::<ServerUtilities<R::Server>>()
76 .expect("Can downcast to `ServerUtilities`");
77
78 Self {
79 device: context,
80 utilities,
81 stream_id: None,
82 }
83 }
84
85 fn stream_id(&self) -> StreamId {
86 match self.stream_id {
87 Some(val) => val,
88 None => StreamId::current(),
89 }
90 }
91
92 pub unsafe fn set_stream(&mut self, stream_id: StreamId) {
98 self.stream_id = Some(stream_id);
99 }
100
101 fn do_read(&self, descriptors: Vec<CopyDescriptor>) -> DynFut<Result<Vec<Bytes>, ServerError>> {
102 let stream_id = self.stream_id();
103 self.device
104 .submit_blocking(move |server| server.read(descriptors, stream_id))
105 .unwrap()
106 }
107
108 pub fn read_async(
110 &self,
111 handles: Vec<Handle>,
112 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
113 let shapes = handles
114 .iter()
115 .map(|it| [it.size_in_used() as usize].into())
116 .collect::<Vec<Shape>>();
117 let descriptors = handles
118 .into_iter()
119 .zip(shapes)
120 .map(|(handle, shape)| CopyDescriptor::new(handle.binding(), shape, [1].into(), 1))
121 .collect();
122
123 self.do_read(descriptors)
124 }
125
126 pub fn read(&self, handles: Vec<Handle>) -> Vec<Bytes> {
132 cubecl_common::reader::read_sync(self.read_async(handles)).expect("TODO")
133 }
134
135 pub fn read_one(&self, handle: Handle) -> Result<Bytes, ServerError> {
137 Ok(cubecl_common::reader::read_sync(self.read_async(vec![handle]))?.remove(0))
138 }
139
140 pub fn read_one_unchecked(&self, handle: Handle) -> Bytes {
146 cubecl_common::reader::read_sync(self.read_async(vec![handle]))
147 .unwrap()
148 .remove(0)
149 }
150
151 pub fn read_tensor_async(
153 &self,
154 descriptors: Vec<CopyDescriptor>,
155 ) -> impl Future<Output = Result<Vec<Bytes>, ServerError>> + Send {
156 self.do_read(descriptors)
157 }
158
159 pub fn read_tensor(&self, descriptors: Vec<CopyDescriptor>) -> Vec<Bytes> {
172 cubecl_common::reader::read_sync(self.read_tensor_async(descriptors)).expect("TODO")
173 }
174
175 pub fn read_one_tensor_async(
178 &self,
179 descriptor: CopyDescriptor,
180 ) -> impl Future<Output = Result<Bytes, ServerError>> + Send {
181 let fut = self.read_tensor_async(vec![descriptor]);
182
183 async { Ok(fut.await?.remove(0)) }
184 }
185
186 pub fn read_one_unchecked_tensor(&self, descriptor: CopyDescriptor) -> Bytes {
193 self.read_tensor(vec![descriptor]).remove(0)
194 }
195
196 pub fn get_resource(
198 &self,
199 handle: Handle,
200 ) -> Result<
201 ManagedResource<<<R::Server as ComputeServer>::Storage as ComputeStorage>::Resource>,
202 ServerError,
203 > {
204 let stream_id = self.stream_id();
205 let binding = handle.binding();
206
207 self.device
208 .submit_blocking(move |state| state.get_resource(binding, stream_id))
209 .unwrap()
210 }
211
212 fn do_create_from_slices(
213 &self,
214 descriptors: Vec<MemoryLayoutDescriptor>,
215 slices: Vec<Vec<u8>>,
216 ) -> Result<Vec<MemoryLayout>, IoError> {
217 let stream_id = self.stream_id();
218 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
219
220 let mut data: Vec<Bytes> = slices.into_iter().map(Bytes::from_bytes_vec).collect();
228 self.staging(data.iter_mut(), true);
229
230 let descriptors = descriptors
231 .into_iter()
232 .zip(layouts.iter())
233 .zip(data)
234 .map(|((desc, alloc), data)| {
235 (
236 CopyDescriptor::new(
237 alloc.memory.clone().binding(),
238 desc.shape,
239 alloc.strides.clone(),
240 desc.elem_size,
241 ),
242 data,
243 )
244 })
245 .collect::<Vec<_>>();
246
247 let (size, memory) = (handle_base.size(), handle_base.memory);
248 self.device.submit(move |server| {
249 server.initialize_memory(memory, size, stream_id);
250 server.write(descriptors, stream_id);
251 });
252
253 Ok(layouts)
254 }
255
256 fn do_create(
257 &self,
258 descriptors: Vec<MemoryLayoutDescriptor>,
259 mut data: Vec<Bytes>,
260 ) -> Result<Vec<MemoryLayout>, IoError> {
261 self.staging(data.iter_mut(), true);
267
268 let stream_id = self.stream_id();
269 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
270
271 let descriptors = descriptors
272 .into_iter()
273 .zip(layouts.iter())
274 .zip(data)
275 .map(|((desc, layout), data)| {
276 (
277 CopyDescriptor::new(
278 layout.memory.clone().binding(),
279 desc.shape,
280 layout.strides.clone(),
281 desc.elem_size,
282 ),
283 data,
284 )
285 })
286 .collect::<Vec<_>>();
287
288 let (size, memory) = (handle_base.size(), handle_base.memory);
289 self.device.submit(move |server| {
290 server.initialize_memory(memory, size, stream_id);
291 server.write(descriptors, stream_id);
292 });
293
294 Ok(layouts)
295 }
296
297 pub fn create_from_slice(&self, slice: &[u8]) -> Handle {
303 let shape: Shape = [slice.len()].into();
304
305 self.do_create_from_slices(
306 vec![MemoryLayoutDescriptor::new(
307 MemoryLayoutStrategy::Contiguous,
308 shape,
309 1,
310 )],
311 vec![slice.to_vec()],
312 )
313 .unwrap()
314 .remove(0)
315 .memory
316 }
317
318 pub fn reserve_staging(&self, sizes: &[usize]) -> Vec<Bytes> {
334 if sizes.is_empty() {
335 return Vec::new();
336 }
337
338 let stream_id = self.stream_id();
339 let sizes_owned = sizes.to_vec();
340 let result = self
341 .device
342 .submit_blocking(move |server| server.staging(&sizes_owned, stream_id))
343 .unwrap();
344
345 match result {
346 Ok(stagings) => stagings,
347 Err(_) => sizes
351 .iter()
352 .map(|&size| Bytes::from_bytes_vec(vec![0u8; size]))
353 .collect(),
354 }
355 }
356
357 pub fn create_from_slice_pinned(&self, slice: &[u8]) -> Handle {
370 let mut staging = self.reserve_staging(&[slice.len()]);
371 let mut bytes = staging.pop().expect("reserve_staging returned no buffers");
372 bytes.copy_from_slice(slice);
373 self.create(bytes)
374 }
375
376 pub fn create_tensors_from_slices_pinned(
380 &self,
381 descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
382 ) -> Vec<MemoryLayout> {
383 let sizes: Vec<usize> = descriptors.iter().map(|(_, s)| s.len()).collect();
384 let stagings = self.reserve_staging(&sizes);
385
386 let mut bytes_vec = Vec::with_capacity(descriptors.len());
387 let mut descs = Vec::with_capacity(descriptors.len());
388 for ((desc, slice), mut staging) in descriptors.into_iter().zip(stagings) {
389 staging.copy_from_slice(slice);
390 bytes_vec.push(staging);
391 descs.push(desc);
392 }
393
394 self.do_create(descs, bytes_vec).unwrap()
395 }
396
397 pub fn exclusive<'a, Re: Send + 'static, F: FnOnce() -> Re + Send + 'a>(
399 &'a self,
400 task: F,
401 ) -> Result<Re, ServerError> {
402 self.device
404 .exclusive(task)
405 .map_err(|err| ServerError::Generic {
406 reason: format!("Communication channel with the server is down: {err:?}"),
407 backtrace: BackTrace::capture(),
408 })
409 }
410
411 pub fn memory_persistent_allocation<
413 'a,
414 Re: Send,
415 Input: Send,
416 F: FnOnce(Input) -> Re + Send + 'a,
417 >(
418 &'a self,
419 input: Input,
420 task: F,
421 ) -> Result<Re, ServerError> {
422 let stream_id = StreamId::current();
423
424 self.device.submit(move |server| {
425 server.allocation_mode(MemoryAllocationMode::Persistent, stream_id);
426 });
427
428 let output = task(input);
430
431 self.device.submit(move |server| {
432 server.allocation_mode(MemoryAllocationMode::Auto, stream_id);
433 });
434
435 Ok(output)
436 }
437
438 pub fn create(&self, data: Bytes) -> Handle {
440 let shape = [data.len()].into();
441
442 self.do_create(
443 vec![MemoryLayoutDescriptor::new(
444 MemoryLayoutStrategy::Contiguous,
445 shape,
446 1,
447 )],
448 vec![data],
449 )
450 .unwrap()
451 .remove(0)
452 .memory
453 }
454
455 pub fn create_tensor_from_slice(
473 &self,
474 slice: &[u8],
475 shape: Shape,
476 elem_size: usize,
477 ) -> MemoryLayout {
478 self.do_create_from_slices(
479 vec![MemoryLayoutDescriptor::new(
480 MemoryLayoutStrategy::Optimized,
481 shape,
482 elem_size,
483 )],
484 vec![slice.to_vec()],
485 )
486 .unwrap()
487 .remove(0)
488 }
489
490 pub fn create_tensor(&self, bytes: Bytes, shape: Shape, elem_size: usize) -> MemoryLayout {
504 self.do_create(
505 vec![MemoryLayoutDescriptor::new(
506 MemoryLayoutStrategy::Optimized,
507 shape,
508 elem_size,
509 )],
510 vec![bytes],
511 )
512 .unwrap()
513 .remove(0)
514 }
515
516 pub fn create_tensors_from_slices(
524 &self,
525 descriptors: Vec<(MemoryLayoutDescriptor, &[u8])>,
526 ) -> Vec<MemoryLayout> {
527 let mut data = Vec::with_capacity(descriptors.len());
528 let mut descriptors_ = Vec::with_capacity(descriptors.len());
529 for (a, b) in descriptors {
530 data.push(b.to_vec());
531 descriptors_.push(a);
532 }
533
534 self.do_create_from_slices(descriptors_, data).unwrap()
535 }
536
537 pub fn create_tensors(
541 &self,
542 descriptors: Vec<(MemoryLayoutDescriptor, Bytes)>,
543 ) -> Vec<MemoryLayout> {
544 let (descriptors, data) = descriptors.into_iter().unzip();
545
546 self.do_create(descriptors, data).unwrap()
547 }
548
549 fn do_empty(
550 &self,
551 descriptors: Vec<MemoryLayoutDescriptor>,
552 ) -> Result<Vec<MemoryLayout>, IoError> {
553 let stream_id = self.stream_id();
554 let (handle_base, layouts) = self.utilities.layout_policy.apply(stream_id, &descriptors);
555
556 let (size, memory) = (handle_base.size(), handle_base.memory);
557 self.device.submit(move |server| {
558 server.initialize_memory(memory, size, stream_id);
559 });
560
561 Ok(layouts)
562 }
563
564 pub fn empty(&self, size: usize) -> Handle {
566 let shape: Shape = [size].into();
567 let descriptor = MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Contiguous, shape, 1);
568 self.do_empty(vec![descriptor]).unwrap().remove(0).memory
569 }
570
571 pub fn empty_tensor(&self, shape: Shape, elem_size: usize) -> MemoryLayout {
574 let descriptor =
575 MemoryLayoutDescriptor::new(MemoryLayoutStrategy::Optimized, shape, elem_size);
576 self.do_empty(vec![descriptor]).unwrap().remove(0)
577 }
578
579 pub fn empty_tensors(&self, descriptors: Vec<MemoryLayoutDescriptor>) -> Vec<MemoryLayout> {
582 self.do_empty(descriptors).unwrap()
583 }
584
585 pub fn staging<'a, I>(&self, bytes: I, file_only: bool)
590 where
591 I: Iterator<Item = &'a mut Bytes>,
592 {
593 let has_staging = |b: &Bytes| match b.property() {
594 AllocationProperty::Pinned => false,
595 AllocationProperty::File => true,
596 AllocationProperty::Native | AllocationProperty::Other => !file_only,
597 };
598
599 let mut to_be_updated = Vec::new();
600 let sizes = bytes
601 .filter_map(|b| match has_staging(b) {
602 true => {
603 let len = b.len();
604 to_be_updated.push(b);
605 Some(len)
606 }
607 false => None,
608 })
609 .collect::<Vec<usize>>();
610
611 if sizes.is_empty() {
612 return;
613 }
614
615 let stream_id = self.stream_id();
616 let sizes = sizes.to_vec();
617 let stagings = self
618 .device
619 .submit_blocking(move |server| server.staging(&sizes, stream_id))
620 .unwrap();
621
622 let stagings = match stagings {
623 Ok(val) => val,
624 Err(_) => return,
625 };
626
627 to_be_updated
628 .into_iter()
629 .zip(stagings)
630 .for_each(|(b, mut staging)| {
631 b.copy_into(&mut staging);
632 core::mem::swap(b, &mut staging);
633 });
634 }
635
636 #[cfg_attr(
638 feature = "tracing",
639 tracing::instrument(level = "trace", skip(self, src, dst_server))
640 )]
641 pub fn to_client(&mut self, src: Handle, dst_server: &Self, dtype: ElemType) -> Handle {
642 let shape = [src.size_in_used() as usize];
643 let src_descriptor = src.copy_descriptor(shape.into(), [1].into(), 1);
644
645 if R::Server::SERVER_COMM_ENABLED {
646 self.to_client_tensor(src_descriptor, dst_server, dtype)
647 } else {
648 let alloc_desc = MemoryLayoutDescriptor::new(
649 MemoryLayoutStrategy::Contiguous,
650 src_descriptor.shape.clone(),
651 src_descriptor.elem_size,
652 );
653 self.change_client_sync(src_descriptor, alloc_desc, dst_server)
654 .memory
655 }
656 }
657
658 #[cfg_attr(
660 feature = "tracing",
661 tracing::instrument(level = "trace", skip(self, device_ids))
662 )]
663 pub fn ensure_init_collective(&mut self, device_ids: Vec<DeviceId>) {
664 let comm_id = CommunicationId::from(device_ids.clone());
665 let is_comms_init = self
666 .utilities
667 .initialized_comms
668 .read()
669 .unwrap()
670 .contains(&comm_id);
671 if !is_comms_init {
672 self.device
673 .submit(move |server| server.comm_init(device_ids).unwrap());
674 let mut initialized_comms = self.utilities.initialized_comms.write().unwrap();
675 initialized_comms.insert(comm_id);
676 self.device.flush_queue();
678 }
679 }
680
681 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
683 pub fn sync_collective(&self) {
684 if DeviceHandle::<R::Server>::is_blocking() {
685 panic!("Can't use `sync_collective` with a blocking device handle");
686 }
687 let stream_id = self.stream_id();
688
689 self.device.submit(move |server| {
690 server.sync_collective(stream_id).unwrap();
691 });
692
693 self.device.flush_queue();
696 }
697
698 #[cfg_attr(
700 feature = "tracing",
701 tracing::instrument(level = "trace", skip(self, src, dst, dtype, device_ids, op))
702 )]
703 pub fn all_reduce(
704 &mut self,
705 src: Handle,
706 dst: Handle,
707 dtype: ElemType,
708 device_ids: Vec<DeviceId>,
709 op: ReduceOperation,
710 ) {
711 if DeviceHandle::<R::Server>::is_blocking() {
712 panic!("Can't use `all_reduce` with a blocking device handle");
713 }
714
715 let stream_id = self.stream_id();
716 let src = src.binding();
717 let dst = dst.binding();
718
719 self.ensure_init_collective(device_ids.clone());
720
721 self.device.submit(move |server| {
722 server
723 .all_reduce(src, dst, dtype, stream_id, op, device_ids)
724 .unwrap();
725 });
726 }
727
728 #[cfg_attr(
732 feature = "tracing",
733 tracing::instrument(level = "trace", skip(self, src_descriptor, dst_server))
734 )]
735 pub fn to_client_tensor(
736 &mut self,
737 src_descriptor: CopyDescriptor,
738 dst_server: &Self,
739 dtype: ElemType,
740 ) -> Handle {
741 let stream_id_src = self.stream_id();
742 let stream_id_dst = dst_server.stream_id();
743
744 let device_id_src = self.device.device_id();
745 let device_id_dst = dst_server.device.device_id();
746
747 let mut dst_server = dst_server.clone();
748 let handle = Handle::new(stream_id_dst, src_descriptor.handle.size_in_used());
749 let handle_cloned = handle.clone();
750
751 let device_ids = vec![device_id_src, device_id_dst];
752 self.ensure_init_collective(device_ids.clone());
753 dst_server.ensure_init_collective(device_ids);
754
755 self.device.submit(move |server_src| {
756 server_src
757 .send(src_descriptor, dtype, stream_id_src, device_id_dst)
758 .unwrap()
759 });
760
761 dst_server.device.submit(move |server_dst| {
762 server_dst
763 .recv(handle_cloned, dtype, stream_id_dst, device_id_src)
764 .unwrap();
765 server_dst.sync_collective(stream_id_dst).unwrap();
766 });
767
768 self.device.flush_queue();
772 dst_server.device.flush_queue();
773
774 handle
775 }
776
777 #[track_caller]
778 #[cfg_attr(feature = "tracing", tracing::instrument(level="trace",
779 skip(self, kernel, bindings),
780 fields(
781 kernel.name = %kernel.name(),
782 kernel.id = %kernel.id(),
783 )
784 ))]
785 unsafe fn launch_inner(
786 &self,
787 kernel: <R::Server as ComputeServer>::Kernel,
788 count: CubeCount,
789 bindings: KernelArguments,
790 mode: ExecutionMode,
791 stream_id: StreamId,
792 ) {
793 let level = self.utilities.logger.profile_level();
794
795 match level {
796 None | Some(ProfileLevel::ExecutionOnly) => {
797 let utilities = self.utilities.clone();
798 self.device.submit(move |state| {
799 let name = kernel.name();
800 unsafe { state.launch(kernel, count, bindings, mode, stream_id) };
801
802 if matches!(level, Some(ProfileLevel::ExecutionOnly)) {
803 let info = type_name_format(name, TypeNameFormatLevel::Balanced);
804 utilities.logger.register_execution(info);
805 }
806 });
807 }
808 Some(level) => {
809 let name = kernel.name();
810 let kernel_id = kernel.id();
811 let context = self.device.clone();
812 let count_moved = count.clone();
813 let (result, profile) = self
814 .profile(
815 move || {
816 context
817 .submit_blocking(move |state| unsafe {
818 state.launch(kernel, count_moved, bindings, mode, stream_id)
819 })
820 .unwrap()
821 },
822 name,
823 )
824 .unwrap();
825 let info = match level {
826 ProfileLevel::Full => {
827 format!("{name}: {kernel_id} CubeCount {count:?}")
828 }
829 _ => type_name_format(name, TypeNameFormatLevel::Balanced),
830 };
831 self.utilities.logger.register_profiled(info, profile);
832 result
833 }
834 }
835 }
836
837 #[track_caller]
839 pub fn launch(
840 &self,
841 kernel: <R::Server as ComputeServer>::Kernel,
842 count: CubeCount,
843 bindings: KernelArguments,
844 ) {
845 unsafe {
847 self.launch_inner(
848 kernel,
849 count,
850 bindings,
851 ExecutionMode::Checked,
852 self.stream_id(),
853 )
854 }
855 }
856
857 #[track_caller]
865 pub unsafe fn launch_unchecked(
866 &self,
867 kernel: <R::Server as ComputeServer>::Kernel,
868 count: CubeCount,
869 bindings: KernelArguments,
870 ) {
871 unsafe {
873 self.launch_inner(
874 kernel,
875 count,
876 bindings,
877 match self.utilities.check_mode {
878 crate::config::compilation::BoundsCheckMode::Enforce => ExecutionMode::Checked,
879 crate::config::compilation::BoundsCheckMode::Validate => {
880 ExecutionMode::Validate
881 }
882 crate::config::compilation::BoundsCheckMode::Auto => ExecutionMode::Unchecked,
883 },
884 self.stream_id(),
885 )
886 }
887 }
888
889 pub fn flush(&self) -> Result<(), ServerError> {
891 let stream_id = self.stream_id();
892
893 self.device
894 .submit_blocking(move |server| server.flush(stream_id))
895 .unwrap()
896 }
897
898 pub fn sync(&self) -> DynFut<Result<(), ServerError>> {
900 let stream_id = self.stream_id();
901
902 let fut = self
903 .device
904 .submit_blocking(move |server| server.sync(stream_id))
905 .unwrap();
906
907 self.utilities.logger.profile_summary();
908
909 fut
910 }
911
912 pub fn properties(&self) -> &DeviceProperties {
914 &self.utilities.properties
915 }
916
917 pub fn features(&self) -> &Features {
919 &self.utilities.properties.features
920 }
921
922 pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> {
926 Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties)
927 }
928
929 pub fn memory_usage(&self) -> Result<MemoryUsage, ServerError> {
931 let stream_id = self.stream_id();
932 self.device
933 .submit_blocking(move |server| server.memory_usage(stream_id))
934 .unwrap()
935 }
936
937 pub fn enumerate_devices(&self, type_id: u16) -> Vec<DeviceId> {
939 R::enumerate_devices(type_id, self.info())
940 }
941
942 pub fn enumerate_all_devices(&self) -> Vec<DeviceId> {
944 R::enumerate_all_devices(self.info())
945 }
946
947 pub fn device_count(&self, type_id: u16) -> usize {
949 self.enumerate_devices(type_id).len()
950 }
951
952 pub fn device_count_total(&self) -> usize {
954 self.enumerate_all_devices().len()
955 }
956
957 pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) {
963 let stream_id = self.stream_id();
964 self.device
965 .submit(move |server| server.allocation_mode(mode, stream_id));
966 }
967
968 pub fn memory_cleanup(&self) {
973 let stream_id = self.stream_id();
974 self.device
975 .submit(move |server| server.memory_cleanup(stream_id));
976 }
977
978 #[track_caller]
980 pub fn profile<O: Send + 'static>(
981 &self,
982 func: impl FnOnce() -> O + Send,
983 #[allow(unused)] func_name: &str,
984 ) -> Result<(O, ProfileDuration), ProfileError> {
985 #[cfg(feature = "profile-tracy")]
988 let location = std::panic::Location::caller();
989
990 #[cfg(feature = "profile-tracy")]
992 let _span = tracy_client::Client::running().unwrap().span_alloc(
993 None,
994 func_name,
995 location.file(),
996 location.line(),
997 0,
998 );
999
1000 let stream_id = self.stream_id();
1001
1002 #[cfg(feature = "profile-tracy")]
1003 let gpu_span = if self.utilities.properties.timing_method == TimingMethod::Device {
1004 let gpu_span = self
1005 .utilities
1006 .gpu_client
1007 .span_alloc(func_name, "profile", location.file(), location.line())
1008 .unwrap();
1009 Some(gpu_span)
1010 } else {
1011 None
1012 };
1013
1014 let device = self.device.clone();
1015 #[allow(unused_mut, reason = "Used in profile-tracy")]
1016 let mut result = self
1017 .device
1018 .exclusive(move || {
1019 let token =
1022 match device.submit_blocking(move |server| server.start_profile(stream_id)) {
1023 Ok(token) => match token {
1024 Ok(token) => token,
1025 Err(err) => return Err(err),
1026 },
1027 Err(err) => {
1028 return Err(ServerError::Generic {
1029 reason: alloc::format!(
1030 "Can't start profiling because of a call error: {err:?}"
1031 ),
1032 backtrace: BackTrace::capture(),
1033 });
1034 }
1035 };
1036
1037 let out = func();
1039
1040 let result = device
1042 .submit_blocking(move |server| {
1043 let mut result = server.end_profile(stream_id, token);
1044
1045 match result {
1046 Ok(result) => Ok((out, result)),
1047 Err(err) => Err(err),
1048 }
1049 })
1050 .unwrap();
1051
1052 Ok(result)
1053 })
1054 .unwrap()
1055 .map_err(|err| ProfileError::Unknown {
1056 reason: alloc::format!("{err:?}"),
1057 backtrace: BackTrace::capture(),
1058 })?;
1059
1060 #[cfg(feature = "profile-tracy")]
1061 if let Some(mut gpu_span) = gpu_span {
1062 gpu_span.end_zone();
1063 let epoch = self.utilities.epoch_time;
1064 result = result.map(|(o, result)| {
1066 (
1067 o,
1068 ProfileDuration::new(
1069 alloc::boxed::Box::pin(async move {
1070 let ticks = result.resolve().await;
1071 let start_duration =
1072 ticks.start_duration_since(epoch).as_nanos() as i64;
1073 let end_duration = ticks.end_duration_since(epoch).as_nanos() as i64;
1074 gpu_span.upload_timestamp_start(start_duration);
1075 gpu_span.upload_timestamp_end(end_duration);
1076 ticks
1077 }),
1078 TimingMethod::Device,
1079 ),
1080 )
1081 });
1082 }
1083
1084 result
1085 }
1086
1087 #[cfg_attr(
1089 feature = "tracing",
1090 tracing::instrument(
1091 level = "trace",
1092 skip(self, src_descriptor, alloc_descriptor, dst_server)
1093 )
1094 )]
1095 fn change_client_sync(
1096 &self,
1097 src_descriptor: CopyDescriptor,
1098 alloc_descriptor: MemoryLayoutDescriptor,
1099 dst_server: &Self,
1100 ) -> MemoryLayout {
1101 let shape = src_descriptor.shape.clone();
1102 let elem_size = src_descriptor.elem_size;
1103 let stream_id = self.stream_id();
1104
1105 let read = self
1106 .device
1107 .submit_blocking(move |server| server.read(vec![src_descriptor], stream_id))
1108 .unwrap();
1109
1110 let mut data = cubecl_common::future::block_on(read).unwrap();
1111
1112 let (handle_base, mut layouts) = self
1113 .utilities
1114 .layout_policy
1115 .apply(stream_id, &[alloc_descriptor]);
1116 let alloc = layouts.remove(0);
1117
1118 let desc_descriptor = CopyDescriptor {
1119 handle: handle_base.clone().binding(),
1120 shape,
1121 strides: alloc.strides.clone(),
1122 elem_size,
1123 };
1124
1125 let (size, memory) = (handle_base.size(), handle_base.memory);
1126 dst_server.device.submit(move |server| {
1127 server.initialize_memory(memory, size, stream_id);
1128 server.write(vec![(desc_descriptor, data.remove(0))], stream_id)
1129 });
1130
1131 alloc
1132 }
1133
1134 pub fn io_optimized_vector_sizes(
1136 &self,
1137 size: usize,
1138 ) -> impl Iterator<Item = VectorSize> + Clone {
1139 let load_width = self.properties().hardware.load_width as usize;
1140 let size_bits = size * 8;
1141 let max = load_width / size_bits;
1142 let max = usize::min(self.properties().hardware.max_vector_size, max);
1143
1144 let num_candidates = max.trailing_zeros() + 1;
1146
1147 (0..num_candidates).map(|i| 2usize.pow(i)).rev()
1148 }
1149}