1#![allow(unreachable_code)]
2
3use sparkl_core::math::Vector;
4use sparkl_core::na::Scalar;
5
6pub trait AtomicAdd {
8 unsafe fn shared_red_add(&mut self, rhs: Self);
9 unsafe fn global_red_add(&mut self, rhs: Self);
10 unsafe fn global_atomic_add(&mut self, rhs: Self) -> Self;
11}
12
13pub trait AtomicInt {
14 unsafe fn global_red_min(&mut self, rhs: Self);
15 unsafe fn global_atomic_exch(&mut self, val: Self) -> Self;
16 unsafe fn global_atomic_cas(&mut self, cmp: Self, val: Self) -> Self;
17 unsafe fn shared_atomic_exch_acq(&mut self, val: Self) -> Self;
18 unsafe fn shared_atomic_exch_rel(&mut self, val: Self) -> Self;
19 unsafe fn global_atomic_dec(&mut self) -> Self;
20}
21
22impl AtomicAdd for u32 {
23 unsafe fn shared_red_add(&mut self, _rhs: Self) {
24 #[cfg(target_os = "cuda")]
25 {
26 let integer_addr = self as *mut _;
27 let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
28
29 asm!(
30 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
31 red.shared.add.u32 [{gbl_ptr}], {number};",
32 gbl_ptr = out(reg64) shared_integer_addr,
33 org_ptr = in(reg64) integer_addr,
34 number = in(reg32) _rhs
35 );
36 }
37
38 #[cfg(not(target_os = "cuda"))]
39 unimplemented!();
40 }
41
42 unsafe fn global_red_add(&mut self, _rhs: Self) {
43 #[cfg(target_os = "cuda")]
44 {
45 let integer_addr = self as *mut _;
46 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
47
48 asm!(
49 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
50 red.global.add.u32 [{gbl_ptr}], {number};",
51 gbl_ptr = out(reg64) global_integer_addr,
52 org_ptr = in(reg64) integer_addr,
53 number = in(reg32) _rhs
54 );
55 }
56
57 #[cfg(not(target_os = "cuda"))]
58 unimplemented!();
59 }
60
61 unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
62 #[cfg(target_os = "cuda")]
63 {
64 let mut old = 0;
65 let integer_addr = self as *mut _;
66 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
67
68 asm!(
69 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
70 atom.global.add.u32 {old}, [{gbl_ptr}], {number};",
71 gbl_ptr = out(reg64) global_integer_addr,
72 org_ptr = in(reg64) integer_addr,
73 number = in(reg32) _rhs,
74 old = out(reg32) old,
75 );
76
77 old
78 }
79
80 #[cfg(not(target_os = "cuda"))]
81 return unimplemented!();
82 }
83}
84
85impl AtomicInt for u32 {
86 unsafe fn global_red_min(&mut self, _rhs: Self) {
87 #[cfg(target_os = "cuda")]
88 {
89 let integer_addr = self as *mut _;
90 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
91
92 asm!(
93 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
94 red.global.min.u32 [{gbl_ptr}], {number};",
95 gbl_ptr = out(reg64) global_integer_addr,
96 org_ptr = in(reg64) integer_addr,
97 number = in(reg32) _rhs
98 );
99 }
100
101 #[cfg(not(target_os = "cuda"))]
102 unimplemented!();
103 }
104
105 unsafe fn global_atomic_exch(&mut self, _rhs: Self) -> Self {
106 #[cfg(target_os = "cuda")]
107 {
108 let mut old = 0;
109 let integer_addr = self as *mut _;
110 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
111
112 asm!(
113 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
114 atom.global.exch.b32 {old}, [{gbl_ptr}], {number};",
115 gbl_ptr = out(reg64) global_integer_addr,
116 org_ptr = in(reg64) integer_addr,
117 number = in(reg32) _rhs,
118 old = out(reg32) old,
119 );
120
121 old
122 }
123
124 #[cfg(not(target_os = "cuda"))]
125 return unimplemented!();
126 }
127
128 unsafe fn global_atomic_cas(&mut self, _cmp: Self, _rhs: Self) -> Self {
129 #[cfg(target_os = "cuda")]
130 {
131 let mut old = 0;
132 let integer_addr = self as *mut _;
133 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
134
135 asm!(
136 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
137 atom.global.cas.b32 {old}, [{gbl_ptr}], {cmp}, {rhs};",
138 gbl_ptr = out(reg64) global_integer_addr,
139 org_ptr = in(reg64) integer_addr,
140 cmp = in(reg32) _cmp,
141 rhs = in(reg32) _rhs,
142 old = out(reg32) old,
143 );
144
145 old
146 }
147
148 #[cfg(not(target_os = "cuda"))]
149 return unimplemented!();
150 }
151
152 unsafe fn shared_atomic_exch_acq(&mut self, _rhs: Self) -> Self {
153 #[cfg(target_os = "cuda")]
154 {
155 let mut old = 0;
156 let integer_addr = self as *mut _;
157 let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
158
159 asm!(
160 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
161 atom.acquire.shared.exch.b32 {old}, [{gbl_ptr}], {number};",
162 gbl_ptr = out(reg64) shared_integer_addr,
163 org_ptr = in(reg64) integer_addr,
164 number = in(reg32) _rhs,
165 old = out(reg32) old,
166 );
167
168 old
169 }
170
171 #[cfg(not(target_os = "cuda"))]
172 return unimplemented!();
173 }
174
175 unsafe fn shared_atomic_exch_rel(&mut self, _rhs: Self) -> Self {
176 #[cfg(target_os = "cuda")]
177 {
178 let mut old = 0;
179 let integer_addr = self as *mut _;
180 let mut shared_integer_addr: *mut u32 = core::ptr::null_mut();
181
182 asm!(
183 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
184 atom.release.shared.exch.b32 {old}, [{gbl_ptr}], {number};",
185 gbl_ptr = out(reg64) shared_integer_addr,
186 org_ptr = in(reg64) integer_addr,
187 number = in(reg32) _rhs,
188 old = out(reg32) old,
189 );
190
191 old
192 }
193
194 #[cfg(not(target_os = "cuda"))]
195 return unimplemented!();
196 }
197
198 unsafe fn global_atomic_dec(&mut self) -> Self {
199 #[cfg(target_os = "cuda")]
200 {
201 let mut old = 0;
202 let max = u32::MAX;
203 let integer_addr = self as *mut _;
204 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
205
206 asm!(
207 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
208 atom.global.dec.u32 {old}, [{gbl_ptr}], {max};",
209 gbl_ptr = out(reg64) global_integer_addr,
210 org_ptr = in(reg64) integer_addr,
211 old = out(reg32) old,
212 max = in(reg32) max,
213 );
214
215 old
216 }
217
218 #[cfg(not(target_os = "cuda"))]
219 return unimplemented!();
220 }
221}
222
223impl AtomicAdd for u64 {
224 unsafe fn shared_red_add(&mut self, _rhs: Self) {
225 #[cfg(target_os = "cuda")]
226 {
227 let integer_addr = self as *mut _;
228 let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
229
230 asm!(
231 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
232 red.shared.add.u64 [{gbl_ptr}], {number};",
233 gbl_ptr = out(reg64) shared_integer_addr,
234 org_ptr = in(reg64) integer_addr,
235 number = in(reg64) _rhs
236 );
237 }
238
239 #[cfg(not(target_os = "cuda"))]
240 unimplemented!();
241 }
242
243 unsafe fn global_red_add(&mut self, _rhs: Self) {
244 #[cfg(target_os = "cuda")]
245 {
246 let integer_addr = self as *mut _;
247 let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
248
249 asm!(
250 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
251 red.global.add.u64 [{gbl_ptr}], {number};",
252 gbl_ptr = out(reg64) global_integer_addr,
253 org_ptr = in(reg64) integer_addr,
254 number = in(reg64) _rhs
255 );
256 }
257
258 #[cfg(not(target_os = "cuda"))]
259 unimplemented!();
260 }
261
262 unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
263 #[cfg(target_os = "cuda")]
264 {
265 let mut old = 0;
266 let integer_addr = self as *mut _;
267 let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
268
269 asm!(
270 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
271 atom.global.add.u64 {old}, [{gbl_ptr}], {number};",
272 gbl_ptr = out(reg64) global_integer_addr,
273 org_ptr = in(reg64) integer_addr,
274 number = in(reg64) _rhs,
275 old = out(reg64) old
276 );
277
278 old
279 }
280
281 #[cfg(not(target_os = "cuda"))]
282 return unimplemented!();
283 }
284}
285
286impl AtomicInt for u64 {
287 unsafe fn global_red_min(&mut self, _rhs: Self) {
288 #[cfg(target_os = "cuda")]
289 {
290 let integer_addr = self as *mut _;
291 let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
292
293 asm!(
294 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
295 red.global.min.u64 [{gbl_ptr}], {number};",
296 gbl_ptr = out(reg64) global_integer_addr,
297 org_ptr = in(reg64) integer_addr,
298 number = in(reg64) _rhs
299 );
300 }
301
302 #[cfg(not(target_os = "cuda"))]
303 unimplemented!();
304 }
305
306 unsafe fn global_atomic_exch(&mut self, _rhs: Self) -> Self {
307 #[cfg(target_os = "cuda")]
308 {
309 let mut old = 0;
310 let integer_addr = self as *mut _;
311 let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
312
313 asm!(
314 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
315 atom.global.exch.b64 {old}, [{gbl_ptr}], {number};",
316 gbl_ptr = out(reg64) global_integer_addr,
317 org_ptr = in(reg64) integer_addr,
318 number = in(reg64) _rhs,
319 old = out(reg64) old
320 );
321
322 old
323 }
324
325 #[cfg(not(target_os = "cuda"))]
326 return unimplemented!();
327 }
328
329 unsafe fn global_atomic_cas(&mut self, _cmp: Self, _rhs: Self) -> Self {
330 #[cfg(target_os = "cuda")]
331 {
332 let mut old = 0;
333 let integer_addr = self as *mut _;
334 let mut global_integer_addr: *mut u64 = core::ptr::null_mut();
335
336 asm!(
337 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
338 atom.global.cas.b64 {old}, [{gbl_ptr}], {cmp}, {rhs};",
339 gbl_ptr = out(reg64) global_integer_addr,
340 org_ptr = in(reg64) integer_addr,
341 cmp = in(reg64) _cmp,
342 rhs = in(reg64) _rhs,
343 old = out(reg64) old
344 );
345
346 old
347 }
348
349 #[cfg(not(target_os = "cuda"))]
350 return unimplemented!();
351 }
352
353 unsafe fn shared_atomic_exch_acq(&mut self, _rhs: Self) -> Self {
354 #[cfg(target_os = "cuda")]
355 {
356 let mut old = 0;
357 let integer_addr = self as *mut _;
358 let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
359
360 asm!(
361 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
362 atom.acquire.shared.exch.b64 {old}, [{gbl_ptr}], {number};",
363 gbl_ptr = out(reg64) shared_integer_addr,
364 org_ptr = in(reg64) integer_addr,
365 number = in(reg64) _rhs,
366 old = out(reg64) old
367 );
368
369 old
370 }
371
372 #[cfg(not(target_os = "cuda"))]
373 return unimplemented!();
374 }
375
376 unsafe fn shared_atomic_exch_rel(&mut self, _rhs: Self) -> Self {
377 #[cfg(target_os = "cuda")]
378 {
379 let mut old = 0;
380 let integer_addr = self as *mut _;
381 let mut shared_integer_addr: *mut u64 = core::ptr::null_mut();
382
383 asm!(
384 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
385 atom.release.shared.exch.b64 {old}, [{gbl_ptr}], {number};",
386 gbl_ptr = out(reg64) shared_integer_addr,
387 org_ptr = in(reg64) integer_addr,
388 number = in(reg64) _rhs,
389 old = out(reg64) old
390 );
391
392 old
393 }
394
395 #[cfg(not(target_os = "cuda"))]
396 return unimplemented!();
397 }
398
399 unsafe fn global_atomic_dec(&mut self) -> Self {
400 #[cfg(target_os = "cuda")]
401 {
402 let mut old = 0;
403 let max = u64::MAX;
404 let integer_addr = self as *mut _;
405 let mut global_integer_addr: *mut u32 = core::ptr::null_mut();
406
407 asm!(
408 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
409 atom.global.dec.u64 {old}, [{gbl_ptr}], {max};",
410 gbl_ptr = out(reg64) global_integer_addr,
411 org_ptr = in(reg64) integer_addr,
412 old = out(reg64) old,
413 max = in(reg64) max,
414 );
415
416 old
417 }
418
419 #[cfg(not(target_os = "cuda"))]
420 return unimplemented!();
421 }
422}
423
424impl AtomicAdd for f32 {
425 unsafe fn shared_red_add(&mut self, _rhs: Self) {
426 #[cfg(target_os = "cuda")]
427 {
428 let float_addr = self as *mut _;
429 let mut shared_float_addr: *mut f32 = core::ptr::null_mut();
430
431 asm!(
432 "cvta.to.shared.u64 {gbl_ptr}, {org_ptr};\
433 red.shared.add.f32 [{gbl_ptr}], {number};",
434 gbl_ptr = out(reg64) shared_float_addr,
435 org_ptr = in(reg64) float_addr,
436 number = in(reg32) _rhs
437 );
438 }
439
440 #[cfg(not(target_os = "cuda"))]
441 unimplemented!();
442 }
443
444 unsafe fn global_red_add(&mut self, _rhs: Self) {
445 #[cfg(target_os = "cuda")]
446 {
447 let float_addr = self as *mut _;
448 let mut global_float_addr: *mut f32 = core::ptr::null_mut();
449
450 asm!(
451 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
452 red.global.add.f32 [{gbl_ptr}], {number};",
453 gbl_ptr = out(reg64) global_float_addr,
454 org_ptr = in(reg64) float_addr,
455 number = in(reg32) _rhs
456 );
457 }
458
459 #[cfg(not(target_os = "cuda"))]
460 unimplemented!();
461 }
462
463 unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
464 #[cfg(target_os = "cuda")]
465 {
466 let mut old = 0.0;
467 let float_addr = self as *mut _;
468 let mut global_float_addr: *mut f32 = core::ptr::null_mut();
469
470 asm!(
471 "cvta.to.global.u64 {gbl_ptr}, {org_ptr};\
472 atom.global.add.f32 {old}, [{gbl_ptr}], {number};",
473 gbl_ptr = out(reg64) global_float_addr,
474 org_ptr = in(reg64) float_addr,
475 number = in(reg32) _rhs,
476 old = out(reg32) old
477 );
478
479 old
480 }
481
482 #[cfg(not(target_os = "cuda"))]
483 unimplemented!();
484 }
485}
486
487impl<T: Scalar + AtomicAdd> AtomicAdd for Vector<T> {
488 unsafe fn shared_red_add(&mut self, rhs: Self) {
489 self.zip_apply(&rhs, |a, b| a.shared_red_add(b))
490 }
491
492 unsafe fn global_red_add(&mut self, rhs: Self) {
493 self.zip_apply(&rhs, |a, b| a.global_red_add(b))
494 }
495
496 unsafe fn global_atomic_add(&mut self, _rhs: Self) -> Self {
497 unimplemented!()
498 }
499}