Skip to main content

solana_optimizer/
optimizer.rs

1use solana_rbpf::ebpf;
2use solana_rbpf::elf::Executable;
3use solana_rbpf::program::{BuiltinProgram, SBPFVersion, FunctionRegistry};
4use solana_rbpf::vm::Config;
5use std::sync::Arc;
6use elf::ElfBytes;
7use elf::endian::{AnyEndian, EndianParse};
8use elf::file::Class;
9use std::fs;
10
11#[derive(Debug, serde::Serialize)]
12pub struct Issue {
13    kind: String,
14    offset: usize,
15    desc: String,
16}
17
18pub struct Optimizer {
19    insns: Vec<solana_rbpf::ebpf::Insn>,
20    issues: Vec<Issue>,
21    elf_bytes: Vec<u8>,
22    text_section_idx: usize,
23}
24
25impl Optimizer {
26    pub fn new(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
27        let elf_bytes = fs::read(path)?;
28        let elf = ElfBytes::<AnyEndian>::minimal_parse(&elf_bytes)?;
29        let (shdrs_opt, strtab_opt) = elf.section_headers_with_strtab()?;
30        let shdrs = shdrs_opt.ok_or("No section headers")?;
31        let strtab = strtab_opt.ok_or("No string table")?;
32        let text_section_idx = shdrs
33            .iter()
34            .position(|sh| strtab.get(sh.sh_name as usize).ok() == Some(".text"))
35            .ok_or("No .text section")?;
36        let text_section = shdrs.get(text_section_idx).map_err(|_| "Invalid text section index")?;
37        let text_bytes = elf.section_data(&text_section)?.0;
38        let insns = Self::disassemble_text_bytes(text_bytes)?;
39        Ok(Self { insns, issues: Vec::new(), elf_bytes, text_section_idx })
40    }
41
42    fn disassemble_text_bytes(bytes: &[u8]) -> Result<Vec<solana_rbpf::ebpf::Insn>, Box<dyn std::error::Error>> {
43        let mut insns = Vec::new();
44        let mut offset = 0;
45        while offset + 8 <= bytes.len() {
46            let chunk = &bytes[offset..offset + 8];
47            insns.push(ebpf::Insn {
48                ptr: 0,
49                opc: chunk[0],
50                dst: chunk[1] & 0x0F,
51                src: (chunk[1] >> 4) & 0x0F,
52                off: i16::from_le_bytes([chunk[2], chunk[3]]),
53                imm: i64::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7], 0, 0, 0, 0]),
54            });
55            offset += 8;
56        }
57        Ok(insns)
58    }
59
60    pub fn remove_logs(&mut self) {
61        let original_len = self.insns.len();
62        self.insns.retain(|insn| {
63            if insn.opc == 0x91 { // 假设 sol_log opcode
64                self.issues.push(Issue {
65                    kind: "LogRemoved".to_string(),
66                    offset: insn.off as usize,
67                    desc: "Removed redundant sol_log call".to_string(),
68                });
69                false
70            } else {
71                true
72            }
73        });
74        println!("Removed {} log instructions", original_len - self.insns.len());
75    }
76
77    #[allow(unused_mut)]
78    pub fn merge_loads(&mut self) {
79        let mut i = 0;
80        while i < self.insns.len() - 1 {
81            if self.insns[i].opc == ebpf::LD_DW_IMM {
82                let mut j = i + 1;
83                while j < self.insns.len() {
84                    if self.insns[j].opc == ebpf::LD_DW_IMM &&
85                       self.insns[i].dst == self.insns[j].dst &&
86                       self.insns[i].imm == self.insns[j].imm {
87                        self.insns.remove(j);
88                        self.issues.push(Issue {
89                            kind: "LoadMerged".to_string(),
90                            offset: i,
91                            desc: "Merged redundant load".to_string(),
92                        });
93                    } else {
94                        break;
95                    }
96                }
97            }
98            i += 1;
99        }
100        println!("Merged redundant load instructions");
101    }
102
103    pub fn merge_arithmetic(&mut self) {
104        let mut i = 0;
105        while i < self.insns.len() - 1 {
106            let curr = &self.insns[i];
107            if curr.opc == ebpf::ADD64_IMM || curr.opc == ebpf::SUB64_IMM {
108                let next = &self.insns[i + 1];
109                if next.opc == ebpf::ADD64_IMM && curr.dst == next.dst && curr.src == 0 && next.src == 0 {
110                    self.insns[i].imm += next.imm;
111                    self.insns.remove(i + 1);
112                    self.issues.push(Issue {
113                        kind: "AddMerged".to_string(),
114                        offset: i,
115                        desc: "Merged consecutive additions".to_string(),
116                    });
117                    continue;
118                } else if next.opc == ebpf::SUB64_IMM && curr.dst == next.dst && curr.src == 0 && next.src == 0 {
119                    if curr.opc == ebpf::ADD64_IMM {
120                        self.insns[i].imm -= next.imm;
121                    } else {
122                        self.insns[i].imm += next.imm;
123                        self.insns[i].opc = ebpf::ADD64_IMM;
124                    }
125                    self.insns.remove(i + 1);
126                    self.issues.push(Issue {
127                        kind: "ArithmeticMerged".to_string(),
128                        offset: i,
129                        desc: "Merged addition and subtraction".to_string(),
130                    });
131                    continue;
132                } else if (curr.opc == ebpf::ADD64_IMM && next.opc == ebpf::SUB64_IMM) ||
133                          (curr.opc == ebpf::SUB64_IMM && next.opc == ebpf::ADD64_IMM) {
134                    if curr.dst == next.dst && curr.imm == next.imm && curr.src == 0 && next.src == 0 {
135                        self.insns.remove(i + 1);
136                        self.insns.remove(i);
137                        self.issues.push(Issue {
138                            kind: "ArithmeticEliminated".to_string(),
139                            offset: i,
140                            desc: "Eliminated canceling addition and subtraction".to_string(),
141                        });
142                        continue;
143                    }
144                }
145            }
146            i += 1;
147        }
148        println!("Merged arithmetic instructions");
149    }
150
151    pub fn fold_constants(&mut self) {
152        let mut i = 0;
153        while i < self.insns.len() - 1 {
154            let curr = &self.insns[i];
155            if curr.opc == ebpf::LD_DW_IMM {
156                let next = &self.insns[i + 1];
157                if next.opc == ebpf::ADD64_IMM && next.dst == curr.dst && next.src == 0 {
158                    self.insns[i].imm += next.imm;
159                    self.insns.remove(i + 1);
160                    self.issues.push(Issue {
161                        kind: "ConstantFolded".to_string(),
162                        offset: i,
163                        desc: "Folded load and addition constants".to_string(),
164                    });
165                    continue;
166                } else if next.opc == ebpf::SUB64_IMM && next.dst == curr.dst && next.src == 0 {
167                    self.insns[i].imm -= next.imm;
168                    self.insns.remove(i + 1);
169                    self.issues.push(Issue {
170                        kind: "ConstantFolded".to_string(),
171                        offset: i,
172                        desc: "Folded load and subtraction constants".to_string(),
173                    });
174                    continue;
175                }
176            }
177            i += 1;
178        }
179        println!("Folded constant computations");
180    }
181
182    pub fn eliminate_dead_code(&mut self) {
183        let mut i = 0;
184        while i < self.insns.len() - 1 {
185            let curr = &self.insns[i];
186            let next = &self.insns[i + 1];
187            if curr.dst == next.dst && next.opc != ebpf::EXIT && !Self::reads_src(curr, next) {
188                self.insns.remove(i);
189                self.issues.push(Issue {
190                    kind: "DeadCodeEliminated".to_string(),
191                    offset: i,
192                    desc: "Removed overwritten dead code".to_string(),
193                });
194                continue;
195            }
196            i += 1;
197        }
198        println!("Eliminated dead code");
199    }
200
201    fn reads_src(curr: &solana_rbpf::ebpf::Insn, next: &solana_rbpf::ebpf::Insn) -> bool {
202        next.src == curr.dst || (next.opc == ebpf::JA && curr.dst == 0)
203    }
204
205    pub fn reduce_strength(&mut self) {
206        for (i, insn) in self.insns.iter_mut().enumerate() {
207            if insn.opc == ebpf::MUL64_IMM && insn.src == 0 {
208                match insn.imm {
209                    2 => {
210                        insn.opc = ebpf::LSH64_IMM;
211                        insn.imm = 1;
212                        self.issues.push(Issue {
213                            kind: "StrengthReduced".to_string(),
214                            offset: i,
215                            desc: "Replaced multiplication by 2 with left shift by 1".to_string(),
216                        });
217                    }
218                    4 => {
219                        insn.opc = ebpf::LSH64_IMM;
220                        insn.imm = 2;
221                        self.issues.push(Issue {
222                            kind: "StrengthReduced".to_string(),
223                            offset: i,
224                            desc: "Replaced multiplication by 4 with left shift by 2".to_string(),
225                        });
226                    }
227                    _ => {}
228                }
229            }
230        }
231        println!("Reduced instruction strength");
232    }
233
234    pub fn optimize_branches(&mut self) {
235        let mut i = 0;
236        while i < self.insns.len() - 1 {
237            let curr = &self.insns[i];
238            if curr.opc == ebpf::JEQ_IMM && curr.off >= 0 {
239                let next = &self.insns[i + 1];
240                if next.opc == ebpf::JA && (i as i16 + curr.off + 1) == (i as i16 + next.off + 1) {
241                    self.insns.remove(i + 1);
242                    self.insns.remove(i);
243                    self.issues.push(Issue {
244                        kind: "BranchEliminated".to_string(),
245                        offset: i,
246                        desc: "Eliminated redundant branch".to_string(),
247                    });
248                    continue;
249                }
250            }
251            i += 1;
252        }
253        println!("Optimized branch instructions");
254    }
255
256    pub fn check_size(&mut self) {
257        let size = self.insns.len() * 8;
258        if size > 128 * 1024 {
259            self.issues.push(Issue {
260                kind: "SizeExceeded".to_string(),
261                offset: 0,
262                desc: format!("Program size {} bytes exceeds 128KB", size),
263            });
264        }
265    }
266
267    pub fn generate(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
268        use solana_rbpf::vm::TestContextObject;
269    
270        let loader = Arc::new(BuiltinProgram::<TestContextObject>::new_loader(
271            Config::default(),
272            FunctionRegistry::default(),
273        ));
274    
275        let executable = Executable::from_text_bytes(
276            &self.insns.iter().flat_map(|insn| {
277                let imm_bytes = insn.imm.to_le_bytes();
278                [
279                    insn.opc,
280                    (insn.dst & 0x0F) | ((insn.src & 0x0F) << 4),
281                    insn.off.to_le_bytes()[0],
282                    insn.off.to_le_bytes()[1],
283                    imm_bytes[0],
284                    imm_bytes[1],
285                    imm_bytes[2],
286                    imm_bytes[3],
287                ]
288            }).collect::<Vec<u8>>(),
289            loader,
290            SBPFVersion::V2,
291            FunctionRegistry::default(),
292        )?;
293        let optimized_text = executable.get_text_bytes().1.to_vec();
294        let elf = ElfBytes::<AnyEndian>::minimal_parse(&self.elf_bytes)?;
295        let ehdr = elf.ehdr;
296        let (shdrs_opt, _) = elf.section_headers_with_strtab()?;
297        let shdrs = shdrs_opt.ok_or("No section headers")?;
298        let mut new_shdrs: Vec<_> = shdrs.iter().collect();
299        let mut elf_bytes = Vec::new();
300    
301        // 预留头部空间
302        let header_size = ehdr.e_ehsize as usize;
303        elf_bytes.resize(header_size, 0);
304    
305        // 计算段数据起始偏移
306        let mut offset = header_size as u64;
307        let mut section_data = Vec::new();
308    
309        // 写入所有段数据并更新偏移
310        for (i, sh) in new_shdrs.iter_mut().enumerate() {
311            let data = if i == self.text_section_idx {
312                optimized_text.clone()
313            } else {
314                elf.section_data(sh)?.0.to_vec()
315            };
316            section_data.push((offset, data.clone()));
317            sh.sh_offset = offset;
318            sh.sh_size = data.len() as u64;
319            offset += data.len() as u64;
320            offset = (offset + 7) & !7; // 8 字节对齐
321        }
322    
323        // 段表起始位置
324        let shoff = offset;
325        println!("e_shoff: 0x{:x}, offset after sections: 0x{:x}", shoff, offset);
326    
327        // 写入段数据
328        for (_, data) in section_data.iter() {
329            elf_bytes.extend_from_slice(data);
330            let padding = (8 - (data.len() % 8)) % 8;
331            elf_bytes.extend_from_slice(&vec![0; padding]);
332        }
333    
334        // 写入段表
335        let section_table_start = elf_bytes.len();
336        for sh in new_shdrs.iter() {
337            let sh_bytes = if ehdr.class == Class::ELF64 {
338                let mut bytes = Vec::new();
339                bytes.extend_from_slice(&sh.sh_name.to_le_bytes()); // 4 bytes
340                bytes.extend_from_slice(&sh.sh_type.to_le_bytes()); // 4 bytes
341                bytes.extend_from_slice(&sh.sh_flags.to_le_bytes()); // 8 bytes
342                bytes.extend_from_slice(&sh.sh_addr.to_le_bytes()); // 8 bytes
343                bytes.extend_from_slice(&sh.sh_offset.to_le_bytes()); // 8 bytes
344                bytes.extend_from_slice(&sh.sh_size.to_le_bytes()); // 8 bytes
345                bytes.extend_from_slice(&sh.sh_link.to_le_bytes()); // 4 bytes
346                bytes.extend_from_slice(&sh.sh_info.to_le_bytes()); // 4 bytes
347                bytes.extend_from_slice(&sh.sh_addralign.to_le_bytes()); // 8 bytes
348                bytes.extend_from_slice(&sh.sh_entsize.to_le_bytes()); // 8 bytes
349                bytes
350            } else {
351                let mut bytes = Vec::new();
352                bytes.extend_from_slice(&sh.sh_name.to_le_bytes()); // 4 bytes
353                bytes.extend_from_slice(&sh.sh_type.to_le_bytes()); // 4 bytes
354                bytes.extend_from_slice(&(sh.sh_flags as u32).to_le_bytes()); // 4 bytes
355                bytes.extend_from_slice(&(sh.sh_addr as u32).to_le_bytes()); // 4 bytes
356                bytes.extend_from_slice(&(sh.sh_offset as u32).to_le_bytes()); // 4 bytes
357                bytes.extend_from_slice(&(sh.sh_size as u32).to_le_bytes()); // 4 bytes
358                bytes.extend_from_slice(&sh.sh_link.to_le_bytes()); // 4 bytes
359                bytes.extend_from_slice(&sh.sh_info.to_le_bytes()); // 4 bytes
360                bytes.extend_from_slice(&(sh.sh_addralign as u32).to_le_bytes()); // 4 bytes
361                bytes.extend_from_slice(&(sh.sh_entsize as u32).to_le_bytes()); // 4 bytes
362                bytes
363            };
364            elf_bytes.extend_from_slice(&sh_bytes);
365        }
366    
367        // 构造并写入 ELF 头部
368        let class_val = match ehdr.class {
369            Class::ELF32 => elf::abi::ELFCLASS32,
370            Class::ELF64 => elf::abi::ELFCLASS64,
371        };
372        let endian_val = if ehdr.endianness.is_little() {
373            elf::abi::ELFDATA2LSB
374        } else {
375            elf::abi::ELFDATA2MSB
376        };
377        let mut ehdr_bytes = Vec::new();
378        ehdr_bytes.extend_from_slice(&[0x7f, b'E', b'L', b'F']); // ei_magic,4 bytes
379        ehdr_bytes.extend_from_slice(&[class_val, endian_val, ehdr.version.try_into()?, ehdr.osabi]); // 4 bytes
380        ehdr_bytes.extend_from_slice(&[ehdr.abiversion, 0, 0, 0, 0, 0, 0, 0]); // EI_PAD,8 bytes
381        ehdr_bytes.extend_from_slice(&ehdr.e_type.to_le_bytes()); // 2 bytes
382        ehdr_bytes.extend_from_slice(&ehdr.e_machine.to_le_bytes()); // 2 bytes
383        ehdr_bytes.extend_from_slice(&ehdr.version.to_le_bytes()); // 4 bytes
384        if ehdr.class == Class::ELF64 {
385            ehdr_bytes.extend_from_slice(&ehdr.e_entry.to_le_bytes()); // 8 bytes
386            ehdr_bytes.extend_from_slice(&ehdr.e_phoff.to_le_bytes()); // 8 bytes (保留原始值)
387            ehdr_bytes.extend_from_slice(&shoff.to_le_bytes()); // 8 bytes
388        } else {
389            ehdr_bytes.extend_from_slice(&(ehdr.e_entry as u32).to_le_bytes()); // 4 bytes
390            ehdr_bytes.extend_from_slice(&(ehdr.e_phoff as u32).to_le_bytes()); // 4 bytes
391            ehdr_bytes.extend_from_slice(&(shoff as u32).to_le_bytes()); // 4 bytes
392        }
393        ehdr_bytes.extend_from_slice(&ehdr.e_flags.to_le_bytes()); // 4 bytes
394        ehdr_bytes.extend_from_slice(&ehdr.e_ehsize.to_le_bytes()); // 2 bytes
395        ehdr_bytes.extend_from_slice(&ehdr.e_phentsize.to_le_bytes()); // 2 bytes
396        ehdr_bytes.extend_from_slice(&ehdr.e_phnum.to_le_bytes()); // 2 bytes
397        let shentsize: u16 = match ehdr.class {
398            Class::ELF32 => 40,
399            Class::ELF64 => 64,
400        };
401        ehdr_bytes.extend_from_slice(&shentsize.to_le_bytes()); // 2 bytes
402        ehdr_bytes.extend_from_slice(&ehdr.e_shnum.to_le_bytes()); // 2 bytes
403        ehdr_bytes.extend_from_slice(&ehdr.e_shstrndx.to_le_bytes()); // 2 bytes
404    
405        if ehdr_bytes.len() < header_size {
406            ehdr_bytes.resize(header_size, 0);
407        }
408        elf_bytes[..header_size].copy_from_slice(&ehdr_bytes);
409    
410        println!("Section table start: 0x{:x}, Final file length: 0x{:x} ({} bytes)", section_table_start, elf_bytes.len(), elf_bytes.len());
411        Ok(elf_bytes)
412    }
413
414    pub fn report(&self) -> String {
415        serde_json::to_string_pretty(&self.issues).unwrap_or("[]".to_string())
416    }
417}
418
419pub fn optimize_sbf(input_path: &str, output_path: &str) -> Result<String, Box<dyn std::error::Error>> {
420    let mut optimizer = Optimizer::new(input_path)?;
421    optimizer.remove_logs();
422    optimizer.merge_loads();
423    optimizer.merge_arithmetic();
424    optimizer.fold_constants();
425    //optimizer.eliminate_dead_code();
426    optimizer.reduce_strength();
427    optimizer.optimize_branches();
428    optimizer.check_size();
429    let optimized_bytes = optimizer.generate()?;
430    fs::write(output_path, optimized_bytes)?;
431    Ok(optimizer.report())
432}
433
434fn main() -> Result<(), Box<dyn std::error::Error>> {
435    let args: Vec<String> = std::env::args().collect();
436    if args.len() != 5 || args[1] != "--input" || args[3] != "--output" {
437        eprintln!("Usage: {} --input <input_file> --output <output_file>", args[0]);
438        std::process::exit(1);
439    }
440    let input_path = &args[2];
441    let output_path = &args[4];
442    let report = optimize_sbf(input_path, output_path)?;
443    println!("Optimization report:\n{}", report);
444    Ok(())
445}