Skip to main content

sesh_sdk/
vec.rs

1//! Vector operations for batch audio processing.
2//!
3//! Each op has two code paths: an inline Rust fallback (always available) and a
4//! host-accelerated import (used when `sesh_vec_version() > 0`). The SDK selects
5//! the path at runtime. Plugin authors call the same functions regardless of platform.
6
7use std::sync::atomic::{AtomicU32, Ordering};
8
9// ---------------------------------------------------------------------------
10// Host capability detection
11// ---------------------------------------------------------------------------
12
13extern "C" {
14    fn sesh_vec_version() -> u32;
15}
16
17/// Cached host vec version. 0 = not yet queried, u32::MAX = stubs (web).
18static HOST_VEC_VERSION: AtomicU32 = AtomicU32::new(0);
19
20fn host_version() -> u32 {
21    let v = HOST_VEC_VERSION.load(Ordering::Relaxed);
22    if v != 0 {
23        return v;
24    }
25    let v = unsafe { sesh_vec_version() };
26    // Store non-zero so we don't re-query. If host returns 0, store a sentinel.
27    let store = if v == 0 { u32::MAX } else { v };
28    HOST_VEC_VERSION.store(store, Ordering::Relaxed);
29    v
30}
31
32#[inline]
33fn use_host_ops() -> bool {
34    host_version() > 0 && host_version() != u32::MAX
35}
36
37// ---------------------------------------------------------------------------
38// Host imports (C ABI, raw pointers)
39// ---------------------------------------------------------------------------
40
41extern "C" {
42    fn sesh_vec_copy_host(dst: *mut f32, src: *const f32, len: u32);
43    fn sesh_vec_fill_host(dst: *mut f32, value: f32, len: u32);
44    fn sesh_vec_add_host(dst: *mut f32, a: *const f32, b: *const f32, len: u32);
45    fn sesh_vec_add_scalar_host(dst: *mut f32, value: f32, len: u32);
46    fn sesh_vec_mul_host(dst: *mut f32, a: *const f32, b: *const f32, len: u32);
47    fn sesh_vec_mul_scalar_host(dst: *mut f32, value: f32, len: u32);
48    fn sesh_vec_mul_add_host(dst: *mut f32, src: *const f32, gain: f32, len: u32);
49    fn sesh_vec_clamp_host(dst: *mut f32, src: *const f32, min: f32, max: f32, len: u32);
50    fn sesh_vec_ring_write_host(
51        buf: *mut f32, buf_len: u32, pos: *mut u32, src: *const f32, len: u32,
52    );
53    fn sesh_vec_ring_read_host(
54        buf: *const f32, buf_len: u32, pos: u32, dst: *mut f32, offset: u32, len: u32,
55    );
56    fn sesh_vec_delay_read_host(
57        buf: *const f32, buf_len: u32, pos: u32, dst: *mut f32, time: *const f32, len: u32,
58    );
59    fn sesh_vec_osc_host(
60        phase: *mut f32, dst: *mut f32, freq: f32, waveform: u32, sample_rate: f32, len: u32,
61    );
62    fn sesh_vec_biquad_host(
63        state: *mut f32, dst: *mut f32, src: *const f32,
64        cutoff: *const f32, q: *const f32, gain: *const f32,
65        filter_type: u32, sample_rate: f32, len: u32,
66    );
67    fn sesh_vec_envelope_host(
68        state: *mut f32, dst: *mut f32, src: *const f32,
69        attack: *const f32, release: *const f32,
70        mode: u32, sample_rate: f32, len: u32,
71    );
72    fn sesh_vec_tanh_host(dst: *mut f32, src: *const f32, drive: *const f32, len: u32);
73    fn sesh_vec_hard_clip_host(dst: *mut f32, src: *const f32, threshold: *const f32, len: u32);
74    fn sesh_vec_abs_host(dst: *mut f32, src: *const f32, len: u32);
75    fn sesh_vec_neg_host(dst: *mut f32, src: *const f32, len: u32);
76    fn sesh_vec_sqrt_host(dst: *mut f32, src: *const f32, len: u32);
77    fn sesh_vec_recip_host(dst: *mut f32, src: *const f32, len: u32);
78    fn sesh_vec_div_host(dst: *mut f32, a: *const f32, b: *const f32, len: u32);
79    fn sesh_vec_pow_host(dst: *mut f32, src: *const f32, exp: *const f32, len: u32);
80}
81
82// ---------------------------------------------------------------------------
83// Enums and state types
84// ---------------------------------------------------------------------------
85
86/// Oscillator waveform shape.
87#[repr(u32)]
88#[derive(Clone, Copy)]
89pub enum Waveform {
90    Sine = 0,
91    Triangle = 1,
92    Saw = 2,
93    Square = 3,
94}
95
96/// Biquad filter type.
97#[repr(u32)]
98#[derive(Clone, Copy)]
99pub enum FilterType {
100    Lowpass = 0,
101    Highpass = 1,
102    Bandpass = 2,
103    Notch = 3,
104    /// Parametric EQ band — boost/cut at cutoff frequency.
105    Peak = 4,
106    /// Boost/cut below cutoff frequency.
107    LowShelf = 5,
108    /// Boost/cut above cutoff frequency.
109    HighShelf = 6,
110    /// Phase shift without changing amplitude — used in phasers.
111    Allpass = 7,
112}
113
114/// Internal state for a biquad filter (two-sample history).
115#[repr(C)]
116pub struct BiquadState {
117    pub x1: f32,
118    pub x2: f32,
119    pub y1: f32,
120    pub y2: f32,
121}
122
123impl BiquadState {
124    pub const fn new() -> Self {
125        Self { x1: 0.0, x2: 0.0, y1: 0.0, y2: 0.0 }
126    }
127}
128
129/// Envelope follower detection mode.
130#[repr(u32)]
131#[derive(Clone, Copy)]
132pub enum EnvelopeMode {
133    /// Track instantaneous peaks.
134    Peak = 0,
135    /// Track root-mean-square level.
136    Rms = 1,
137}
138
139/// Internal state for an envelope follower.
140#[repr(C)]
141pub struct EnvelopeState {
142    pub current: f32,
143}
144
145impl EnvelopeState {
146    pub const fn new() -> Self {
147        Self { current: 0.0 }
148    }
149}
150
151// ===========================================================================
152// Math ops
153// ===========================================================================
154
155/// Copy `src` into `dst`.
156pub fn vec_copy(dst: &mut [f32], src: &[f32]) {
157    let len = dst.len().min(src.len());
158    if use_host_ops() {
159        unsafe { sesh_vec_copy_host(dst.as_mut_ptr(), src.as_ptr(), len as u32) }
160    } else {
161        dst[..len].copy_from_slice(&src[..len]);
162    }
163}
164
165/// Fill `dst` with a constant value.
166pub fn vec_fill(dst: &mut [f32], value: f32) {
167    let len = dst.len();
168    if use_host_ops() {
169        unsafe { sesh_vec_fill_host(dst.as_mut_ptr(), value, len as u32) }
170    } else {
171        for s in dst.iter_mut() {
172            *s = value;
173        }
174    }
175}
176
177/// Element-wise addition: `dst[i] = a[i] + b[i]`.
178pub fn vec_add(dst: &mut [f32], a: &[f32], b: &[f32]) {
179    let len = dst.len().min(a.len()).min(b.len());
180    if use_host_ops() {
181        unsafe { sesh_vec_add_host(dst.as_mut_ptr(), a.as_ptr(), b.as_ptr(), len as u32) }
182    } else {
183        for i in 0..len {
184            dst[i] = a[i] + b[i];
185        }
186    }
187}
188
189/// In-place element-wise addition: `dst[i] += src[i]`.
190///
191/// Convenience wrapper — calls the same underlying op as `vec_add` with dst aliased as input.
192pub fn vec_add_assign(dst: &mut [f32], src: &[f32]) {
193    let len = dst.len().min(src.len());
194    if use_host_ops() {
195        unsafe { sesh_vec_add_host(dst.as_mut_ptr(), dst.as_ptr(), src.as_ptr(), len as u32) }
196    } else {
197        for i in 0..len {
198            dst[i] += src[i];
199        }
200    }
201}
202
203/// Add scalar to every element: `dst[i] += value`.
204pub fn vec_add_scalar(dst: &mut [f32], value: f32) {
205    let len = dst.len();
206    if use_host_ops() {
207        unsafe { sesh_vec_add_scalar_host(dst.as_mut_ptr(), value, len as u32) }
208    } else {
209        for s in dst.iter_mut() {
210            *s += value;
211        }
212    }
213}
214
215/// Element-wise multiplication: `dst[i] = a[i] * b[i]`.
216pub fn vec_mul(dst: &mut [f32], a: &[f32], b: &[f32]) {
217    let len = dst.len().min(a.len()).min(b.len());
218    if use_host_ops() {
219        unsafe { sesh_vec_mul_host(dst.as_mut_ptr(), a.as_ptr(), b.as_ptr(), len as u32) }
220    } else {
221        for i in 0..len {
222            dst[i] = a[i] * b[i];
223        }
224    }
225}
226
227/// In-place element-wise multiplication: `dst[i] *= src[i]`.
228///
229/// Convenience wrapper — calls the same underlying op as `vec_mul` with dst aliased as input.
230pub fn vec_mul_assign(dst: &mut [f32], src: &[f32]) {
231    let len = dst.len().min(src.len());
232    if use_host_ops() {
233        unsafe { sesh_vec_mul_host(dst.as_mut_ptr(), dst.as_ptr(), src.as_ptr(), len as u32) }
234    } else {
235        for i in 0..len {
236            dst[i] *= src[i];
237        }
238    }
239}
240
241/// Multiply every element by scalar: `dst[i] *= value`.
242pub fn vec_mul_scalar(dst: &mut [f32], value: f32) {
243    let len = dst.len();
244    if use_host_ops() {
245        unsafe { sesh_vec_mul_scalar_host(dst.as_mut_ptr(), value, len as u32) }
246    } else {
247        for s in dst.iter_mut() {
248            *s *= value;
249        }
250    }
251}
252
253/// Multiply and accumulate: `dst[i] += src[i] * gain`.
254pub fn vec_mul_add(dst: &mut [f32], src: &[f32], gain: f32) {
255    let len = dst.len().min(src.len());
256    if use_host_ops() {
257        unsafe { sesh_vec_mul_add_host(dst.as_mut_ptr(), src.as_ptr(), gain, len as u32) }
258    } else {
259        for i in 0..len {
260            dst[i] += src[i] * gain;
261        }
262    }
263}
264
265/// Clamp: `dst[i] = clamp(src[i], min, max)`.
266pub fn vec_clamp(dst: &mut [f32], src: &[f32], min: f32, max: f32) {
267    let len = dst.len().min(src.len());
268    if use_host_ops() {
269        unsafe { sesh_vec_clamp_host(dst.as_mut_ptr(), src.as_ptr(), min, max, len as u32) }
270    } else {
271        for i in 0..len {
272            dst[i] = src[i].clamp(min, max);
273        }
274    }
275}
276
277/// In-place clamp: `dst[i] = clamp(dst[i], min, max)`.
278pub fn vec_clamp_assign(dst: &mut [f32], min: f32, max: f32) {
279    let len = dst.len();
280    if use_host_ops() {
281        unsafe { sesh_vec_clamp_host(dst.as_mut_ptr(), dst.as_ptr(), min, max, len as u32) }
282    } else {
283        for i in 0..len {
284            dst[i] = dst[i].clamp(min, max);
285        }
286    }
287}
288
289// ===========================================================================
290// Circular buffer ops
291// ===========================================================================
292
293/// Write `src` into circular buffer `buf` starting at `*pos`, wrapping at `buf.len()`.
294/// Advances `*pos` by `src.len()`.
295pub fn vec_ring_write(buf: &mut [f32], pos: &mut usize, src: &[f32]) {
296    let buf_len = buf.len();
297    let frames = src.len();
298    if use_host_ops() {
299        let mut pos32 = *pos as u32;
300        unsafe {
301            sesh_vec_ring_write_host(
302                buf.as_mut_ptr(), buf_len as u32, &mut pos32, src.as_ptr(), frames as u32,
303            );
304        }
305        *pos = pos32 as usize;
306    } else {
307        for i in 0..frames {
308            buf[(*pos + i) % buf_len] = src[i];
309        }
310        *pos = (*pos + frames) % buf_len;
311    }
312}
313
314/// Read `dst.len()` contiguous samples from circular buffer at `pos - offset`, wrapping.
315pub fn vec_ring_read(buf: &[f32], pos: usize, dst: &mut [f32], offset: usize) {
316    let buf_len = buf.len();
317    let frames = dst.len();
318    if use_host_ops() {
319        unsafe {
320            sesh_vec_ring_read_host(
321                buf.as_ptr(), buf_len as u32, pos as u32,
322                dst.as_mut_ptr(), offset as u32, frames as u32,
323            );
324        }
325    } else {
326        let start = (pos + buf_len - offset) % buf_len;
327        for i in 0..frames {
328            dst[i] = buf[(start + i) % buf_len];
329        }
330    }
331}
332
333// ===========================================================================
334// Delay op
335// ===========================================================================
336
337/// Per-sample modulated delay read with linear interpolation.
338///
339/// For each sample `i`, reads from circular buffer at a fractional offset
340/// `time[i]` samples behind where the write head was at sample `i`.
341/// `pos` should be the write head position *after* the most recent `vec_ring_write`.
342pub fn vec_delay_read(buf: &[f32], pos: usize, dst: &mut [f32], time: &[f32]) {
343    let buf_len = buf.len();
344    let frames = dst.len().min(time.len());
345    if use_host_ops() {
346        unsafe {
347            sesh_vec_delay_read_host(
348                buf.as_ptr(), buf_len as u32, pos as u32,
349                dst.as_mut_ptr(), time.as_ptr(), frames as u32,
350            );
351        }
352    } else {
353        for i in 0..frames {
354            // The write head was at (pos - frames + i) when sample i was written.
355            let write_pos_at_i = (pos + buf_len - frames + i) % buf_len;
356
357            let delay_int = time[i] as usize;
358            let delay_frac = time[i] - delay_int as f32;
359
360            let idx1 = (write_pos_at_i + buf_len - delay_int) % buf_len;
361            let idx2 = (idx1 + buf_len - 1) % buf_len;
362
363            dst[i] = buf[idx1] + delay_frac * (buf[idx2] - buf[idx1]);
364        }
365    }
366}
367
368// ===========================================================================
369// Oscillator
370// ===========================================================================
371
372/// Fill `dst` with oscillator output. Advances `*phase`. `freq` is in Hz.
373pub fn vec_osc(
374    phase: &mut f32,
375    dst: &mut [f32],
376    freq: f32,
377    waveform: Waveform,
378    sample_rate: f32,
379) {
380    let frames = dst.len();
381    if use_host_ops() {
382        unsafe {
383            sesh_vec_osc_host(
384                phase as *mut f32, dst.as_mut_ptr(),
385                freq, waveform as u32, sample_rate, frames as u32,
386            );
387        }
388    } else {
389        let phase_inc = freq / sample_rate;
390        for i in 0..frames {
391            dst[i] = match waveform {
392                Waveform::Sine => (*phase * std::f32::consts::TAU).sin(),
393                Waveform::Triangle => 4.0 * (*phase - (*phase + 0.5).floor()).abs() - 1.0,
394                Waveform::Saw => 2.0 * (*phase - (*phase + 0.5).floor()),
395                Waveform::Square => if *phase % 1.0 < 0.5 { 1.0 } else { -1.0 },
396            };
397            *phase += phase_inc;
398            if *phase >= 1.0 {
399                *phase -= 1.0;
400            }
401        }
402    }
403}
404
405// ===========================================================================
406// Filter
407// ===========================================================================
408
409/// Biquad filter with per-sample modulation of cutoff, Q, and gain.
410///
411/// `cutoff` is in Hz, `q` is the Q factor, `gain` is in dB (used for Peak/Shelf types).
412/// Coefficients are recomputed each sample from the parameter buffers.
413pub fn vec_biquad(
414    state: &mut BiquadState,
415    dst: &mut [f32],
416    src: &[f32],
417    cutoff: &[f32],
418    q: &[f32],
419    gain: &[f32],
420    filter_type: FilterType,
421    sample_rate: f32,
422) {
423    let frames = dst.len().min(src.len()).min(cutoff.len()).min(q.len()).min(gain.len());
424    if use_host_ops() {
425        unsafe {
426            sesh_vec_biquad_host(
427                state as *mut BiquadState as *mut f32,
428                dst.as_mut_ptr(), src.as_ptr(),
429                cutoff.as_ptr(), q.as_ptr(), gain.as_ptr(),
430                filter_type as u32, sample_rate, frames as u32,
431            );
432        }
433    } else {
434        for i in 0..frames {
435            let w0 = std::f32::consts::TAU * cutoff[i] / sample_rate;
436            let cos_w0 = w0.cos();
437            let sin_w0 = w0.sin();
438            let alpha = sin_w0 / (2.0 * q[i]);
439            let a_db = gain[i];
440            let a_lin = 10.0f32.powf(a_db / 40.0);
441
442            let (b0, b1, b2, a0, a1, a2) = match filter_type {
443                FilterType::Lowpass => {
444                    let b1 = 1.0 - cos_w0;
445                    let b0 = b1 / 2.0;
446                    (b0, b1, b0, 1.0 + alpha, -2.0 * cos_w0, 1.0 - alpha)
447                }
448                FilterType::Highpass => {
449                    let b1 = -(1.0 + cos_w0);
450                    let b0 = (1.0 + cos_w0) / 2.0;
451                    (b0, b1, b0, 1.0 + alpha, -2.0 * cos_w0, 1.0 - alpha)
452                }
453                FilterType::Bandpass => {
454                    (alpha, 0.0, -alpha, 1.0 + alpha, -2.0 * cos_w0, 1.0 - alpha)
455                }
456                FilterType::Notch => {
457                    (1.0, -2.0 * cos_w0, 1.0, 1.0 + alpha, -2.0 * cos_w0, 1.0 - alpha)
458                }
459                FilterType::Peak => {
460                    (
461                        1.0 + alpha * a_lin,
462                        -2.0 * cos_w0,
463                        1.0 - alpha * a_lin,
464                        1.0 + alpha / a_lin,
465                        -2.0 * cos_w0,
466                        1.0 - alpha / a_lin,
467                    )
468                }
469                FilterType::LowShelf => {
470                    let two_sqrt_a_alpha = 2.0 * a_lin.sqrt() * alpha;
471                    (
472                        a_lin * ((a_lin + 1.0) - (a_lin - 1.0) * cos_w0 + two_sqrt_a_alpha),
473                        2.0 * a_lin * ((a_lin - 1.0) - (a_lin + 1.0) * cos_w0),
474                        a_lin * ((a_lin + 1.0) - (a_lin - 1.0) * cos_w0 - two_sqrt_a_alpha),
475                        (a_lin + 1.0) + (a_lin - 1.0) * cos_w0 + two_sqrt_a_alpha,
476                        -2.0 * ((a_lin - 1.0) + (a_lin + 1.0) * cos_w0),
477                        (a_lin + 1.0) + (a_lin - 1.0) * cos_w0 - two_sqrt_a_alpha,
478                    )
479                }
480                FilterType::HighShelf => {
481                    let two_sqrt_a_alpha = 2.0 * a_lin.sqrt() * alpha;
482                    (
483                        a_lin * ((a_lin + 1.0) + (a_lin - 1.0) * cos_w0 + two_sqrt_a_alpha),
484                        -2.0 * a_lin * ((a_lin - 1.0) + (a_lin + 1.0) * cos_w0),
485                        a_lin * ((a_lin + 1.0) + (a_lin - 1.0) * cos_w0 - two_sqrt_a_alpha),
486                        (a_lin + 1.0) - (a_lin - 1.0) * cos_w0 + two_sqrt_a_alpha,
487                        2.0 * ((a_lin - 1.0) - (a_lin + 1.0) * cos_w0),
488                        (a_lin + 1.0) - (a_lin - 1.0) * cos_w0 - two_sqrt_a_alpha,
489                    )
490                }
491                FilterType::Allpass => {
492                    (1.0 - alpha, -2.0 * cos_w0, 1.0 + alpha, 1.0 + alpha, -2.0 * cos_w0, 1.0 - alpha)
493                }
494            };
495
496            // Normalize coefficients.
497            let b0 = b0 / a0;
498            let b1 = b1 / a0;
499            let b2 = b2 / a0;
500            let a1 = a1 / a0;
501            let a2 = a2 / a0;
502
503            let x0 = src[i];
504            let y0 = b0 * x0 + b1 * state.x1 + b2 * state.x2
505                - a1 * state.y1 - a2 * state.y2;
506
507            state.x2 = state.x1;
508            state.x1 = x0;
509            state.y2 = state.y1;
510            state.y1 = y0;
511
512            dst[i] = y0;
513        }
514    }
515}
516
517// ===========================================================================
518// Dynamics
519// ===========================================================================
520
521/// Envelope follower. Tracks amplitude of `src` with attack/release smoothing.
522///
523/// `attack` and `release` are in seconds (per-sample buffers for modulation).
524/// Output in `dst` is the smoothed envelope value.
525pub fn vec_envelope(
526    state: &mut EnvelopeState,
527    dst: &mut [f32],
528    src: &[f32],
529    attack: &[f32],
530    release: &[f32],
531    mode: EnvelopeMode,
532    sample_rate: f32,
533) {
534    let frames = dst.len().min(src.len()).min(attack.len()).min(release.len());
535    if use_host_ops() {
536        unsafe {
537            sesh_vec_envelope_host(
538                state as *mut EnvelopeState as *mut f32,
539                dst.as_mut_ptr(), src.as_ptr(),
540                attack.as_ptr(), release.as_ptr(),
541                mode as u32, sample_rate, frames as u32,
542            );
543        }
544    } else {
545        for i in 0..frames {
546            let input_level = match mode {
547                EnvelopeMode::Peak => src[i].abs(),
548                EnvelopeMode::Rms => src[i] * src[i],
549            };
550
551            let att_coeff = (-1.0 / (attack[i] * sample_rate)).exp();
552            let rel_coeff = (-1.0 / (release[i] * sample_rate)).exp();
553
554            let coeff = if input_level > state.current { att_coeff } else { rel_coeff };
555            state.current = coeff * state.current + (1.0 - coeff) * input_level;
556
557            dst[i] = match mode {
558                EnvelopeMode::Peak => state.current,
559                EnvelopeMode::Rms => state.current.sqrt(),
560            };
561        }
562    }
563}
564
565// ===========================================================================
566// Waveshaping
567// ===========================================================================
568
569/// Soft saturation: `dst[i] = tanh(src[i] * drive[i])`.
570pub fn vec_tanh(dst: &mut [f32], src: &[f32], drive: &[f32]) {
571    let len = dst.len().min(src.len()).min(drive.len());
572    if use_host_ops() {
573        unsafe { sesh_vec_tanh_host(dst.as_mut_ptr(), src.as_ptr(), drive.as_ptr(), len as u32) }
574    } else {
575        for i in 0..len {
576            dst[i] = (src[i] * drive[i]).tanh();
577        }
578    }
579}
580
581/// Hard clipping: clamp `src` to `±threshold[i]`.
582pub fn vec_hard_clip(dst: &mut [f32], src: &[f32], threshold: &[f32]) {
583    let len = dst.len().min(src.len()).min(threshold.len());
584    if use_host_ops() {
585        unsafe {
586            sesh_vec_hard_clip_host(dst.as_mut_ptr(), src.as_ptr(), threshold.as_ptr(), len as u32)
587        }
588    } else {
589        for i in 0..len {
590            dst[i] = src[i].clamp(-threshold[i], threshold[i]);
591        }
592    }
593}
594
595// ===========================================================================
596// Unary / additional math ops
597// ===========================================================================
598
599/// Absolute value: `dst[i] = |src[i]|`.
600pub fn vec_abs(dst: &mut [f32], src: &[f32]) {
601    let len = dst.len().min(src.len());
602    if use_host_ops() {
603        unsafe { sesh_vec_abs_host(dst.as_mut_ptr(), src.as_ptr(), len as u32) }
604    } else {
605        for i in 0..len {
606            dst[i] = src[i].abs();
607        }
608    }
609}
610
611/// Negate: `dst[i] = -src[i]`. Phase inversion.
612pub fn vec_neg(dst: &mut [f32], src: &[f32]) {
613    let len = dst.len().min(src.len());
614    if use_host_ops() {
615        unsafe { sesh_vec_neg_host(dst.as_mut_ptr(), src.as_ptr(), len as u32) }
616    } else {
617        for i in 0..len {
618            dst[i] = -src[i];
619        }
620    }
621}
622
623/// Square root: `dst[i] = sqrt(src[i])`.
624pub fn vec_sqrt(dst: &mut [f32], src: &[f32]) {
625    let len = dst.len().min(src.len());
626    if use_host_ops() {
627        unsafe { sesh_vec_sqrt_host(dst.as_mut_ptr(), src.as_ptr(), len as u32) }
628    } else {
629        for i in 0..len {
630            dst[i] = src[i].sqrt();
631        }
632    }
633}
634
635/// Reciprocal: `dst[i] = 1.0 / src[i]`.
636pub fn vec_recip(dst: &mut [f32], src: &[f32]) {
637    let len = dst.len().min(src.len());
638    if use_host_ops() {
639        unsafe { sesh_vec_recip_host(dst.as_mut_ptr(), src.as_ptr(), len as u32) }
640    } else {
641        for i in 0..len {
642            dst[i] = 1.0 / src[i];
643        }
644    }
645}
646
647/// Element-wise division: `dst[i] = a[i] / b[i]`.
648pub fn vec_div(dst: &mut [f32], a: &[f32], b: &[f32]) {
649    let len = dst.len().min(a.len()).min(b.len());
650    if use_host_ops() {
651        unsafe { sesh_vec_div_host(dst.as_mut_ptr(), a.as_ptr(), b.as_ptr(), len as u32) }
652    } else {
653        for i in 0..len {
654            dst[i] = a[i] / b[i];
655        }
656    }
657}
658
659/// Element-wise power: `dst[i] = src[i].powf(exp[i])`.
660pub fn vec_pow(dst: &mut [f32], src: &[f32], exp: &[f32]) {
661    let len = dst.len().min(src.len()).min(exp.len());
662    if use_host_ops() {
663        unsafe { sesh_vec_pow_host(dst.as_mut_ptr(), src.as_ptr(), exp.as_ptr(), len as u32) }
664    } else {
665        for i in 0..len {
666            dst[i] = src[i].powf(exp[i]);
667        }
668    }
669}
670
671// ===========================================================================
672// In-place (_assign) variants
673// ===========================================================================
674//
675// These are Rust convenience wrappers that call the same host imports with
676// dst aliased as src. Raw pointer aliasing is fine — this is purely a Rust
677// borrow-checker workaround. No additional C/host API surface.
678
679/// In-place soft saturation: `dst[i] = tanh(dst[i] * drive[i])`.
680pub fn vec_tanh_assign(dst: &mut [f32], drive: &[f32]) {
681    let len = dst.len().min(drive.len());
682    if use_host_ops() {
683        unsafe { sesh_vec_tanh_host(dst.as_mut_ptr(), dst.as_ptr(), drive.as_ptr(), len as u32) }
684    } else {
685        for i in 0..len {
686            dst[i] = (dst[i] * drive[i]).tanh();
687        }
688    }
689}
690
691/// In-place hard clipping: clamp `dst` to `±threshold[i]`.
692pub fn vec_hard_clip_assign(dst: &mut [f32], threshold: &[f32]) {
693    let len = dst.len().min(threshold.len());
694    if use_host_ops() {
695        unsafe { sesh_vec_hard_clip_host(dst.as_mut_ptr(), dst.as_ptr(), threshold.as_ptr(), len as u32) }
696    } else {
697        for i in 0..len {
698            dst[i] = dst[i].clamp(-threshold[i], threshold[i]);
699        }
700    }
701}
702
703/// In-place absolute value: `dst[i] = |dst[i]|`.
704pub fn vec_abs_assign(dst: &mut [f32]) {
705    let len = dst.len();
706    if use_host_ops() {
707        unsafe { sesh_vec_abs_host(dst.as_mut_ptr(), dst.as_ptr(), len as u32) }
708    } else {
709        for i in 0..len {
710            dst[i] = dst[i].abs();
711        }
712    }
713}
714
715/// In-place negate: `dst[i] = -dst[i]`.
716pub fn vec_neg_assign(dst: &mut [f32]) {
717    let len = dst.len();
718    if use_host_ops() {
719        unsafe { sesh_vec_neg_host(dst.as_mut_ptr(), dst.as_ptr(), len as u32) }
720    } else {
721        for i in 0..len {
722            dst[i] = -dst[i];
723        }
724    }
725}
726
727/// In-place square root: `dst[i] = sqrt(dst[i])`.
728pub fn vec_sqrt_assign(dst: &mut [f32]) {
729    let len = dst.len();
730    if use_host_ops() {
731        unsafe { sesh_vec_sqrt_host(dst.as_mut_ptr(), dst.as_ptr(), len as u32) }
732    } else {
733        for i in 0..len {
734            dst[i] = dst[i].sqrt();
735        }
736    }
737}
738
739/// In-place reciprocal: `dst[i] = 1.0 / dst[i]`.
740pub fn vec_recip_assign(dst: &mut [f32]) {
741    let len = dst.len();
742    if use_host_ops() {
743        unsafe { sesh_vec_recip_host(dst.as_mut_ptr(), dst.as_ptr(), len as u32) }
744    } else {
745        for i in 0..len {
746            dst[i] = 1.0 / dst[i];
747        }
748    }
749}
750
751/// In-place element-wise division: `dst[i] /= src[i]`.
752pub fn vec_div_assign(dst: &mut [f32], src: &[f32]) {
753    let len = dst.len().min(src.len());
754    if use_host_ops() {
755        unsafe { sesh_vec_div_host(dst.as_mut_ptr(), dst.as_ptr(), src.as_ptr(), len as u32) }
756    } else {
757        for i in 0..len {
758            dst[i] /= src[i];
759        }
760    }
761}
762
763/// In-place element-wise power: `dst[i] = dst[i].powf(exp[i])`.
764pub fn vec_pow_assign(dst: &mut [f32], exp: &[f32]) {
765    let len = dst.len().min(exp.len());
766    if use_host_ops() {
767        unsafe { sesh_vec_pow_host(dst.as_mut_ptr(), dst.as_ptr(), exp.as_ptr(), len as u32) }
768    } else {
769        for i in 0..len {
770            dst[i] = dst[i].powf(exp[i]);
771        }
772    }
773}