1use core::cell::UnsafeCell;
4
5use zencan_common::{sdo::AbortCode, AtomicCell};
6
7pub trait SubObjectAccess: Sync + Send {
9 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode>;
28
29 fn read_size(&self) -> usize;
41
42 fn write(&self, data: &[u8]) -> Result<(), AbortCode>;
64
65 fn begin_partial(&self) -> Result<(), AbortCode> {
85 Err(AbortCode::UnsupportedAccess)
86 }
87
88 fn write_partial(&self, _buf: &[u8]) -> Result<(), AbortCode> {
96 Err(AbortCode::UnsupportedAccess)
97 }
98
99 fn end_partial(&self) -> Result<(), AbortCode> {
101 Err(AbortCode::UnsupportedAccess)
102 }
103}
104
105#[allow(missing_debug_implementations)]
107pub struct ScalarField<T: Copy> {
108 value: AtomicCell<T>,
109}
110
111impl<T: Send + Copy + PartialEq> ScalarField<T> {
112 pub fn load(&self) -> T {
114 self.value.load()
115 }
116
117 pub fn store(&self, value: T) {
119 self.value.store(value);
120 }
121}
122
123impl<T: Copy + Default> Default for ScalarField<T> {
124 fn default() -> Self {
125 Self {
126 value: AtomicCell::default(),
127 }
128 }
129}
130
131macro_rules! impl_scalar_field {
132 ($rust_type: ty, $data_type: ty) => {
133 impl ScalarField<$rust_type> {
134 pub const fn new(value: $rust_type) -> Self {
136 Self {
137 value: AtomicCell::new(value),
138 }
139 }
140 }
141 impl SubObjectAccess for ScalarField<$rust_type> {
142 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
143 let bytes = self.value.load().to_le_bytes();
144 if offset < bytes.len() {
145 let read_len = buf.len().min(bytes.len() - offset);
146 buf[0..read_len].copy_from_slice(&bytes[offset..offset + read_len]);
147 Ok(read_len)
148 } else {
149 Ok(0)
150 }
151 }
152
153 fn read_size(&self) -> usize {
154 core::mem::size_of::<$rust_type>()
155 }
156
157 fn write(&self, data: &[u8]) -> Result<(), AbortCode> {
158 let value = <$rust_type>::from_le_bytes(data.try_into().map_err(|_| {
159 if data.len() < size_of::<$rust_type>() {
160 AbortCode::DataTypeMismatchLengthLow
161 } else {
162 AbortCode::DataTypeMismatchLengthHigh
163 }
164 })?);
165 self.value.store(value);
166 Ok(())
167 }
168 }
169 };
170}
171
172impl_scalar_field!(u8, DataType::UInt8);
173impl_scalar_field!(u16, DataType::UInt16);
174impl_scalar_field!(u32, DataType::UInt32);
175impl_scalar_field!(i8, DataType::Int8);
176impl_scalar_field!(i16, DataType::Int16);
177impl_scalar_field!(i32, DataType::Int32);
178impl_scalar_field!(f32, DataType::Float);
179
180impl SubObjectAccess for ScalarField<bool> {
182 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
183 let value = self.value.load();
184 if offset != 0 || buf.len() > 1 {
185 return Err(AbortCode::DataTypeMismatchLengthHigh);
186 }
187 buf[0] = if value { 1 } else { 0 };
188 Ok(1)
189 }
190
191 fn read_size(&self) -> usize {
192 1
193 }
194
195 fn write(&self, data: &[u8]) -> Result<(), AbortCode> {
196 if data.len() != 1 {
197 return Err(AbortCode::DataTypeMismatchLengthHigh);
198 }
199 let value = data[0] != 0;
200 self.value.store(value);
201 Ok(())
202 }
203}
204
205#[allow(clippy::len_without_is_empty, missing_debug_implementations)]
209pub struct ByteField<const N: usize> {
210 value: UnsafeCell<[u8; N]>,
211 write_offset: AtomicCell<Option<usize>>,
212}
213
214unsafe impl<const N: usize> Sync for ByteField<N> {}
215
216impl<const N: usize> ByteField<N> {
217 pub const fn new(value: [u8; N]) -> Self {
219 Self {
220 value: UnsafeCell::new(value),
221 write_offset: AtomicCell::new(None),
222 }
223 }
224
225 pub fn len(&self) -> usize {
227 N
228 }
229
230 pub fn store(&self, value: [u8; N]) {
232 self.write_offset.store(None);
234 critical_section::with(|_| {
235 let bytes = unsafe { &mut *self.value.get() };
236 bytes.copy_from_slice(&value);
237 });
238 }
239
240 pub fn load(&self) -> [u8; N] {
242 critical_section::with(|_| unsafe { *self.value.get() })
243 }
244}
245
246impl<const N: usize> Default for ByteField<N> {
247 fn default() -> Self {
248 Self {
249 value: UnsafeCell::new([0; N]),
250 write_offset: AtomicCell::new(None),
251 }
252 }
253}
254
255impl<const N: usize> SubObjectAccess for ByteField<N> {
256 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
257 critical_section::with(|_| {
258 let bytes = unsafe { &*self.value.get() };
259 if bytes.len() > offset {
260 let read_len = buf.len().min(bytes.len() - offset);
261 buf[..read_len].copy_from_slice(&bytes[offset..offset + read_len]);
262 Ok(read_len)
263 } else {
264 Ok(0)
265 }
266 })
267 }
268
269 fn read_size(&self) -> usize {
270 N
271 }
272
273 fn write(&self, data: &[u8]) -> Result<(), AbortCode> {
274 critical_section::with(|_| {
275 let bytes = unsafe { &mut *self.value.get() };
276 if data.len() > bytes.len() {
277 return Err(AbortCode::DataTypeMismatchLengthHigh);
278 }
279 bytes[..data.len()].copy_from_slice(data);
280 Ok(())
281 })
282 }
283
284 fn begin_partial(&self) -> Result<(), AbortCode> {
285 self.write_offset.store(Some(0));
286 Ok(())
287 }
288
289 fn write_partial(&self, buf: &[u8]) -> Result<(), AbortCode> {
290 let offset = self
292 .write_offset
293 .fetch_update(|old| Some(old.map(|x| x + buf.len())))
294 .unwrap();
295 if offset.is_none() {
296 return Err(AbortCode::GeneralError);
297 }
298 let offset = offset.unwrap();
299 if offset + buf.len() > N {
300 return Err(AbortCode::DataTypeMismatchLengthHigh);
301 }
302 critical_section::with(|_| {
303 let bytes = unsafe { &mut *self.value.get() };
304 bytes[offset..offset + buf.len()].copy_from_slice(buf);
305 });
306 Ok(())
307 }
308
309 fn end_partial(&self) -> Result<(), AbortCode> {
310 self.write_offset.store(None);
312 Ok(())
313 }
314}
315
316#[allow(clippy::len_without_is_empty, missing_debug_implementations)]
320pub struct NullTermByteField<const N: usize>(ByteField<N>);
321
322impl<const N: usize> NullTermByteField<N> {
323 pub const fn new(value: [u8; N]) -> Self {
325 Self(ByteField::new(value))
326 }
327
328 pub fn len(&self) -> usize {
330 N
331 }
332
333 pub fn load(&self) -> [u8; N] {
338 self.0.load()
339 }
340
341 pub fn store(&self, value: [u8; N]) {
343 self.0.store(value);
344 }
345
346 pub fn set_str(&self, value: &[u8]) -> Result<(), AbortCode> {
351 self.0.begin_partial()?;
352 self.0.write_partial(value)?;
353 if value.len() < N {
354 self.0.write_partial(&[0])?;
355 }
356 self.end_partial()?;
357 Ok(())
358 }
359}
360
361impl<const N: usize> Default for NullTermByteField<N> {
362 fn default() -> Self {
363 Self(ByteField::default())
364 }
365}
366
367impl<const N: usize> SubObjectAccess for NullTermByteField<N> {
368 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
369 let size = self.0.read(offset, buf)?;
370 let size = buf[0..size].iter().position(|b| *b == 0).unwrap_or(size);
371 Ok(size)
372 }
373
374 fn read_size(&self) -> usize {
375 critical_section::with(|_| {
376 let bytes = unsafe { &*self.0.value.get() };
377 bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len())
379 })
380 }
381
382 fn write(&self, data: &[u8]) -> Result<(), AbortCode> {
383 self.0.begin_partial()?;
384 self.0.write_partial(data)?;
385 if data.len() < N {
386 self.0.write_partial(&[0])?;
387 }
388 self.0.end_partial()?;
389 Ok(())
390 }
391
392 fn begin_partial(&self) -> Result<(), AbortCode> {
393 self.0.begin_partial()
394 }
395
396 fn write_partial(&self, data: &[u8]) -> Result<(), AbortCode> {
397 self.0.write_partial(data)
398 }
399
400 fn end_partial(&self) -> Result<(), AbortCode> {
401 if self.0.write_offset.load().unwrap_or(0) < N {
403 self.0.write_partial(&[0])?;
404 }
405 self.0.end_partial()
406 }
407}
408
409#[derive(Clone, Copy, Debug)]
411pub struct ConstByteRefField {
412 value: &'static [u8],
413}
414
415impl ConstByteRefField {
416 pub const fn new(value: &'static [u8]) -> Self {
418 Self { value }
419 }
420}
421
422impl SubObjectAccess for ConstByteRefField {
423 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
424 let read_len = buf.len().min(self.value.len() - offset);
425 buf[..read_len].copy_from_slice(&self.value[offset..offset + read_len]);
426 Ok(read_len)
427 }
428
429 fn read_size(&self) -> usize {
430 self.value.len()
431 }
432
433 fn write(&self, _data: &[u8]) -> Result<(), AbortCode> {
434 Err(AbortCode::ReadOnly)
435 }
436}
437
438#[derive(Debug)]
439pub struct ConstField<const N: usize> {
444 bytes: [u8; N],
445}
446
447impl<const N: usize> ConstField<N> {
448 pub const fn new(bytes: [u8; N]) -> Self {
450 Self { bytes }
451 }
452}
453
454impl<const N: usize> SubObjectAccess for ConstField<N> {
455 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
456 if offset < self.bytes.len() {
457 let read_len = buf.len().min(self.bytes.len() - offset);
458 buf[..read_len].copy_from_slice(&self.bytes[offset..offset + read_len]);
459 Ok(read_len)
460 } else {
461 Ok(0)
462 }
463 }
464
465 fn read_size(&self) -> usize {
466 N
467 }
468
469 fn write(&self, _data: &[u8]) -> Result<(), AbortCode> {
470 Err(AbortCode::ReadOnly)
471 }
472}
473
474#[allow(missing_debug_implementations)]
476pub struct CallbackSubObject {
477 handler: AtomicCell<Option<&'static dyn SubObjectAccess>>,
478}
479
480impl Default for CallbackSubObject {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486impl CallbackSubObject {
487 pub const fn new() -> Self {
489 Self {
490 handler: AtomicCell::new(None),
491 }
492 }
493
494 pub fn register_handler(&self, handler: &'static dyn SubObjectAccess) {
496 self.handler.store(Some(handler));
497 }
498}
499
500impl SubObjectAccess for CallbackSubObject {
501 fn read(&self, offset: usize, buf: &mut [u8]) -> Result<usize, AbortCode> {
502 if let Some(handler) = self.handler.load() {
503 handler.read(offset, buf)
504 } else {
505 Err(AbortCode::ResourceNotAvailable)
506 }
507 }
508
509 fn read_size(&self) -> usize {
510 if let Some(handler) = self.handler.load() {
511 handler.read_size()
512 } else {
513 0
514 }
515 }
516
517 fn write(&self, data: &[u8]) -> Result<(), AbortCode> {
518 if let Some(handler) = self.handler.load() {
519 handler.write(data)
520 } else {
521 Err(AbortCode::ResourceNotAvailable)
522 }
523 }
524
525 fn begin_partial(&self) -> Result<(), AbortCode> {
526 if let Some(handler) = self.handler.load() {
527 handler.begin_partial()
528 } else {
529 Err(AbortCode::ResourceNotAvailable)
530 }
531 }
532
533 fn write_partial(&self, buf: &[u8]) -> Result<(), AbortCode> {
534 if let Some(handler) = self.handler.load() {
535 handler.write_partial(buf)
536 } else {
537 Err(AbortCode::ResourceNotAvailable)
538 }
539 }
540
541 fn end_partial(&self) -> Result<(), AbortCode> {
542 if let Some(handler) = self.handler.load() {
543 handler.end_partial()
544 } else {
545 Err(AbortCode::ResourceNotAvailable)
546 }
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use zencan_common::objects::{ObjectCode, SubInfo};
553
554 use crate::object_dict::{ObjectAccess, ProvidesSubObjects};
555
556 use super::*;
557
558 #[derive(Default)]
559 struct ExampleRecord {
560 val1: ScalarField<u32>,
561 val2: ScalarField<bool>,
562 val3: NullTermByteField<10>,
563 }
564
565 impl ProvidesSubObjects for ExampleRecord {
566 fn get_sub_object(&self, sub: u8) -> Option<(SubInfo, &dyn SubObjectAccess)> {
567 match sub {
568 0 => Some((
569 SubInfo::MAX_SUB_NUMBER,
570 const { &ConstField::new(3u8.to_le_bytes()) },
571 )),
572 1 => Some((SubInfo::new_u32().rw_access(), &self.val1)),
573 2 => Some((SubInfo::new_u8().rw_access(), &self.val2)),
574 3 => Some((
575 SubInfo::new_visibile_str(self.val3.len()).rw_access(),
576 &self.val3,
577 )),
578 _ => None,
579 }
580 }
581
582 fn object_code(&self) -> ObjectCode {
583 ObjectCode::Record
584 }
585 }
586
587 #[test]
588 fn test_record_with_provides_sub_objects() {
589 let record = ExampleRecord::default();
590
591 assert_eq!(3, record.read_u8(0).unwrap());
592 record.write(1, &42u32.to_le_bytes()).unwrap();
593 assert_eq!(42, record.read_u32(1).unwrap());
594
595 record.begin_partial(3).unwrap();
596 record
598 .write_partial(3, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
599 .unwrap();
600 let mut buf = [0; 10];
601 record.read(3, 0, &mut buf).unwrap();
602 assert_eq!([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], buf);
603 record.begin_partial(3).unwrap();
605 record.write_partial(3, &[0, 1, 2, 3]).unwrap();
606 record.write_partial(3, &[4, 5, 6, 7]).unwrap();
607 record.end_partial(3).unwrap();
608 let mut buf = [0; 9];
609 record.read(3, 0, &mut buf).unwrap();
610 assert_eq!([0u8, 1, 2, 3, 4, 5, 6, 7, 0], buf)
611 }
612
613 fn sub_read_test_helper(field: &dyn SubObjectAccess, expected_bytes: &[u8]) {
614 let n = expected_bytes.len();
615
616 assert!(n > 2, "Expected bytes cannot be shorted than 2 bytes");
617
618 assert_eq!(n, field.read_size());
619
620 let mut read_buf = vec![0xffu8; n + 10];
622 let read_size = field.read(0, &mut read_buf).unwrap();
623 assert_eq!(n, read_size);
624 assert_eq!(expected_bytes, &read_buf[0..n]);
625
626 let mut read_buf = vec![0xffu8; n + 10];
628 let read_size = field.read(0, &mut read_buf).unwrap();
629 assert_eq!(n, read_size);
630 assert_eq!(expected_bytes, &read_buf[0..n]);
631
632 let mut read_buf = vec![0xffu8; n + 10];
634 let read_size = field.read(2, &mut read_buf).unwrap();
635 assert_eq!(n - 2, read_size);
636 assert_eq!(&expected_bytes[2..], &read_buf[0..n - 2]);
637
638 let mut read_buf = vec![0xffu8; n - 2];
640 let read_size = field.read(1, &mut read_buf).unwrap();
641 assert_eq!(n - 2, read_size);
642 assert_eq!(expected_bytes[1..n - 1], read_buf);
643 }
644
645 #[test]
646 fn test_scalar_field() {
647 let field = ScalarField::<u32>::new(42u32);
648
649 let exp_bytes = 42u32.to_le_bytes();
650
651 sub_read_test_helper(&field, &exp_bytes);
652 }
653
654 #[test]
655 fn test_byte_field() {
656 const N: usize = 10;
657 let field = ByteField::new([0; N]);
658
659 let write_data = Vec::from_iter(0u8..N as u8);
660 field.write(&write_data).unwrap();
661
662 sub_read_test_helper(&field, &write_data);
663 }
664
665 #[test]
666 fn test_null_term_byte_field() {
667 let field = NullTermByteField::new([0; 10]);
668 field.write(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
670 sub_read_test_helper(&field, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
671 field.write(&[1, 2, 3, 4]).unwrap();
673 sub_read_test_helper(&field, &[1, 2, 3, 4]);
674 }
675
676 #[test]
677 fn test_const_field() {
678 let field = ConstField::new([1, 2, 3, 4, 5]);
679 sub_read_test_helper(&field, &[1, 2, 3, 4, 5]);
680 }
681
682 #[test]
683 fn test_const_byte_ref_field() {
684 let field = ConstByteRefField::new(&[1, 2, 3, 4, 5]);
685 sub_read_test_helper(&field, &[1, 2, 3, 4, 5]);
686 }
687}