sftool_lib/
write_flash.rs

1use crate::SifliTool;
2use crate::ram_command::{Command, RamCommand, Response};
3use crc::Algorithm;
4use indicatif::{ProgressBar, ProgressStyle};
5use lazy_static::lazy_static;
6use memmap2::Mmap;
7use phf::phf_map;
8use std::cmp::PartialEq;
9use std::collections::HashMap;
10use std::fmt::format;
11use std::fs::File;
12use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write};
13use std::path::Path;
14use tempfile::tempfile;
15
16const ELF_MAGIC: &[u8] = &[0x7F, 0x45, 0x4C, 0x46]; // ELF file magic number
17
18pub trait WriteFlashTrait {
19    fn write_flash(&mut self) -> Result<(), std::io::Error>;
20}
21
22#[derive(Debug, PartialEq, Eq, Clone)]
23enum FileType {
24    Bin,
25    Hex,
26    Elf,
27}
28
29struct WriteFlashFile {
30    address: u32,
31    file: File,
32    crc32: u32,
33}
34
35fn str_to_u32(s: &str) -> Result<u32, std::num::ParseIntError> {
36    if let Some(hex_digits) = s.strip_prefix("0x") {
37        u32::from_str_radix(hex_digits, 16)
38    } else if let Some(bin_digits) = s.strip_prefix("0b") {
39        u32::from_str_radix(bin_digits, 2)
40    } else if let Some(oct_digits) = s.strip_prefix("0o") {
41        u32::from_str_radix(oct_digits, 8)
42    } else {
43        s.parse::<u32>()
44    }
45}
46
47fn detect_file_type(path: &Path) -> Result<FileType, std::io::Error> {
48    if let Some(ext) = path.extension().and_then(|s| s.to_str()) {
49        match ext.to_lowercase().as_str() {
50            "bin" => return Ok(FileType::Bin),
51            "hex" => return Ok(FileType::Hex),
52            "elf" | "axf" => return Ok(FileType::Elf),
53            _ => {} // 如果扩展名无法识别,继续检查MAGIC
54        }
55    }
56    
57    // 如果没有可识别的扩展名,则检查文件MAGIC
58    let mut file = File::open(path)?;
59    let mut magic = [0u8; 4];
60    file.read_exact(&mut magic)?;
61    
62    if magic == ELF_MAGIC {
63        return Ok(FileType::Elf);
64    }
65    
66    Err(std::io::Error::new(
67        std::io::ErrorKind::InvalidInput,
68        "Unrecognized file type",
69    ))
70}
71
72fn hex_to_bin(hex_file: &Path) -> Result<Vec<WriteFlashFile>, std::io::Error> {
73    let mut write_flash_files: Vec<WriteFlashFile> = Vec::new();
74
75    let file = std::fs::File::open(hex_file)?;
76    let mut reader = std::io::BufReader::new(file);
77
78    let mut address = 0;
79    let mut temp_file = tempfile()?;
80
81    for line in reader.lines() {
82        let line = line?;
83        let line = line.trim_end_matches('\r');
84        let bytes_read = line.len();
85        if bytes_read == 0 {
86            break;
87        }
88        let ihex_record = ihex::Record::from_record_string(&line)
89            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
90
91        match ihex_record {
92            ihex::Record::ExtendedLinearAddress(addr) => {
93                address = (addr as u32) << 16;
94            }
95            ihex::Record::Data { offset, value } => {
96                // 获取当前文件长度
97                let metadata = temp_file.metadata()?;
98                let current_len = metadata.len();
99                let offset_u64 = offset as u64;
100
101                // 如果当前文件长度小于 offset,则说明文件中存在空隙,需要填充 0xFF
102                if current_len < offset_u64 {
103                    // 先定位到文件末尾(也就是 current_len 位置)
104                    temp_file.seek(SeekFrom::End(0))?;
105
106                    // 计算需要填充的字节数
107                    let gap_size = offset_u64 - current_len;
108
109                    // 构造一个填充缓冲区,该缓冲区内容全为 0xFF
110                    let fill_data = vec![0xFF; gap_size as usize];
111                    temp_file.write_all(&fill_data)?;
112                }
113
114                // 定位到指定的 offset 开始写入数据
115                temp_file.seek(SeekFrom::Start(offset_u64))?;
116                temp_file.write_all(&value)?;
117            }
118            ihex::Record::EndOfFile => {
119                temp_file.seek(SeekFrom::Start(0))?;
120                let crc32 = get_file_crc32(&temp_file.try_clone()?)?;
121                write_flash_files.push(WriteFlashFile {
122                    address,
123                    file: temp_file.try_clone()?,
124                    crc32,
125                });
126            }
127            _ => {}
128        }
129    }
130
131    Ok(write_flash_files)
132}
133
134fn elf_to_bin(elf_file: &Path) -> Result<Vec<WriteFlashFile>, std::io::Error> {
135    let mut write_flash_files: Vec<WriteFlashFile> = Vec::new();
136    const SECTOR_SIZE: u32 = 0x1000; // 扇区大小
137    const FILL_BYTE: u8 = 0xFF; // 填充字节
138
139    let file = File::open(elf_file)?;
140    let mmap = unsafe { Mmap::map(&file)? };
141    let elf = goblin::elf::Elf::parse(&mmap[..])
142        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
143
144    // 收集所有需要烧录的段
145    let mut load_segments: Vec<_> = elf.program_headers.iter()
146        .filter(|ph| ph.p_type == goblin::elf::program_header::PT_LOAD && ph.p_paddr < 0x2000_0000)
147        .collect();
148    load_segments.sort_by_key(|ph| ph.p_paddr);
149
150    if load_segments.is_empty() {
151        return Ok(write_flash_files);
152    }
153
154    let mut current_file = tempfile()?;
155    let mut current_base = (load_segments[0].p_paddr as u32) & !(SECTOR_SIZE - 1);
156    let mut current_offset = 0; // 跟踪当前文件中的偏移量
157
158    for ph in load_segments.iter() {
159        let vaddr = ph.p_paddr as u32;
160        let offset = ph.p_offset as usize;
161        let size = ph.p_filesz as usize;
162        let data = &mmap[offset..offset + size];
163        
164        // 计算当前段的对齐基地址
165        let segment_base = vaddr & !(SECTOR_SIZE - 1);
166
167        // 如果超出了当前对齐块,创建新文件
168        if segment_base > current_base + current_offset {
169            current_file.seek(std::io::SeekFrom::Start(0))?;
170            let crc32 = get_file_crc32(&current_file)?;
171            write_flash_files.push(WriteFlashFile {
172                address: current_base,
173                file: std::mem::replace(&mut current_file, tempfile()?),
174                crc32,
175            });
176            current_base = segment_base;
177            current_offset = 0;
178        }
179
180        // 计算相对于当前文件基地址的偏移
181        let relative_offset = vaddr - current_base;
182        
183        // 如果当前偏移小于目标偏移,填充间隙
184        if current_offset < relative_offset {
185            let padding = relative_offset - current_offset;
186            current_file.write_all(&vec![FILL_BYTE; padding as usize])?;
187            current_offset = relative_offset;
188        }
189
190        // 写入数据
191        current_file.write_all(data)?;
192        current_offset += size as u32;
193    }
194
195    // 处理最后一个bin文件
196    if current_offset > 0 {      
197        current_file.seek(std::io::SeekFrom::Start(0))?;
198        let crc32 = get_file_crc32(&current_file)?;
199        write_flash_files.push(WriteFlashFile {
200            address: current_base,
201            file: current_file,
202            crc32,
203        });
204    }
205
206    Ok(write_flash_files)
207}
208
209fn get_file_crc32(file: &File) -> Result<u32, std::io::Error> {
210    const CRC_32_ALGO: Algorithm<u32> = Algorithm {
211        width: 32,
212        poly: 0x04C11DB7,
213        init: 0,
214        refin: true,
215        refout: true,
216        xorout: 0,
217        check: 0x2DFD2D88,
218        residue: 0,
219    };
220
221    const CRC: crc::Crc<u32> = crc::Crc::<u32>::new(&CRC_32_ALGO);
222    let mut reader = BufReader::new(file);
223
224    let mut digest = CRC.digest();
225
226    let mut buffer = [0u8; 4 * 1024];
227    loop {
228        let n = reader.read(&mut buffer)?;
229        if n == 0 {
230            break;
231        }
232        digest.update(&buffer[..n]);
233    }
234
235    let checksum = digest.finalize();
236    reader.seek(SeekFrom::Start(0))?;
237    Ok(checksum)
238}
239
240lazy_static! {
241    static ref CHIP_MEMORY_LAYOUT: HashMap<&'static str, Vec<u32>> = {
242        let mut m = HashMap::new();
243        m.insert("sf32lb52", vec![0x10000000, 0x12000000]);
244        m
245    };
246}
247
248impl SifliTool {
249    fn erase_all(
250        &mut self,
251        write_flash_files: &[WriteFlashFile],
252        step: &mut i32,
253    ) -> Result<(), std::io::Error> {
254        let spinner = ProgressBar::new_spinner();
255        if !self.base.quiet {
256            spinner.enable_steady_tick(std::time::Duration::from_millis(100));
257            spinner.set_style(ProgressStyle::with_template("[{prefix}] {spinner} {msg}").unwrap());
258            spinner.set_prefix(format!("0x{:02X}", step));
259            spinner.set_message("Erasing all flash regions...");
260            *step = step.wrapping_add(1);
261        }
262        let mut erase_address: Vec<u32> = Vec::new();
263        for f in write_flash_files.iter() {
264            let address = f.address & 0xFF00_0000;
265            // 如果ERASE_ADDRESS中的地址已经被擦除过,则跳过
266            if erase_address.contains(&address) {
267                continue;
268            }
269            self.command(Command::EraseAll { address: f.address })?;
270            erase_address.push(address);
271        }
272        if !self.base.quiet {
273            spinner.finish_with_message("All flash regions erased");
274        }
275        Ok(())
276    }
277
278    fn verify(&mut self, address: u32, len: u32, crc: u32, step: &mut i32) -> Result<(), std::io::Error> {
279        let spinner = ProgressBar::new_spinner();
280        if !self.base.quiet {
281            spinner.enable_steady_tick(std::time::Duration::from_millis(100));
282            spinner.set_style(ProgressStyle::with_template("[{prefix}] {spinner} {msg}").unwrap());
283            spinner.set_prefix(format!("0x{:02X}", step));
284            spinner.set_message("Verifying data...");
285        }
286        let response = self.command(Command::Verify { address, len, crc })?;
287        if response != Response::Ok {
288            return Err(std::io::Error::new(
289                std::io::ErrorKind::InvalidData,
290                "Verify failed",
291            ));
292        }
293        if !self.base.quiet {
294            spinner.finish_with_message("Verify success!");
295        }
296        *step = step.wrapping_add(1);
297        Ok(())
298    }
299}
300
301impl WriteFlashTrait for SifliTool {
302    fn write_flash(&mut self) -> Result<(), std::io::Error> {
303        let mut step = self.step;
304        let params = self
305            .write_flash_params
306            .as_ref()
307            .cloned()
308            .ok_or(std::io::Error::new(
309                std::io::ErrorKind::InvalidInput,
310                "No write flash params",
311            ))?;
312        let mut write_flash_files: Vec<WriteFlashFile> = Vec::new();
313
314        let packet_size = if self.base.compat { 256 } else { 128 * 1024 };
315
316        for file in params.file_path.iter() {
317            // file@address
318            let parts: Vec<_> = file.split('@').collect();
319            // 如果存在@符号,则证明是bin文件
320            if parts.len() == 2 {
321                let addr = str_to_u32(parts[1])
322                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
323                let file = File::open(parts[0])?;
324                let crc32 = get_file_crc32(&file.try_clone()?)?;
325                write_flash_files.push(WriteFlashFile {
326                    address: addr,
327                    file,
328                    crc32,
329                });
330                continue;
331            }
332
333            let file_type = detect_file_type(Path::new(parts[0]))?;
334
335            match file_type {
336                FileType::Hex => {
337                    write_flash_files.append(&mut hex_to_bin(Path::new(parts[0]))?);
338                }
339                FileType::Elf => {
340                    write_flash_files.append(&mut elf_to_bin(Path::new(parts[0]))?);
341                }
342                FileType::Bin => {
343                    return Err(std::io::Error::new(
344                        std::io::ErrorKind::InvalidInput,
345                        "For binary files, please use the <file@address> format",
346                    ));
347                }
348            }
349        }
350
351        if params.erase_all {
352            self.erase_all(&write_flash_files, &mut step)?;
353        }
354
355        for file in write_flash_files.iter() {
356            let re_download_spinner = ProgressBar::new_spinner();
357            let download_bar = ProgressBar::new(file.file.metadata()?.len());
358
359            let download_bar_template = ProgressStyle::default_bar()
360                .template("[{prefix}] Download at {msg}... {wide_bar} {bytes_per_sec} {percent_precise}%")
361                .unwrap()
362                .progress_chars("=>-");
363
364            if !params.erase_all {
365                if !self.base.quiet {
366                    re_download_spinner.enable_steady_tick(std::time::Duration::from_millis(100));
367                    re_download_spinner.set_style(
368                        ProgressStyle::with_template("[{prefix}] {spinner} {msg}").unwrap(),
369                    );
370                    re_download_spinner.set_prefix(format!("0x{:02X}", step));
371                    re_download_spinner.set_message(format!(
372                        "Checking whether a re-download is necessary at address 0x{:08X}...",
373                        file.address
374                    ));
375                    step += 1;
376                }
377                let response = self.command(Command::Verify {
378                    address: file.address,
379                    len: file.file.metadata()?.len() as u32,
380                    crc: file.crc32,
381                })?;
382                if response == Response::Ok {
383                    if !self.base.quiet {
384                        re_download_spinner.finish_with_message("No need to re-download, skip!");
385                    }
386                    continue;
387                }
388                if !self.base.quiet {
389                    re_download_spinner.finish_with_message("Need to re-download");
390
391                    download_bar.set_style(download_bar_template);
392                    download_bar.set_message(format!("0x{:08X}", file.address));
393                    download_bar.set_prefix(format!("0x{:02X}", step));
394                    step += 1;
395                }
396
397                let res = self.command(Command::WriteAndErase {
398                    address: file.address,
399                    len: file.file.metadata()?.len() as u32,
400                })?;
401                if res != Response::RxWait {
402                    return Err(std::io::Error::new(
403                        std::io::ErrorKind::InvalidData,
404                        "Write flash failed",
405                    ));
406                }
407
408                let mut buffer = vec![0u8; 128 * 1024];
409                let mut reader = BufReader::new(&file.file);
410
411                loop {
412                    let bytes_read = reader.read(&mut buffer)?;
413                    if bytes_read == 0 {
414                        break;
415                    }
416                    let res = self.send_data(&buffer[..bytes_read])?;
417                    if res == Response::RxWait {
418                        if !self.base.quiet {
419                            download_bar.inc(bytes_read as u64);
420                            // downloaded += bytes_read;
421                        }
422                        continue;
423                    } else if res != Response::Ok {
424                        return Err(std::io::Error::new(
425                            std::io::ErrorKind::InvalidData,
426                            "Write flash failed",
427                        ));
428                    }
429                }
430
431                if !self.base.quiet {
432                    download_bar.finish_with_message("Download success!");
433                }
434            } else {
435                let mut buffer = vec![0u8; packet_size];
436                let mut reader = BufReader::new(&file.file);
437
438                if !self.base.quiet {
439                    download_bar.set_style(download_bar_template);
440                    download_bar.set_message(format!("0x{:08X}", file.address));
441                    download_bar.set_prefix(format!("0x{:02X}", step));
442                    step += 1;
443                }
444
445                let mut address = file.address;
446                loop {
447                    let bytes_read = reader.read(&mut buffer)?;
448                    if bytes_read == 0 {
449                        break;
450                    }
451                    self.port.write_all(
452                        Command::Write {
453                            address: address,
454                            len: bytes_read as u32,
455                        }
456                            .to_string()
457                            .as_bytes(),
458                    )?;
459                    self.port.flush()?;
460                    let res = self.send_data(&buffer[..bytes_read])?;
461                    if res != Response::Ok {
462                        return Err(std::io::Error::new(
463                            std::io::ErrorKind::InvalidData,
464                            "Write flash failed",
465                        ));
466                    }
467                    address += bytes_read as u32;
468                    if !self.base.quiet {
469                        download_bar.inc(bytes_read as u64);
470                    }
471                }
472                if !self.base.quiet {
473                    download_bar.finish_with_message("Download success!");
474                }
475            }
476            // verify
477            if params.verify {
478                self.verify(file.address, file.file.metadata()?.len() as u32, file.crc32, &mut step)?;
479            }
480        }
481        Ok(())
482    }
483}