1use alloc::string::String;
18use alloc::vec::Vec;
19
20use zerodds_cdr::{BufferReader, BufferWriter, Endianness};
21use zerodds_corba_csiv2::CompoundSecMechList;
22use zerodds_corba_iiop::profile_body::CdrError;
23
24use crate::component_tags::ComponentId;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TaggedComponent {
29 pub tag: ComponentId,
31 pub component_data: Vec<u8>,
33}
34
35impl TaggedComponent {
36 pub fn encode(&self, w: &mut BufferWriter) -> Result<(), CdrError> {
41 w.write_u32(self.tag.as_u32())?;
42 let n = u32::try_from(self.component_data.len()).map_err(|_| CdrError::Overflow)?;
43 w.write_u32(n)?;
44 w.write_bytes(&self.component_data)?;
45 Ok(())
46 }
47
48 pub fn decode(r: &mut BufferReader<'_>) -> Result<Self, CdrError> {
53 let tag = ComponentId::from_u32(r.read_u32()?);
54 let n = r.read_u32()? as usize;
55 let bytes = r.read_bytes(n)?;
56 Ok(Self {
57 tag,
58 component_data: bytes.to_vec(),
59 })
60 }
61
62 pub fn structured(&self) -> Result<StructuredComponent, CdrError> {
68 StructuredComponent::decode(self.tag, &self.component_data)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct OrbType(pub u32);
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct CodeSetComponent {
83 pub native_code_set: u32,
85 pub conversion_code_sets: Vec<u32>,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
93pub struct CodeSetComponentInfo {
94 pub for_char_data: CodeSetComponent,
96 pub for_wchar_data: CodeSetComponent,
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct AlternateIiopAddress {
103 pub host: String,
105 pub port: u16,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub struct Ssl {
112 pub target_supports: u16,
115 pub target_requires: u16,
118 pub port: u16,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct TlsSecTrans {
126 pub target_supports: u16,
128 pub target_requires: u16,
130 pub addresses: Vec<AlternateIiopAddress>,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct StreamFormatVersion(pub u8);
137
138#[derive(Debug, Clone, PartialEq, Eq)]
140pub enum StructuredComponent {
141 OrbType(OrbType),
143 CodeSets(CodeSetComponentInfo),
145 AlternateIiopAddress(AlternateIiopAddress),
147 Ssl(Ssl),
149 TlsSecTrans(TlsSecTrans),
151 CsiSecMechList(CompoundSecMechList),
153 StreamFormatVersion(StreamFormatVersion),
155 JavaCodebase(String),
157 Opaque {
159 tag: ComponentId,
161 bytes: Vec<u8>,
163 },
164}
165
166impl StructuredComponent {
167 pub fn decode(tag: ComponentId, encap: &[u8]) -> Result<Self, CdrError> {
173 let endianness = read_endianness(encap)?;
174 let body = &encap[1..];
175 match tag {
176 ComponentId::OrbType => {
177 let mut r = BufferReader::new(body, endianness);
178 Ok(Self::OrbType(OrbType(r.read_u32()?)))
179 }
180 ComponentId::CodeSets => {
181 let mut r = BufferReader::new(body, endianness);
182 let for_char = decode_code_set_component(&mut r)?;
183 let for_wchar = decode_code_set_component(&mut r)?;
184 Ok(Self::CodeSets(CodeSetComponentInfo {
185 for_char_data: for_char,
186 for_wchar_data: for_wchar,
187 }))
188 }
189 ComponentId::AlternateIiopAddress => {
190 let mut r = BufferReader::new(body, endianness);
191 let host = r.read_string()?;
192 let port = r.read_u16()?;
193 Ok(Self::AlternateIiopAddress(AlternateIiopAddress {
194 host,
195 port,
196 }))
197 }
198 ComponentId::SslSecTrans => {
199 let mut r = BufferReader::new(body, endianness);
200 Ok(Self::Ssl(Ssl {
201 target_supports: r.read_u16()?,
202 target_requires: r.read_u16()?,
203 port: r.read_u16()?,
204 }))
205 }
206 ComponentId::TlsSecTrans => {
207 let mut r = BufferReader::new(body, endianness);
208 let target_supports = r.read_u16()?;
209 let target_requires = r.read_u16()?;
210 let n = r.read_u32()? as usize;
211 let mut addresses = Vec::with_capacity(n.min(32));
212 for _ in 0..n {
213 let host = r.read_string()?;
214 let port = r.read_u16()?;
215 addresses.push(AlternateIiopAddress { host, port });
216 }
217 Ok(Self::TlsSecTrans(TlsSecTrans {
218 target_supports,
219 target_requires,
220 addresses,
221 }))
222 }
223 ComponentId::CsiSecMechList => {
224 let mut r = BufferReader::new(body, endianness);
225 Ok(Self::CsiSecMechList(CompoundSecMechList::decode(&mut r)?))
226 }
227 ComponentId::RmiCustomMaxStreamFormat => {
228 let mut r = BufferReader::new(body, endianness);
229 Ok(Self::StreamFormatVersion(StreamFormatVersion(r.read_u8()?)))
230 }
231 ComponentId::JavaCodebase => {
232 let mut r = BufferReader::new(body, endianness);
233 Ok(Self::JavaCodebase(r.read_string()?))
234 }
235 other => Ok(Self::Opaque {
236 tag: other,
237 bytes: encap.to_vec(),
238 }),
239 }
240 }
241
242 pub fn encode_encapsulation(&self, endianness: Endianness) -> Result<Vec<u8>, CdrError> {
248 let mut out = Vec::with_capacity(64);
249 out.push(endianness_to_byte(endianness));
250 let mut w = BufferWriter::new(endianness);
251 match self {
252 Self::OrbType(OrbType(v)) => w.write_u32(*v)?,
253 Self::CodeSets(info) => {
254 encode_code_set_component(&mut w, &info.for_char_data)?;
255 encode_code_set_component(&mut w, &info.for_wchar_data)?;
256 }
257 Self::AlternateIiopAddress(a) => {
258 w.write_string(&a.host)?;
259 w.write_u16(a.port)?;
260 }
261 Self::Ssl(s) => {
262 w.write_u16(s.target_supports)?;
263 w.write_u16(s.target_requires)?;
264 w.write_u16(s.port)?;
265 }
266 Self::TlsSecTrans(t) => {
267 w.write_u16(t.target_supports)?;
268 w.write_u16(t.target_requires)?;
269 let n = u32::try_from(t.addresses.len()).map_err(|_| CdrError::Overflow)?;
270 w.write_u32(n)?;
271 for a in &t.addresses {
272 w.write_string(&a.host)?;
273 w.write_u16(a.port)?;
274 }
275 }
276 Self::CsiSecMechList(list) => list.encode(&mut w)?,
277 Self::StreamFormatVersion(StreamFormatVersion(v)) => w.write_u8(*v)?,
278 Self::JavaCodebase(s) => w.write_string(s)?,
279 Self::Opaque { bytes, .. } => {
280 return Ok(bytes.clone());
283 }
284 }
285 out.extend_from_slice(w.as_bytes());
286 Ok(out)
287 }
288}
289
290fn read_endianness(encap: &[u8]) -> Result<Endianness, CdrError> {
291 if encap.is_empty() {
292 return Err(CdrError::Truncated);
293 }
294 match encap[0] {
295 0 => Ok(Endianness::Big),
296 1 => Ok(Endianness::Little),
297 _ => Err(CdrError::InvalidEndianness),
298 }
299}
300
301const fn endianness_to_byte(e: Endianness) -> u8 {
302 match e {
303 Endianness::Big => 0,
304 Endianness::Little => 1,
305 }
306}
307
308fn decode_code_set_component(r: &mut BufferReader<'_>) -> Result<CodeSetComponent, CdrError> {
309 let native_code_set = r.read_u32()?;
310 let n = r.read_u32()? as usize;
311 let mut conversion = Vec::with_capacity(n.min(16));
312 for _ in 0..n {
313 conversion.push(r.read_u32()?);
314 }
315 Ok(CodeSetComponent {
316 native_code_set,
317 conversion_code_sets: conversion,
318 })
319}
320
321fn encode_code_set_component(w: &mut BufferWriter, c: &CodeSetComponent) -> Result<(), CdrError> {
322 w.write_u32(c.native_code_set)?;
323 let n = u32::try_from(c.conversion_code_sets.len()).map_err(|_| CdrError::Overflow)?;
324 w.write_u32(n)?;
325 for cs in &c.conversion_code_sets {
326 w.write_u32(*cs)?;
327 }
328 Ok(())
329}
330
331#[cfg(test)]
332#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn orb_type_round_trip() {
338 let s = StructuredComponent::OrbType(OrbType(0x4F4D_4732)); let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
340 let decoded = StructuredComponent::decode(ComponentId::OrbType, &bytes).unwrap();
341 assert_eq!(decoded, s);
342 }
343
344 #[test]
345 fn code_sets_round_trip_le() {
346 let info = CodeSetComponentInfo {
347 for_char_data: CodeSetComponent {
348 native_code_set: 0x0001_0001,
349 conversion_code_sets: alloc::vec![0x0001_0109],
350 },
351 for_wchar_data: CodeSetComponent {
352 native_code_set: 0x0001_0109,
353 conversion_code_sets: alloc::vec![],
354 },
355 };
356 let s = StructuredComponent::CodeSets(info.clone());
357 let bytes = s.encode_encapsulation(Endianness::Little).unwrap();
358 let decoded = StructuredComponent::decode(ComponentId::CodeSets, &bytes).unwrap();
359 match decoded {
360 StructuredComponent::CodeSets(d) => assert_eq!(d, info),
361 other => panic!("expected CodeSets, got {other:?}"),
362 }
363 }
364
365 #[test]
366 fn alternate_iiop_address_round_trip() {
367 let s = StructuredComponent::AlternateIiopAddress(AlternateIiopAddress {
368 host: "alt.host".into(),
369 port: 1234,
370 });
371 let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
372 let decoded =
373 StructuredComponent::decode(ComponentId::AlternateIiopAddress, &bytes).unwrap();
374 assert_eq!(decoded, s);
375 }
376
377 #[test]
378 fn ssl_sec_trans_round_trip() {
379 let s = StructuredComponent::Ssl(Ssl {
380 target_supports: 0x0040,
381 target_requires: 0x0020,
382 port: 4242,
383 });
384 let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
385 let decoded = StructuredComponent::decode(ComponentId::SslSecTrans, &bytes).unwrap();
386 assert_eq!(decoded, s);
387 }
388
389 #[test]
390 fn tls_sec_trans_with_addresses_round_trip() {
391 let s = StructuredComponent::TlsSecTrans(TlsSecTrans {
392 target_supports: 0x0040,
393 target_requires: 0x0040,
394 addresses: alloc::vec![
395 AlternateIiopAddress {
396 host: "tls-a.lab".into(),
397 port: 443,
398 },
399 AlternateIiopAddress {
400 host: "tls-b.lab".into(),
401 port: 8443,
402 },
403 ],
404 });
405 let bytes = s.encode_encapsulation(Endianness::Little).unwrap();
406 let decoded = StructuredComponent::decode(ComponentId::TlsSecTrans, &bytes).unwrap();
407 assert_eq!(decoded, s);
408 }
409
410 #[test]
411 fn csi_sec_mech_list_round_trip() {
412 use zerodds_corba_csiv2::{
413 AsContextSec, AssociationOptions, CompoundSecMech, CompoundSecMechList, SasContextSec,
414 };
415 let list = CompoundSecMechList {
416 stateful: true,
417 mechanism_list: alloc::vec![CompoundSecMech {
418 target_requires: AssociationOptions(
419 AssociationOptions::INTEGRITY | AssociationOptions::CONFIDENTIALITY,
420 ),
421 transport_mech_tag: 36, transport_mech_data: alloc::vec![0x01, 0x02, 0x03],
423 as_context: AsContextSec {
424 target_supports: AssociationOptions(0x0040),
425 target_requires: AssociationOptions(0x0040),
426 client_authentication_mech: alloc::vec![0xAA, 0xBB],
427 target_name: alloc::vec![0xCC],
428 },
429 sas_context: SasContextSec {
430 target_supports: AssociationOptions(0x0080),
431 target_requires: AssociationOptions(0x0080),
432 privilege_authorities: alloc::vec![alloc::vec![0xDE, 0xAD]],
433 supported_naming_mechanisms: alloc::vec![alloc::vec![0xBE, 0xEF]],
434 supported_identity_types: 0x0001_0203,
435 },
436 }],
437 };
438 let s = StructuredComponent::CsiSecMechList(list.clone());
439 let bytes = s.encode_encapsulation(Endianness::Little).unwrap();
440 let decoded = StructuredComponent::decode(ComponentId::CsiSecMechList, &bytes).unwrap();
441 match decoded {
442 StructuredComponent::CsiSecMechList(d) => assert_eq!(d, list),
443 other => panic!("expected CsiSecMechList, got {other:?}"),
444 }
445 }
446
447 #[test]
448 fn stream_format_version_round_trip() {
449 let s = StructuredComponent::StreamFormatVersion(StreamFormatVersion(2));
450 let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
451 let decoded =
452 StructuredComponent::decode(ComponentId::RmiCustomMaxStreamFormat, &bytes).unwrap();
453 assert_eq!(decoded, s);
454 }
455
456 #[test]
457 fn java_codebase_round_trip() {
458 let s = StructuredComponent::JavaCodebase("http://server/codebase.jar".into());
459 let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
460 let decoded = StructuredComponent::decode(ComponentId::JavaCodebase, &bytes).unwrap();
461 assert_eq!(decoded, s);
462 }
463
464 #[test]
465 fn opaque_unknown_tag_pass_through() {
466 let raw = alloc::vec![1, 0xff, 0xee, 0xdd];
467 let s = StructuredComponent::decode(ComponentId::Other(9999), &raw).unwrap();
468 match s {
469 StructuredComponent::Opaque { tag, bytes } => {
470 assert_eq!(tag, ComponentId::Other(9999));
471 assert_eq!(bytes, raw);
472 }
473 other => panic!("expected Opaque, got {other:?}"),
474 }
475 }
476
477 #[test]
478 fn invalid_endianness_byte_is_diagnostic() {
479 let bytes = alloc::vec![0xff, 0, 0, 0, 1];
480 let err = StructuredComponent::decode(ComponentId::OrbType, &bytes).unwrap_err();
481 assert!(matches!(err, CdrError::InvalidEndianness));
482 }
483
484 #[test]
485 fn tagged_component_round_trip() {
486 let s = StructuredComponent::OrbType(OrbType(42));
487 let bytes = s.encode_encapsulation(Endianness::Big).unwrap();
488 let tc = TaggedComponent {
489 tag: ComponentId::OrbType,
490 component_data: bytes,
491 };
492 let mut w = BufferWriter::new(Endianness::Big);
493 tc.encode(&mut w).unwrap();
494 let buf = w.into_bytes();
495 let mut r = BufferReader::new(&buf, Endianness::Big);
496 let decoded = TaggedComponent::decode(&mut r).unwrap();
497 assert_eq!(decoded, tc);
498 }
499}