1use crate::common::{Reset, ResetKind};
4use rand::RngCore;
5use serde::{
6 Deserialize, Deserializer, Serialize, Serializer,
7 de::{SeqAccess, Visitor},
8 ser::SerializeTuple,
9};
10use std::{
11 fmt,
12 marker::PhantomData,
13 num::NonZeroUsize,
14 ops::{Deref, DerefMut, Index, IndexMut},
15 str::FromStr,
16};
17
18#[derive(Default, Clone, Serialize, Deserialize)]
21pub struct Memory<D> {
22 ram_state: RamState,
23 is_ram: bool,
24 data: D,
25}
26
27impl<D> Memory<D> {
28 pub fn new() -> Self
30 where
31 D: Default,
32 {
33 Self::default()
34 }
35
36 pub const fn is_ram(&self) -> bool {
38 self.is_ram
39 }
40}
41
42impl Memory<Vec<u8>> {
43 pub fn rom() -> Self {
45 Self::default()
46 }
47
48 pub const fn ram(ram_state: RamState) -> Self {
50 Self {
51 ram_state,
52 is_ram: true,
53 data: Vec::new(),
54 }
55 }
56
57 pub fn with_ram_state(mut self, state: RamState) -> Self {
59 self.ram_state = state;
60 self.ram_state.fill(&mut self.data);
61 self
62 }
63
64 pub fn with_size(mut self, size: usize) -> Self {
66 self.resize(size);
67 self
68 }
69
70 pub fn resize(&mut self, size: usize) {
72 self.data.resize(size, 0);
73 self.ram_state.fill(&mut self.data);
74 }
75}
76
77impl<T, const N: usize> Memory<ConstSlice<T, N>> {
78 pub fn rom_const() -> Self
80 where
81 T: Default + Copy,
82 {
83 Self::default()
84 }
85
86 pub fn ram_const(ram_state: RamState) -> Self
88 where
89 T: Default + Copy,
90 {
91 Self {
92 ram_state,
93 is_ram: true,
94 data: ConstSlice::new(),
95 }
96 }
97}
98
99impl Reset for Memory<Vec<u8>> {
100 fn reset(&mut self, kind: ResetKind) {
101 if self.is_ram && kind == ResetKind::Hard {
102 self.ram_state.fill(&mut self.data);
103 }
104 }
105}
106
107impl<const N: usize> Reset for Memory<ConstSlice<u8, N>> {
108 fn reset(&mut self, kind: ResetKind) {
109 if self.is_ram && kind == ResetKind::Hard {
110 self.ram_state.fill(&mut *self.data);
111 }
112 }
113}
114
115impl fmt::Debug for Memory<Vec<u8>> {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 f.debug_struct("Memory")
118 .field("ram_state", &self.ram_state)
119 .field("is_ram", &self.is_ram)
120 .field("len", &self.data.len())
121 .field("capacity", &self.data.capacity())
122 .finish()
123 }
124}
125
126impl<T, const N: usize> fmt::Debug for Memory<ConstSlice<T, N>> {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_struct("Memory")
129 .field("ram_state", &self.ram_state)
130 .field("is_ram", &self.is_ram)
131 .field("len", &self.data.len())
132 .finish()
133 }
134}
135
136impl<D: Deref> Deref for Memory<D> {
137 type Target = <D as Deref>::Target;
138 fn deref(&self) -> &Self::Target {
139 &self.data
140 }
141}
142
143impl<D: DerefMut> DerefMut for Memory<D> {
144 fn deref_mut(&mut self) -> &mut Self::Target {
145 &mut self.data
146 }
147}
148
149impl<T, D: AsRef<[T]>> AsRef<[T]> for Memory<D> {
150 fn as_ref(&self) -> &[T] {
151 self.data.as_ref()
152 }
153}
154
155impl<T, D: AsMut<[T]>> AsMut<[T]> for Memory<D> {
156 fn as_mut(&mut self) -> &mut [T] {
157 self.data.as_mut()
158 }
159}
160
161#[derive(Clone)]
162pub struct ConstSlice<T, const N: usize>([T; N]);
163
164impl<T, const N: usize> ConstSlice<T, N> {
165 pub fn new() -> Self
167 where
168 T: Default + Copy,
169 {
170 Self::default()
171 }
172
173 pub const fn filled(val: T) -> Self
175 where
176 T: Copy,
177 {
178 Self([val; N])
179 }
180}
181
182impl<T: Default + Copy, const N: usize> Default for ConstSlice<T, N> {
183 fn default() -> Self {
184 Self([T::default(); N])
185 }
186}
187
188impl<T, const N: usize> Deref for ConstSlice<T, N> {
189 type Target = [T; N];
190 fn deref(&self) -> &Self::Target {
191 &self.0
192 }
193}
194
195impl<T, const N: usize> DerefMut for ConstSlice<T, N> {
196 fn deref_mut(&mut self) -> &mut Self::Target {
197 &mut self.0
198 }
199}
200
201impl<T, const N: usize> AsRef<[T]> for ConstSlice<T, N> {
202 fn as_ref(&self) -> &[T] {
203 self.0.as_ref()
204 }
205}
206
207impl<T, const N: usize> AsMut<[T]> for ConstSlice<T, N> {
208 fn as_mut(&mut self) -> &mut [T] {
209 self.0.as_mut()
210 }
211}
212
213impl<T, const N: usize> Index<usize> for ConstSlice<T, N> {
214 type Output = T;
215
216 fn index(&self, index: usize) -> &Self::Output {
217 debug_assert!(self.0.len().is_power_of_two());
218 self.0.index(index & (self.0.len() - 1))
219 }
220}
221
222impl<T, const N: usize> IndexMut<usize> for ConstSlice<T, N> {
223 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
224 debug_assert!(self.0.len().is_power_of_two());
225 self.0.index_mut(index & (self.0.len() - 1))
226 }
227}
228
229impl<T: Serialize, const N: usize> Serialize for ConstSlice<T, N> {
230 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
231 where
232 S: Serializer,
233 {
234 let mut s = serializer.serialize_tuple(N)?;
235 for item in &self.0 {
236 s.serialize_element(item)?;
237 }
238 s.end()
239 }
240}
241
242impl<'de, T, const N: usize> Deserialize<'de> for ConstSlice<T, N>
243where
244 T: Deserialize<'de> + Default + Copy,
245{
246 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
247 where
248 D: Deserializer<'de>,
249 {
250 struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
251
252 impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
253 where
254 T: Deserialize<'de> + Default + Copy,
255 {
256 type Value = [T; N];
257
258 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 formatter.write_str(&format!("an array of length {}", N))
260 }
261
262 #[inline]
263 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
264 where
265 A: SeqAccess<'de>,
266 {
267 let mut data = [T::default(); N];
268 for data in &mut data {
269 match (seq.next_element())? {
270 Some(val) => *data = val,
271 None => return Err(serde::de::Error::invalid_length(N, &self)),
272 }
273 }
274 Ok(data)
275 }
276 }
277
278 deserializer
279 .deserialize_tuple(N, ArrayVisitor(PhantomData))
280 .map(Self)
281 }
282}
283
284pub trait Read {
286 fn read(&mut self, addr: u16) -> u8 {
288 self.peek(addr)
289 }
290
291 fn read_u16(&mut self, addr: u16) -> u16 {
293 let lo = self.read(addr);
294 let hi = self.read(addr.wrapping_add(1));
295 u16::from_le_bytes([lo, hi])
296 }
297
298 fn peek(&self, addr: u16) -> u8;
300
301 fn peek_u16(&self, addr: u16) -> u16 {
303 let lo = self.peek(addr);
304 let hi = self.peek(addr.wrapping_add(1));
305 u16::from_le_bytes([lo, hi])
306 }
307}
308
309pub trait Write {
311 fn write(&mut self, addr: u16, val: u8);
313
314 fn write_u16(&mut self, addr: u16, val: u16) {
316 let [lo, hi] = val.to_le_bytes();
317 self.write(addr, lo);
318 self.write(addr, hi);
319 }
320}
321
322#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
324#[must_use]
325pub enum RamState {
326 #[default]
327 AllZeros,
328 AllOnes,
329 Random,
330}
331
332impl RamState {
333 pub const fn as_slice() -> &'static [Self] {
335 &[Self::AllZeros, Self::AllOnes, Self::Random]
336 }
337
338 #[must_use]
340 pub const fn as_str(&self) -> &'static str {
341 match self {
342 Self::AllZeros => "all-zeros",
343 Self::AllOnes => "all-ones",
344 Self::Random => "random",
345 }
346 }
347
348 pub fn fill(&self, data: &mut [u8]) {
350 match self {
351 RamState::AllZeros => data.fill(0x00),
352 RamState::AllOnes => data.fill(0xFF),
353 RamState::Random => {
354 rand::rng().fill_bytes(data);
355 }
356 }
357 }
358}
359
360impl From<usize> for RamState {
361 fn from(value: usize) -> Self {
362 match value {
363 0 => Self::AllZeros,
364 1 => Self::AllOnes,
365 _ => Self::Random,
366 }
367 }
368}
369
370impl AsRef<str> for RamState {
371 fn as_ref(&self) -> &str {
372 self.as_str()
373 }
374}
375
376impl std::fmt::Display for RamState {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 let s = match self {
379 Self::AllZeros => "All $00",
380 Self::AllOnes => "All $FF",
381 Self::Random => "Random",
382 };
383 write!(f, "{s}")
384 }
385}
386
387impl FromStr for RamState {
388 type Err = &'static str;
389 fn from_str(s: &str) -> Result<Self, Self::Err> {
390 match s {
391 "all-zeros" => Ok(Self::AllZeros),
392 "all-ones" => Ok(Self::AllOnes),
393 "random" => Ok(Self::Random),
394 _ => Err("invalid RamState value. valid options: `all-zeros`, `all-ones`, or `random`"),
395 }
396 }
397}
398
399#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
401#[must_use]
402pub enum BankAccess {
403 None,
404 Read,
405 ReadWrite,
406}
407
408#[derive(Clone, Serialize, Deserialize)]
410#[must_use]
411pub struct Banks {
412 start: usize,
413 end: usize,
414 size: NonZeroUsize,
415 window: NonZeroUsize,
416 shift: usize,
417 mask: usize,
418 banks: Vec<usize>,
419 access: Vec<BankAccess>,
420 page_count: usize,
421}
422
423#[derive(thiserror::Error, Debug)]
424#[must_use]
425pub enum Error {
426 #[error("bank `window` must a non-zero power of two")]
427 InvalidWindow,
428 #[error("bank `size` must be non-zero")]
429 InvalidSize,
430}
431
432impl Banks {
433 pub fn new(
434 start: usize,
435 end: usize,
436 capacity: usize,
437 window: impl TryInto<NonZeroUsize>,
438 ) -> Result<Self, Error> {
439 let window = window.try_into().map_err(|_| Error::InvalidWindow)?;
440 if !window.is_power_of_two() {
441 return Err(Error::InvalidWindow);
442 }
443
444 let size = NonZeroUsize::try_from(end - start).map_err(|_| Error::InvalidSize)?;
445 let bank_count = (size.get() + 1) / window;
446
447 let mut banks = vec![0; bank_count];
448 let access = vec![BankAccess::ReadWrite; bank_count];
449 for (i, bank) in banks.iter_mut().enumerate() {
450 *bank = (i * window.get()) % capacity;
451 }
452 let page_count = capacity / window.get();
453
454 Ok(Self {
455 start,
456 end,
457 size,
458 window,
459 shift: window.trailing_zeros() as usize,
460 mask: page_count.saturating_sub(1),
461 banks,
462 access,
463 page_count,
464 })
465 }
466
467 pub fn set(&mut self, mut bank: usize, page: usize) {
468 if bank >= self.banks.len() {
469 bank %= self.banks.len();
470 }
471 assert!(bank < self.banks.len());
472 self.banks[bank] = (page & self.mask) << self.shift;
473 debug_assert!(self.banks[bank] < self.page_count * self.window.get());
474 }
475
476 pub fn set_range(&mut self, start: usize, end: usize, page: usize) {
477 let mut new_addr = (page & self.mask) << self.shift;
478 for mut bank in start..=end {
479 if bank >= self.banks.len() {
480 bank %= self.banks.len();
481 }
482 assert!(bank < self.banks.len());
483 self.banks[bank] = new_addr;
484 debug_assert!(self.banks[bank] < self.page_count * self.window.get());
485 new_addr += self.window.get();
486 }
487 }
488
489 pub fn set_access(&mut self, mut bank: usize, access: BankAccess) {
490 if bank >= self.banks.len() {
491 bank %= self.banks.len();
492 }
493 assert!(bank < self.banks.len());
494 self.access[bank] = access;
495 }
496
497 pub fn set_access_range(&mut self, start: usize, end: usize, access: BankAccess) {
498 for slot in start..=end {
499 self.set_access(slot, access);
500 }
501 }
502
503 pub fn readable(&self, addr: u16) -> bool {
504 let slot = self.get(addr);
505 assert!(slot < self.banks.len());
506 matches!(self.access[slot], BankAccess::Read | BankAccess::ReadWrite)
507 }
508
509 pub fn writable(&self, addr: u16) -> bool {
510 let slot = self.get(addr);
511 assert!(slot < self.banks.len());
512 self.access[slot] == BankAccess::ReadWrite
513 }
514
515 #[must_use]
516 pub const fn last(&self) -> usize {
517 self.page_count.saturating_sub(1)
518 }
519
520 #[must_use]
521 pub fn banks_len(&self) -> usize {
522 self.banks.len()
523 }
524
525 #[must_use]
526 pub const fn get(&self, addr: u16) -> usize {
527 (addr as usize & self.size.get()) >> self.shift
528 }
529
530 #[must_use]
531 pub fn translate(&self, addr: u16) -> usize {
532 let slot = self.get(addr);
533 assert!(slot < self.banks.len());
534 let page_offset = self.banks[slot];
535 page_offset | (addr as usize) & (self.window.get() - 1)
536 }
537
538 #[must_use]
539 pub fn page(&self, bank: usize) -> usize {
540 self.banks[bank] >> self.shift
541 }
542
543 #[must_use]
544 pub fn page_offset(&self, bank: usize) -> usize {
545 self.banks[bank]
546 }
547
548 #[must_use]
549 pub const fn page_count(&self) -> usize {
550 self.page_count
551 }
552}
553
554impl std::fmt::Debug for Banks {
555 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
556 f.debug_struct("Bank")
557 .field("start", &format_args!("${:04X}", self.start))
558 .field("end", &format_args!("${:04X}", self.end))
559 .field("size", &format_args!("${:04X}", self.size))
560 .field("window", &format_args!("${:04X}", self.window))
561 .field("shift", &self.shift)
562 .field("mask", &self.shift)
563 .field("banks", &self.banks)
564 .field("page_count", &self.page_count)
565 .finish()
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn get_bank() {
575 let banks = Banks::new(
576 0x8000,
577 0xFFFF,
578 128 * 1024,
579 NonZeroUsize::new(0x4000).unwrap(),
580 )
581 .unwrap();
582 assert_eq!(banks.get(0x8000), 0);
583 assert_eq!(banks.get(0x9FFF), 0);
584 assert_eq!(banks.get(0xA000), 0);
585 assert_eq!(banks.get(0xBFFF), 0);
586 assert_eq!(banks.get(0xC000), 1);
587 assert_eq!(banks.get(0xDFFF), 1);
588 assert_eq!(banks.get(0xE000), 1);
589 assert_eq!(banks.get(0xFFFF), 1);
590 }
591
592 #[test]
593 fn bank_translate() {
594 let mut banks = Banks::new(
595 0x8000,
596 0xFFFF,
597 128 * 1024,
598 NonZeroUsize::new(0x2000).unwrap(),
599 )
600 .unwrap();
601
602 let last_bank = banks.last();
603 assert_eq!(last_bank, 15, "bank count");
604
605 assert_eq!(banks.translate(0x8000), 0x0000);
606 banks.set(0, 1);
607 assert_eq!(banks.translate(0x8000), 0x2000);
608 banks.set(0, 2);
609 assert_eq!(banks.translate(0x8000), 0x4000);
610 banks.set(0, 0);
611 assert_eq!(banks.translate(0x8000), 0x0000);
612 banks.set(0, banks.last());
613 assert_eq!(banks.translate(0x8000), 0x1E000);
614 }
615}