stm32_bootloader_client/
lib.rs

1// Copyright 2022 The stm32-bootloader-client-rs Authors.
2// This project is dual-licensed under Apache 2.0 and MIT terms.
3// See LICENSE-APACHE and LICENSE-MIT for details.
4
5//! Communicates with the STM32 factory bootloader over i2c. See AN4221 for
6//! details of how the factory bootloader works.
7
8#![cfg_attr(not(any(test, feature = "std")), no_std)]
9
10use core::ops::Deref;
11use core::ops::DerefMut;
12use embedded_hal::blocking::i2c::Read;
13use embedded_hal::blocking::i2c::Write;
14
15const BOOTLOADER_ACK: u8 = 0x79;
16const BOOTLOADER_NACK: u8 = 0x1f;
17const BOOTLOADER_BUSY: u8 = 0x76;
18
19/// Configuration for communication with stm32 system bootloader.
20#[derive(Debug, Clone, Copy)]
21#[non_exhaustive]
22pub struct Config {
23    /// See AN2606 for the i2c address for the specific chip you're talking to.
24    /// The i2c address will also depend on which i2c bus on the chip you're
25    /// connected to.
26    i2c_address: u8,
27
28    /// The maximum number of milliseconds that a full flash erase can take.
29    /// Search the datasheet for your specific STM32 for "mass erase time". If
30    /// in doubt, round up.
31    pub mass_erase_max_ms: u32,
32}
33
34#[derive(Clone, Copy)]
35#[repr(u8)]
36enum Command {
37    GetVersion = 0x01,
38    GetId = 0x02,
39    ReadMemory = 0x11,
40    Go = 0x21,
41    WriteMemory = 0x31,
42    Erase = 0x44,
43}
44
45const SPECIAL_ERASE_ALL: [u8; 2] = [0xff, 0xff];
46
47pub const MAX_READ_WRITE_SIZE: usize = 128;
48
49type Result<T, E = Error> = core::result::Result<T, E>;
50
51#[derive(Debug, Clone, Copy)]
52#[cfg_attr(feature = "defmt", derive(defmt::Format))]
53pub enum Error {
54    TransportError,
55    Nack,
56    NackFromCommand(u8),
57    Busy,
58    UnexpectedResponse,
59    InvalidArgument,
60    VerifyFailedAtAddress(u32),
61    EraseFailed,
62}
63
64pub struct Stm32<'a, I2c: Write + Read, DelayMs> {
65    dev: MaybeOwned<'a, I2c>,
66    delay: MaybeOwned<'a, DelayMs>,
67    config: Config,
68}
69
70#[derive(Debug, Clone)]
71pub struct Progress {
72    pub bytes_complete: usize,
73    pub bytes_total: usize,
74}
75
76#[cfg(feature = "std")]
77impl<E, I2c> Stm32<'static, I2c, StdDelay>
78where
79    E: core::fmt::Debug,
80    I2c: Write<Error = E> + Read<Error = E>,
81{
82    pub fn new(dev: I2c, config: Config) -> Stm32<'static, I2c, StdDelay> {
83        Self {
84            dev: MaybeOwned::Owned(dev),
85            delay: MaybeOwned::Owned(StdDelay),
86            config,
87        }
88    }
89}
90
91#[cfg(feature = "std")]
92impl<'a, E, I2c> Stm32<'a, I2c, StdDelay>
93where
94    E: core::fmt::Debug,
95    I2c: Write<Error = E> + Read<Error = E>,
96{
97    /// Construct a new instance where we only borrow the I2C implementation.
98    /// This is useful if you have other things on the I2C bus that want to
99    /// communicate with so don't want to give up ownership of the I2C bus.
100    pub fn borrowed(dev: &'a mut I2c, config: Config) -> Stm32<'a, I2c, StdDelay> {
101        Self {
102            dev: MaybeOwned::Borrowed(dev),
103            delay: MaybeOwned::Owned(StdDelay),
104            config,
105        }
106    }
107}
108
109impl<E, I2c, Delay> Stm32<'static, I2c, Delay>
110where
111    E: core::fmt::Debug,
112    I2c: Write<Error = E> + Read<Error = E>,
113    Delay: embedded_hal::blocking::delay::DelayMs<u32>,
114{
115    /// Constructs a new instance with a custom delay implementation.
116    pub fn new_with_delay(dev: I2c, delay: Delay, config: Config) -> Stm32<'static, I2c, Delay> {
117        Self {
118            dev: MaybeOwned::Owned(dev),
119            delay: MaybeOwned::Owned(delay),
120            config,
121        }
122    }
123}
124
125impl<'a, E, I2c, Delay> Stm32<'a, I2c, Delay>
126where
127    E: core::fmt::Debug,
128    I2c: Write<Error = E> + Read<Error = E>,
129    Delay: embedded_hal::blocking::delay::DelayMs<u32>,
130{
131    /// Borrows both the I2C implementation and a custom delay.
132    pub fn borrowed_with_delay(
133        dev: &'a mut I2c,
134        delay: &'a mut Delay,
135        config: Config,
136    ) -> Stm32<'a, I2c, Delay> {
137        Self {
138            dev: MaybeOwned::Borrowed(dev),
139            delay: MaybeOwned::Borrowed(delay),
140            config,
141        }
142    }
143
144    pub fn get_chip_id(&mut self) -> Result<u16> {
145        self.send_command(Command::GetId)?;
146        // For STM32, the first byte will always be a 1 and the payload will
147        // always be 3 bytes.
148        let mut buffer = [0u8; 3];
149        self.read(&mut buffer)?;
150        self.get_ack_for_command(Command::GetId)?;
151        Ok(u16::from_be_bytes([buffer[1], buffer[2]]))
152    }
153
154    /// Reads memory starting from `address`, putting the result into `out`.
155    pub fn read_memory(&mut self, address: u32, out: &mut [u8]) -> Result<()> {
156        if out.len() > MAX_READ_WRITE_SIZE {
157            return Err(Error::InvalidArgument);
158        }
159        self.send_command(Command::ReadMemory)?;
160        self.send_address(address)?;
161        let mut buffer = [0u8; 2];
162        buffer[0] = (out.len() - 1) as u8;
163        buffer[1] = checksum(&buffer[0..1]);
164        self.write(&buffer)?;
165        self.get_ack_for_command(Command::ReadMemory)?;
166        self.read(out)?;
167
168        Ok(())
169    }
170
171    /// Writes `data` at `address`. Maximum write size is 256 bytes.
172    pub fn write_memory(&mut self, address: u32, data: &[u8]) -> Result<()> {
173        if data.len() > MAX_READ_WRITE_SIZE {
174            return Err(Error::InvalidArgument);
175        }
176        self.send_command(Command::WriteMemory)?;
177        self.send_address(address)?;
178
179        let mut buffer = [0u8; MAX_READ_WRITE_SIZE + 2];
180        buffer[0] = (data.len() - 1) as u8;
181        buffer[1..1 + data.len()].copy_from_slice(data);
182        buffer[1 + data.len()] = checksum(&buffer[0..1 + data.len()]);
183        self.write(&buffer[..data.len() + 2])?;
184
185        self.get_ack_for_command(Command::WriteMemory)
186    }
187
188    /// Writes `bytes` to `address`, calling `progress_cb` after each block.
189    pub fn write_bulk(
190        &mut self,
191        mut address: u32,
192        bytes: &[u8],
193        mut progress_cb: impl FnMut(Progress),
194    ) -> Result<()> {
195        let mut complete = 0;
196        for chunk in bytes.chunks(MAX_READ_WRITE_SIZE) {
197            // Write-memory sometimes gets a NACK, so allow a single retry.
198            if self.write_memory(address, chunk).is_err() {
199                self.write_memory(address, chunk)?;
200            }
201            complete += chunk.len();
202            address += chunk.len() as u32;
203            progress_cb(Progress {
204                bytes_complete: complete,
205                bytes_total: bytes.len(),
206            });
207        }
208        Ok(())
209    }
210
211    /// Verifies that memory at `address` is equal to `bytes`, calling
212    /// `progress_cb` to report progress.
213    pub fn verify(
214        &mut self,
215        mut address: u32,
216        bytes: &[u8],
217        mut progress_cb: impl FnMut(Progress),
218    ) -> Result<()> {
219        let mut read_back_buffer = [0; MAX_READ_WRITE_SIZE];
220        let mut complete = 0;
221        for chunk in bytes.chunks(MAX_READ_WRITE_SIZE) {
222            let read_back = &mut read_back_buffer[..chunk.len()];
223            self.read_memory(address, read_back)?;
224            for (offset, (expected, actual)) in chunk.iter().zip(read_back.iter()).enumerate() {
225                if expected != actual {
226                    return Err(Error::VerifyFailedAtAddress(address + offset as u32));
227                }
228            }
229            complete += chunk.len();
230            address += chunk.len() as u32;
231            progress_cb(Progress {
232                bytes_complete: complete,
233                bytes_total: bytes.len(),
234            });
235        }
236        Ok(())
237    }
238
239    /// Erase the flash of the STM32.
240    pub fn erase_flash(&mut self) -> Result<()> {
241        self.send_command(Command::Erase)?;
242        let mut buffer = [0u8; 3];
243        buffer[0..2].copy_from_slice(&SPECIAL_ERASE_ALL);
244        buffer[2] = checksum(&buffer[..2]);
245        self.write(&buffer)?;
246        self.delay.delay_ms(self.config.mass_erase_max_ms);
247        self.get_ack().map_err(|_| Error::EraseFailed)
248    }
249
250    /// Returns the version number of the bootloader.
251    pub fn get_bootloader_version(&mut self) -> Result<u8> {
252        self.send_command(Command::GetVersion)?;
253        let mut buffer = [0];
254        self.read(&mut buffer)?;
255        self.get_ack_for_command(Command::GetVersion)?;
256        Ok(buffer[0])
257    }
258
259    /// Exit system bootloader by jumping to the reset vector specified in the
260    /// vector table at `address`.
261    pub fn go(&mut self, address: u32) -> Result<()> {
262        self.send_command(Command::Go)?;
263        self.send_address(address)
264    }
265
266    fn write(&mut self, bytes: &[u8]) -> Result<()> {
267        self.dev
268            .write(self.config.i2c_address, bytes)
269            .map_err(|error| {
270                log_error(&error);
271                Error::TransportError
272            })
273    }
274
275    fn read(&mut self, out: &mut [u8]) -> Result<()> {
276        self.dev
277            .read(self.config.i2c_address, out)
278            .map_err(|error| {
279                log_error(&error);
280                Error::TransportError
281            })
282    }
283
284    fn read_with_timeout(&mut self, out: &mut [u8]) -> Result<()> {
285        // TODO: Implement timeout mechanism
286        const MAX_ATTEMPTS: u32 = 10000;
287        let mut attempts = 0;
288        loop {
289            attempts += 1;
290            let result = self.read(out);
291            if result.is_ok() || attempts == MAX_ATTEMPTS {
292                return result;
293            }
294        }
295    }
296
297    fn send_command(&mut self, command: Command) -> Result<()> {
298        let command_u8 = command as u8;
299        self.write(&[command_u8, !command_u8])?;
300        self.get_ack_for_command(command)
301    }
302
303    fn get_ack(&mut self) -> Result<()> {
304        let mut response = [0u8; 1];
305        self.read_with_timeout(&mut response)?;
306        match response[0] {
307            BOOTLOADER_ACK => Ok(()),
308            BOOTLOADER_NACK => Err(Error::Nack),
309            BOOTLOADER_BUSY => Err(Error::Busy),
310            _ => Err(Error::UnexpectedResponse),
311        }
312    }
313
314    fn get_ack_for_command(&mut self, command: Command) -> Result<()> {
315        self.get_ack().map_err(|error| match error {
316            Error::Nack => Error::NackFromCommand(command as u8),
317            x => x,
318        })
319    }
320
321    fn send_address(&mut self, address: u32) -> Result<()> {
322        let mut buffer = [0u8; 5];
323        buffer[0..4].copy_from_slice(&address.to_be_bytes());
324        buffer[4] = checksum(&buffer[0..4]);
325        self.write(&buffer)?;
326        self.get_ack()
327    }
328}
329
330/// Wraps either an owned value, or a mutable reference. This sounds a little
331/// bit like std::borrow::Cow, but is actually quite different in that (a) it
332/// doesn't rely on std, (b) it mutates via either variant without having to
333/// switch to an owned variant and (c) it works with types that we can't, or
334/// don't want to clone.
335enum MaybeOwned<'a, T> {
336    Borrowed(&'a mut T),
337    Owned(T),
338}
339
340impl<'a, T> Deref for MaybeOwned<'a, T> {
341    type Target = T;
342
343    fn deref(&self) -> &Self::Target {
344        match self {
345            MaybeOwned::Borrowed(x) => *x,
346            MaybeOwned::Owned(x) => x,
347        }
348    }
349}
350
351impl<'a, T> DerefMut for MaybeOwned<'a, T> {
352    fn deref_mut(&mut self) -> &mut Self::Target {
353        match self {
354            MaybeOwned::Borrowed(x) => *x,
355            MaybeOwned::Owned(x) => x,
356        }
357    }
358}
359
360/// An implementation of the embedded-hal DelayMs trait that works when std is
361/// available.
362#[cfg(feature = "std")]
363pub struct StdDelay;
364
365#[cfg(feature = "std")]
366impl embedded_hal::blocking::delay::DelayMs<u32> for StdDelay {
367    fn delay_ms(&mut self, ms: u32) {
368        std::thread::sleep(std::time::Duration::from_millis(ms.into()));
369    }
370}
371
372fn checksum(bytes: &[u8]) -> u8 {
373    let initial = if bytes.len() == 1 { 0xff } else { 0 };
374    bytes
375        .iter()
376        .fold(initial, |checksum, value| checksum ^ value)
377}
378
379impl core::fmt::Display for Error {
380    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
381        match self {
382            Error::TransportError => write!(f, "Transport error"),
383            Error::Nack => write!(f, "NACK"),
384            Error::NackFromCommand(command) => write!(f, "Nack from command {}", command),
385            Error::Busy => write!(f, "Busy"),
386            Error::UnexpectedResponse => write!(f, "Unexpected response"),
387            Error::InvalidArgument => write!(f, "Invalid argument"),
388            Error::VerifyFailedAtAddress(address) => {
389                write!(f, "Verify failed at address {:x}", address)
390            }
391            Error::EraseFailed => write!(f, "Erase failed"),
392        }
393    }
394}
395
396impl Config {
397    pub const fn i2c_address(i2c_address: u8) -> Self {
398        Self {
399            i2c_address,
400            // A moderately conservative default. stm32g071 has 40.1ms.
401            // stm32l452 has 24.59ms
402            mass_erase_max_ms: 200,
403        }
404    }
405}
406
407#[cfg(feature = "std")]
408impl std::error::Error for Error {}
409
410fn log_error<E: core::fmt::Debug>(_error: &E) {
411    #[cfg(feature = "defmt")]
412    {
413        defmt::error!("I2C error: {:?}", defmt::Debug2Format(_error));
414    }
415    #[cfg(feature = "log")]
416    {
417        log::error!("I2C error: {:?}", _error);
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use embedded_hal_mock::i2c;
425
426    const I2C_ADDRESS: u8 = 0x51;
427    const CONFIG: Config = Config::i2c_address(I2C_ADDRESS);
428
429    fn mock_write_with_checksum(bytes: &[u8]) -> i2c::Transaction {
430        let mut with_checksum = Vec::with_capacity(bytes.len() + 1);
431        with_checksum.extend_from_slice(bytes);
432        with_checksum.push(checksum(bytes));
433        i2c::Transaction::write(I2C_ADDRESS, with_checksum)
434    }
435
436    fn mock_read(bytes: &[u8]) -> i2c::Transaction {
437        i2c::Transaction::read(I2C_ADDRESS, bytes.to_owned())
438    }
439
440    struct Delay;
441
442    impl embedded_hal::blocking::delay::DelayMs<u32> for Delay {
443        fn delay_ms(&mut self, _ms: u32) {}
444    }
445
446    #[test]
447    fn test_get_chip_id() {
448        let expectations = [
449            mock_write_with_checksum(&[Command::GetId as u8]),
450            mock_read(&[BOOTLOADER_ACK]),
451            mock_read(&[0, 0xab, 0xcd]),
452            mock_read(&[BOOTLOADER_ACK]),
453        ];
454        let mut i2c = i2c::Mock::new(&expectations);
455        let mut delay = Delay;
456        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
457        assert_eq!(stm32.get_chip_id().unwrap(), 0xabcd);
458        i2c.done();
459    }
460
461    #[test]
462    fn test_get_bootloader_version() {
463        let expectations = [
464            mock_write_with_checksum(&[Command::GetVersion as u8]),
465            mock_read(&[BOOTLOADER_ACK]),
466            mock_read(&[0xef]),
467            mock_read(&[BOOTLOADER_ACK]),
468        ];
469        let mut i2c = i2c::Mock::new(&expectations);
470        let mut delay = Delay;
471        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
472        assert_eq!(stm32.get_bootloader_version().unwrap(), 0xef);
473        i2c.done();
474    }
475
476    #[test]
477    fn test_read_memory() {
478        let expectations = [
479            mock_write_with_checksum(&[Command::ReadMemory as u8]),
480            mock_read(&[BOOTLOADER_ACK]),
481            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
482            mock_read(&[BOOTLOADER_ACK]),
483            mock_write_with_checksum(&[0x02]),
484            mock_read(&[BOOTLOADER_ACK]),
485            mock_read(&[0xab, 0xcd, 0xef]),
486        ];
487        let mut i2c = i2c::Mock::new(&expectations);
488        let mut delay = Delay;
489        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
490        let mut out = [0u8; 3];
491        stm32.read_memory(0x12345678, &mut out).unwrap();
492        assert_eq!(&out, &[0xab, 0xcd, 0xef]);
493        i2c.done();
494    }
495
496    #[test]
497    fn test_write_memory() {
498        let expectations = [
499            mock_write_with_checksum(&[Command::WriteMemory as u8]),
500            mock_read(&[BOOTLOADER_ACK]),
501            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
502            mock_read(&[BOOTLOADER_ACK]),
503            mock_write_with_checksum(&[0x03, 0xab, 0xcd, 0xef, 0x12]),
504            mock_read(&[BOOTLOADER_ACK]),
505        ];
506        let mut i2c = i2c::Mock::new(&expectations);
507        let mut delay = Delay;
508        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
509        stm32
510            .write_memory(0x12345678, &[0xab, 0xcd, 0xef, 0x12])
511            .unwrap();
512        i2c.done();
513    }
514
515    #[test]
516    fn test_go() {
517        let expectations = [
518            mock_write_with_checksum(&[Command::Go as u8]),
519            mock_read(&[BOOTLOADER_ACK]),
520            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
521            mock_read(&[BOOTLOADER_ACK]),
522        ];
523        let mut i2c = i2c::Mock::new(&expectations);
524        let mut delay = Delay;
525        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
526        stm32.go(0x12345678).unwrap();
527        i2c.done();
528    }
529
530    #[test]
531    fn test_write_bulk() {
532        let to_write: Vec<u8> = (0..200).collect();
533        let mut write1 = vec![127u8];
534        write1.extend_from_slice(&to_write[..128]);
535        let mut write2 = vec![(to_write.len() - 128 - 1) as u8];
536        write2.extend_from_slice(&to_write[128..]);
537        let expectations = [
538            mock_write_with_checksum(&[Command::WriteMemory as u8]),
539            mock_read(&[BOOTLOADER_ACK]),
540            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
541            mock_read(&[BOOTLOADER_ACK]),
542            mock_write_with_checksum(&write1),
543            mock_read(&[BOOTLOADER_ACK]),
544            mock_write_with_checksum(&[Command::WriteMemory as u8]),
545            mock_read(&[BOOTLOADER_ACK]),
546            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0xf8]),
547            mock_read(&[BOOTLOADER_ACK]),
548            mock_write_with_checksum(&write2),
549            mock_read(&[BOOTLOADER_ACK]),
550        ];
551        let mut i2c = i2c::Mock::new(&expectations);
552        let mut delay = Delay;
553        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
554        let mut callback_count = 0;
555        stm32
556            .write_bulk(0x12345678, &to_write, |_| {
557                callback_count += 1;
558            })
559            .unwrap();
560        assert_eq!(callback_count, 2);
561        i2c.done();
562    }
563
564    #[test]
565    fn test_verify() {
566        let to_verify: Vec<u8> = (0..200).collect();
567        let expectations = [
568            mock_write_with_checksum(&[Command::ReadMemory as u8]),
569            mock_read(&[BOOTLOADER_ACK]),
570            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
571            mock_read(&[BOOTLOADER_ACK]),
572            mock_write_with_checksum(&[127]),
573            mock_read(&[BOOTLOADER_ACK]),
574            mock_read(&to_verify[..128]),
575            mock_write_with_checksum(&[Command::ReadMemory as u8]),
576            mock_read(&[BOOTLOADER_ACK]),
577            mock_write_with_checksum(&[0x12, 0x34, 0x56, 0xf8]),
578            mock_read(&[BOOTLOADER_ACK]),
579            mock_write_with_checksum(&[(to_verify.len() - 128 - 1) as u8]),
580            mock_read(&[BOOTLOADER_ACK]),
581            mock_read(&to_verify[128..]),
582        ];
583        let mut i2c = i2c::Mock::new(&expectations);
584        let mut delay = Delay;
585        let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
586        let mut callback_count = 0;
587        stm32
588            .verify(0x12345678, &to_verify, |_| {
589                callback_count += 1;
590            })
591            .unwrap();
592        assert_eq!(callback_count, 2);
593        i2c.done();
594    }
595}