1use rust_asm::class_reader::{AttributeInfo, LocalVariable, read_class_file};
2use rust_asm::class_writer::{
3 COMPUTE_FRAMES, COMPUTE_MAXS, ClassWriter as AsmClassWriter, FieldVisitor, MethodVisitor,
4};
5use rust_asm::constant_pool::{ConstantPoolBuilder, CpInfo};
6pub use rust_asm::insn::Label;
7use rust_asm::insn::LabelNode;
8pub use rust_asm::insn::{BootstrapArgument, Handle};
9use std::collections::HashMap;
10
11pub struct ClassFileWriter {
12 cw: AsmClassWriter,
13 class_signature: Option<String>,
14 nest_host: Option<String>,
15 nest_members: Vec<String>,
16 record_components: Vec<RecordComponentMetadata>,
17 runtime_invisible_annotations: Vec<AnnotationMetadata>,
18 method_metadata: Vec<MethodMetadata>,
19 field_metadata: Vec<FieldMetadata>,
20}
21
22impl ClassFileWriter {
23 pub fn new() -> Self {
24 Self {
25 cw: AsmClassWriter::new(COMPUTE_FRAMES | COMPUTE_MAXS),
26 class_signature: None,
27 nest_host: None,
28 nest_members: Vec::new(),
29 record_components: Vec::new(),
30 runtime_invisible_annotations: Vec::new(),
31 method_metadata: Vec::new(),
32 field_metadata: Vec::new(),
33 }
34 }
35
36 pub fn visit(
37 &mut self,
38 major_version: u16,
39 access_flags: u16,
40 name: &str,
41 super_name: Option<&str>,
42 interfaces: &[&str],
43 ) {
44 self.cw
45 .visit(major_version, 0, access_flags, name, super_name, interfaces);
46 }
47
48 pub fn visit_method(
49 &mut self,
50 access_flags: u16,
51 name: &str,
52 descriptor: &str,
53 ) -> MethodWriter {
54 let mv = self.cw.visit_method(access_flags, name, descriptor);
55 MethodWriter {
56 inner: mv,
57 name: name.to_string(),
58 descriptor: descriptor.to_string(),
59 signature: None,
60 exceptions: Vec::new(),
61 local_variables: Vec::new(),
62 }
63 }
64
65 pub fn visit_field(&mut self, access_flags: u16, name: &str, descriptor: &str) -> FieldWriter {
66 let fv = self.cw.visit_field(access_flags, name, descriptor);
67 FieldWriter {
68 inner: fv,
69 name: name.to_string(),
70 descriptor: descriptor.to_string(),
71 signature: None,
72 }
73 }
74
75 pub fn visit_signature(&mut self, signature: &str) {
76 self.class_signature = Some(signature.to_string());
77 }
78
79 pub fn visit_source_file(&mut self, name: &str) {
80 self.cw.visit_source_file(name);
81 }
82
83 pub fn visit_nest_host(&mut self, host: &str) {
84 self.nest_host = Some(host.to_string());
85 }
86
87 pub fn visit_nest_member(&mut self, member: &str) {
88 self.nest_members.push(member.to_string());
89 }
90
91 pub fn visit_record_component(
92 &mut self,
93 name: &str,
94 descriptor: &str,
95 signature: Option<&str>,
96 ) {
97 self.record_components.push(RecordComponentMetadata {
98 name: name.to_string(),
99 descriptor: descriptor.to_string(),
100 signature: signature.map(str::to_string),
101 });
102 }
103
104 pub fn visit_runtime_invisible_annotation(&mut self, annotation: AnnotationMetadata) {
105 self.runtime_invisible_annotations.push(annotation);
106 }
107
108 pub fn to_bytes(self) -> Result<Vec<u8>, String> {
109 let mut class_node = self.cw.to_class_node().map_err(|e| e.to_string())?;
110 add_extra_attributes(
111 &mut class_node,
112 ExtraAttributes {
113 class_signature: self.class_signature.as_deref(),
114 nest_host: self.nest_host.as_deref(),
115 nest_members: &self.nest_members,
116 record_components: &self.record_components,
117 runtime_invisible_annotations: &self.runtime_invisible_annotations,
118 field_metadata: &self.field_metadata,
119 method_metadata: &self.method_metadata,
120 },
121 );
122
123 let first_pass =
124 AsmClassWriter::write_class_node(&class_node, COMPUTE_FRAMES | COMPUTE_MAXS)
125 .map_err(|e| format!("{:?}", e))?;
126 let code_lengths = method_code_lengths(&first_pass)?;
127 add_local_variables(&mut class_node, &self.method_metadata, &code_lengths);
128
129 AsmClassWriter::write_class_node(&class_node, COMPUTE_FRAMES | COMPUTE_MAXS)
130 .map_err(|e| format!("{:?}", e))
131 }
132}
133
134impl Default for ClassFileWriter {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140pub struct MethodWriter {
141 inner: MethodVisitor,
142 name: String,
143 descriptor: String,
144 signature: Option<String>,
145 exceptions: Vec<String>,
146 local_variables: Vec<LocalVariableSpec>,
147}
148
149impl MethodWriter {
150 pub fn visit_code(&mut self) {
151 self.inner.visit_code();
152 }
153
154 pub fn visit_insn(&mut self, opcode: u8) {
155 self.inner.visit_insn(opcode);
156 }
157
158 pub fn visit_var_insn(&mut self, opcode: u8, var_index: u16) {
159 self.inner.visit_var_insn(opcode, var_index);
160 }
161
162 pub fn visit_type_insn(&mut self, opcode: u8, type_name: &str) {
163 self.inner.visit_type_insn(opcode, type_name);
164 }
165
166 pub fn visit_new_array(&mut self, array_type: u8) {
167 self.inner
168 .visit_var_insn(rust_asm::opcodes::NEWARRAY, array_type as u16);
169 }
170
171 pub fn visit_jump_insn(&mut self, opcode: u8, target: Label) {
172 self.inner.visit_jump_insn(opcode, target);
173 }
174
175 pub fn visit_lookup_switch(&mut self, default: Label, pairs: &[(i32, Label)]) {
176 self.inner.visit_lookup_switch(default, pairs);
177 }
178
179 pub fn visit_label(&mut self, label: Label) {
180 self.inner.visit_label(label);
181 }
182
183 pub fn visit_line_number(&mut self, line: u16, label: Label) {
184 self.inner
185 .visit_line_number(line, LabelNode::from_label(label));
186 }
187
188 pub fn visit_try_catch_block(
189 &mut self,
190 start: Label,
191 end: Label,
192 handler: Label,
193 catch_type: Option<&str>,
194 ) {
195 self.inner
196 .visit_try_catch_block(start, end, handler, catch_type);
197 }
198
199 pub fn visit_local_variable(&mut self, name: &str, descriptor: &str, index: u16) {
200 self.local_variables.push(LocalVariableSpec {
201 name: name.to_string(),
202 descriptor: descriptor.to_string(),
203 index,
204 });
205 }
206
207 pub fn visit_signature(&mut self, signature: &str) {
208 self.signature = Some(signature.to_string());
209 }
210
211 pub fn visit_exception(&mut self, internal_name: &str) {
212 self.exceptions.push(internal_name.to_string());
213 }
214
215 pub fn visit_field_insn(&mut self, opcode: u8, owner: &str, name: &str, descriptor: &str) {
216 self.inner.visit_field_insn(opcode, owner, name, descriptor);
217 }
218
219 pub fn visit_method_insn(
220 &mut self,
221 opcode: u8,
222 owner: &str,
223 name: &str,
224 descriptor: &str,
225 is_interface: bool,
226 ) {
227 self.inner
228 .visit_method_insn(opcode, owner, name, descriptor, is_interface);
229 }
230
231 pub fn visit_invoke_dynamic_insn(
232 &mut self,
233 name: &str,
234 descriptor: &str,
235 bootstrap_method: Handle,
236 bootstrap_args: &[BootstrapArgument],
237 ) {
238 self.inner
239 .visit_invokedynamic_insn(name, descriptor, bootstrap_method, bootstrap_args);
240 }
241
242 pub fn visit_ldc_insn_int(&mut self, value: i32) {
243 self.inner
244 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::int(value));
245 }
246
247 pub fn visit_ldc_insn_float(&mut self, value: f32) {
248 self.inner
249 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::float(value));
250 }
251
252 pub fn visit_ldc_insn_long(&mut self, value: i64) {
253 self.inner
254 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::long(value));
255 }
256
257 pub fn visit_ldc_insn_double(&mut self, value: f64) {
258 self.inner
259 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::double(value));
260 }
261
262 pub fn visit_ldc_insn_string(&mut self, value: &str) {
263 self.inner
264 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::string(value));
265 }
266
267 pub fn visit_ldc_insn_type(&mut self, type_name: &str) {
268 self.inner
269 .visit_ldc_insn(rust_asm::insn::LdcInsnNode::typed(
270 rust_asm::types::Type::get_object_type(type_name),
271 ));
272 }
273
274 pub fn visit_iinc_insn(&mut self, var_index: u16, increment: i16) {
275 self.inner.visit_iinc_insn(var_index, increment);
276 }
277
278 pub fn visit_maxs(&mut self, max_stack: u16, max_locals: u16) {
279 self.inner.visit_maxs(max_stack, max_locals);
280 }
281
282 pub fn visit_end(self, cw: &mut ClassFileWriter) {
283 cw.method_metadata.push(MethodMetadata {
284 name: self.name.clone(),
285 descriptor: self.descriptor.clone(),
286 signature: self.signature.clone(),
287 exceptions: self.exceptions.clone(),
288 local_variables: self.local_variables.clone(),
289 });
290 self.inner.visit_end(&mut cw.cw);
291 }
292}
293
294pub struct FieldWriter {
295 inner: FieldVisitor,
296 name: String,
297 descriptor: String,
298 signature: Option<String>,
299}
300
301impl FieldWriter {
302 pub fn visit_signature(&mut self, signature: &str) {
303 self.signature = Some(signature.to_string());
304 }
305
306 pub fn visit_end(self, cw: &mut ClassFileWriter) {
307 cw.field_metadata.push(FieldMetadata {
308 name: self.name.clone(),
309 descriptor: self.descriptor.clone(),
310 signature: self.signature.clone(),
311 });
312 self.inner.visit_end(&mut cw.cw);
313 }
314}
315
316#[derive(Debug, Clone)]
317struct LocalVariableSpec {
318 name: String,
319 descriptor: String,
320 index: u16,
321}
322
323#[derive(Debug, Clone)]
324struct MethodMetadata {
325 name: String,
326 descriptor: String,
327 signature: Option<String>,
328 exceptions: Vec<String>,
329 local_variables: Vec<LocalVariableSpec>,
330}
331
332#[derive(Debug, Clone)]
333struct FieldMetadata {
334 name: String,
335 descriptor: String,
336 signature: Option<String>,
337}
338
339#[derive(Debug, Clone)]
340struct RecordComponentMetadata {
341 name: String,
342 descriptor: String,
343 signature: Option<String>,
344}
345
346#[derive(Debug, Clone)]
347pub struct AnnotationMetadata {
348 pub descriptor: String,
349 pub elements: Vec<AnnotationElementMetadata>,
350}
351
352#[derive(Debug, Clone)]
353pub struct AnnotationElementMetadata {
354 pub name: String,
355 pub value: AnnotationElementValueMetadata,
356}
357
358#[derive(Debug, Clone)]
359pub enum AnnotationElementValueMetadata {
360 String(String),
361 Int(i64),
362 Boolean(bool),
363}
364
365struct ExtraAttributes<'a> {
366 class_signature: Option<&'a str>,
367 nest_host: Option<&'a str>,
368 nest_members: &'a [String],
369 record_components: &'a [RecordComponentMetadata],
370 runtime_invisible_annotations: &'a [AnnotationMetadata],
371 field_metadata: &'a [FieldMetadata],
372 method_metadata: &'a [MethodMetadata],
373}
374
375fn add_extra_attributes(class_node: &mut rust_asm::nodes::ClassNode, extras: ExtraAttributes<'_>) {
376 let mut cp = ConstantPoolBuilder::from_pool(class_node.constant_pool.clone());
377 if extras.class_signature.is_some()
378 || extras
379 .field_metadata
380 .iter()
381 .any(|metadata| metadata.signature.is_some())
382 || extras
383 .method_metadata
384 .iter()
385 .any(|metadata| metadata.signature.is_some())
386 {
387 cp.utf8("Signature");
388 }
389 if extras
390 .method_metadata
391 .iter()
392 .any(|metadata| !metadata.exceptions.is_empty())
393 {
394 cp.utf8("Exceptions");
395 }
396 if extras.nest_host.is_some() {
397 cp.utf8("NestHost");
398 }
399 if !extras.nest_members.is_empty() {
400 cp.utf8("NestMembers");
401 }
402 if !extras.record_components.is_empty() {
403 cp.utf8("Record");
404 }
405 if !extras.runtime_invisible_annotations.is_empty() {
406 cp.utf8("RuntimeInvisibleAnnotations");
407 }
408 for metadata in extras.field_metadata {
409 cp.utf8(&metadata.name);
410 cp.utf8(&metadata.descriptor);
411 }
412 for component in extras.record_components {
413 cp.utf8(&component.name);
414 cp.utf8(&component.descriptor);
415 if let Some(signature) = &component.signature {
416 cp.utf8("Signature");
417 cp.utf8(signature);
418 }
419 }
420 for annotation in extras.runtime_invisible_annotations {
421 cp.utf8(&annotation.descriptor);
422 for element in &annotation.elements {
423 cp.utf8(&element.name);
424 match &element.value {
425 AnnotationElementValueMetadata::String(value) => {
426 cp.utf8(value);
427 }
428 AnnotationElementValueMetadata::Int(value) => {
429 cp.integer(*value as i32);
430 }
431 AnnotationElementValueMetadata::Boolean(value) => {
432 cp.integer(i32::from(*value));
433 }
434 }
435 }
436 }
437 for metadata in extras.method_metadata {
438 cp.utf8(&metadata.name);
439 cp.utf8(&metadata.descriptor);
440 }
441
442 if let Some(signature) = extras.class_signature {
443 add_signature_attribute(&mut class_node.attributes, &mut cp, signature);
444 }
445 if let Some(host) = extras.nest_host {
446 add_nest_host_attribute(&mut class_node.attributes, &mut cp, host);
447 }
448 if !extras.nest_members.is_empty() {
449 add_nest_members_attribute(&mut class_node.attributes, &mut cp, extras.nest_members);
450 }
451 if !extras.record_components.is_empty() {
452 add_record_attribute(
453 &mut class_node.attributes,
454 &mut cp,
455 extras.record_components,
456 );
457 }
458 if !extras.runtime_invisible_annotations.is_empty() {
459 add_runtime_invisible_annotations_attribute(
460 &mut class_node.attributes,
461 &mut cp,
462 extras.runtime_invisible_annotations,
463 );
464 }
465
466 for (field, metadata) in class_node.fields.iter_mut().zip(extras.field_metadata) {
467 if field.name == metadata.name
468 && field.descriptor == metadata.descriptor
469 && let Some(signature) = metadata.signature.as_deref()
470 {
471 add_signature_attribute(&mut field.attributes, &mut cp, signature);
472 }
473 }
474
475 for (method, metadata) in class_node.methods.iter_mut().zip(extras.method_metadata) {
476 if method.name == metadata.name
477 && method.descriptor == metadata.descriptor
478 && let Some(signature) = metadata.signature.as_deref()
479 {
480 add_signature_attribute(&mut method.attributes, &mut cp, signature);
481 }
482 if method.name == metadata.name
483 && method.descriptor == metadata.descriptor
484 && !metadata.exceptions.is_empty()
485 {
486 add_exceptions_attribute(&mut method.attributes, &mut cp, &metadata.exceptions);
487 }
488 }
489
490 class_node.constant_pool = cp.into_pool();
491}
492
493fn add_runtime_invisible_annotations_attribute(
494 attributes: &mut Vec<AttributeInfo>,
495 cp: &mut ConstantPoolBuilder,
496 annotations: &[AnnotationMetadata],
497) {
498 attributes.retain(
499 |attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "RuntimeInvisibleAnnotations"),
500 );
501 let mut info = Vec::new();
502 info.extend_from_slice(&(annotations.len() as u16).to_be_bytes());
503 for annotation in annotations {
504 info.extend_from_slice(&cp.utf8(&annotation.descriptor).to_be_bytes());
505 info.extend_from_slice(&(annotation.elements.len() as u16).to_be_bytes());
506 for element in &annotation.elements {
507 info.extend_from_slice(&cp.utf8(&element.name).to_be_bytes());
508 write_annotation_value(&mut info, cp, &element.value);
509 }
510 }
511 attributes.push(AttributeInfo::Unknown {
512 name: "RuntimeInvisibleAnnotations".to_string(),
513 info,
514 });
515}
516
517fn write_annotation_value(
518 info: &mut Vec<u8>,
519 cp: &mut ConstantPoolBuilder,
520 value: &AnnotationElementValueMetadata,
521) {
522 match value {
523 AnnotationElementValueMetadata::String(value) => {
524 info.push(b's');
525 info.extend_from_slice(&cp.utf8(value).to_be_bytes());
526 }
527 AnnotationElementValueMetadata::Int(value) => {
528 info.push(b'I');
529 info.extend_from_slice(&cp.integer(*value as i32).to_be_bytes());
530 }
531 AnnotationElementValueMetadata::Boolean(value) => {
532 info.push(b'Z');
533 info.extend_from_slice(&cp.integer(i32::from(*value)).to_be_bytes());
534 }
535 }
536}
537
538fn add_record_attribute(
539 attributes: &mut Vec<AttributeInfo>,
540 cp: &mut ConstantPoolBuilder,
541 components: &[RecordComponentMetadata],
542) {
543 attributes
544 .retain(|attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "Record"));
545 let mut info = Vec::new();
546 info.extend_from_slice(&(components.len() as u16).to_be_bytes());
547 for component in components {
548 info.extend_from_slice(&cp.utf8(&component.name).to_be_bytes());
549 info.extend_from_slice(&cp.utf8(&component.descriptor).to_be_bytes());
550 if let Some(signature) = &component.signature {
551 info.extend_from_slice(&1u16.to_be_bytes());
552 info.extend_from_slice(&cp.utf8("Signature").to_be_bytes());
553 info.extend_from_slice(&2u32.to_be_bytes());
554 info.extend_from_slice(&cp.utf8(signature).to_be_bytes());
555 } else {
556 info.extend_from_slice(&0u16.to_be_bytes());
557 }
558 }
559 attributes.push(AttributeInfo::Unknown {
560 name: "Record".to_string(),
561 info,
562 });
563}
564
565fn add_nest_host_attribute(
566 attributes: &mut Vec<AttributeInfo>,
567 cp: &mut ConstantPoolBuilder,
568 host: &str,
569) {
570 attributes
571 .retain(|attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "NestHost"));
572 attributes.push(AttributeInfo::Unknown {
573 name: "NestHost".to_string(),
574 info: cp.class(host).to_be_bytes().to_vec(),
575 });
576}
577
578fn add_nest_members_attribute(
579 attributes: &mut Vec<AttributeInfo>,
580 cp: &mut ConstantPoolBuilder,
581 members: &[String],
582) {
583 attributes.retain(
584 |attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "NestMembers"),
585 );
586 let mut info = Vec::new();
587 info.extend_from_slice(&(members.len() as u16).to_be_bytes());
588 for member in members {
589 info.extend_from_slice(&cp.class(member).to_be_bytes());
590 }
591 attributes.push(AttributeInfo::Unknown {
592 name: "NestMembers".to_string(),
593 info,
594 });
595}
596
597fn add_signature_attribute(
598 attributes: &mut Vec<AttributeInfo>,
599 cp: &mut ConstantPoolBuilder,
600 signature: &str,
601) {
602 attributes.retain(|attr| !matches!(attr, AttributeInfo::Signature { .. }));
603 let signature_index = cp.utf8(signature);
604 attributes.push(AttributeInfo::Signature { signature_index });
605}
606
607fn add_exceptions_attribute(
608 attributes: &mut Vec<AttributeInfo>,
609 cp: &mut ConstantPoolBuilder,
610 exceptions: &[String],
611) {
612 attributes.retain(|attr| !matches!(attr, AttributeInfo::Exceptions { .. }));
613 let exception_index_table = exceptions
614 .iter()
615 .map(|exception| cp.class(exception))
616 .collect();
617 attributes.push(AttributeInfo::Exceptions {
618 exception_index_table,
619 });
620}
621
622fn add_local_variables(
623 class_node: &mut rust_asm::nodes::ClassNode,
624 method_metadata: &[MethodMetadata],
625 code_lengths: &HashMap<(String, String), u16>,
626) {
627 let mut cp = ConstantPoolBuilder::from_pool(class_node.constant_pool.clone());
628 cp.utf8("LocalVariableTable");
629
630 for (method, metadata) in class_node.methods.iter_mut().zip(method_metadata) {
631 if method.name != metadata.name || method.descriptor != metadata.descriptor {
632 continue;
633 }
634 let Some(length) = code_lengths
635 .get(&(metadata.name.clone(), metadata.descriptor.clone()))
636 .copied()
637 .filter(|length| *length > 0)
638 else {
639 continue;
640 };
641
642 method.local_variables.clear();
643 for variable in &metadata.local_variables {
644 method.local_variables.push(LocalVariable {
645 start_pc: 0,
646 length,
647 name_index: cp.utf8(&variable.name),
648 descriptor_index: cp.utf8(&variable.descriptor),
649 index: variable.index,
650 });
651 }
652 }
653
654 class_node.constant_pool = cp.into_pool();
655}
656
657fn method_code_lengths(bytes: &[u8]) -> Result<HashMap<(String, String), u16>, String> {
658 let class_file = read_class_file(bytes).map_err(|e| format!("{:?}", e))?;
659 let mut lengths = HashMap::new();
660
661 for method in class_file.methods {
662 let name = cp_utf8(&class_file.constant_pool, method.name_index)?.to_string();
663 let descriptor = cp_utf8(&class_file.constant_pool, method.descriptor_index)?.to_string();
664 let length = method
665 .attributes
666 .iter()
667 .find_map(|attr| match attr {
668 AttributeInfo::Code(code) => Some(code.code.len().min(u16::MAX as usize) as u16),
669 _ => None,
670 })
671 .unwrap_or(0);
672 lengths.insert((name, descriptor), length);
673 }
674
675 Ok(lengths)
676}
677
678fn cp_utf8(pool: &[CpInfo], index: u16) -> Result<&str, String> {
679 match pool.get(index as usize) {
680 Some(CpInfo::Utf8(value)) => Ok(value.as_str()),
681 _ => Err(format!("invalid UTF-8 constant pool index {index}")),
682 }
683}