use arch::x86_64::kernel::apic;
use arch::x86_64::kernel::idt;
use arch::x86_64::kernel::irq;
use arch::x86_64::kernel::percore::*;
use arch::x86_64::mm::paging::{BasePageSize, PageSize, PageTableEntryFlags};
use config::*;
use core::{mem, ptr};
use environment;
use mm;
use scheduler::task::{Task, TaskFrame};
#[repr(C, packed)]
struct State {
	
	fs: usize,
	
	r15: usize,
	
	r14: usize,
	
	r13: usize,
	
	r12: usize,
	
	r11: usize,
	
	r10: usize,
	
	r9: usize,
	
	r8: usize,
	
	rdi: usize,
	
	rsi: usize,
	
	rbp: usize,
	
	rbx: usize,
	
	rdx: usize,
	
	rcx: usize,
	
	rax: usize,
	
	rflags: usize,
	
	rip: usize,
}
pub struct BootStack {
	
	stack: usize,
	
	ist0: usize,
}
pub struct CommonStack {
	
	virt_addr: usize,
	
	phys_addr: usize,
	
	total_size: usize,
}
pub enum TaskStacks {
	Boot(BootStack),
	Common(CommonStack),
}
impl TaskStacks {
	pub fn new(size: usize) -> TaskStacks {
		let user_stack_size = if size < KERNEL_STACK_SIZE {
			KERNEL_STACK_SIZE
		} else {
			align_up!(size, BasePageSize::SIZE)
		};
		let total_size = user_stack_size + DEFAULT_STACK_SIZE + KERNEL_STACK_SIZE;
		let virt_addr = ::arch::mm::virtualmem::allocate(total_size + 4 * BasePageSize::SIZE)
			.expect("Failed to allocate Virtual Memory for TaskStacks");
		let phys_addr = ::arch::mm::physicalmem::allocate(total_size)
			.expect("Failed to allocate Physical Memory for TaskStacks");
		debug!(
			"Create stacks at {:#X} with a size of {} KB",
			virt_addr,
			total_size >> 10
		);
		let mut flags = PageTableEntryFlags::empty();
		flags.normal().writable().execute_disable();
		
		::arch::mm::paging::map::<BasePageSize>(
			virt_addr + BasePageSize::SIZE,
			phys_addr,
			KERNEL_STACK_SIZE / BasePageSize::SIZE,
			flags,
		);
		
		::arch::mm::paging::map::<BasePageSize>(
			virt_addr + KERNEL_STACK_SIZE + 2 * BasePageSize::SIZE,
			phys_addr + KERNEL_STACK_SIZE,
			DEFAULT_STACK_SIZE / BasePageSize::SIZE,
			flags,
		);
		
		::arch::mm::paging::map::<BasePageSize>(
			virt_addr + KERNEL_STACK_SIZE + DEFAULT_STACK_SIZE + 3 * BasePageSize::SIZE,
			phys_addr + KERNEL_STACK_SIZE + DEFAULT_STACK_SIZE,
			user_stack_size / BasePageSize::SIZE,
			flags,
		);
		
		unsafe {
			ptr::write_bytes(
				(virt_addr + KERNEL_STACK_SIZE + DEFAULT_STACK_SIZE + 3 * BasePageSize::SIZE)
					as *mut u8,
				0xAC,
				user_stack_size,
			);
		}
		TaskStacks::Common(CommonStack {
			virt_addr: virt_addr,
			phys_addr: phys_addr,
			total_size: total_size,
		})
	}
	pub fn from_boot_stacks() -> TaskStacks {
		let tss = unsafe { &(*PERCORE.tss.get()) };
		let stack = tss.rsp[0] as usize + 0x10 - KERNEL_STACK_SIZE;
		debug!("Using boot stack {:#X}", stack);
		let ist0 = tss.ist[0] as usize + 0x10 - KERNEL_STACK_SIZE;
		debug!("IST0 is located at {:#X}", ist0);
		TaskStacks::Boot(BootStack {
			stack: stack,
			ist0: ist0,
		})
	}
	pub fn get_user_stack_size(&self) -> usize {
		match self {
			TaskStacks::Boot(_) => 0,
			TaskStacks::Common(stacks) => {
				stacks.total_size - DEFAULT_STACK_SIZE - KERNEL_STACK_SIZE
			}
		}
	}
	pub fn get_user_stack(&self) -> usize {
		match self {
			TaskStacks::Boot(_) => 0,
			TaskStacks::Common(stacks) => {
				stacks.virt_addr + KERNEL_STACK_SIZE + DEFAULT_STACK_SIZE + 3 * BasePageSize::SIZE
			}
		}
	}
	pub fn get_kernel_stack(&self) -> usize {
		match self {
			TaskStacks::Boot(stacks) => stacks.stack,
			TaskStacks::Common(stacks) => {
				stacks.virt_addr + KERNEL_STACK_SIZE + 2 * BasePageSize::SIZE
			}
		}
	}
	pub fn get_kernel_stack_size(&self) -> usize {
		match self {
			TaskStacks::Boot(_) => KERNEL_STACK_SIZE,
			TaskStacks::Common(_) => DEFAULT_STACK_SIZE,
		}
	}
	pub fn get_interupt_stack(&self) -> usize {
		match self {
			TaskStacks::Boot(stacks) => stacks.ist0,
			TaskStacks::Common(stacks) => stacks.virt_addr + BasePageSize::SIZE,
		}
	}
	pub fn get_interupt_stack_size(&self) -> usize {
		KERNEL_STACK_SIZE
	}
}
impl Drop for TaskStacks {
	fn drop(&mut self) {
		
		match self {
			TaskStacks::Boot(_) => {}
			TaskStacks::Common(stacks) => {
				debug!(
					"Deallocating stacks at {:#X} with a size of {} KB",
					stacks.virt_addr,
					stacks.total_size >> 10,
				);
				::arch::mm::paging::unmap::<BasePageSize>(
					stacks.virt_addr,
					stacks.total_size / BasePageSize::SIZE + 4,
				);
				::arch::mm::virtualmem::deallocate(
					stacks.virt_addr,
					stacks.total_size + 4 * BasePageSize::SIZE,
				);
				::arch::mm::physicalmem::deallocate(stacks.phys_addr, stacks.total_size);
			}
		}
	}
}
impl Clone for TaskStacks {
	fn clone(&self) -> TaskStacks {
		match self {
			TaskStacks::Boot(_) => TaskStacks::new(0),
			TaskStacks::Common(stacks) => {
				TaskStacks::new(stacks.total_size - DEFAULT_STACK_SIZE - KERNEL_STACK_SIZE)
			}
		}
	}
}
pub struct TaskTLS {
	address: usize,
	size: usize,
	fs: usize,
}
impl TaskTLS {
	pub fn new(tls_size: usize) -> Self {
		
		let tdata_size: usize = environment::get_tls_filesz();
		
		
		let tls_allocation_size = align_up!(tls_size, 32) + mem::size_of::<usize>();
		
		
		let memory_size = align_up!(tls_allocation_size, BasePageSize::SIZE);
		let ptr = ::mm::allocate(memory_size, true);
		
		let tls_pointer = ptr + align_up!(tls_size, 32);
		unsafe {
			
			ptr::copy_nonoverlapping(
				environment::get_tls_start() as *const u8,
				ptr as *mut u8,
				tdata_size,
			);
			ptr::write_bytes(
				(ptr + tdata_size) as *mut u8,
				0,
				align_up!(tls_size, 32) - tdata_size,
			);
			
			
			
			
			
			*(tls_pointer as *mut usize) = tls_pointer;
		}
		debug!(
			"Set up TLS at 0x{:x}, tdata_size 0x{:x}, tls_size 0x{:x}",
			tls_pointer, tdata_size, tls_size
		);
		Self {
			address: ptr,
			size: memory_size,
			fs: tls_pointer,
		}
	}
	#[inline]
	pub fn address(&self) -> usize {
		self.address
	}
	#[inline]
	pub fn get_fs(&self) -> usize {
		self.fs
	}
}
impl Drop for TaskTLS {
	fn drop(&mut self) {
		debug!(
			"Deallocate TLS at 0x{:x} (size 0x{:x})",
			self.address, self.size
		);
		mm::deallocate(self.address, self.size);
	}
}
impl Clone for TaskTLS {
	fn clone(&self) -> Self {
		TaskTLS::new(environment::get_tls_memsz())
	}
}
#[cfg(test)]
extern "C" fn task_entry(func: extern "C" fn(usize), arg: usize) {}
#[cfg(not(test))]
#[inline(never)]
#[naked]
extern "C" fn task_entry(func: extern "C" fn(usize), arg: usize) -> ! {
	
	switch_to_user!();
	func(arg);
	switch_to_kernel!();
	
	core_scheduler().exit(0);
}
impl TaskFrame for Task {
	fn create_stack_frame(&mut self, func: extern "C" fn(usize), arg: usize) {
		
		let tls_size = environment::get_tls_memsz();
		self.tls = if tls_size > 0 {
			Some(TaskTLS::new(tls_size))
		} else {
			None
		};
		unsafe {
			
			let mut stack = (self.stacks.get_kernel_stack() + self.stacks.get_kernel_stack_size()
				- 0x10) as *mut usize;
			*stack = 0xDEAD_BEEFusize;
			
			stack = (stack as usize - mem::size_of::<State>()) as *mut usize;
			let state = stack as *mut State;
			ptr::write_bytes(state as *mut u8, 0, mem::size_of::<State>());
			if let Some(tls) = &self.tls {
				(*state).fs = tls.get_fs();
			}
			(*state).rip = task_entry as usize;
			(*state).rdi = func as usize;
			(*state).rsi = arg as usize;
			
			(*state).rflags = 0x1202usize;
			
			self.last_stack_pointer = stack as usize;
			self.user_stack_pointer =
				self.stacks.get_user_stack() + self.stacks.get_user_stack_size() - 0x10;
		}
	}
}
extern "x86-interrupt" fn timer_handler(_stack_frame: &mut irq::ExceptionStackFrame) {
	core_scheduler().handle_waiting_tasks();
	apic::eoi();
	core_scheduler().scheduler();
}
pub fn install_timer_handler() {
	idt::set_gate(apic::TIMER_INTERRUPT_NUMBER, timer_handler as usize, 0);
}