sparkl2d_kernels/cuda/
atomic.rs

1#![allow(unreachable_code)]
2
3use sparkl_core::math::Vector;
4use sparkl_core::na::Scalar;
5
6// TODO: this is needed untile Rust-GPU supports atomics.
7pub trait AtomicAdd {
8    unsafe fn shared_red_add(&mut self, rhs: Self);
9    unsafe fn global_red_add(&mut self, rhs: Self);
10    unsafe fn global_atomic_add(&mut self, rhs: Self) -> Self;
11}
12
13pub trait AtomicInt {
14    unsafe fn global_red_min(&mut self, rhs: Self);
15    unsafe fn global_atomic_exch(&mut self, val: Self) -> Self;
16    unsafe fn global_atomic_cas(&mut self, cmp: Self, val: Self) -> Self;
17    unsafe fn shared_atomic_exch_acq(&mut self, val: Self) -> Self;
18    unsafe fn shared_atomic_exch_rel(&mut self, val: Self) -> Self;
19    unsafe fn global_atomic_dec(&mut self) -> Self;
20}
21
22impl AtomicAdd for u32 {
23    unsafe fn shared_red_add(&mut self, _rhs: Self) {
24        #[cfg(target_os = "cuda")]
25        {
26            let integer_addr = self as *mut _;
27            let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
28
29            asm!(
30            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
31            red.shared.add.u32 [{gbl_ptr}], {number};",
32            gbl_ptr = out(reg64) shared_integer_addr,
33            org_ptr = in(reg64) integer_addr,
34            number = in(reg32) _rhs
35            );
36        }
37
38        #[cfg(not(target_os = "cuda"))]
39        unimplemented!();
40    }
41
42    unsafe fn global_red_add(&mut self, _rhs: Self) {
43        #[cfg(target_os = "cuda")]
44        {
45            let integer_addr = self as *mut _;
46            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
47
48            asm!(
49            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
50            red.global.add.u32 [{gbl_ptr}], {number};",
51            gbl_ptr = out(reg64) global_integer_addr,
52            org_ptr = in(reg64) integer_addr,
53            number = in(reg32) _rhs
54            );
55        }
56
57        #[cfg(not(target_os = "cuda"))]
58        unimplemented!();
59    }
60
61    unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
62        #[cfg(target_os = "cuda")]
63        {
64            let mut old = 0;
65            let integer_addr = self as *mut _;
66            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
67
68            asm!(
69            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
70            atom.global.add.u32 {old}, [{gbl_ptr}], {number};",
71            gbl_ptr = out(reg64) global_integer_addr,
72            org_ptr = in(reg64) integer_addr,
73            number = in(reg32) _rhs,
74            old = out(reg32) old,
75            );
76
77            old
78        }
79
80        #[cfg(not(target_os = "cuda"))]
81        return unimplemented!();
82    }
83}
84
85impl AtomicInt for u32 {
86    unsafe fn global_red_min(&mut self, _rhs: Self) {
87        #[cfg(target_os = "cuda")]
88        {
89            let integer_addr = self as *mut _;
90            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
91
92            asm!(
93            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
94            red.global.min.u32 [{gbl_ptr}], {number};",
95            gbl_ptr = out(reg64) global_integer_addr,
96            org_ptr = in(reg64) integer_addr,
97            number = in(reg32) _rhs
98            );
99        }
100
101        #[cfg(not(target_os = "cuda"))]
102        unimplemented!();
103    }
104
105    unsafe fn global_atomic_exch(&mut self, _rhs: Self) -> Self {
106        #[cfg(target_os = "cuda")]
107        {
108            let mut old = 0;
109            let integer_addr = self as *mut _;
110            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
111
112            asm!(
113            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
114            atom.global.exch.b32 {old}, [{gbl_ptr}], {number};",
115            gbl_ptr = out(reg64) global_integer_addr,
116            org_ptr = in(reg64) integer_addr,
117            number = in(reg32) _rhs,
118            old = out(reg32) old,
119            );
120
121            old
122        }
123
124        #[cfg(not(target_os = "cuda"))]
125        return unimplemented!();
126    }
127
128    unsafe fn global_atomic_cas(&mut self, _cmp: Self, _rhs: Self) -> Self {
129        #[cfg(target_os = "cuda")]
130        {
131            let mut old = 0;
132            let integer_addr = self as *mut _;
133            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
134
135            asm!(
136            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
137                atom.global.cas.b32 {old}, [{gbl_ptr}], {cmp}, {rhs};",
138            gbl_ptr = out(reg64) global_integer_addr,
139            org_ptr = in(reg64) integer_addr,
140            cmp = in(reg32) _cmp,
141            rhs = in(reg32) _rhs,
142            old = out(reg32) old,
143            );
144
145            old
146        }
147
148        #[cfg(not(target_os = "cuda"))]
149        return unimplemented!();
150    }
151
152    unsafe fn shared_atomic_exch_acq(&mut self, _rhs: Self) -> Self {
153        #[cfg(target_os = "cuda")]
154        {
155            let mut old = 0;
156            let integer_addr = self as *mut _;
157            let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
158
159            asm!(
160            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
161            atom.acquire.shared.exch.b32 {old}, [{gbl_ptr}], {number};",
162            gbl_ptr = out(reg64) shared_integer_addr,
163            org_ptr = in(reg64) integer_addr,
164            number = in(reg32) _rhs,
165            old = out(reg32) old,
166            );
167
168            old
169        }
170
171        #[cfg(not(target_os = "cuda"))]
172        return unimplemented!();
173    }
174
175    unsafe fn shared_atomic_exch_rel(&mut self, _rhs: Self) -> Self {
176        #[cfg(target_os = "cuda")]
177        {
178            let mut old = 0;
179            let integer_addr = self as *mut _;
180            let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
181
182            asm!(
183            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
184            atom.release.shared.exch.b32 {old}, [{gbl_ptr}], {number};",
185            gbl_ptr = out(reg64) shared_integer_addr,
186            org_ptr = in(reg64) integer_addr,
187            number = in(reg32) _rhs,
188            old = out(reg32) old,
189            );
190
191            old
192        }
193
194        #[cfg(not(target_os = "cuda"))]
195        return unimplemented!();
196    }
197
198    unsafe fn global_atomic_dec(&mut self) -> Self {
199        #[cfg(target_os = "cuda")]
200        {
201            let mut old = 0;
202            let max = u32::MAX;
203            let integer_addr = self as *mut _;
204            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
205
206            asm!(
207            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
208            atom.global.dec.u32 {old}, [{gbl_ptr}], {max};",
209                gbl_ptr = out(reg64) global_integer_addr,
210                org_ptr = in(reg64) integer_addr,
211                old = out(reg32) old,
212                max = in(reg32) max,
213            );
214
215            old
216        }
217
218        #[cfg(not(target_os = "cuda"))]
219        return unimplemented!();
220    }
221}
222
223impl AtomicAdd for u64 {
224    unsafe fn shared_red_add(&mut self, _rhs: Self) {
225        #[cfg(target_os = "cuda")]
226        {
227            let integer_addr = self as *mut _;
228            let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
229
230            asm!(
231            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
232            red.shared.add.u64 [{gbl_ptr}], {number};",
233            gbl_ptr = out(reg64) shared_integer_addr,
234            org_ptr = in(reg64) integer_addr,
235            number = in(reg64) _rhs
236            );
237        }
238
239        #[cfg(not(target_os = "cuda"))]
240        unimplemented!();
241    }
242
243    unsafe fn global_red_add(&mut self, _rhs: Self) {
244        #[cfg(target_os = "cuda")]
245        {
246            let integer_addr = self as *mut _;
247            let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
248
249            asm!(
250            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
251            red.global.add.u64 [{gbl_ptr}], {number};",
252            gbl_ptr = out(reg64) global_integer_addr,
253            org_ptr = in(reg64) integer_addr,
254            number = in(reg64) _rhs
255            );
256        }
257
258        #[cfg(not(target_os = "cuda"))]
259        unimplemented!();
260    }
261
262    unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
263        #[cfg(target_os = "cuda")]
264        {
265            let mut old = 0;
266            let integer_addr = self as *mut _;
267            let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
268
269            asm!(
270            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
271            atom.global.add.u64 {old}, [{gbl_ptr}], {number};",
272            gbl_ptr = out(reg64) global_integer_addr,
273            org_ptr = in(reg64) integer_addr,
274            number = in(reg64) _rhs,
275            old = out(reg64) old
276            );
277
278            old
279        }
280
281        #[cfg(not(target_os = "cuda"))]
282        return unimplemented!();
283    }
284}
285
286impl AtomicInt for u64 {
287    unsafe fn global_red_min(&mut self, _rhs: Self) {
288        #[cfg(target_os = "cuda")]
289        {
290            let integer_addr = self as *mut _;
291            let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
292
293            asm!(
294            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
295            red.global.min.u64 [{gbl_ptr}], {number};",
296            gbl_ptr = out(reg64) global_integer_addr,
297            org_ptr = in(reg64) integer_addr,
298            number = in(reg64) _rhs
299            );
300        }
301
302        #[cfg(not(target_os = "cuda"))]
303        unimplemented!();
304    }
305
306    unsafe fn global_atomic_exch(&mut self, _rhs: Self) -> Self {
307        #[cfg(target_os = "cuda")]
308        {
309            let mut old = 0;
310            let integer_addr = self as *mut _;
311            let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
312
313            asm!(
314            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
315            atom.global.exch.b64 {old}, [{gbl_ptr}], {number};",
316            gbl_ptr = out(reg64) global_integer_addr,
317            org_ptr = in(reg64) integer_addr,
318            number = in(reg64) _rhs,
319            old = out(reg64) old
320            );
321
322            old
323        }
324
325        #[cfg(not(target_os = "cuda"))]
326        return unimplemented!();
327    }
328
329    unsafe fn global_atomic_cas(&mut self, _cmp: Self, _rhs: Self) -> Self {
330        #[cfg(target_os = "cuda")]
331        {
332            let mut old = 0;
333            let integer_addr = self as *mut _;
334            let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
335
336            asm!(
337            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
338            atom.global.cas.b64 {old}, [{gbl_ptr}], {cmp}, {rhs};",
339            gbl_ptr = out(reg64) global_integer_addr,
340            org_ptr = in(reg64) integer_addr,
341            cmp = in(reg64) _cmp,
342            rhs = in(reg64) _rhs,
343            old = out(reg64) old
344            );
345
346            old
347        }
348
349        #[cfg(not(target_os = "cuda"))]
350        return unimplemented!();
351    }
352
353    unsafe fn shared_atomic_exch_acq(&mut self, _rhs: Self) -> Self {
354        #[cfg(target_os = "cuda")]
355        {
356            let mut old = 0;
357            let integer_addr = self as *mut _;
358            let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
359
360            asm!(
361            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
362            atom.acquire.shared.exch.b64 {old}, [{gbl_ptr}], {number};",
363            gbl_ptr = out(reg64) shared_integer_addr,
364            org_ptr = in(reg64) integer_addr,
365            number = in(reg64) _rhs,
366            old = out(reg64) old
367            );
368
369            old
370        }
371
372        #[cfg(not(target_os = "cuda"))]
373        return unimplemented!();
374    }
375
376    unsafe fn shared_atomic_exch_rel(&mut self, _rhs: Self) -> Self {
377        #[cfg(target_os = "cuda")]
378        {
379            let mut old = 0;
380            let integer_addr = self as *mut _;
381            let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
382
383            asm!(
384            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
385            atom.release.shared.exch.b64 {old}, [{gbl_ptr}], {number};",
386            gbl_ptr = out(reg64) shared_integer_addr,
387            org_ptr = in(reg64) integer_addr,
388            number = in(reg64) _rhs,
389            old = out(reg64) old
390            );
391
392            old
393        }
394
395        #[cfg(not(target_os = "cuda"))]
396        return unimplemented!();
397    }
398
399    unsafe fn global_atomic_dec(&mut self) -> Self {
400        #[cfg(target_os = "cuda")]
401        {
402            let mut old = 0;
403            let max = u64::MAX;
404            let integer_addr = self as *mut _;
405            let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
406
407            asm!(
408            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
409            atom.global.dec.u64 {old}, [{gbl_ptr}], {max};",
410                gbl_ptr = out(reg64) global_integer_addr,
411                org_ptr = in(reg64) integer_addr,
412                old = out(reg64) old,
413                max = in(reg64) max,
414            );
415
416            old
417        }
418
419        #[cfg(not(target_os = "cuda"))]
420        return unimplemented!();
421    }
422}
423
424impl AtomicAdd for f32 {
425    unsafe fn shared_red_add(&mut self, _rhs: Self) {
426        #[cfg(target_os = "cuda")]
427        {
428            let float_addr = self as *mut _;
429            let mut shared_float_addr: *mut f32 = core::ptr::null_mut();
430
431            asm!(
432            "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
433            red.shared.add.f32 [{gbl_ptr}], {number};",
434            gbl_ptr = out(reg64) shared_float_addr,
435            org_ptr = in(reg64) float_addr,
436            number = in(reg32) _rhs
437            );
438        }
439
440        #[cfg(not(target_os = "cuda"))]
441        unimplemented!();
442    }
443
444    unsafe fn global_red_add(&mut self, _rhs: Self) {
445        #[cfg(target_os = "cuda")]
446        {
447            let float_addr = self as *mut _;
448            let mut global_float_addr: *mut f32 = core::ptr::null_mut();
449
450            asm!(
451            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
452            red.global.add.f32 [{gbl_ptr}], {number};",
453            gbl_ptr = out(reg64) global_float_addr,
454            org_ptr = in(reg64) float_addr,
455            number = in(reg32) _rhs
456            );
457        }
458
459        #[cfg(not(target_os = "cuda"))]
460        unimplemented!();
461    }
462
463    unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
464        #[cfg(target_os = "cuda")]
465        {
466            let mut old = 0.0;
467            let float_addr = self as *mut _;
468            let mut global_float_addr: *mut f32 = core::ptr::null_mut();
469
470            asm!(
471            "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
472            atom.global.add.f32 {old}, [{gbl_ptr}], {number};",
473            gbl_ptr = out(reg64) global_float_addr,
474            org_ptr = in(reg64) float_addr,
475            number = in(reg32) _rhs,
476            old = out(reg32) old
477            );
478
479            old
480        }
481
482        #[cfg(not(target_os = "cuda"))]
483        unimplemented!();
484    }
485}
486
487impl<T: Scalar + AtomicAdd> AtomicAdd for Vector<T> {
488    unsafe fn shared_red_add(&mut self, rhs: Self) {
489        self.zip_apply(&rhs, |a, b| a.shared_red_add(b))
490    }
491
492    unsafe fn global_red_add(&mut self, rhs: Self) {
493        self.zip_apply(&rhs, |a, b| a.global_red_add(b))
494    }
495
496    unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
497        unimplemented!()
498    }
499}