1#![cfg_attr(not(any(test, feature = "std")), no_std)]
9
10use core::ops::Deref;
11use core::ops::DerefMut;
12use embedded_hal::blocking::i2c::Read;
13use embedded_hal::blocking::i2c::Write;
14
15const BOOTLOADER_ACK: u8 = 0x79;
16const BOOTLOADER_NACK: u8 = 0x1f;
17const BOOTLOADER_BUSY: u8 = 0x76;
18
19#[derive(Debug, Clone, Copy)]
21#[non_exhaustive]
22pub struct Config {
23 i2c_address: u8,
27
28 pub mass_erase_max_ms: u32,
32}
33
34#[derive(Clone, Copy)]
35#[repr(u8)]
36enum Command {
37 GetVersion = 0x01,
38 GetId = 0x02,
39 ReadMemory = 0x11,
40 Go = 0x21,
41 WriteMemory = 0x31,
42 Erase = 0x44,
43}
44
45const SPECIAL_ERASE_ALL: [u8; 2] = [0xff, 0xff];
46
47pub const MAX_READ_WRITE_SIZE: usize = 128;
48
49type Result<T, E = Error> = core::result::Result<T, E>;
50
51#[derive(Debug, Clone, Copy)]
52#[cfg_attr(feature = "defmt", derive(defmt::Format))]
53pub enum Error {
54 TransportError,
55 Nack,
56 NackFromCommand(u8),
57 Busy,
58 UnexpectedResponse,
59 InvalidArgument,
60 VerifyFailedAtAddress(u32),
61 EraseFailed,
62}
63
64pub struct Stm32<'a, I2c: Write + Read, DelayMs> {
65 dev: MaybeOwned<'a, I2c>,
66 delay: MaybeOwned<'a, DelayMs>,
67 config: Config,
68}
69
70#[derive(Debug, Clone)]
71pub struct Progress {
72 pub bytes_complete: usize,
73 pub bytes_total: usize,
74}
75
76#[cfg(feature = "std")]
77impl<E, I2c> Stm32<'static, I2c, StdDelay>
78where
79 E: core::fmt::Debug,
80 I2c: Write<Error = E> + Read<Error = E>,
81{
82 pub fn new(dev: I2c, config: Config) -> Stm32<'static, I2c, StdDelay> {
83 Self {
84 dev: MaybeOwned::Owned(dev),
85 delay: MaybeOwned::Owned(StdDelay),
86 config,
87 }
88 }
89}
90
91#[cfg(feature = "std")]
92impl<'a, E, I2c> Stm32<'a, I2c, StdDelay>
93where
94 E: core::fmt::Debug,
95 I2c: Write<Error = E> + Read<Error = E>,
96{
97 pub fn borrowed(dev: &'a mut I2c, config: Config) -> Stm32<'a, I2c, StdDelay> {
101 Self {
102 dev: MaybeOwned::Borrowed(dev),
103 delay: MaybeOwned::Owned(StdDelay),
104 config,
105 }
106 }
107}
108
109impl<E, I2c, Delay> Stm32<'static, I2c, Delay>
110where
111 E: core::fmt::Debug,
112 I2c: Write<Error = E> + Read<Error = E>,
113 Delay: embedded_hal::blocking::delay::DelayMs<u32>,
114{
115 pub fn new_with_delay(dev: I2c, delay: Delay, config: Config) -> Stm32<'static, I2c, Delay> {
117 Self {
118 dev: MaybeOwned::Owned(dev),
119 delay: MaybeOwned::Owned(delay),
120 config,
121 }
122 }
123}
124
125impl<'a, E, I2c, Delay> Stm32<'a, I2c, Delay>
126where
127 E: core::fmt::Debug,
128 I2c: Write<Error = E> + Read<Error = E>,
129 Delay: embedded_hal::blocking::delay::DelayMs<u32>,
130{
131 pub fn borrowed_with_delay(
133 dev: &'a mut I2c,
134 delay: &'a mut Delay,
135 config: Config,
136 ) -> Stm32<'a, I2c, Delay> {
137 Self {
138 dev: MaybeOwned::Borrowed(dev),
139 delay: MaybeOwned::Borrowed(delay),
140 config,
141 }
142 }
143
144 pub fn get_chip_id(&mut self) -> Result<u16> {
145 self.send_command(Command::GetId)?;
146 let mut buffer = [0u8; 3];
149 self.read(&mut buffer)?;
150 self.get_ack_for_command(Command::GetId)?;
151 Ok(u16::from_be_bytes([buffer[1], buffer[2]]))
152 }
153
154 pub fn read_memory(&mut self, address: u32, out: &mut [u8]) -> Result<()> {
156 if out.len() > MAX_READ_WRITE_SIZE {
157 return Err(Error::InvalidArgument);
158 }
159 self.send_command(Command::ReadMemory)?;
160 self.send_address(address)?;
161 let mut buffer = [0u8; 2];
162 buffer[0] = (out.len() - 1) as u8;
163 buffer[1] = checksum(&buffer[0..1]);
164 self.write(&buffer)?;
165 self.get_ack_for_command(Command::ReadMemory)?;
166 self.read(out)?;
167
168 Ok(())
169 }
170
171 pub fn write_memory(&mut self, address: u32, data: &[u8]) -> Result<()> {
173 if data.len() > MAX_READ_WRITE_SIZE {
174 return Err(Error::InvalidArgument);
175 }
176 self.send_command(Command::WriteMemory)?;
177 self.send_address(address)?;
178
179 let mut buffer = [0u8; MAX_READ_WRITE_SIZE + 2];
180 buffer[0] = (data.len() - 1) as u8;
181 buffer[1..1 + data.len()].copy_from_slice(data);
182 buffer[1 + data.len()] = checksum(&buffer[0..1 + data.len()]);
183 self.write(&buffer[..data.len() + 2])?;
184
185 self.get_ack_for_command(Command::WriteMemory)
186 }
187
188 pub fn write_bulk(
190 &mut self,
191 mut address: u32,
192 bytes: &[u8],
193 mut progress_cb: impl FnMut(Progress),
194 ) -> Result<()> {
195 let mut complete = 0;
196 for chunk in bytes.chunks(MAX_READ_WRITE_SIZE) {
197 if self.write_memory(address, chunk).is_err() {
199 self.write_memory(address, chunk)?;
200 }
201 complete += chunk.len();
202 address += chunk.len() as u32;
203 progress_cb(Progress {
204 bytes_complete: complete,
205 bytes_total: bytes.len(),
206 });
207 }
208 Ok(())
209 }
210
211 pub fn verify(
214 &mut self,
215 mut address: u32,
216 bytes: &[u8],
217 mut progress_cb: impl FnMut(Progress),
218 ) -> Result<()> {
219 let mut read_back_buffer = [0; MAX_READ_WRITE_SIZE];
220 let mut complete = 0;
221 for chunk in bytes.chunks(MAX_READ_WRITE_SIZE) {
222 let read_back = &mut read_back_buffer[..chunk.len()];
223 self.read_memory(address, read_back)?;
224 for (offset, (expected, actual)) in chunk.iter().zip(read_back.iter()).enumerate() {
225 if expected != actual {
226 return Err(Error::VerifyFailedAtAddress(address + offset as u32));
227 }
228 }
229 complete += chunk.len();
230 address += chunk.len() as u32;
231 progress_cb(Progress {
232 bytes_complete: complete,
233 bytes_total: bytes.len(),
234 });
235 }
236 Ok(())
237 }
238
239 pub fn erase_flash(&mut self) -> Result<()> {
241 self.send_command(Command::Erase)?;
242 let mut buffer = [0u8; 3];
243 buffer[0..2].copy_from_slice(&SPECIAL_ERASE_ALL);
244 buffer[2] = checksum(&buffer[..2]);
245 self.write(&buffer)?;
246 self.delay.delay_ms(self.config.mass_erase_max_ms);
247 self.get_ack().map_err(|_| Error::EraseFailed)
248 }
249
250 pub fn get_bootloader_version(&mut self) -> Result<u8> {
252 self.send_command(Command::GetVersion)?;
253 let mut buffer = [0];
254 self.read(&mut buffer)?;
255 self.get_ack_for_command(Command::GetVersion)?;
256 Ok(buffer[0])
257 }
258
259 pub fn go(&mut self, address: u32) -> Result<()> {
262 self.send_command(Command::Go)?;
263 self.send_address(address)
264 }
265
266 fn write(&mut self, bytes: &[u8]) -> Result<()> {
267 self.dev
268 .write(self.config.i2c_address, bytes)
269 .map_err(|error| {
270 log_error(&error);
271 Error::TransportError
272 })
273 }
274
275 fn read(&mut self, out: &mut [u8]) -> Result<()> {
276 self.dev
277 .read(self.config.i2c_address, out)
278 .map_err(|error| {
279 log_error(&error);
280 Error::TransportError
281 })
282 }
283
284 fn read_with_timeout(&mut self, out: &mut [u8]) -> Result<()> {
285 const MAX_ATTEMPTS: u32 = 10000;
287 let mut attempts = 0;
288 loop {
289 attempts += 1;
290 let result = self.read(out);
291 if result.is_ok() || attempts == MAX_ATTEMPTS {
292 return result;
293 }
294 }
295 }
296
297 fn send_command(&mut self, command: Command) -> Result<()> {
298 let command_u8 = command as u8;
299 self.write(&[command_u8, !command_u8])?;
300 self.get_ack_for_command(command)
301 }
302
303 fn get_ack(&mut self) -> Result<()> {
304 let mut response = [0u8; 1];
305 self.read_with_timeout(&mut response)?;
306 match response[0] {
307 BOOTLOADER_ACK => Ok(()),
308 BOOTLOADER_NACK => Err(Error::Nack),
309 BOOTLOADER_BUSY => Err(Error::Busy),
310 _ => Err(Error::UnexpectedResponse),
311 }
312 }
313
314 fn get_ack_for_command(&mut self, command: Command) -> Result<()> {
315 self.get_ack().map_err(|error| match error {
316 Error::Nack => Error::NackFromCommand(command as u8),
317 x => x,
318 })
319 }
320
321 fn send_address(&mut self, address: u32) -> Result<()> {
322 let mut buffer = [0u8; 5];
323 buffer[0..4].copy_from_slice(&address.to_be_bytes());
324 buffer[4] = checksum(&buffer[0..4]);
325 self.write(&buffer)?;
326 self.get_ack()
327 }
328}
329
330enum MaybeOwned<'a, T> {
336 Borrowed(&'a mut T),
337 Owned(T),
338}
339
340impl<'a, T> Deref for MaybeOwned<'a, T> {
341 type Target = T;
342
343 fn deref(&self) -> &Self::Target {
344 match self {
345 MaybeOwned::Borrowed(x) => *x,
346 MaybeOwned::Owned(x) => x,
347 }
348 }
349}
350
351impl<'a, T> DerefMut for MaybeOwned<'a, T> {
352 fn deref_mut(&mut self) -> &mut Self::Target {
353 match self {
354 MaybeOwned::Borrowed(x) => *x,
355 MaybeOwned::Owned(x) => x,
356 }
357 }
358}
359
360#[cfg(feature = "std")]
363pub struct StdDelay;
364
365#[cfg(feature = "std")]
366impl embedded_hal::blocking::delay::DelayMs<u32> for StdDelay {
367 fn delay_ms(&mut self, ms: u32) {
368 std::thread::sleep(std::time::Duration::from_millis(ms.into()));
369 }
370}
371
372fn checksum(bytes: &[u8]) -> u8 {
373 let initial = if bytes.len() == 1 { 0xff } else { 0 };
374 bytes
375 .iter()
376 .fold(initial, |checksum, value| checksum ^ value)
377}
378
379impl core::fmt::Display for Error {
380 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
381 match self {
382 Error::TransportError => write!(f, "Transport error"),
383 Error::Nack => write!(f, "NACK"),
384 Error::NackFromCommand(command) => write!(f, "Nack from command {}", command),
385 Error::Busy => write!(f, "Busy"),
386 Error::UnexpectedResponse => write!(f, "Unexpected response"),
387 Error::InvalidArgument => write!(f, "Invalid argument"),
388 Error::VerifyFailedAtAddress(address) => {
389 write!(f, "Verify failed at address {:x}", address)
390 }
391 Error::EraseFailed => write!(f, "Erase failed"),
392 }
393 }
394}
395
396impl Config {
397 pub const fn i2c_address(i2c_address: u8) -> Self {
398 Self {
399 i2c_address,
400 mass_erase_max_ms: 200,
403 }
404 }
405}
406
407#[cfg(feature = "std")]
408impl std::error::Error for Error {}
409
410fn log_error<E: core::fmt::Debug>(_error: &E) {
411 #[cfg(feature = "defmt")]
412 {
413 defmt::error!("I2C error: {:?}", defmt::Debug2Format(_error));
414 }
415 #[cfg(feature = "log")]
416 {
417 log::error!("I2C error: {:?}", _error);
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use embedded_hal_mock::i2c;
425
426 const I2C_ADDRESS: u8 = 0x51;
427 const CONFIG: Config = Config::i2c_address(I2C_ADDRESS);
428
429 fn mock_write_with_checksum(bytes: &[u8]) -> i2c::Transaction {
430 let mut with_checksum = Vec::with_capacity(bytes.len() + 1);
431 with_checksum.extend_from_slice(bytes);
432 with_checksum.push(checksum(bytes));
433 i2c::Transaction::write(I2C_ADDRESS, with_checksum)
434 }
435
436 fn mock_read(bytes: &[u8]) -> i2c::Transaction {
437 i2c::Transaction::read(I2C_ADDRESS, bytes.to_owned())
438 }
439
440 struct Delay;
441
442 impl embedded_hal::blocking::delay::DelayMs<u32> for Delay {
443 fn delay_ms(&mut self, _ms: u32) {}
444 }
445
446 #[test]
447 fn test_get_chip_id() {
448 let expectations = [
449 mock_write_with_checksum(&[Command::GetId as u8]),
450 mock_read(&[BOOTLOADER_ACK]),
451 mock_read(&[0, 0xab, 0xcd]),
452 mock_read(&[BOOTLOADER_ACK]),
453 ];
454 let mut i2c = i2c::Mock::new(&expectations);
455 let mut delay = Delay;
456 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
457 assert_eq!(stm32.get_chip_id().unwrap(), 0xabcd);
458 i2c.done();
459 }
460
461 #[test]
462 fn test_get_bootloader_version() {
463 let expectations = [
464 mock_write_with_checksum(&[Command::GetVersion as u8]),
465 mock_read(&[BOOTLOADER_ACK]),
466 mock_read(&[0xef]),
467 mock_read(&[BOOTLOADER_ACK]),
468 ];
469 let mut i2c = i2c::Mock::new(&expectations);
470 let mut delay = Delay;
471 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
472 assert_eq!(stm32.get_bootloader_version().unwrap(), 0xef);
473 i2c.done();
474 }
475
476 #[test]
477 fn test_read_memory() {
478 let expectations = [
479 mock_write_with_checksum(&[Command::ReadMemory as u8]),
480 mock_read(&[BOOTLOADER_ACK]),
481 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
482 mock_read(&[BOOTLOADER_ACK]),
483 mock_write_with_checksum(&[0x02]),
484 mock_read(&[BOOTLOADER_ACK]),
485 mock_read(&[0xab, 0xcd, 0xef]),
486 ];
487 let mut i2c = i2c::Mock::new(&expectations);
488 let mut delay = Delay;
489 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
490 let mut out = [0u8; 3];
491 stm32.read_memory(0x12345678, &mut out).unwrap();
492 assert_eq!(&out, &[0xab, 0xcd, 0xef]);
493 i2c.done();
494 }
495
496 #[test]
497 fn test_write_memory() {
498 let expectations = [
499 mock_write_with_checksum(&[Command::WriteMemory as u8]),
500 mock_read(&[BOOTLOADER_ACK]),
501 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
502 mock_read(&[BOOTLOADER_ACK]),
503 mock_write_with_checksum(&[0x03, 0xab, 0xcd, 0xef, 0x12]),
504 mock_read(&[BOOTLOADER_ACK]),
505 ];
506 let mut i2c = i2c::Mock::new(&expectations);
507 let mut delay = Delay;
508 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
509 stm32
510 .write_memory(0x12345678, &[0xab, 0xcd, 0xef, 0x12])
511 .unwrap();
512 i2c.done();
513 }
514
515 #[test]
516 fn test_go() {
517 let expectations = [
518 mock_write_with_checksum(&[Command::Go as u8]),
519 mock_read(&[BOOTLOADER_ACK]),
520 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
521 mock_read(&[BOOTLOADER_ACK]),
522 ];
523 let mut i2c = i2c::Mock::new(&expectations);
524 let mut delay = Delay;
525 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
526 stm32.go(0x12345678).unwrap();
527 i2c.done();
528 }
529
530 #[test]
531 fn test_write_bulk() {
532 let to_write: Vec<u8> = (0..200).collect();
533 let mut write1 = vec![127u8];
534 write1.extend_from_slice(&to_write[..128]);
535 let mut write2 = vec![(to_write.len() - 128 - 1) as u8];
536 write2.extend_from_slice(&to_write[128..]);
537 let expectations = [
538 mock_write_with_checksum(&[Command::WriteMemory as u8]),
539 mock_read(&[BOOTLOADER_ACK]),
540 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
541 mock_read(&[BOOTLOADER_ACK]),
542 mock_write_with_checksum(&write1),
543 mock_read(&[BOOTLOADER_ACK]),
544 mock_write_with_checksum(&[Command::WriteMemory as u8]),
545 mock_read(&[BOOTLOADER_ACK]),
546 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0xf8]),
547 mock_read(&[BOOTLOADER_ACK]),
548 mock_write_with_checksum(&write2),
549 mock_read(&[BOOTLOADER_ACK]),
550 ];
551 let mut i2c = i2c::Mock::new(&expectations);
552 let mut delay = Delay;
553 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
554 let mut callback_count = 0;
555 stm32
556 .write_bulk(0x12345678, &to_write, |_| {
557 callback_count += 1;
558 })
559 .unwrap();
560 assert_eq!(callback_count, 2);
561 i2c.done();
562 }
563
564 #[test]
565 fn test_verify() {
566 let to_verify: Vec<u8> = (0..200).collect();
567 let expectations = [
568 mock_write_with_checksum(&[Command::ReadMemory as u8]),
569 mock_read(&[BOOTLOADER_ACK]),
570 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0x78]),
571 mock_read(&[BOOTLOADER_ACK]),
572 mock_write_with_checksum(&[127]),
573 mock_read(&[BOOTLOADER_ACK]),
574 mock_read(&to_verify[..128]),
575 mock_write_with_checksum(&[Command::ReadMemory as u8]),
576 mock_read(&[BOOTLOADER_ACK]),
577 mock_write_with_checksum(&[0x12, 0x34, 0x56, 0xf8]),
578 mock_read(&[BOOTLOADER_ACK]),
579 mock_write_with_checksum(&[(to_verify.len() - 128 - 1) as u8]),
580 mock_read(&[BOOTLOADER_ACK]),
581 mock_read(&to_verify[128..]),
582 ];
583 let mut i2c = i2c::Mock::new(&expectations);
584 let mut delay = Delay;
585 let mut stm32 = Stm32::borrowed_with_delay(&mut i2c, &mut delay, CONFIG);
586 let mut callback_count = 0;
587 stm32
588 .verify(0x12345678, &to_verify, |_| {
589 callback_count += 1;
590 })
591 .unwrap();
592 assert_eq!(callback_count, 2);
593 i2c.done();
594 }
595}