xinput_mapper/
lib.rs

1//! Functional helpers to convert a DInput->XInput YAML mapping into an XInput-like state.
2//! - Pure functions; no hidden global state
3//! - YAML schema matches the one produced by `dinput_mapper`, with backward-compatible extensions
4//! - Comments are in English by request
5
6mod utils;
7pub use utils::*;
8
9use serde::Deserialize;
10use thiserror::Error;
11use std::collections::BTreeMap;
12
13/* =============================== YAML Types =============================== */
14
15#[derive(Debug, Deserialize)]
16pub struct MappingYaml {
17    pub device: Device,
18    pub axes: Axes,
19    #[serde(default)]
20    pub triggers: Option<Triggers>,
21    #[serde(default)]
22    pub dpad: Option<Dpad>,
23    #[serde(default)]
24    pub buttons: Option<Buttons>,
25    /// NEW: base offset (in flat bits) for button map indices; defaults to 0
26    #[serde(default)]
27    pub buttons_base_bit_offset: usize,
28}
29
30#[derive(Debug, Deserialize)]
31pub struct Device {
32    pub name: String,
33    pub vid: String,
34    pub pid: String,
35}
36
37/// Generic axis descriptor with optional bit-level addressing.
38/// Backward compatible: if `bit_offset` is absent, `report_offset`*8 is used.
39/// `size_bits` supports 4/8/10/12/16 (others fall back via generic reader).
40#[derive(Debug, Deserialize, Clone, Copy)]
41pub struct Axis {
42    pub report_offset: usize, // kept for backward compatibility (byte index)
43    pub size_bits: u8, // typical: 8 or 16; can be 4/10/12 as well
44    pub logical_min: i32,
45    pub logical_max: i32,
46    pub inverted: bool,
47    #[serde(default)]
48    pub bit_offset: Option<usize>, // NEW: flat bit offset (overrides report_offset)
49}
50
51#[derive(Debug, Deserialize)]
52pub struct Axes {
53    #[serde(default)]
54    pub left_x: Option<Axis>,
55    #[serde(default)]
56    pub left_y: Option<Axis>,
57    #[serde(default)]
58    pub right_x: Option<Axis>,
59    #[serde(default)]
60    pub right_y: Option<Axis>,
61}
62
63#[derive(Debug, Deserialize, Clone, Copy)]
64pub struct TriggerDef {
65    pub report_offset: usize,
66    pub size_bits: u8,
67    pub logical_min: i32,
68    pub logical_max: i32,
69    #[serde(default)]
70    pub bit_offset: Option<usize>, // NEW: flat bit offset for triggers
71}
72
73#[derive(Debug, Deserialize)]
74pub struct Triggers {
75    #[serde(default)]
76    pub left_trigger: Option<TriggerDef>,
77    #[serde(default)]
78    pub right_trigger: Option<TriggerDef>,
79    // Optional future: combined axis; now supported if L/R are None
80    #[serde(default)]
81    pub combined_axis: Option<Axis>,
82}
83
84/// D-Pad (HAT) descriptor.
85/// Backward compatible byte model + NEW bit/nibble addressing.
86/// Priority: if `bit_offset` is set, read via bits (`size_bits` default 4).
87#[derive(Debug, Deserialize)]
88pub struct Dpad {
89    #[serde(rename = "type")]
90    pub dtype: String, // "hat"
91    pub report_offset: usize, // byte index (legacy)
92    pub logical_min: i32, // usually 0
93    pub logical_max: i32, // usually 7
94    pub neutral: u8, // e.g., 8 or 15
95    #[serde(default)]
96    pub bit_offset: Option<usize>, // NEW: flat bit offset for HAT
97    #[serde(default)]
98    pub size_bits: Option<u8>, // NEW: default Some(4) when bit_offset present
99    #[serde(default)]
100    pub nibble: Option<String>, // NEW: "low" | "high" for byte-packed 4-bit hats
101}
102
103#[derive(Debug, Deserialize, Default)]
104pub struct Buttons(pub BTreeMap<String, usize>); // flat bit indices relative to buttons_base_bit_offset
105
106/* ============================== XInput Types ============================== */
107
108/// XInput button bit flags (matches Windows XINPUT header values)
109#[allow(non_snake_case)]
110pub mod XButtons {
111    pub const DPAD_UP: u16 = 0x0001;
112    pub const DPAD_DOWN: u16 = 0x0002;
113    pub const DPAD_LEFT: u16 = 0x0004;
114    pub const DPAD_RIGHT: u16 = 0x0008;
115    pub const START: u16 = 0x0010;
116    pub const BACK: u16 = 0x0020;
117    pub const LEFT_THUMB: u16 = 0x0040;
118    pub const RIGHT_THUMB: u16 = 0x0080;
119    pub const LEFT_SHOULDER: u16 = 0x0100;
120    pub const RIGHT_SHOULDER: u16 = 0x0200;
121    // 0x0400 and 0x0800 reserved in old headers
122    pub const A: u16 = 0x1000;
123    pub const B: u16 = 0x2000;
124    pub const X: u16 = 0x4000;
125    pub const Y: u16 = 0x8000;
126}
127
128/// Minimal XInput-like state we produce.
129/// - Sticks are i16 in the standard range [-32768, 32767]
130/// - Triggers are u8 [0..255]
131/// - Buttons are packed into a u16 mask using XButtons.
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct XInputState {
134    pub buttons: u16,
135    pub left_trigger: u8,
136    pub right_trigger: u8,
137    pub thumb_lx: i16,
138    pub thumb_ly: i16,
139    pub thumb_rx: i16,
140    pub thumb_ry: i16,
141}
142
143/* ================================ Errors ================================= */
144
145/// Distinguish IO vs YAML parse errors for file loading.
146#[derive(Debug, Error)]
147pub enum LoadError {
148    #[error("io error: {0}")] Io(#[from] std::io::Error),
149    #[error("yaml error: {0}")] Yaml(#[from] serde_yaml::Error),
150}
151
152/* =========================== Public API Functions ========================== */
153
154/// Parse YAML string into MappingYaml.
155pub fn parse_mapping_yaml_str(s: &str) -> Result<MappingYaml, serde_yaml::Error> {
156    serde_yaml::from_str::<MappingYaml>(s)
157}
158
159/// Parse YAML file content into MappingYaml.
160/// Returns `LoadError` to separate IO and YAML failures.
161pub fn parse_mapping_yaml_file(path: &std::path::Path) -> Result<MappingYaml, LoadError> {
162    let txt = std::fs::read_to_string(path)?; // io::Error -> LoadError::Io
163    let mapping = parse_mapping_yaml_str(&txt)?; // serde_yaml::Error -> LoadError::Yaml
164    Ok(mapping)
165}
166
167/// Convert a raw HID input report to an XInput-like state using the mapping.
168/// This is a *pure* function: output depends only on (mapping, report).
169pub fn map_report_to_xinput(m: &MappingYaml, report: &[u8]) -> XInputState {
170    // Axes
171    let lx = m.axes.left_x.map(|a| read_axis_to_i16(report, a)).unwrap_or(0);
172    let ly = m.axes.left_y.map(|a| read_axis_to_i16(report, a)).unwrap_or(0);
173    let rx = m.axes.right_x.map(|a| read_axis_to_i16(report, a)).unwrap_or(0);
174    let ry = m.axes.right_y.map(|a| read_axis_to_i16(report, a)).unwrap_or(0);
175
176    // Triggers (support combined axis split if L/R absent)
177    let (lt, rt) = match &m.triggers {
178        Some(tr) if
179            tr.left_trigger.is_none() &&
180            tr.right_trigger.is_none() &&
181            tr.combined_axis.is_some()
182        => {
183            let ax = tr.combined_axis.unwrap();
184            let raw = read_axis_raw_i32(report, ax).unwrap_or(0);
185            // scale to [-255, 255]
186            let v = scale_i32(raw, ax.logical_min, ax.logical_max, -255, 255);
187            if v >= 0 {
188                (v as u8, 0)
189            } else {
190                (0, -v as u8)
191            }
192        }
193        Some(tr) =>
194            (
195                tr.left_trigger.map(|t| read_trigger_to_u8(report, t)).unwrap_or(0),
196                tr.right_trigger.map(|t| read_trigger_to_u8(report, t)).unwrap_or(0),
197            ),
198        None => (0, 0),
199    };
200
201    // Buttons
202    let mut mask: u16 = 0;
203    if let Some(btns) = &m.buttons {
204        for (name, rel_bit) in btns.0.iter() {
205            let flat = m.buttons_base_bit_offset.saturating_add(*rel_bit);
206            if is_bit_set_flat(report, flat) {
207                match name.as_str() {
208                    "a" => {
209                        mask |= XButtons::A;
210                    }
211                    "b" => {
212                        mask |= XButtons::B;
213                    }
214                    "x" => {
215                        mask |= XButtons::X;
216                    }
217                    "y" => {
218                        mask |= XButtons::Y;
219                    }
220                    "lb" | "left_shoulder" => {
221                        mask |= XButtons::LEFT_SHOULDER;
222                    }
223                    "rb" | "right_shoulder" => {
224                        mask |= XButtons::RIGHT_SHOULDER;
225                    }
226                    "back" => {
227                        mask |= XButtons::BACK;
228                    }
229                    "start" => {
230                        mask |= XButtons::START;
231                    }
232                    "lt_click" | "ls" | "left_thumb" => {
233                        mask |= XButtons::LEFT_THUMB;
234                    }
235                    "rt_click" | "rs" | "right_thumb" => {
236                        mask |= XButtons::RIGHT_THUMB;
237                    }
238                    "dpad_up" => {
239                        mask |= XButtons::DPAD_UP;
240                    }
241                    "dpad_down" => {
242                        mask |= XButtons::DPAD_DOWN;
243                    }
244                    "dpad_left" => {
245                        mask |= XButtons::DPAD_LEFT;
246                    }
247                    "dpad_right" => {
248                        mask |= XButtons::DPAD_RIGHT;
249                    }
250                    _ => {}
251                }
252            }
253        }
254    }
255
256    // DPAD (hat) — supports bit/nibble models
257    if let Some(h) = &m.dpad {
258        if h.dtype == "hat" {
259            let value = read_hat_value(report, h);
260            let (up, right, down, left) = decode_hat_general(
261                value,
262                h.logical_min,
263                h.logical_max,
264                h.neutral
265            );
266            if up {
267                mask |= XButtons::DPAD_UP;
268            }
269            if right {
270                mask |= XButtons::DPAD_RIGHT;
271            }
272            if down {
273                mask |= XButtons::DPAD_DOWN;
274            }
275            if left {
276                mask |= XButtons::DPAD_LEFT;
277            }
278        }
279    }
280
281    XInputState {
282        buttons: mask,
283        left_trigger: lt,
284        right_trigger: rt,
285        thumb_lx: lx,
286        thumb_ly: ly,
287        thumb_rx: rx,
288        thumb_ry: ry,
289    }
290}
291
292/* ============================== Helper Logic ============================== */
293
294/// Read and normalize an axis to i16 [-32768..32767] with signed/unsigned handling and inversion.
295fn read_axis_to_i16(report: &[u8], ax: Axis) -> i16 {
296    let raw = read_axis_raw_i32(report, ax).unwrap_or(0);
297    let mut v = scale_i32(raw, ax.logical_min, ax.logical_max, -32768, 32767);
298    if ax.inverted {
299        v = -v;
300    }
301    v.clamp(i16::MIN as i32, i16::MAX as i32) as i16
302}
303
304/// Read trigger (u8 [0..255]) with signed/unsigned aware scaling.
305fn read_trigger_to_u8(report: &[u8], t: TriggerDef) -> u8 {
306    let raw = read_value_i32(report, t.bit_offset, t.report_offset, t.size_bits);
307    scale_i32(raw, t.logical_min, t.logical_max, 0, 255) as u8
308}
309
310/// Read raw i32 for an Axis, choosing bit_offset if present and sign-extending if needed.
311fn read_axis_raw_i32(report: &[u8], ax: Axis) -> Option<i32> {
312    let bit_off = ax.bit_offset.unwrap_or(ax.report_offset.saturating_mul(8));
313    let v = read_bits(report, bit_off, ax.size_bits)?;
314    let signed = ax.logical_min < 0;
315    Some(if signed { sign_extend(v, ax.size_bits) } else { v as i32 })
316}
317
318/// Generic value reader returning i32, preferring bit_offset when provided.
319fn read_value_i32(
320    report: &[u8],
321    bit_offset: Option<usize>,
322    report_offset: usize,
323    size_bits: u8
324) -> i32 {
325    let bit_off = bit_offset.unwrap_or(report_offset.saturating_mul(8));
326    let v = read_bits(report, bit_off, size_bits).unwrap_or(0);
327    v as i32
328}
329
330/// Scale value from [src_min..src_max] to [dst_min..dst_max] with clamping.
331fn scale_i32(v: i32, src_min: i32, src_max: i32, dst_min: i32, dst_max: i32) -> i32 {
332    if src_max <= src_min {
333        return if dst_min <= dst_max { dst_min } else { dst_max };
334    }
335    let v_clamped = v.clamp(src_min, src_max);
336    let num = ((v_clamped - src_min) as i64) * ((dst_max - dst_min) as i64);
337    let den = (src_max - src_min) as i64;
338    ((dst_min as i64) + num / den) as i32
339}
340
341/// Read arbitrary-sized little-endian bits from a flat bit offset.
342/// Supports 1..=24 bits safely; typical cases: 4, 8, 10, 12, 16.
343/// Returns None if OOB.
344fn read_bits(report: &[u8], bit_offset: usize, size_bits: u8) -> Option<u32> {
345    if size_bits == 0 || size_bits > 24 {
346        return None;
347    }
348    let byte_idx = bit_offset / 8;
349    let bit_in_byte = (bit_offset % 8) as u32;
350    let need_bits = size_bits as u32;
351    // Read up to 4 bytes to cover 24 bits + cross-byte shift
352    let b0 = *report.get(byte_idx)? as u32;
353    let b1 = *report.get(byte_idx + 1).unwrap_or(&0) as u32;
354    let b2 = *report.get(byte_idx + 2).unwrap_or(&0) as u32;
355    let b3 = *report.get(byte_idx + 3).unwrap_or(&0) as u32;
356    let chunk = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
357    let shifted = chunk >> bit_in_byte;
358    let mask = if need_bits == 32 { u32::MAX } else { (1u32 << need_bits) - 1 };
359    Some(shifted & mask)
360}
361
362/// Legacy helper: Returns whether a flat bit index (byte*8 + bit) is set in the report (LSB-first).
363fn is_bit_set_flat(report: &[u8], flat_idx: usize) -> bool {
364    let byte = flat_idx / 8;
365    let bit = (flat_idx % 8) as u8;
366    if byte >= report.len() {
367        return false;
368    }
369    (report[byte] & (1 << bit)) != 0
370}
371
372/// Sign-extend an unsigned `v` read from `width` bits into i32.
373fn sign_extend(v: u32, width: u8) -> i32 {
374    let w = width.min(31);
375    let sign_bit = 1u32 << (w - 1);
376    let mask = (1u32 << w) - 1;
377    let v = v & mask;
378    if (v & sign_bit) != 0 {
379        // negative
380        let ext_mask = !mask;
381        (v | ext_mask) as i32
382    } else {
383        v as i32
384    }
385}
386
387/* =============================== DPAD Logic =============================== */
388
389/// Read HAT value honoring bit/nibble descriptors with backward compatibility.
390fn read_hat_value(report: &[u8], h: &Dpad) -> u8 {
391    if let Some(bit_off) = h.bit_offset {
392        let sz = h.size_bits.unwrap_or(4).max(1);
393        read_bits(report, bit_off, sz).unwrap_or(0) as u8
394    } else {
395        // legacy byte read + optional nibble masking
396        if h.report_offset >= report.len() {
397            return h.neutral;
398        }
399        let mut v = report[h.report_offset];
400        if let Some(ref nb) = h.nibble {
401            match nb.as_str() {
402                "low" => {
403                    v &= 0x0f;
404                }
405                "high" => {
406                    v = (v >> 4) & 0x0f;
407                }
408                _ => {}
409            }
410        }
411        v
412    }
413}
414
415/// Generalized hat decoder.
416/// - Accepts [logical_min..logical_max] but treats `neutral` as no-press.
417/// - For 8-way values (0..7), keeps classic diagonals; for 4-way (0..3), maps to cardinals.
418fn decode_hat_general(
419    v: u8,
420    logical_min: i32,
421    logical_max: i32,
422    neutral: u8
423) -> (bool, bool, bool, bool) {
424    if v == neutral {
425        return (false, false, false, false);
426    }
427    // clamp into range for safety
428    let v = v.clamp(logical_min as u8, logical_max as u8);
429    let span = (logical_max - logical_min).max(0) as u8;
430
431    match span {
432        3 => {
433            // 4-way: 0=up,1=right,2=down,3=left
434            let idx = v - (logical_min as u8);
435            let up = idx == 0;
436            let right = idx == 1;
437            let down = idx == 2;
438            let left = idx == 3;
439            (up, right, down, left)
440        }
441        _ => {
442            // Default to classic 8-way semantics
443            let up = v == 0 || v == 1 || v == 7;
444            let right = v == 1 || v == 2 || v == 3;
445            let down = v == 3 || v == 4 || v == 5;
446            let left = v == 5 || v == 6 || v == 7;
447            (up, right, down, left)
448        }
449    }
450}
451
452/* ================================== Tests ================================== */
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use pretty_assertions::assert_eq;
458
459    fn sample_yaml() -> &'static str {
460        r#"device:
461  name: "Sample"
462  vid: "0x1234"
463  pid: "0xabcd"
464axes:
465  left_x:  { report_offset: 0, size_bits: 8,  logical_min: 0,   logical_max: 255, inverted: false }
466  left_y:  { report_offset: 1, size_bits: 8,  logical_min: 0,   logical_max: 255, inverted: true  }
467  right_x: { report_offset: 2, size_bits: 8,  logical_min: 0,   logical_max: 255, inverted: false }
468  right_y: { report_offset: 3, size_bits: 8,  logical_min: 0,   logical_max: 255, inverted: true  }
469triggers:
470  left_trigger:  { report_offset: 4, size_bits: 8, logical_min: 0, logical_max: 255 }
471  right_trigger: { report_offset: 5, size_bits: 8, logical_min: 0, logical_max: 255 }
472dpad:
473  type: "hat"
474  report_offset: 6
475  logical_min: 0
476  logical_max: 7
477  neutral: 8
478buttons:
479  a: 56
480  b: 57
481  x: 58
482  y: 59
483  lb: 60
484  rb: 61
485  back: 62
486  start: 63
487  lt_click: 64
488  rt_click: 65
489"#
490    }
491
492    #[test]
493    fn parse_yaml_ok() {
494        let m = parse_mapping_yaml_str(sample_yaml()).unwrap();
495        assert_eq!(m.device.name, "Sample");
496        assert!(m.axes.left_x.is_some());
497        assert!(m.triggers.is_some());
498        assert!(m.dpad.is_some());
499        assert!(m.buttons.is_some());
500        assert_eq!(m.buttons_base_bit_offset, 0);
501    }
502
503    #[test]
504    fn map_report_basic() {
505        let m = parse_mapping_yaml_str(sample_yaml()).unwrap();
506
507        let mut report = vec![0u8; 9];
508        report[0] = 128;
509        report[1] = 128;
510        report[2] = 255;
511        report[3] = 0;
512        report[4] = 200;
513        report[5] = 10;
514        report[6] = 2; // HAT = right-ish
515        report[7] = 0b1000_0011; // A,B,START bits via flat bit indexes
516        report[8] = 0b0000_0001; // LT click
517
518        let xs = map_report_to_xinput(&m, &report);
519
520        use XButtons::*;
521        let expected_mask = A | B | START | LEFT_THUMB | DPAD_RIGHT;
522        assert_eq!(xs.buttons & expected_mask, expected_mask);
523        assert_eq!(xs.left_trigger, 200);
524        assert_eq!(xs.right_trigger, 10);
525        assert!(xs.thumb_lx >= -512 && xs.thumb_lx <= 512);
526        assert!(xs.thumb_ly >= -512 && xs.thumb_ly <= 512);
527        assert!(xs.thumb_rx > 32000);
528        assert!(xs.thumb_ry > 30000);
529    }
530
531    #[test]
532    fn hat_diagonals_and_neutral() {
533        assert_eq!(super::decode_hat_general(0, 0, 7, 8), (true, false, false, false));
534        assert_eq!(super::decode_hat_general(1, 0, 7, 8), (true, true, false, false));
535        assert_eq!(super::decode_hat_general(3, 0, 7, 8), (false, true, true, false));
536        assert_eq!(super::decode_hat_general(7, 0, 7, 8), (true, false, false, true));
537        assert_eq!(super::decode_hat_general(8, 0, 7, 8), (false, false, false, false));
538    }
539
540    #[test]
541    fn bit_reader_12bit_crossing() {
542        // Construct report where a 12-bit value starts at bit offset 5: value = 0xABC (2748)
543        // Layout bits (little-endian per byte): we craft three bytes.
544        // We want after shifting by 5, next 12 bits == 0xABC (0b1010_1011_1100).
545        // Let's directly build a 24-bit bucket then shift back to bytes:
546        let value: u32 = 0xabc;
547        let bit_offset = 5;
548        let mut bucket: u32 = value << bit_offset;
549        // bucket occupies 17 bits now; split into 3 bytes (little-endian stream)
550        let b0 = (bucket & 0xff) as u8;
551        bucket >>= 8;
552        let b1 = (bucket & 0xff) as u8;
553        bucket >>= 8;
554        let b2 = (bucket & 0xff) as u8;
555        let report = vec![b0, b1, b2, 0];
556        let read = super::read_bits(&report, 5, 12).unwrap();
557        assert_eq!(read, 0xabc);
558    }
559
560    #[test]
561    fn sign_extend_12bit() {
562        // For width=12, 0x800 (1000_0000_000b) should be -2048
563        let neg = super::sign_extend(0x800, 12);
564        assert_eq!(neg, -2048);
565        let pos = super::sign_extend(0x7ff, 12);
566        assert_eq!(pos, 2047);
567    }
568
569    #[test]
570    fn combined_trigger_split() {
571        // Build a mapping with only combined axis from -32768..32767 -> -255..255
572        let yaml =
573            r#"device: { name: "Cmb", vid: "0x0", pid: "0x0" }
574axes:
575  left_x:  { report_offset: 0, size_bits: 16, logical_min: -32768, logical_max: 32767, inverted: false }
576triggers:
577  combined_axis: { report_offset: 0, size_bits: 16, logical_min: -32768, logical_max: 32767, inverted: false }
578"#;
579        let m = parse_mapping_yaml_str(yaml).unwrap();
580
581        // Report where the 16-bit little-endian value at offset 0 is +32767 -> RT ~ 255, LT 0
582        let mut report = vec![0u8; 2];
583        report[0] = 0xff;
584        report[1] = 0x7f;
585        let xs = map_report_to_xinput(&m, &report);
586        assert_eq!(xs.right_trigger, 255);
587        assert_eq!(xs.left_trigger, 0);
588
589        // Negative extreme -> LT ~ 255
590        report[0] = 0x00;
591        report[1] = 0x80;
592        let xs = map_report_to_xinput(&m, &report);
593        assert_eq!(xs.left_trigger, 255);
594        assert_eq!(xs.right_trigger, 0);
595    }
596
597    #[test]
598    fn hat_nibble_high() {
599        // HAT packed in high nibble of byte 2, neutral=0xF
600        let yaml =
601            r#"device: { name: "Nib", vid: "0x0", pid: "0x0" }
602axes: {}
603triggers: {}
604dpad:
605  type: "hat"
606  report_offset: 2
607  logical_min: 0
608  logical_max: 7
609  neutral: 15
610  nibble: "high"
611"#;
612        let m = parse_mapping_yaml_str(yaml).unwrap();
613
614        let mut report = vec![0u8; 3];
615        // set high nibble = 0x1 (up-right), low nibble junk
616        report[2] = 0x10 | 0x0d;
617        let xs = map_report_to_xinput(&m, &report);
618        use XButtons::*;
619        assert_ne!(xs.buttons & DPAD_UP, 0);
620        assert_ne!(xs.buttons & DPAD_RIGHT, 0);
621        assert_eq!(xs.buttons & DPAD_DOWN, 0);
622        assert_eq!(xs.buttons & DPAD_LEFT, 0);
623    }
624}