Skip to main content

rust_asm/
types.rs

1use std::fmt;
2use std::hash::{Hash, Hasher};
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub enum Type {
6    Void,
7    Boolean,
8    Char,
9    Byte,
10    Short,
11    Int,
12    Float,
13    Long,
14    Double,
15    /// Array type, with the element type (which may itself be an array).
16    Array(Box<Type>),
17    /// Object type, storing the internal name (e.g. `"java/lang/Object"`).
18    Object(String),
19    /// Method type, storing argument types and return type.
20    Method {
21        argument_types: Vec<Type>,
22        return_type: Box<Type>,
23    },
24}
25
26impl Type {
27    pub const VOID_TYPE: Type = Type::Void;
28    pub const BOOLEAN_TYPE: Type = Type::Boolean;
29    pub const CHAR_TYPE: Type = Type::Char;
30    pub const BYTE_TYPE: Type = Type::Byte;
31    pub const SHORT_TYPE: Type = Type::Short;
32    pub const INT_TYPE: Type = Type::Int;
33    pub const FLOAT_TYPE: Type = Type::Float;
34    pub const LONG_TYPE: Type = Type::Long;
35    pub const DOUBLE_TYPE: Type = Type::Double;
36
37    /// Returns the `Type` corresponding to the given field or method descriptor.
38    ///
39    /// # Panics
40    /// Panics if the descriptor is invalid.
41    pub fn get_type(descriptor: &str) -> Self {
42        let bytes = descriptor.as_bytes();
43        let mut pos = 0;
44        Self::parse(bytes, &mut pos)
45    }
46
47    /// Returns the `Type` corresponding to the given internal name.
48    /// If the name starts with `'['`, it is treated as an array descriptor.
49    pub fn get_object_type(internal_name: &str) -> Self {
50        if internal_name.starts_with('[') {
51            Self::get_type(internal_name)
52        } else {
53            Type::Object(internal_name.to_string())
54        }
55    }
56
57    /// Returns the method `Type` corresponding to the given method descriptor.
58    ///
59    /// # Panics
60    /// Panics if the descriptor is not a valid method descriptor.
61    pub fn get_method_type(descriptor: &str) -> Self {
62        let ty = Self::get_type(descriptor);
63        match ty {
64            Type::Method { .. } => ty,
65            _ => panic!("Not a method descriptor: {}", descriptor),
66        }
67    }
68
69    /// Creates a method type from its return type and argument types.
70    pub fn get_method_type_from_parts(return_type: Type, argument_types: Vec<Type>) -> Self {
71        Type::Method {
72            argument_types,
73            return_type: Box::new(return_type),
74        }
75    }
76
77    /// Returns the sort of this type (as an integer, compatible with ASM constants).
78    pub fn get_sort(&self) -> u8 {
79        match self {
80            Type::Void => 0,
81            Type::Boolean => 1,
82            Type::Char => 2,
83            Type::Byte => 3,
84            Type::Short => 4,
85            Type::Int => 5,
86            Type::Float => 6,
87            Type::Long => 7,
88            Type::Double => 8,
89            Type::Array(_) => 9,
90            Type::Object(_) => 10,
91            Type::Method { .. } => 11,
92        }
93    }
94
95    /// Returns the descriptor of this type.
96    pub fn get_descriptor(&self) -> String {
97        match self {
98            Type::Void => "V".to_string(),
99            Type::Boolean => "Z".to_string(),
100            Type::Char => "C".to_string(),
101            Type::Byte => "B".to_string(),
102            Type::Short => "S".to_string(),
103            Type::Int => "I".to_string(),
104            Type::Float => "F".to_string(),
105            Type::Long => "J".to_string(),
106            Type::Double => "D".to_string(),
107            Type::Array(elem) => format!("[{}", elem.get_descriptor()),
108            Type::Object(name) => format!("L{};", name),
109            Type::Method {
110                argument_types,
111                return_type,
112            } => {
113                let mut desc = String::from("(");
114                for arg in argument_types {
115                    desc.push_str(&arg.get_descriptor());
116                }
117                desc.push(')');
118                desc.push_str(&return_type.get_descriptor());
119                desc
120            }
121        }
122    }
123
124    /// Returns the Java class name corresponding to this type (e.g. "int", "java.lang.Object[]").
125    pub fn get_class_name(&self) -> String {
126        match self {
127            Type::Void => "void".to_string(),
128            Type::Boolean => "boolean".to_string(),
129            Type::Char => "char".to_string(),
130            Type::Byte => "byte".to_string(),
131            Type::Short => "short".to_string(),
132            Type::Int => "int".to_string(),
133            Type::Float => "float".to_string(),
134            Type::Long => "long".to_string(),
135            Type::Double => "double".to_string(),
136            Type::Array(elem) => format!("{}[]", elem.get_class_name()),
137            Type::Object(name) => name.replace('/', "."),
138            Type::Method { .. } => panic!("get_class_name() called on a method type"),
139        }
140    }
141
142    /// Returns the internal name of this type.
143    /// For object types, this is the internal name (e.g. `"java/lang/Object"`).
144    /// For array types, this is the descriptor itself (e.g. `"[I"`).
145    /// For other types, returns `None`.
146    pub fn internal_name(&self) -> Option<String> {
147        match self {
148            Type::Object(name) => Some(name.clone()),
149            Type::Array(_) => Some(self.get_descriptor()), // array internal name = descriptor
150            _ => None,
151        }
152    }
153
154    /// If this is an array type, returns the number of dimensions.
155    /// Otherwise returns 0.
156    pub fn get_dimensions(&self) -> usize {
157        match self {
158            Type::Array(elem) => 1 + elem.get_dimensions(),
159            _ => 0,
160        }
161    }
162
163    /// If this is an array type, returns the element type (which may itself be an array).
164    /// Otherwise returns `None`.
165    pub fn get_element_type(&self) -> Option<&Type> {
166        match self {
167            Type::Array(elem) => Some(elem),
168            _ => None,
169        }
170    }
171
172    /// If this is a method type, returns the argument types.
173    pub fn get_argument_types(&self) -> Option<&[Type]> {
174        match self {
175            Type::Method { argument_types, .. } => Some(argument_types),
176            _ => None,
177        }
178    }
179
180    /// If this is a method type, returns the return type.
181    pub fn get_return_type(&self) -> Option<&Type> {
182        match self {
183            Type::Method { return_type, .. } => Some(return_type),
184            _ => None,
185        }
186    }
187
188    /// Returns the size of values of this type (1 for most, 2 for long/double, 0 for void).
189    pub fn get_size(&self) -> usize {
190        match self {
191            Type::Void => 0,
192            Type::Long | Type::Double => 2,
193            _ => 1,
194        }
195    }
196
197    /// Returns the number of arguments of this method type.
198    /// Panics if called on a non-method type.
199    pub fn get_argument_count(&self) -> usize {
200        match self {
201            Type::Method { argument_types, .. } => argument_types.len(),
202            _ => panic!("get_argument_count() called on a non-method type"),
203        }
204    }
205
206    /// Parses a type from a byte slice starting at position `pos`.
207    /// Returns the type and advances `pos` to the next position after the type.
208    fn parse(bytes: &[u8], pos: &mut usize) -> Self {
209        let c = bytes[*pos] as char;
210        match c {
211            'V' => {
212                *pos += 1;
213                Type::Void
214            }
215            'Z' => {
216                *pos += 1;
217                Type::Boolean
218            }
219            'C' => {
220                *pos += 1;
221                Type::Char
222            }
223            'B' => {
224                *pos += 1;
225                Type::Byte
226            }
227            'S' => {
228                *pos += 1;
229                Type::Short
230            }
231            'I' => {
232                *pos += 1;
233                Type::Int
234            }
235            'F' => {
236                *pos += 1;
237                Type::Float
238            }
239            'J' => {
240                *pos += 1;
241                Type::Long
242            }
243            'D' => {
244                *pos += 1;
245                Type::Double
246            }
247            'L' => {
248                *pos += 1; // skip 'L'
249                let start = *pos;
250                while bytes[*pos] != b';' {
251                    *pos += 1;
252                }
253                let name = std::str::from_utf8(&bytes[start..*pos])
254                    .expect("Invalid UTF-8 in internal name")
255                    .to_string();
256                *pos += 1; // skip ';'
257                Type::Object(name)
258            }
259            '[' => {
260                *pos += 1; // skip '['
261                let elem = Box::new(Self::parse(bytes, pos));
262                Type::Array(elem)
263            }
264            '(' => {
265                *pos += 1; // skip '('
266                let mut args = Vec::new();
267                while bytes[*pos] != b')' {
268                    args.push(Self::parse(bytes, pos));
269                }
270                *pos += 1; // skip ')'
271                let ret = Box::new(Self::parse(bytes, pos));
272                Type::Method {
273                    argument_types: args,
274                    return_type: ret,
275                }
276            }
277            _ => panic!("Invalid descriptor character: {}", c),
278        }
279    }
280}
281
282impl fmt::Display for Type {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        write!(f, "{}", self.get_descriptor())
285    }
286}