From 6517fe26a2a0c89c3112f4a383c601572c71d64a Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Thu, 12 Mar 2015 17:38:04 -0700 Subject: Update to new privileged spec --- pk/atomic.h | 75 +++--- pk/bits.h | 36 +++ pk/console.c | 4 +- pk/device.c | 12 +- pk/elf.c | 73 ++++-- pk/emulation.c | 812 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ pk/encoding.h | 94 ++++--- pk/entry.S | 59 ++--- pk/file.c | 29 +-- pk/file.h | 10 +- pk/fp.c | 275 ------------------- pk/fp.h | 23 -- pk/fp_asm.S | 92 +------ pk/frontend.c | 27 +- pk/frontend.h | 15 ++ pk/handlers.c | 30 +-- pk/init.c | 83 +++--- pk/int.c | 89 ------- pk/mcall.h | 15 ++ pk/mentry.S | 250 ++++++++++++++++++ pk/minit.c | 62 +++++ pk/mtrap.c | 222 ++++++++++++++++ pk/mtrap.h | 232 +++++++++++++++++ pk/pk.S | 15 -- pk/pk.h | 39 ++- pk/pk.ld | 17 +- pk/pk.mk.in | 14 +- pk/sbi.S | 7 + pk/sbi.h | 27 ++ pk/sbi_entry.S | 61 +++++ pk/sbi_impl.c | 23 ++ pk/string.c | 22 ++ pk/syscall.c | 33 ++- pk/syscall.h | 20 +- pk/vm.c | 226 +++++++++------- pk/vm.h | 11 + 36 files changed, 2273 insertions(+), 861 deletions(-) create mode 100644 pk/bits.h create mode 100644 pk/emulation.c delete mode 100644 pk/fp.c delete mode 100644 pk/fp.h delete mode 100644 pk/int.c create mode 100644 pk/mcall.h create mode 100644 pk/mentry.S create mode 100644 pk/minit.c create mode 100644 pk/mtrap.c create mode 100644 pk/mtrap.h create mode 100644 pk/sbi.S create mode 100644 pk/sbi.h create mode 100644 pk/sbi_entry.S create mode 100644 pk/sbi_impl.c (limited to 'pk') diff --git a/pk/atomic.h b/pk/atomic.h index 24db8be..c2adf00 100644 --- a/pk/atomic.h +++ b/pk/atomic.h @@ -6,55 +6,40 @@ #include "config.h" #include "encoding.h" -typedef struct { volatile long val; } atomic_t; -typedef struct { atomic_t lock; } spinlock_t; -#define SPINLOCK_INIT {{0}} +#define disable_irqsave() clear_csr(sstatus, SSTATUS_IE) +#define enable_irqrestore(flags) set_csr(sstatus, (flags) & SSTATUS_IE) -#define mb() __sync_synchronize() - -static inline void atomic_set(atomic_t* a, long val) -{ - a->val = val; -} +typedef struct { int lock; } spinlock_t; +#define SPINLOCK_INIT {0} -static inline long atomic_read(atomic_t* a) -{ - return a->val; -} - -static inline long atomic_add(atomic_t* a, long inc) -{ -#ifdef PK_ENABLE_ATOMICS - return __sync_fetch_and_add(&a->val, inc); -#else - long ret = atomic_read(a); - atomic_set(a, ret + inc); - return ret; -#endif -} - -static inline long atomic_swap(atomic_t* a, long val) -{ -#ifdef PK_ENABLE_ATOMICS - return __sync_lock_test_and_set(&a->val, val); -#else - long ret = atomic_read(a); - atomic_set(a, val); - return ret; -#endif -} +#define mb() __sync_synchronize() +#define atomic_set(ptr, val) (*(volatile typeof(*(ptr)) *)(ptr) = val) +#define atomic_read(ptr) (*(volatile typeof(*(ptr)) *)(ptr)) -static inline long atomic_cas(atomic_t* a, long compare, long swap) -{ #ifdef PK_ENABLE_ATOMICS - return __sync_val_compare_and_swap(&a->val, compare, swap); +# define atomic_add(ptr, inc) __sync_fetch_and_add(ptr, inc) +# define atomic_swap(ptr, swp) __sync_lock_test_and_set(ptr, swp) +# define atomic_cas(ptr, cmp, swp) __sync_val_compare_and_swap(ptr, cmp, swp) #else - long ret = atomic_read(a); - if (ret == compare) - atomic_set(a, swap); - return ret; +# define atomic_add(ptr, inc) ({ \ + long flags = disable_irqsave(); \ + typeof(ptr) res = *(volatile typeof(ptr))(ptr); \ + *(volatile typeof(ptr))(ptr) = res + (inc); \ + enable_irqrestore(flags); \ + res; }) +# define atomic_swap(ptr, swp) ({ \ + long flags = disable_irqsave(); \ + typeof(ptr) res = *(volatile typeof(ptr))(ptr); \ + *(volatile typeof(ptr))(ptr) = (swp); \ + enable_irqrestore(flags); \ + res; }) +# define atomic_cas(ptr, cmp, swp) ({ \ + long flags = disable_irqsave(); \ + typeof(ptr) res = *(volatile typeof(ptr))(ptr); \ + if (res == (cmp)) *(volatile typeof(ptr))(ptr) = (swp); \ + enable_irqrestore(flags); \ + res; }) #endif -} static inline void spinlock_lock(spinlock_t* lock) { @@ -74,7 +59,7 @@ static inline void spinlock_unlock(spinlock_t* lock) static inline long spinlock_lock_irqsave(spinlock_t* lock) { - long flags = clear_csr(mstatus, MSTATUS_IE); + long flags = disable_irqsave(); spinlock_lock(lock); return flags; } @@ -82,7 +67,7 @@ static inline long spinlock_lock_irqsave(spinlock_t* lock) static inline void spinlock_unlock_irqrestore(spinlock_t* lock, long flags) { spinlock_unlock(lock); - set_csr(mstatus, flags & MSTATUS_IE); + enable_irqrestore(flags); } #endif diff --git a/pk/bits.h b/pk/bits.h new file mode 100644 index 0000000..e7fd8d3 --- /dev/null +++ b/pk/bits.h @@ -0,0 +1,36 @@ +#ifndef PK_BITS_H +#define PK_BITS_H + +#define CONST_POPCOUNT2(x) ((((x) >> 0) & 1) + (((x) >> 1) & 1)) +#define CONST_POPCOUNT4(x) (CONST_POPCOUNT2(x) + CONST_POPCOUNT2((x)>>2)) +#define CONST_POPCOUNT8(x) (CONST_POPCOUNT4(x) + CONST_POPCOUNT4((x)>>4)) +#define CONST_POPCOUNT16(x) (CONST_POPCOUNT8(x) + CONST_POPCOUNT8((x)>>8)) +#define CONST_POPCOUNT32(x) (CONST_POPCOUNT16(x) + CONST_POPCOUNT16((x)>>16)) +#define CONST_POPCOUNT64(x) (CONST_POPCOUNT32(x) + CONST_POPCOUNT32((x)>>32)) +#define CONST_POPCOUNT(x) CONST_POPCOUNT64(x) + +#define CONST_CTZ2(x) CONST_POPCOUNT2(((x) & -(x))-1) +#define CONST_CTZ4(x) CONST_POPCOUNT4(((x) & -(x))-1) +#define CONST_CTZ8(x) CONST_POPCOUNT8(((x) & -(x))-1) +#define CONST_CTZ16(x) CONST_POPCOUNT16(((x) & -(x))-1) +#define CONST_CTZ32(x) CONST_POPCOUNT32(((x) & -(x))-1) +#define CONST_CTZ64(x) CONST_POPCOUNT64(((x) & -(x))-1) +#define CONST_CTZ(x) CONST_CTZ64(x) + +#define STR(x) XSTR(x) +#define XSTR(x) #x + +#ifdef __riscv64 +# define SLL32 sllw +# define STORE sd +# define LOAD ld +# define LOG_REGBYTES 3 +#else +# define SLL32 sll +# define STORE sw +# define LOAD lw +# define LOG_REGBYTES 2 +#endif +#define REGBYTES (1 << LOG_REGBYTES) + +#endif diff --git a/pk/console.c b/pk/console.c index 366c313..a15cb8a 100644 --- a/pk/console.c +++ b/pk/console.c @@ -126,8 +126,8 @@ void dump_tf(trapframe_t* tf) for(int j = 0; j < 4; j++) printk("%s %lx%c",regnames[i+j],tf->gpr[i+j],j < 3 ? ' ' : '\n'); } - printk("pc %lx va %lx insn %x\n", tf->epc, tf->badvaddr, - (uint32_t)tf->insn); + printk("pc %lx va %lx insn %x sr %lx\n", tf->epc, tf->badvaddr, + (uint32_t)tf->insn, tf->status); } void do_panic(const char* s, ...) diff --git a/pk/device.c b/pk/device.c index 73cd71d..f8b39ca 100644 --- a/pk/device.c +++ b/pk/device.c @@ -1,18 +1,8 @@ #include "pk.h" +#include "frontend.h" #include #include -static uint64_t tohost_sync(unsigned dev, unsigned cmd, uint64_t payload) -{ - uint64_t tohost = (uint64_t)dev << 56 | (uint64_t)cmd << 48 | payload; - uint64_t fromhost; - __sync_synchronize(); - while (swap_csr(tohost, tohost) != 0); - while ((fromhost = swap_csr(fromhost, 0)) == 0); - __sync_synchronize(); - return fromhost; -} - void enumerate_devices() { char buf[64] __attribute__((aligned(64))); diff --git a/pk/elf.c b/pk/elf.c index 28760eb..dde28ed 100644 --- a/pk/elf.c +++ b/pk/elf.c @@ -21,36 +21,65 @@ void load_elf(const char* fn, elf_info* info) eh64.e_ident[2] == 'L' && eh64.e_ident[3] == 'F')) goto fail; - size_t bias = 0; - extern char _end; - if (eh64.e_type == ET_DYN) - bias = ROUNDUP((uintptr_t)&_end, RISCV_PGSIZE); + uintptr_t min_vaddr = -1, max_vaddr = 0; #define LOAD_ELF do { \ eh = (typeof(eh))&eh64; \ size_t phdr_size = eh->e_phnum*sizeof(*ph); \ - if (info->phdr_top - phdr_size < info->stack_bottom) \ + if (phdr_size > info->phdr_size) \ goto fail; \ - info->phdr = info->phdr_top - phdr_size; \ ssize_t ret = file_pread(file, (void*)info->phdr, phdr_size, eh->e_phoff); \ - if (ret < (ssize_t)phdr_size) goto fail; \ - info->entry = bias + eh->e_entry; \ + if (ret < (ssize_t)phdr_size) \ + goto fail; \ info->phnum = eh->e_phnum; \ info->phent = sizeof(*ph); \ ph = (typeof(ph))info->phdr; \ - for(int i = 0; i < eh->e_phnum; i++, ph++) { \ - if(ph->p_type == PT_LOAD && ph->p_memsz) { \ - info->brk_min = MAX(info->brk_min, bias + ph->p_vaddr + ph->p_memsz); \ - size_t vaddr = ROUNDDOWN(ph->p_vaddr, RISCV_PGSIZE), prepad = ph->p_vaddr - vaddr; \ - size_t memsz = ph->p_memsz + prepad, filesz = ph->p_filesz + prepad; \ - size_t offset = ph->p_offset - prepad; \ - vaddr += bias; \ - if (__do_mmap(vaddr, filesz, -1, MAP_FIXED|MAP_PRIVATE, file, offset) != vaddr) \ - goto fail; \ - size_t mapped = ROUNDUP(filesz, RISCV_PGSIZE); \ - if (memsz > mapped) \ - if (__do_mmap(vaddr + mapped, memsz - mapped, -1, MAP_FIXED|MAP_PRIVATE|MAP_ANONYMOUS, 0, 0) != vaddr + mapped) \ + info->is_supervisor = (eh->e_entry >> (8*sizeof(eh->e_entry)-1)) != 0; \ + if (info->is_supervisor) \ + info->first_free_paddr = ROUNDUP(info->first_free_paddr, SUPERPAGE_SIZE); \ + for (int i = 0; i < eh->e_phnum; i++) \ + if (ph[i].p_type == PT_LOAD && ph[i].p_memsz && ph[i].p_vaddr < min_vaddr) \ + min_vaddr = ph[i].p_vaddr; \ + if (info->is_supervisor) \ + min_vaddr = ROUNDDOWN(min_vaddr, SUPERPAGE_SIZE); \ + else \ + min_vaddr = ROUNDDOWN(min_vaddr, RISCV_PGSIZE); \ + uintptr_t bias = 0; \ + if (info->is_supervisor || eh->e_type == ET_DYN) \ + bias = info->first_free_paddr - min_vaddr; \ + info->entry = eh->e_entry; \ + if (!info->is_supervisor) { \ + info->entry += bias; \ + min_vaddr += bias; \ + } \ + info->bias = bias; \ + int flags = MAP_FIXED | MAP_PRIVATE; \ + if (info->is_supervisor) \ + flags |= MAP_POPULATE; \ + for (int i = eh->e_phnum - 1; i >= 0; i--) { \ + if(ph[i].p_type == PT_LOAD && ph[i].p_memsz) { \ + uintptr_t prepad = ph[i].p_vaddr % RISCV_PGSIZE; \ + uintptr_t vaddr = ph[i].p_vaddr + bias; \ + if (vaddr + ph[i].p_memsz > max_vaddr) \ + max_vaddr = vaddr + ph[i].p_memsz; \ + if (info->is_supervisor) { \ + if (!__valid_user_range(vaddr - prepad, vaddr + ph[i].p_memsz)) \ + goto fail; \ + ret = file_pread(file, (void*)vaddr, ph[i].p_filesz, ph[i].p_offset); \ + if (ret < (ssize_t)ph[i].p_filesz) \ goto fail; \ + memset((void*)vaddr - prepad, 0, prepad); \ + memset((void*)vaddr + ph[i].p_filesz, 0, ph[i].p_memsz - ph[i].p_filesz); \ + } else { \ + int flags2 = flags | (prepad ? MAP_POPULATE : 0); \ + if (__do_mmap(vaddr - prepad, ph[i].p_filesz + prepad, -1, flags2, file, ph[i].p_offset - prepad) != vaddr) \ + goto fail; \ + memset((void*)vaddr - prepad, 0, prepad); \ + size_t mapped = ROUNDUP(ph[i].p_filesz + prepad, RISCV_PGSIZE) - prepad; \ + if (ph[i].p_memsz > mapped) \ + if (__do_mmap(vaddr + mapped, ph[i].p_memsz - mapped, -1, flags|MAP_ANONYMOUS, 0, 0) != vaddr + mapped) \ + goto fail; \ + } \ } \ } \ } while(0) @@ -71,6 +100,10 @@ void load_elf(const char* fn, elf_info* info) else goto fail; + info->first_user_vaddr = min_vaddr; + info->first_vaddr_after_user = ROUNDUP(max_vaddr - info->bias, RISCV_PGSIZE); + info->brk_min = max_vaddr; + file_decref(file); return; diff --git a/pk/emulation.c b/pk/emulation.c new file mode 100644 index 0000000..25c4de0 --- /dev/null +++ b/pk/emulation.c @@ -0,0 +1,812 @@ +#include "mtrap.h" +#include "softfloat.h" +#include + +DECLARE_EMULATION_FUNC(truly_illegal_insn) +{ + return -1; +} + +uintptr_t misaligned_load_trap(uintptr_t mcause, uintptr_t* regs) +{ + uintptr_t mstatus = read_csr(mstatus); + uintptr_t mepc = read_csr(mepc); + insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); + if (fetch.error) + return -1; + + uintptr_t val, res, tmp; + uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_I(fetch.insn); + + #define DO_LOAD(type_lo, type_hi, insn_lo, insn_hi) ({ \ + type_lo val_lo; \ + type_hi val_hi; \ + uintptr_t addr_lo = addr & -sizeof(type_hi); \ + uintptr_t addr_hi = addr_lo + sizeof(type_hi); \ + uintptr_t masked_addr = sizeof(type_hi) < 4 ? addr % sizeof(type_hi) : addr; \ + res = unpriv_mem_access(mstatus, mepc, \ + insn_lo " %[val_lo], (%[addr_lo]);" \ + insn_hi " %[val_hi], (%[addr_hi])", \ + val_lo, val_hi, addr_lo, addr_hi); \ + val_lo >>= masked_addr * 8; \ + val_hi <<= (sizeof(type_hi) - masked_addr) * 8; \ + val = (type_hi)(val_lo | val_hi); \ + }) + + if ((fetch.insn & MASK_LW) == MATCH_LW) + DO_LOAD(uint32_t, int32_t, "lw", "lw"); +#ifdef __riscv64 + else if ((fetch.insn & MASK_LD) == MATCH_LD) + DO_LOAD(uint64_t, uint64_t, "ld", "ld"); + else if ((fetch.insn & MASK_LWU) == MATCH_LWU) + DO_LOAD(uint32_t, uint32_t, "lwu", "lwu"); +#endif + else if ((fetch.insn & MASK_FLD) == MATCH_FLD) { +#ifdef __riscv64 + DO_LOAD(uint64_t, uint64_t, "ld", "ld"); + if (res == 0) { + SET_F64_RD(fetch.insn, regs, val); + goto success; + } +#else + DO_LOAD(uint32_t, int32_t, "lw", "lw"); + if (res == 0) { + uint64_t double_val = val; + addr += 4; + DO_LOAD(uint32_t, int32_t, "lw", "lw"); + double_val |= (uint64_t)val << 32; + if (res == 0) { + SET_F64_RD(fetch.insn, regs, val); + goto success; + } + } +#endif + } else if ((fetch.insn & MASK_FLW) == MATCH_FLW) { + DO_LOAD(uint32_t, uint32_t, "lw", "lw"); + if (res == 0) { + SET_F32_RD(fetch.insn, regs, val); + goto success; + } + } else if ((fetch.insn & MASK_LH) == MATCH_LH) { + // equivalent to DO_LOAD(uint32_t, int16_t, "lhu", "lh") + res = unpriv_mem_access(mstatus, mepc, + "lbu %[val], 0(%[addr]);" + "lb %[tmp], 1(%[addr])", + val, tmp, addr, mstatus/*X*/); + val |= tmp << 8; + } else if ((fetch.insn & MASK_LHU) == MATCH_LHU) { + // equivalent to DO_LOAD(uint32_t, uint16_t, "lhu", "lhu") + res = unpriv_mem_access(mstatus, mepc, + "lbu %[val], 0(%[addr]);" + "lbu %[tmp], 1(%[addr])", + val, tmp, addr, mstatus/*X*/); + val |= tmp << 8; + } else { + return -1; + } + + if (res) { + restore_mstatus(mstatus, mepc); + return -1; + } + + SET_RD(fetch.insn, regs, val); + +success: + write_csr(mepc, mepc + 4); + return 0; +} + +uintptr_t misaligned_store_trap(uintptr_t mcause, uintptr_t* regs) +{ + uintptr_t mstatus = read_csr(mstatus); + uintptr_t mepc = read_csr(mepc); + insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); + if (fetch.error) + return -1; + + uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_S(fetch.insn); + uintptr_t val = GET_RS2(fetch.insn, regs), error; + + if ((fetch.insn & MASK_SW) == MATCH_SW) { +SW: + error = unpriv_mem_access(mstatus, mepc, + "sb %[val], 0(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 1(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 2(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 3(%[addr]);", + unused1, unused2, val, addr); +#ifdef __riscv64 + } else if ((fetch.insn & MASK_SD) == MATCH_SD) { +SD: + error = unpriv_mem_access(mstatus, mepc, + "sb %[val], 0(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 1(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 2(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 3(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 4(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 5(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 6(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 7(%[addr]);", + unused1, unused2, val, addr); +#endif + } else if ((fetch.insn & MASK_SH) == MATCH_SH) { + error = unpriv_mem_access(mstatus, mepc, + "sb %[val], 0(%[addr]);" + "srl %[val], %[val], 8; sb %[val], 1(%[addr]);", + unused1, unused2, val, addr); + } else if ((fetch.insn & MASK_FSD) == MATCH_FSD) { +#ifdef __riscv64 + val = GET_F64_RS2(fetch.insn, regs); + goto SD; +#else + uint64_t double_val = GET_F64_RS2(fetch.insn, regs); + uint32_t val_lo = double_val, val_hi = double_val >> 32; + error = unpriv_mem_access(mstatus, mepc, + "sb %[val_lo], 0(%[addr]);" + "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 1(%[addr]);" + "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 2(%[addr]);" + "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 3(%[addr]);" + "sb %[val_hi], 4(%[addr]);" + "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 5(%[addr]);" + "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 6(%[addr]);" + "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 7(%[addr]);", + unused1, unused2, val_lo, val_hi, addr); +#endif + } else if ((fetch.insn & MASK_FSW) == MATCH_FSW) { + val = GET_F32_RS2(fetch.insn, regs); + goto SW; + } else + return -1; + + if (error) { + restore_mstatus(mstatus, mepc); + return -1; + } + + write_csr(mepc, mepc + 4); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_float_load) +{ + uintptr_t val_lo, val_hi, error; + uint64_t val; + uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn); + + switch (insn & MASK_FUNCT3) + { + case MATCH_FLW & MASK_FUNCT3: + if (addr % 4 != 0) + return misaligned_load_trap(mcause, regs); + + error = unpriv_mem_access(mstatus, mepc, + "lw %[val_lo], (%[addr])", + val_lo, val_hi/*X*/, addr, mstatus/*X*/); + + if (error == 0) { + SET_F32_RD(insn, regs, val_lo); + goto success; + } + break; + + case MATCH_FLD & MASK_FUNCT3: + if (addr % sizeof(uintptr_t) != 0) + return misaligned_load_trap(mcause, regs); +#ifdef __riscv64 + error = unpriv_mem_access(mstatus, mepc, + "ld %[val], (%[addr])", + val, val_hi/*X*/, addr, mstatus/*X*/); +#else + error = unpriv_mem_access(mstatus, mepc, + "lw %[val_lo], (%[addr]);" + "lw %[val_hi], 4(%[addr])", + val_lo, val_hi, addr, mstatus/*X*/); + val = val_lo | ((uint64_t)val_hi << 32); +#endif + + if (error == 0) { + SET_F64_RD(insn, regs, val); + goto success; + } + break; + } + + restore_mstatus(mstatus, mepc); + return -1; + +success: + write_csr(mepc, mepc + 4); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_float_store) +{ + uintptr_t val_lo, val_hi, error; + uint64_t val; + uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn); + + switch (insn & MASK_FUNCT3) + { + case MATCH_FSW & MASK_FUNCT3: + if (addr % 4 != 0) + return misaligned_store_trap(mcause, regs); + + val_lo = GET_F32_RS2(insn, regs); + error = unpriv_mem_access(mstatus, mepc, + "sw %[val_lo], (%[addr])", + unused1, unused2, val_lo, addr); + break; + + case MATCH_FSD & MASK_FUNCT3: + if (addr % sizeof(uintptr_t) != 0) + return misaligned_store_trap(mcause, regs); + + val = GET_F64_RS2(insn, regs); +#ifdef __riscv64 + error = unpriv_mem_access(mstatus, mepc, + "sd %[val], (%[addr])", + unused1, unused2, val, addr); +#else + val_lo = val; + val_hi = val >> 32; + error = unpriv_mem_access(mstatus, mepc, + "sw %[val_lo], (%[addr]);" + "sw %[val_hi], 4(%[addr])", + unused1, unused2, val_lo, val_hi, addr); +#endif + break; + + default: + error = 1; + } + + if (error) { + restore_mstatus(mstatus, mepc); + return -1; + } + + write_csr(mepc, mepc + 4); + return 0; +} + +#ifdef __riscv64 +typedef int double_int __attribute__((mode(TI))); +typedef unsigned int double_uint __attribute__((mode(TI))); +#else +typedef int64_t double_int; +typedef uint64_t double_int; +#endif + +DECLARE_EMULATION_FUNC(emulate_mul_div) +{ + uintptr_t rs1 = GET_RS1(insn, regs), rs2 = GET_RS2(insn, regs), val; + + // If compiled with -mno-multiply, GCC will expand these out + if ((insn & MASK_MUL) == MATCH_MUL) + val = rs1 * rs2; + else if ((insn & MASK_DIV) == MATCH_DIV) + val = (intptr_t)rs1 / (intptr_t)rs2; + else if ((insn & MASK_DIVU) == MATCH_DIVU) + val = rs1 / rs2; + else if ((insn & MASK_REM) == MATCH_REM) + val = (intptr_t)rs1 % (intptr_t)rs2; + else if ((insn & MASK_REMU) == MATCH_REMU) + val = rs1 % rs2; + else if ((insn & MASK_MULH) == MATCH_MULH) + val = ((double_int)(intptr_t)rs1 * (double_int)(intptr_t)rs2) >> (8 * sizeof(rs1)); + else if ((insn & MASK_MULHU) == MATCH_MULHU) + val = ((double_int)rs1 * (double_int)rs2) >> (8 * sizeof(rs1)); + else if ((insn & MASK_MULHSU) == MATCH_MULHSU) + val = ((double_int)(intptr_t)rs1 * (double_int)rs2) >> (8 * sizeof(rs1)); + else + return -1; + + SET_RD(insn, regs, val); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_mul_div32) +{ +#ifndef __riscv64 + return truly_illegal_insn(mcause, regs, insn, mstatus, mepc); +#endif + + uint32_t rs1 = GET_RS1(insn, regs), rs2 = GET_RS2(insn, regs); + int32_t val; + + // If compiled with -mno-multiply, GCC will expand these out + if ((insn & MASK_MUL) == MATCH_MULW) + val = rs1 * rs2; + else if ((insn & MASK_DIV) == MATCH_DIV) + val = (int32_t)rs1 / (int32_t)rs2; + else if ((insn & MASK_DIVU) == MATCH_DIVU) + val = rs1 / rs2; + else if ((insn & MASK_REM) == MATCH_REM) + val = (int32_t)rs1 % (int32_t)rs2; + else if ((insn & MASK_REMU) == MATCH_REMU) + val = rs1 % rs2; + else + return -1; + + SET_RD(insn, regs, val); + return 0; +} + +static inline int emulate_read_csr(int num, uintptr_t* result, uintptr_t mstatus) +{ + switch (num) + { + case CSR_FRM: + if ((mstatus & MSTATUS_FS) == 0) break; + *result = GET_FRM(); + return 0; + case CSR_FFLAGS: + if ((mstatus & MSTATUS_FS) == 0) break; + *result = GET_FFLAGS(); + return 0; + case CSR_FCSR: + if ((mstatus & MSTATUS_FS) == 0) break; + *result = GET_FCSR(); + return 0; + } + return -1; +} + +static inline int emulate_write_csr(int num, uintptr_t value, uintptr_t mstatus) +{ + switch (num) + { + case CSR_FRM: SET_FRM(value); return 0; + case CSR_FFLAGS: SET_FFLAGS(value); return 0; + case CSR_FCSR: SET_FCSR(value); return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_system) +{ + int rs1_num = (insn >> 15) & 0x1f; + uintptr_t rs1_val = GET_RS1(insn, regs); + int csr_num = (uint32_t)insn >> 20; + uintptr_t csr_val, new_csr_val; + + if (emulate_read_csr(csr_num, &csr_val, mstatus) != 0) + return -1; + + int do_write = rs1_num; + switch (GET_RM(insn)) + { + case 0: return -1; + case 1: new_csr_val = rs1_val; do_write = 1; break; + case 2: new_csr_val = csr_val | rs1_val; break; + case 3: new_csr_val = csr_val & ~rs1_val; break; + case 4: return -1; + case 5: new_csr_val = rs1_num; do_write = 1; break; + case 6: new_csr_val = csr_val | rs1_num; break; + case 7: new_csr_val = csr_val & ~rs1_num; break; + } + + if (do_write && emulate_write_csr(csr_num, new_csr_val, mstatus) != 0) + return -1; + + SET_RD(insn, regs, csr_val); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_fp) +{ + asm (".pushsection .rodata\n" + "fp_emulation_table:\n" + " .word emulate_fadd\n" + " .word emulate_fsub\n" + " .word emulate_fmul\n" + " .word emulate_fdiv\n" + " .word emulate_fsgnj\n" + " .word emulate_fmin\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_fcvt_ff\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_fsqrt\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_fcmp\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_fcvt_if\n" + " .word truly_illegal_insn\n" + " .word emulate_fcvt_fi\n" + " .word truly_illegal_insn\n" + " .word emulate_fmv_if\n" + " .word truly_illegal_insn\n" + " .word emulate_fmv_fi\n" + " .word truly_illegal_insn\n" + " .popsection"); + + // if FPU is disabled, punt back to the OS + if (unlikely((mstatus & MSTATUS_FS) == 0)) + return -1; + + extern int32_t fp_emulation_table[]; + int32_t* pf = (void*)fp_emulation_table + ((insn >> 25) & 0x7c); + emulation_func f = (emulation_func)(uintptr_t)*pf; + + SETUP_STATIC_ROUNDING(insn); + return f(mcause, regs, insn, mstatus, mepc); +} + +uintptr_t emulate_any_fadd(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t neg_b) +{ + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs) ^ neg_b; + SET_F32_RD(insn, regs, f32_add(rs1, rs2)); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs) ^ ((uint64_t)neg_b << 32); + SET_F64_RD(insn, regs, f64_add(rs1, rs2)); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fadd) +{ + return emulate_any_fadd(mcause, regs, insn, 0); +} + +DECLARE_EMULATION_FUNC(emulate_fsub) +{ + return emulate_any_fadd(mcause, regs, insn, INT32_MIN); +} + +DECLARE_EMULATION_FUNC(emulate_fmul) +{ + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + SET_F32_RD(insn, regs, f32_mul(rs1, rs2)); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + SET_F64_RD(insn, regs, f64_mul(rs1, rs2)); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fdiv) +{ + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + SET_F32_RD(insn, regs, f32_div(rs1, rs2)); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + SET_F64_RD(insn, regs, f64_div(rs1, rs2)); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fsqrt) +{ + if ((insn >> 20) & 0x1f) + return -1; + + if (GET_PRECISION(insn) == PRECISION_S) { + SET_F32_RD(insn, regs, f32_sqrt(GET_F32_RS1(insn, regs))); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + SET_F64_RD(insn, regs, f64_sqrt(GET_F64_RS1(insn, regs))); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fsgnj) +{ + int rm = GET_RM(insn); + if (rm >= 3) + return -1; + + #define DO_FSGNJ(rs1, rs2, rm) ({ \ + typeof(rs1) rs1_sign = (rs1) >> (8*sizeof(rs1)-1); \ + typeof(rs1) rs2_sign = (rs2) >> (8*sizeof(rs1)-1); \ + rs1_sign &= (rm) >> 1; \ + rs1_sign ^= (rm) ^ rs2_sign; \ + ((rs1) << 1 >> 1) | (rs1_sign << (8*sizeof(rs1)-1)); }) + + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + SET_F32_RD(insn, regs, DO_FSGNJ(rs1, rs2, rm)); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + SET_F64_RD(insn, regs, DO_FSGNJ(rs1, rs2, rm)); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fmin) +{ + int rm = GET_RM(insn); + if (rm >= 2) + return -1; + + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + uint32_t arg1 = rm ? rs2 : rs1; + uint32_t arg2 = rm ? rs1 : rs2; + int use_rs1 = f32_lt_quiet(arg1, arg2) || isNaNF32UI(rs2); + SET_F32_RD(insn, regs, use_rs1 ? rs1 : rs2); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + uint64_t arg1 = rm ? rs2 : rs1; + uint64_t arg2 = rm ? rs1 : rs2; + int use_rs1 = f64_lt_quiet(arg1, arg2) || isNaNF64UI(rs2); + SET_F64_RD(insn, regs, use_rs1 ? rs1 : rs2); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fcvt_ff) +{ + int rs2_num = (insn >> 20) & 0x1f; + if (GET_PRECISION(insn) == PRECISION_S) { + if (rs2_num != 1) + return -1; + SET_F32_RD(insn, regs, f64_to_f32(GET_F64_RS1(insn, regs))); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + if (rs2_num != 0) + return -1; + SET_F64_RD(insn, regs, f32_to_f64(GET_F32_RS1(insn, regs))); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fcvt_fi) +{ + if (GET_PRECISION(insn) != PRECISION_S && GET_PRECISION(insn) != PRECISION_D) + return -1; + + int negative = 0; + uint64_t uint_val = GET_RS1(insn, regs); + + switch ((insn >> 20) & 0x1f) + { + case 0: // int32 + negative = (int32_t)uint_val < 0; + uint_val = negative ? -(int32_t)uint_val : (int32_t)uint_val; + break; + case 1: // uint32 + uint_val = (uint32_t)uint_val; + break; +#ifdef __riscv64 + case 2: // int64 + negative = (int64_t)uint_val < 0; + uint_val = negative ? -uint_val : uint_val; + case 3: // uint64 + break; +#endif + default: + return -1; + } + + uint64_t float64 = ui64_to_f64(uint_val); + if (negative) + float64 ^= INT64_MIN; + + if (GET_PRECISION(insn) == PRECISION_S) + SET_F32_RD(insn, regs, f64_to_f32(float64)); + else + SET_F64_RD(insn, regs, float64); + + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_fcvt_if) +{ + int rs2_num = (insn >> 20) & 0x1f; +#ifdef __riscv64 + if (rs2_num >= 4) + return -1; +#else + if (rs2_num >= 2) + return -1; +#endif + + int64_t float64; + if (GET_PRECISION(insn) == PRECISION_S) + float64 = f32_to_f64(GET_F32_RS1(insn, regs)); + else if (GET_PRECISION(insn) == PRECISION_D) + float64 = GET_F64_RS1(insn, regs); + else + return -1; + + int negative = 0; + if (float64 < 0) { + negative = 1; + float64 ^= INT64_MIN; + } + uint64_t uint_val = f64_to_ui64(float64, softfloat_roundingMode, true); + uint64_t result, limit, limit_result; + + switch (rs2_num) + { + case 0: // int32 + if (negative) { + result = (int32_t)-uint_val; + limit_result = limit = (uint32_t)INT32_MIN; + } else { + result = (int32_t)uint_val; + limit_result = limit = INT32_MAX; + } + break; + + case 1: // uint32 + limit = limit_result = UINT32_MAX; + if (negative) + result = limit = 0; + else + result = (uint32_t)uint_val; + break; + + case 2: // int32 + if (negative) { + result = (int64_t)-uint_val; + limit_result = limit = (uint64_t)INT64_MIN; + } else { + result = (int64_t)uint_val; + limit_result = limit = INT64_MAX; + } + break; + + case 3: // uint64 + limit = limit_result = UINT64_MAX; + if (negative) + result = limit = 0; + else + result = (uint64_t)uint_val; + break; + } + + if (uint_val > limit) { + result = limit_result; + softfloat_raiseFlags(softfloat_flag_invalid); + } + + SET_FS_DIRTY(); + SET_RD(insn, regs, result); + + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_fcmp) +{ + int rm = GET_RM(insn); + if (rm >= 3) + return -1; + + uintptr_t result; + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + if (rm != 1) + result = f32_eq(rs1, rs2); + if (rm == 1 || (rm == 0 && !result)) + result = f32_lt(rs1, rs2); + goto success; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + if (rm != 1) + result = f64_eq(rs1, rs2); + if (rm == 1 || (rm == 0 && !result)) + result = f64_lt(rs1, rs2); + goto success; + } + return -1; +success: + SET_RD(insn, regs, result); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_fmv_if) +{ + uintptr_t result; + if ((insn & MASK_FMV_X_S) == MATCH_FMV_X_S) + result = GET_F32_RS1(insn, regs); +#ifdef __riscv64 + else if ((insn & MASK_FMV_X_D) == MATCH_FMV_X_D) + result = GET_F64_RS1(insn, regs); +#endif + else + return -1; + + SET_RD(insn, regs, result); + return 0; +} + +DECLARE_EMULATION_FUNC(emulate_fmv_fi) +{ + uintptr_t rs1 = GET_RS1(insn, regs); + + if ((insn & MASK_FMV_S_X) == MATCH_FMV_S_X) + SET_F32_RD(insn, regs, rs1); + else if ((insn & MASK_FMV_D_X) == MATCH_FMV_D_X) + SET_F64_RD(insn, regs, rs1); + else + return -1; + + return 0; +} + +uintptr_t emulate_any_fmadd(int op, uintptr_t* regs, insn_t insn, uintptr_t mstatus) +{ + // if FPU is disabled, punt back to the OS + if (unlikely((mstatus & MSTATUS_FS) == 0)) + return -1; + + SETUP_STATIC_ROUNDING(insn); + if (GET_PRECISION(insn) == PRECISION_S) { + uint32_t rs1 = GET_F32_RS1(insn, regs); + uint32_t rs2 = GET_F32_RS2(insn, regs); + uint32_t rs3 = GET_F32_RS3(insn, regs); + SET_F32_RD(insn, regs, softfloat_mulAddF32(op, rs1, rs2, rs3)); + return 0; + } else if (GET_PRECISION(insn) == PRECISION_D) { + uint64_t rs1 = GET_F64_RS1(insn, regs); + uint64_t rs2 = GET_F64_RS2(insn, regs); + uint64_t rs3 = GET_F64_RS3(insn, regs); + SET_F64_RD(insn, regs, softfloat_mulAddF64(op, rs1, rs2, rs3)); + return 0; + } + return -1; +} + +DECLARE_EMULATION_FUNC(emulate_fmadd) +{ + int op = 0; + return emulate_any_fmadd(op, regs, insn, mstatus); +} + +DECLARE_EMULATION_FUNC(emulate_fmsub) +{ + int op = softfloat_mulAdd_subC; + return emulate_any_fmadd(op, regs, insn, mstatus); +} + +DECLARE_EMULATION_FUNC(emulate_fnmadd) +{ + int op = softfloat_mulAdd_subC | softfloat_mulAdd_subProd; + return emulate_any_fmadd(op, regs, insn, mstatus); +} + +DECLARE_EMULATION_FUNC(emulate_fnmsub) +{ + int op = softfloat_mulAdd_subProd; + return emulate_any_fmadd(op, regs, insn, mstatus); +} diff --git a/pk/encoding.h b/pk/encoding.h index d20ee5b..4b929d3 100644 --- a/pk/encoding.h +++ b/pk/encoding.h @@ -14,8 +14,7 @@ #define MSTATUS_PRV2 0x00001800 #define MSTATUS_IE3 0x00002000 #define MSTATUS_PRV3 0x0000C000 -#define MSTATUS_IE4 0x00010000 -#define MSTATUS_PRV4 0x00060000 +#define MSTATUS_MPRV 0x00030000 #define MSTATUS_VM 0x00780000 #define MSTATUS_STIE 0x01000000 #define MSTATUS_HTIE 0x02000000 @@ -28,6 +27,18 @@ #define MSTATUS64_HA 0x00000F0000000000 #define MSTATUS64_SD 0x8000000000000000 +#define SSTATUS_SIP 0x00000002 +#define SSTATUS_IE 0x00000010 +#define SSTATUS_PIE 0x00000080 +#define SSTATUS_PS 0x00000100 +#define SSTATUS_UA 0x000F0000 +#define SSTATUS_TIE 0x01000000 +#define SSTATUS_TIP 0x04000000 +#define SSTATUS_FS 0x18000000 +#define SSTATUS_XS 0x60000000 +#define SSTATUS32_SD 0x80000000 +#define SSTATUS64_SD 0x8000000000000000 + #define PRV_U 0 #define PRV_S 1 #define PRV_H 2 @@ -70,10 +81,12 @@ # define MSTATUS_SA MSTATUS64_SA # define MSTATUS_HA MSTATUS64_HA # define MSTATUS_SD MSTATUS64_SD +# define SSTATUS_SD SSTATUS64_SD # define RISCV_PGLEVELS 3 # define RISCV_PGSHIFT 13 #else # define MSTATUS_SD MSTATUS32_SD +# define SSTATUS_SD SSTATUS32_SD # define RISCV_PGLEVELS 2 # define RISCV_PGSHIFT 12 #endif @@ -82,7 +95,9 @@ #ifndef __ASSEMBLER__ -#define read_csr(reg) ({ long __tmp; \ +#ifdef __GNUC__ + +#define read_csr(reg) ({ unsigned long __tmp; \ asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \ __tmp; }) @@ -93,31 +108,25 @@ asm volatile ("csrrw %0, " #reg ", %1" : "=r"(__tmp) : "r"(val)); \ __tmp; }) -#define set_csr(reg, bit) ({ long __tmp; \ +#define set_csr(reg, bit) ({ unsigned long __tmp; \ if (__builtin_constant_p(bit) && (bit) < 32) \ asm volatile ("csrrs %0, " #reg ", %1" : "=r"(__tmp) : "i"(bit)); \ else \ asm volatile ("csrrs %0, " #reg ", %1" : "=r"(__tmp) : "r"(bit)); \ __tmp; }) -#define clear_csr(reg, bit) ({ long __tmp; \ +#define clear_csr(reg, bit) ({ unsigned long __tmp; \ if (__builtin_constant_p(bit) && (bit) < 32) \ asm volatile ("csrrc %0, " #reg ", %1" : "=r"(__tmp) : "i"(bit)); \ else \ asm volatile ("csrrc %0, " #reg ", %1" : "=r"(__tmp) : "r"(bit)); \ __tmp; }) -#define rdtime() ({ unsigned long __tmp; \ - asm volatile ("rdtime %0" : "=r"(__tmp)); \ - __tmp; }) - -#define rdcycle() ({ unsigned long __tmp; \ - asm volatile ("rdcycle %0" : "=r"(__tmp)); \ - __tmp; }) +#define rdtime() read_csr(time) +#define rdcycle() read_csr(cycle) +#define rdinstret() read_csr(instret) -#define rdinstret() ({ unsigned long __tmp; \ - asm volatile ("rdinstret %0" : "=r"(__tmp)); \ - __tmp; }) +#endif #endif @@ -251,6 +260,8 @@ #define MASK_MULH 0xfe00707f #define MATCH_FMUL_S 0x10000053 #define MASK_FMUL_S 0xfe00007f +#define MATCH_MCALL 0x20000073 +#define MASK_MCALL 0xffffffff #define MATCH_CSRRSI 0x6073 #define MASK_CSRRSI 0x707f #define MATCH_SRAI 0x40005013 @@ -291,6 +302,8 @@ #define MASK_FSUB_D 0xfe00007f #define MATCH_FSGNJX_S 0x20002053 #define MASK_FSGNJX_S 0xfe00707f +#define MATCH_MRTS 0x30900073 +#define MASK_MRTS 0xffffffff #define MATCH_FEQ_D 0xa2002053 #define MASK_FEQ_D 0xfe00707f #define MATCH_FCVT_D_WU 0xd2100053 @@ -407,8 +420,6 @@ #define MASK_FMADD_S 0x600007f #define MATCH_FSQRT_S 0x58000053 #define MASK_FSQRT_S 0xfff0007f -#define MATCH_MSENTER 0x30900073 -#define MASK_MSENTER 0xffffffff #define MATCH_AMOMIN_W 0x8000202f #define MASK_AMOMIN_W 0xf800707f #define MATCH_FSGNJN_S 0x20001053 @@ -472,34 +483,38 @@ #define CSR_UARCH15 0xccf #define CSR_SSTATUS 0x100 #define CSR_STVEC 0x101 -#define CSR_SCOMPARE 0x121 +#define CSR_STIMECMP 0x121 #define CSR_SSCRATCH 0x140 #define CSR_SEPC 0x141 #define CSR_SPTBR 0x188 #define CSR_SASID 0x189 -#define CSR_COUNT 0x900 +#define CSR_SCYCLE 0x900 #define CSR_STIME 0x901 #define CSR_SINSTRET 0x902 #define CSR_SCAUSE 0xd40 #define CSR_SBADADDR 0xd41 -#define CSR_TOHOST 0x580 -#define CSR_FROMHOST 0x581 #define CSR_MSTATUS 0x300 #define CSR_MSCRATCH 0x340 #define CSR_MEPC 0x341 -#define CSR_MCAUSE 0xf40 -#define CSR_MBADADDR 0xf41 +#define CSR_MCAUSE 0x342 +#define CSR_MBADADDR 0x343 #define CSR_RESET 0x780 +#define CSR_TOHOST 0x781 +#define CSR_FROMHOST 0x782 +#define CSR_SEND_IPI 0x783 +#define CSR_HARTID 0xfc0 #define CSR_CYCLEH 0xc80 #define CSR_TIMEH 0xc81 #define CSR_INSTRETH 0xc82 -#define CSR_COUNTH 0x980 +#define CSR_SCYCLEH 0x980 #define CSR_STIMEH 0x981 #define CSR_SINSTRETH 0x982 #define CAUSE_MISALIGNED_FETCH 0x0 #define CAUSE_FAULT_FETCH 0x1 -#define CAUSE_ILLEGAL_INSTRUCTION 0x4 -#define CAUSE_SYSCALL 0x6 +#define CAUSE_ILLEGAL_INSTRUCTION 0x2 +#define CAUSE_SCALL 0x4 +#define CAUSE_HCALL 0x5 +#define CAUSE_MCALL 0x6 #define CAUSE_BREAKPOINT 0x7 #define CAUSE_MISALIGNED_LOAD 0x8 #define CAUSE_FAULT_LOAD 0x9 @@ -569,6 +584,7 @@ DECLARE_INSN(csrrci, MATCH_CSRRCI, MASK_CSRRCI) DECLARE_INSN(addi, MATCH_ADDI, MASK_ADDI) DECLARE_INSN(mulh, MATCH_MULH, MASK_MULH) DECLARE_INSN(fmul_s, MATCH_FMUL_S, MASK_FMUL_S) +DECLARE_INSN(mcall, MATCH_MCALL, MASK_MCALL) DECLARE_INSN(csrrsi, MATCH_CSRRSI, MASK_CSRRSI) DECLARE_INSN(srai, MATCH_SRAI, MASK_SRAI) DECLARE_INSN(amoand_d, MATCH_AMOAND_D, MASK_AMOAND_D) @@ -589,6 +605,7 @@ DECLARE_INSN(sraiw, MATCH_SRAIW, MASK_SRAIW) DECLARE_INSN(srl, MATCH_SRL, MASK_SRL) DECLARE_INSN(fsub_d, MATCH_FSUB_D, MASK_FSUB_D) DECLARE_INSN(fsgnjx_s, MATCH_FSGNJX_S, MASK_FSGNJX_S) +DECLARE_INSN(mrts, MATCH_MRTS, MASK_MRTS) DECLARE_INSN(feq_d, MATCH_FEQ_D, MASK_FEQ_D) DECLARE_INSN(fcvt_d_wu, MATCH_FCVT_D_WU, MASK_FCVT_D_WU) DECLARE_INSN(or, MATCH_OR, MASK_OR) @@ -647,7 +664,6 @@ DECLARE_INSN(csrrwi, MATCH_CSRRWI, MASK_CSRRWI) DECLARE_INSN(sc_d, MATCH_SC_D, MASK_SC_D) DECLARE_INSN(fmadd_s, MATCH_FMADD_S, MASK_FMADD_S) DECLARE_INSN(fsqrt_s, MATCH_FSQRT_S, MASK_FSQRT_S) -DECLARE_INSN(msenter, MATCH_MSENTER, MASK_MSENTER) DECLARE_INSN(amomin_w, MATCH_AMOMIN_W, MASK_AMOMIN_W) DECLARE_INSN(fsgnjn_s, MATCH_FSGNJN_S, MASK_FSGNJN_S) DECLARE_INSN(amoswap_d, MATCH_AMOSWAP_D, MASK_AMOSWAP_D) @@ -694,28 +710,30 @@ DECLARE_CSR(uarch14, CSR_UARCH14) DECLARE_CSR(uarch15, CSR_UARCH15) DECLARE_CSR(sstatus, CSR_SSTATUS) DECLARE_CSR(stvec, CSR_STVEC) -DECLARE_CSR(scompare, CSR_SCOMPARE) +DECLARE_CSR(stimecmp, CSR_STIMECMP) DECLARE_CSR(sscratch, CSR_SSCRATCH) DECLARE_CSR(sepc, CSR_SEPC) DECLARE_CSR(sptbr, CSR_SPTBR) DECLARE_CSR(sasid, CSR_SASID) -DECLARE_CSR(count, CSR_COUNT) +DECLARE_CSR(scycle, CSR_SCYCLE) DECLARE_CSR(stime, CSR_STIME) DECLARE_CSR(sinstret, CSR_SINSTRET) DECLARE_CSR(scause, CSR_SCAUSE) DECLARE_CSR(sbadaddr, CSR_SBADADDR) -DECLARE_CSR(tohost, CSR_TOHOST) -DECLARE_CSR(fromhost, CSR_FROMHOST) DECLARE_CSR(mstatus, CSR_MSTATUS) DECLARE_CSR(mscratch, CSR_MSCRATCH) DECLARE_CSR(mepc, CSR_MEPC) DECLARE_CSR(mcause, CSR_MCAUSE) DECLARE_CSR(mbadaddr, CSR_MBADADDR) DECLARE_CSR(reset, CSR_RESET) +DECLARE_CSR(tohost, CSR_TOHOST) +DECLARE_CSR(fromhost, CSR_FROMHOST) +DECLARE_CSR(send_ipi, CSR_SEND_IPI) +DECLARE_CSR(hartid, CSR_HARTID) DECLARE_CSR(cycleh, CSR_CYCLEH) DECLARE_CSR(timeh, CSR_TIMEH) DECLARE_CSR(instreth, CSR_INSTRETH) -DECLARE_CSR(counth, CSR_COUNTH) +DECLARE_CSR(scycleh, CSR_SCYCLEH) DECLARE_CSR(stimeh, CSR_STIMEH) DECLARE_CSR(sinstreth, CSR_SINSTRETH) #endif @@ -745,28 +763,30 @@ DECLARE_CAUSE("uarch14", CAUSE_UARCH14) DECLARE_CAUSE("uarch15", CAUSE_UARCH15) DECLARE_CAUSE("sstatus", CAUSE_SSTATUS) DECLARE_CAUSE("stvec", CAUSE_STVEC) -DECLARE_CAUSE("scompare", CAUSE_SCOMPARE) +DECLARE_CAUSE("stimecmp", CAUSE_STIMECMP) DECLARE_CAUSE("sscratch", CAUSE_SSCRATCH) DECLARE_CAUSE("sepc", CAUSE_SEPC) DECLARE_CAUSE("sptbr", CAUSE_SPTBR) DECLARE_CAUSE("sasid", CAUSE_SASID) -DECLARE_CAUSE("count", CAUSE_COUNT) +DECLARE_CAUSE("scycle", CAUSE_SCYCLE) DECLARE_CAUSE("stime", CAUSE_STIME) DECLARE_CAUSE("sinstret", CAUSE_SINSTRET) DECLARE_CAUSE("scause", CAUSE_SCAUSE) DECLARE_CAUSE("sbadaddr", CAUSE_SBADADDR) -DECLARE_CAUSE("tohost", CAUSE_TOHOST) -DECLARE_CAUSE("fromhost", CAUSE_FROMHOST) DECLARE_CAUSE("mstatus", CAUSE_MSTATUS) DECLARE_CAUSE("mscratch", CAUSE_MSCRATCH) DECLARE_CAUSE("mepc", CAUSE_MEPC) DECLARE_CAUSE("mcause", CAUSE_MCAUSE) DECLARE_CAUSE("mbadaddr", CAUSE_MBADADDR) DECLARE_CAUSE("reset", CAUSE_RESET) +DECLARE_CAUSE("tohost", CAUSE_TOHOST) +DECLARE_CAUSE("fromhost", CAUSE_FROMHOST) +DECLARE_CAUSE("send_ipi", CAUSE_SEND_IPI) +DECLARE_CAUSE("hartid", CAUSE_HARTID) DECLARE_CAUSE("cycleh", CAUSE_CYCLEH) DECLARE_CAUSE("timeh", CAUSE_TIMEH) DECLARE_CAUSE("instreth", CAUSE_INSTRETH) -DECLARE_CAUSE("counth", CAUSE_COUNTH) +DECLARE_CAUSE("scycleh", CAUSE_SCYCLEH) DECLARE_CAUSE("stimeh", CAUSE_STIMEH) DECLARE_CAUSE("sinstreth", CAUSE_SINSTRETH) #endif diff --git a/pk/entry.S b/pk/entry.S index aced3b8..cdf076f 100644 --- a/pk/entry.S +++ b/pk/entry.S @@ -1,16 +1,7 @@ // See LICENSE for license details. #include "encoding.h" - -#ifdef __riscv64 -# define STORE sd -# define LOAD ld -# define REGBYTES 8 -#else -# define STORE sw -# define LOAD lw -# define REGBYTES 4 -#endif +#include "bits.h" .macro save_tf # save gprs @@ -47,47 +38,42 @@ # get sr, epc, badvaddr, cause addi t0,sp,320 - csrrw t0,mscratch,t0 - csrr t1,mstatus - csrr t2,mepc - csrr t3,mcause + csrrw t0,sscratch,t0 + csrr t1,sstatus + csrr t2,sepc + csrr t3,scause STORE t0,2*REGBYTES(x2) STORE t1,32*REGBYTES(x2) STORE t2,33*REGBYTES(x2) STORE t3,35*REGBYTES(x2) - la gp, _gp - # get faulting insn, if it wasn't a fetch-related trap li x5,-1 STORE x5,36*REGBYTES(x2) 1: .endm - .section .text.init,"ax",@progbits + .text .global trap_entry trap_entry: - # entry point for reset - j _start - - # entry point when coming from machine mode - j 1f - - # entry point when coming from other modes - csrrw sp, mscratch, sp + csrrw sp, sscratch, sp 1:addi sp,sp,-320 save_tf move a0,sp - j handle_trap + j handle_trap .globl pop_tf pop_tf: # write the trap frame onto the stack - # restore sr (disable interrupts) and epc - LOAD a1,32*REGBYTES(a0) - LOAD a2,33*REGBYTES(a0) - csrw mstatus, a1 - csrw mepc, a2 + # restore sstatus and epc + csrc sstatus, SSTATUS_IE + li t0, SSTATUS_PS + LOAD t1, 32*REGBYTES(a0) + LOAD t2, 33*REGBYTES(a0) + csrc sstatus, t0 + and t0, t0, t1 + csrs sstatus, t0 + csrw sepc, t2 # restore x registers LOAD x1,1*REGBYTES(a0) @@ -124,13 +110,4 @@ pop_tf: # write the trap frame onto the stack LOAD x10,10*REGBYTES(a0) # gtfo - mret - - - .bss - .align 4 - .global stack_bot - .global stack_top -stack_bot: - .skip 4096 -stack_top: + sret diff --git a/pk/file.c b/pk/file.c index bc495c2..ad1bde3 100644 --- a/pk/file.c +++ b/pk/file.c @@ -8,10 +8,9 @@ #include "vm.h" #define MAX_FDS 128 -static atomic_t fds[MAX_FDS]; +static file_t* fds[MAX_FDS]; #define MAX_FILES 128 -static file_t files[MAX_FILES] = {[0 ... MAX_FILES-1] = {-1,{0}}}; -file_t *stdout, *stdin, *stderr; +file_t files[MAX_FILES] = {[0 ... MAX_FILES-1] = {-1,0}}; void file_incref(file_t* f) { @@ -43,7 +42,7 @@ int file_dup(file_t* f) { for (int i = 0; i < MAX_FDS; i++) { - if (atomic_cas(&fds[i], 0, (long)f) == 0) + if (atomic_cas(&fds[i], 0, f) == 0) { file_incref(f); return i; @@ -54,24 +53,18 @@ int file_dup(file_t* f) void file_init() { - stdin = file_get_free(); - stdout = file_get_free(); - stderr = file_get_free(); - - stdin->kfd = 0; - stdout->kfd = 1; - stderr->kfd = 2; - - // create user FDs 0, 1, and 2 - file_dup(stdin); - file_dup(stdout); - file_dup(stderr); + // create stdin, stdout, stderr and FDs 0-2 + for (int i = 0; i < 3; i++) { + file_t* f = file_get_free(); + f->kfd = i; + file_dup(f); + } } file_t* file_get(int fd) { file_t* f; - if (fd < 0 || fd >= MAX_FDS || (f = (file_t*)atomic_read(&fds[fd])) == NULL) + if (fd < 0 || fd >= MAX_FDS || (f = atomic_read(&fds[fd])) == NULL) return 0; long old_cnt; @@ -114,7 +107,7 @@ int fd_close(int fd) file_t* f = file_get(fd); if (!f) return -1; - file_t* old = (file_t*)atomic_cas(&fds[fd], (long)f, 0); + file_t* old = atomic_cas(&fds[fd], (long)f, 0); file_decref(f); if (old != f) return -1; diff --git a/pk/file.h b/pk/file.h index 68b68a3..0d942b2 100644 --- a/pk/file.h +++ b/pk/file.h @@ -5,15 +5,19 @@ #include #include +#include #include "atomic.h" typedef struct file { - int kfd; // file descriptor on the appserver side - atomic_t refcnt; + int kfd; // file descriptor on the host side of the HTIF + uint32_t refcnt; } file_t; -extern file_t *stdin, *stdout, *stderr; +extern file_t files[]; +#define stdin (files + 0) +#define stdout (files + 1) +#define stderr (files + 2) file_t* file_get(int fd); file_t* file_open(const char* fn, int flags, int mode); diff --git a/pk/fp.c b/pk/fp.c deleted file mode 100644 index 96eb449..0000000 --- a/pk/fp.c +++ /dev/null @@ -1,275 +0,0 @@ -// See LICENSE for license details. - -#include "pk.h" -#include "fp.h" -#include "config.h" - -#ifdef PK_ENABLE_FP_EMULATION - -#include "softfloat.h" -#include - -#define noisy 0 - -static inline void -validate_address(trapframe_t* tf, long addr, int size, int store) -{ -} - -#ifdef __riscv_hard_float -# define get_fcsr() ({ fcsr_t fcsr; asm ("frcsr %0" : "=r"(fcsr)); fcsr; }) -# define put_fcsr(value) ({ asm ("fscsr %0" :: "r"(value)); }) -# define get_f32_reg(i) ({ \ - register int value asm("a0"); \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(get_f32_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(value) : "r"(offset)); \ - value; }) -# define put_f32_reg(i, value) ({ \ - long tmp; \ - register long __value asm("a0") = (value); \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(put_f32_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(tmp) : "r"(offset), "r"(__value)); }) -# ifdef __riscv64 -# define get_f64_reg(i) ({ \ - register long value asm("a0"); \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(get_f64_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(value) : "r"(offset)); \ - value; }) -# define put_f64_reg(i, value) ({ \ - long tmp; \ - register long __value asm("a0") = (value); \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(put_f64_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(tmp) : "r"(offset), "r"(__value)); }) -# else -# define get_f64_reg(i) ({ \ - long long value; \ - register long long* valuep asm("a0") = &value; \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(get_f64_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(valuep) : "r"(offset)); \ - value; }) -# define put_f64_reg(i, value) ({ \ - long long __value = (value); \ - register long long* valuep asm("a0") = &__value; \ - register long offset asm("a1") = (i) * 8; \ - asm ("1: auipc %0, %%pcrel_hi(put_f64_reg); add %0, %0, %1; jalr %0, %%pcrel_lo(1b)" : "=&r"(tmp) : "r"(offset), "r"(__value)); }) -# endif -#else -static fp_state_t fp_state; -# define get_fcsr() fp_state.fcsr -# define put_fcsr(value) fp_state.fcsr = (value) -# define get_f32_reg(i) fp_state.fpr[i] -# define get_f64_reg(i) fp_state.fpr[i] -# define put_f32_reg(i, value) fp_state.fpr[i] = (value) -# define put_f64_reg(i, value) fp_state.fpr[i] = (value) -#endif - -int emulate_fp(trapframe_t* tf) -{ - if(noisy) - printk("FPU emulation at pc %lx, insn %x\n",tf->epc,(uint32_t)tf->insn); - - #define RS1 ((tf->insn >> 15) & 0x1F) - #define RS2 ((tf->insn >> 20) & 0x1F) - #define RS3 ((tf->insn >> 27) & 0x1F) - #define RD ((tf->insn >> 7) & 0x1F) - #define RM ((tf->insn >> 12) & 0x7) - - int32_t imm = (int32_t)tf->insn >> 20; - int32_t bimm = RD | imm >> 5 << 5; - - #define XRS1 (tf->gpr[RS1]) - #define XRS2 (tf->gpr[RS2]) - #define XRDR (tf->gpr[RD]) - - #define frs1d get_f64_reg(RS1) - #define frs2d get_f64_reg(RS2) - #define frs3d get_f64_reg(RS3) - #define frs1s get_f32_reg(RS1) - #define frs2s get_f32_reg(RS2) - #define frs3s get_f32_reg(RS3) - - long effective_address_load = XRS1 + imm; - long effective_address_store = XRS1 + bimm; - - fcsr_t fcsr = get_fcsr(); - softfloat_exceptionFlags = fcsr.fcsr.flags; - softfloat_roundingMode = (RM == 7) ? fcsr.fcsr.rm : RM; - - #define IS_INSN(x) ((tf->insn & MASK_ ## x) == MATCH_ ## x) - - #define DO_WRITEBACK(dp, value) ({ \ - if (dp) put_f64_reg(RD, value); \ - else put_f32_reg(RD, value); }) - - #define DO_CSR(which, op) ({ long tmp = which; which op; tmp; }) - - if(IS_INSN(FDIV_S)) - DO_WRITEBACK(0, f32_div(frs1s, frs2s)); - else if(IS_INSN(FDIV_D)) - DO_WRITEBACK(1, f64_div(frs1d, frs2d)); - else if(IS_INSN(FSQRT_S)) - DO_WRITEBACK(0, f32_sqrt(frs1s)); - else if(IS_INSN(FSQRT_D)) - DO_WRITEBACK(1, f64_sqrt(frs1d)); - else if(IS_INSN(FLW)) - { - validate_address(tf, effective_address_load, 4, 0); - DO_WRITEBACK(0, *(uint32_t*)effective_address_load); - } - else if(IS_INSN(FLD)) - { - validate_address(tf, effective_address_load, 8, 0); - DO_WRITEBACK(1, *(uint64_t*)effective_address_load); - } - else if(IS_INSN(FSW)) - { - validate_address(tf, effective_address_store, 4, 1); - *(uint32_t*)effective_address_store = frs2s; - } - else if(IS_INSN(FSD)) - { - validate_address(tf, effective_address_store, 8, 1); - *(uint64_t*)effective_address_store = frs2d; - } - else if(IS_INSN(FMV_X_S)) - XRDR = frs1s; - else if(IS_INSN(FMV_X_D)) - XRDR = frs1d; - else if(IS_INSN(FMV_S_X)) - DO_WRITEBACK(0, XRS1); - else if(IS_INSN(FMV_D_X)) - DO_WRITEBACK(1, XRS1); - else if(IS_INSN(FSGNJ_S)) - DO_WRITEBACK(0, (frs1s &~ (uint32_t)INT32_MIN) | (frs2s & (uint32_t)INT32_MIN)); - else if(IS_INSN(FSGNJ_D)) - DO_WRITEBACK(1, (frs1d &~ INT64_MIN) | (frs2d & INT64_MIN)); - else if(IS_INSN(FSGNJN_S)) - DO_WRITEBACK(0, (frs1s &~ (uint32_t)INT32_MIN) | ((~frs2s) & (uint32_t)INT32_MIN)); - else if(IS_INSN(FSGNJN_D)) - DO_WRITEBACK(1, (frs1d &~ INT64_MIN) | ((~frs2d) & INT64_MIN)); - else if(IS_INSN(FSGNJX_S)) - DO_WRITEBACK(0, frs1s ^ (frs2s & (uint32_t)INT32_MIN)); - else if(IS_INSN(FSGNJX_D)) - DO_WRITEBACK(1, frs1d ^ (frs2d & INT64_MIN)); - else if(IS_INSN(FEQ_S)) - XRDR = f32_eq(frs1s, frs2s); - else if(IS_INSN(FEQ_D)) - XRDR = f64_eq(frs1d, frs2d); - else if(IS_INSN(FLE_S)) - XRDR = f32_eq(frs1s, frs2s) || f32_lt(frs1s, frs2s); - else if(IS_INSN(FLE_D)) - XRDR = f64_eq(frs1d, frs2d) || f64_lt(frs1d, frs2d); - else if(IS_INSN(FLT_S)) - XRDR = f32_lt(frs1s, frs2s); - else if(IS_INSN(FLT_D)) - XRDR = f64_lt(frs1d, frs2d); - else if(IS_INSN(FCVT_S_W)) - DO_WRITEBACK(0, i64_to_f32((int64_t)(int32_t)XRS1)); - else if(IS_INSN(FCVT_S_L)) - DO_WRITEBACK(0, i64_to_f32(XRS1)); - else if(IS_INSN(FCVT_S_D)) - DO_WRITEBACK(0, f64_to_f32(frs1d)); - else if(IS_INSN(FCVT_D_W)) - DO_WRITEBACK(1, i64_to_f64((int64_t)(int32_t)XRS1)); - else if(IS_INSN(FCVT_D_L)) - DO_WRITEBACK(1, i64_to_f64(XRS1)); - else if(IS_INSN(FCVT_D_S)) - DO_WRITEBACK(1, f32_to_f64(frs1s)); - else if(IS_INSN(FCVT_S_WU)) - DO_WRITEBACK(0, ui64_to_f32((uint64_t)(uint32_t)XRS1)); - else if(IS_INSN(FCVT_S_LU)) - DO_WRITEBACK(0, ui64_to_f32(XRS1)); - else if(IS_INSN(FCVT_D_WU)) - DO_WRITEBACK(1, ui64_to_f64((uint64_t)(uint32_t)XRS1)); - else if(IS_INSN(FCVT_D_LU)) - DO_WRITEBACK(1, ui64_to_f64(XRS1)); - else if(IS_INSN(FADD_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, 0x3f800000, frs2s)); - else if(IS_INSN(FADD_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, 0x3ff0000000000000LL, frs2d)); - else if(IS_INSN(FSUB_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, 0x3f800000, frs2s ^ (uint32_t)INT32_MIN)); - else if(IS_INSN(FSUB_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, 0x3ff0000000000000LL, frs2d ^ INT64_MIN)); - else if(IS_INSN(FMUL_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, frs2s, 0)); - else if(IS_INSN(FMUL_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, frs2d, 0)); - else if(IS_INSN(FMADD_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, frs2s, frs3s)); - else if(IS_INSN(FMADD_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, frs2d, frs3d)); - else if(IS_INSN(FMSUB_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, frs2s, frs3s ^ (uint32_t)INT32_MIN)); - else if(IS_INSN(FMSUB_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, frs2d, frs3d ^ INT64_MIN)); - else if(IS_INSN(FNMADD_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, frs2s, frs3s) ^ (uint32_t)INT32_MIN); - else if(IS_INSN(FNMADD_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, frs2d, frs3d) ^ INT64_MIN); - else if(IS_INSN(FNMSUB_S)) - DO_WRITEBACK(0, f32_mulAdd(frs1s, frs2s, frs3s ^ (uint32_t)INT32_MIN) ^ (uint32_t)INT32_MIN); - else if(IS_INSN(FNMSUB_D)) - DO_WRITEBACK(1, f64_mulAdd(frs1d, frs2d, frs3d ^ INT64_MIN) ^ INT64_MIN); - else if(IS_INSN(FCVT_W_S)) - XRDR = f32_to_i32(frs1s, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_W_D)) - XRDR = f64_to_i32(frs1d, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_L_S)) - XRDR = f32_to_i64(frs1s, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_L_D)) - XRDR = f64_to_i64(frs1d, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_WU_S)) - XRDR = f32_to_ui32(frs1s, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_WU_D)) - XRDR = f64_to_ui32(frs1d, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_LU_S)) - XRDR = f32_to_ui64(frs1s, softfloat_roundingMode, true); - else if(IS_INSN(FCVT_LU_D)) - XRDR = f64_to_ui64(frs1d, softfloat_roundingMode, true); - else if(IS_INSN(FCLASS_S)) - XRDR = f32_classify(frs1s); - else if(IS_INSN(FCLASS_D)) - XRDR = f64_classify(frs1s); - else if(IS_INSN(CSRRS) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, |= XRS1); - else if(IS_INSN(CSRRS) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, |= XRS1); - else if(IS_INSN(CSRRS) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, |= XRS1); - else if(IS_INSN(CSRRSI) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, |= RS1); - else if(IS_INSN(CSRRSI) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, |= RS1); - else if(IS_INSN(CSRRSI) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, |= RS1); - else if(IS_INSN(CSRRC) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, &= ~XRS1); - else if(IS_INSN(CSRRC) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, &= ~XRS1); - else if(IS_INSN(CSRRC) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, &= ~XRS1); - else if(IS_INSN(CSRRCI) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, &= ~RS1); - else if(IS_INSN(CSRRCI) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, &= ~RS1); - else if(IS_INSN(CSRRCI) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, &= ~RS1); - else if(IS_INSN(CSRRW) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, = XRS1); - else if(IS_INSN(CSRRW) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, = XRS1); - else if(IS_INSN(CSRRW) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, = XRS1); - else if(IS_INSN(CSRRWI) && imm == CSR_FCSR) XRDR = DO_CSR(fcsr.bits, = RS1); - else if(IS_INSN(CSRRWI) && imm == CSR_FRM) XRDR = DO_CSR(fcsr.fcsr.rm, = RS1); - else if(IS_INSN(CSRRWI) && imm == CSR_FFLAGS) XRDR = DO_CSR(fcsr.fcsr.flags, = RS1); - else - return -1; - - put_fcsr(fcsr); - - return 0; -} - -#define STR(x) XSTR(x) -#define XSTR(x) #x - -#define PUT_FP_REG(which, type, val) asm("fmv." STR(type) ".x f" STR(which) ",%0" : : "r"(val)) -#define GET_FP_REG(which, type, val) asm("fmv.x." STR(type) " %0,f" STR(which) : "=r"(val)) -#define LOAD_FP_REG(which, type, val) asm("fl" STR(type) " f" STR(which) ",%0" : : "m"(val)) -#define STORE_FP_REG(which, type, val) asm("fs" STR(type) " f" STR(which) ",%0" : "=m"(val) : : "memory") - -#endif - -void fp_init() -{ - if (read_csr(mstatus) & MSTATUS_FS) - for (int i = 0; i < 32; i++) - put_f64_reg(i, 0); -} diff --git a/pk/fp.h b/pk/fp.h deleted file mode 100644 index 4cfb167..0000000 --- a/pk/fp.h +++ /dev/null @@ -1,23 +0,0 @@ -// See LICENSE for license details. - -#ifndef _FP_H -#define _FP_H - -typedef union { - struct { - uint8_t flags : 5; - uint8_t rm : 3; - } fcsr; - uint8_t bits; -} fcsr_t; - -typedef struct -{ - uint64_t fpr[32]; - fcsr_t fcsr; -} fp_state_t; - -void put_fp_state(const void* fp_regs, uint8_t fsr); -long get_fp_state(void* fp_regs); - -#endif diff --git a/pk/fp_asm.S b/pk/fp_asm.S index 0a9f34b..0839511 100644 --- a/pk/fp_asm.S +++ b/pk/fp_asm.S @@ -1,13 +1,13 @@ // See LICENSE for license details. -#define get_f32(which) fmv.x.s a0, which; ret -#define put_f32(which) fmv.s.x which, a0; ret +#define get_f32(which) fmv.x.s a0, which; jr t0 +#define put_f32(which) fmv.s.x which, a0; jr t0 #ifdef __riscv64 -# define get_f64(which) fmv.x.d a0, which; ret -# define put_f64(which) fmv.d.x which, a0; ret +# define get_f64(which) fmv.x.d a0, which; jr t0 +# define put_f64(which) fmv.d.x which, a0; jr t0 #else -# define get_f64(which) fsd which, 0(a0); ret -# define put_f64(which) fld which, 0(a0); ret +# define get_f64(which) fsd which, 0(a0); jr t0 +# define put_f64(which) fld which, 0(a0); jr t0 #endif .text @@ -153,83 +153,3 @@ put_f64(f29) put_f64(f30) put_f64(f31) - - .text - .globl get_fp_state -get_fp_state: - - fsd f0 , 0(a0) - fsd f1 , 8(a0) - fsd f2 , 16(a0) - fsd f3 , 24(a0) - fsd f4 , 32(a0) - fsd f5 , 40(a0) - fsd f6 , 48(a0) - fsd f7 , 56(a0) - fsd f8 , 64(a0) - fsd f9 , 72(a0) - fsd f10, 80(a0) - fsd f11, 88(a0) - fsd f12, 96(a0) - fsd f13,104(a0) - fsd f14,112(a0) - fsd f15,120(a0) - fsd f16,128(a0) - fsd f17,136(a0) - fsd f18,144(a0) - fsd f19,152(a0) - fsd f20,160(a0) - fsd f21,168(a0) - fsd f22,176(a0) - fsd f23,184(a0) - fsd f24,192(a0) - fsd f25,200(a0) - fsd f26,208(a0) - fsd f27,216(a0) - fsd f28,224(a0) - fsd f29,232(a0) - fsd f30,240(a0) - fsd f31,248(a0) - - frsr a0 - ret - - .globl put_fp_state -put_fp_state: - - fld f0 , 0(a0) - fld f1 , 8(a0) - fld f2 , 16(a0) - fld f3 , 24(a0) - fld f4 , 32(a0) - fld f5 , 40(a0) - fld f6 , 48(a0) - fld f7 , 56(a0) - fld f8 , 64(a0) - fld f9 , 72(a0) - fld f10, 80(a0) - fld f11, 88(a0) - fld f12, 96(a0) - fld f13,104(a0) - fld f14,112(a0) - fld f15,120(a0) - fld f16,128(a0) - fld f17,136(a0) - fld f18,144(a0) - fld f19,152(a0) - fld f20,160(a0) - fld f21,168(a0) - fld f22,176(a0) - fld f23,184(a0) - fld f24,192(a0) - fld f25,200(a0) - fld f26,208(a0) - fld f27,216(a0) - fld f28,224(a0) - fld f29,232(a0) - fld f30,240(a0) - fld f31,248(a0) - - fssr a1 - - ret diff --git a/pk/frontend.c b/pk/frontend.c index 3771bca..0929d54 100644 --- a/pk/frontend.c +++ b/pk/frontend.c @@ -3,14 +3,30 @@ #include "pk.h" #include "atomic.h" #include "frontend.h" +#include "sbi.h" +#include "mcall.h" #include +uint64_t tohost_sync(unsigned dev, unsigned cmd, uint64_t payload) +{ + uint64_t fromhost; + __sync_synchronize(); + + sbi_device_message m = {dev, cmd, payload}, *p; + do_mcall(MCALL_SEND_DEVICE_REQUEST, &m); + while ((p = (void*)do_mcall(MCALL_RECEIVE_DEVICE_RESPONSE)) == 0); + kassert(p == &m); + + __sync_synchronize(); + return m.data; +} + long frontend_syscall(long n, long a0, long a1, long a2, long a3, long a4, long a5, long a6) { static volatile uint64_t magic_mem[8]; static spinlock_t lock = SPINLOCK_INIT; - long irq = spinlock_lock_irqsave(&lock); + spinlock_lock(&lock); magic_mem[0] = n; magic_mem[1] = a0; @@ -21,15 +37,10 @@ long frontend_syscall(long n, long a0, long a1, long a2, long a3, long a4, long magic_mem[6] = a5; magic_mem[7] = a6; - mb(); - - write_csr(tohost, magic_mem); - while (swap_csr(fromhost, 0) == 0); - - mb(); + tohost_sync(0, 0, (uintptr_t)magic_mem); long ret = magic_mem[0]; - spinlock_unlock_irqrestore(&lock, irq); + spinlock_unlock(&lock); return ret; } diff --git a/pk/frontend.h b/pk/frontend.h index dde0d0c..b6418f2 100644 --- a/pk/frontend.h +++ b/pk/frontend.h @@ -3,6 +3,21 @@ #ifndef _RISCV_FRONTEND_H #define _RISCV_FRONTEND_H +#include + +#ifdef __riscv64 +# define TOHOST_CMD(dev, cmd, payload) \ + (((uint64_t)(dev) << 56) | ((uint64_t)(cmd) << 48) | (uint64_t)(payload)) +#else +# define TOHOST_CMD(dev, cmd, payload) ({ \ + if ((dev) || (cmd)) __builtin_trap(); \ + (payload); }) +#endif +#define FROMHOST_DEV(fromhost_value) ((uint64_t)(fromhost_value) >> 56) +#define FROMHOST_CMD(fromhost_value) ((uint64_t)(fromhost_value) << 8 >> 56) +#define FROMHOST_DATA(fromhost_value) ((uint64_t)(fromhost_value) << 16 >> 16) + long frontend_syscall(long n, long a0, long a1, long a2, long a3, long a4, long a5, long a6); +uint64_t tohost_sync(unsigned dev, unsigned cmd, uint64_t payload); #endif diff --git a/pk/handlers.c b/pk/handlers.c index 35adefc..eb18038 100644 --- a/pk/handlers.c +++ b/pk/handlers.c @@ -14,20 +14,6 @@ static void handle_illegal_instruction(trapframe_t* tf) else kassert(len == 2); -#ifdef PK_ENABLE_FP_EMULATION - if (emulate_fp(tf) == 0) - { - tf->epc += len; - return; - } -#endif - - if (emulate_int(tf) == 0) - { - tf->epc += len; - return; - } - dump_tf(tf); panic("An illegal instruction was executed!"); } @@ -73,14 +59,14 @@ static void handle_fault_fetch(trapframe_t* tf) void handle_fault_load(trapframe_t* tf) { - tf->badvaddr = read_csr(mbadaddr); + tf->badvaddr = read_csr(sbadaddr); if (handle_page_fault(tf->badvaddr, PROT_READ) != 0) segfault(tf, tf->badvaddr, "load"); } void handle_fault_store(trapframe_t* tf) { - tf->badvaddr = read_csr(mbadaddr); + tf->badvaddr = read_csr(sbadaddr); if (handle_page_fault(tf->badvaddr, PROT_WRITE) != 0) segfault(tf, tf->badvaddr, "store"); } @@ -92,9 +78,17 @@ static void handle_syscall(trapframe_t* tf) tf->epc += 4; } +static void handle_interrupt(trapframe_t* tf) +{ + clear_csr(sstatus, SSTATUS_SIP); + + pop_tf(tf); +} + void handle_trap(trapframe_t* tf) { - set_csr(mstatus, MSTATUS_IE); + if ((intptr_t)tf->cause < 0) + return handle_interrupt(tf); typedef void (*trap_handler)(trapframe_t*); @@ -102,7 +96,7 @@ void handle_trap(trapframe_t* tf) [CAUSE_MISALIGNED_FETCH] = handle_misaligned_fetch, [CAUSE_FAULT_FETCH] = handle_fault_fetch, [CAUSE_ILLEGAL_INSTRUCTION] = handle_illegal_instruction, - [CAUSE_SYSCALL] = handle_syscall, + [CAUSE_SCALL] = handle_syscall, [CAUSE_BREAKPOINT] = handle_breakpoint, [CAUSE_MISALIGNED_LOAD] = handle_misaligned_load, [CAUSE_MISALIGNED_STORE] = handle_misaligned_store, diff --git a/pk/init.c b/pk/init.c index 58147ed..5a2c258 100644 --- a/pk/init.c +++ b/pk/init.c @@ -6,6 +6,7 @@ #include "frontend.h" #include "elf.h" #include +#include #include elf_info current; @@ -17,12 +18,12 @@ char* uarch_counter_names[NUM_COUNTERS]; void init_tf(trapframe_t* tf, long pc, long sp, int user64) { - memset(tf,0,sizeof(*tf)); - if(sizeof(void*) != 8) - kassert(!user64); - tf->status = read_csr(mstatus); - if (user64) - tf->status |= (long long)UA_RV64 << __builtin_ctzll(MSTATUS_UA); + memset(tf, 0, sizeof(*tf)); + if (user64) { + kassert(sizeof(void*) == 8); + set_csr(sstatus, UA_RV64 * (SSTATUS_UA & ~(SSTATUS_UA << 1))); + } + tf->status = read_csr(sstatus); tf->gpr[2] = sp; tf->epc = pc; } @@ -40,22 +41,34 @@ static void handle_option(const char* s) uarch_counters_enabled = 1; break; - case 'p': // physical memory mode - have_vm = 0; + case 'm': // memory capacity in MiB + { + uintptr_t mem_mb = atol(&s[2]); + if (!mem_mb) + goto need_nonzero_int; + mem_size = mem_mb << 20; + if ((mem_size >> 20) < mem_mb) + mem_size = (typeof(mem_size))-1 & -RISCV_PGSIZE; + break; + } + + case 'p': // number of harts + num_harts = atol(&s[2]); + if (!num_harts) + goto need_nonzero_int; break; default: panic("unrecognized option: `%c'", s[1]); break; } -} + return; -struct mainvars { - uint64_t argc; - uint64_t argv[127]; // this space is shared with the arg strings themselves -}; +need_nonzero_int: + panic("the -%c flag requires a nonzero argument", s[1]); +} -static struct mainvars* handle_args(struct mainvars* args) +struct mainvars* parse_args(struct mainvars* args) { long r = frontend_syscall(SYS_getmainvars, (uintptr_t)args, sizeof(*args), 0, 0, 0, 0, 0); kassert(r == 0); @@ -68,10 +81,32 @@ static struct mainvars* handle_args(struct mainvars* args) return (struct mainvars*)&args->argv[a0-1]; } -static void user_init(struct mainvars* args) +uintptr_t boot_loader(struct mainvars* args) { + // load program named by argv[0] + long phdrs[128]; + current.phdr = (uintptr_t)phdrs; + current.phdr_size = sizeof(phdrs); + if (!args->argc) + panic("tell me what ELF to load!"); + load_elf((char*)(uintptr_t)args->argv[0], ¤t); + + if (current.is_supervisor) { + supervisor_vm_init(); + write_csr(mepc, current.entry); + asm volatile("mret"); + __builtin_unreachable(); + } + + pk_vm_init(); + asm volatile("la t0, 1f; csrw mepc, t0; mret; 1:" ::: "t0"); + + // copy phdrs to user stack + size_t stack_top = current.stack_top - current.phdr_size; + memcpy((void*)stack_top, (void*)current.phdr, current.phdr_size); + current.phdr = stack_top; + // copy argv to user stack - size_t stack_top = current.stack_top; for (size_t i = 0; i < args->argc; i++) { size_t len = strlen((char*)(uintptr_t)args->argv[i])+1; stack_top -= len; @@ -79,12 +114,6 @@ static void user_init(struct mainvars* args) args->argv[i] = stack_top; } stack_top &= -sizeof(void*); - populate_mapping((void*)stack_top, current.stack_top - stack_top, PROT_WRITE); - - // load program named by argv[0] - current.phdr_top = stack_top; - load_elf((char*)(uintptr_t)args->argv[0], ¤t); - stack_top = current.phdr; struct { long key; @@ -152,13 +181,3 @@ static void user_init(struct mainvars* args) __clear_cache(0, 0); pop_tf(&tf); } - -void boot() -{ - file_init(); - struct mainvars args0; - struct mainvars* args = handle_args(&args0); - vm_init(); - fp_init(); - user_init(args); -} diff --git a/pk/int.c b/pk/int.c deleted file mode 100644 index 38cc7f0..0000000 --- a/pk/int.c +++ /dev/null @@ -1,89 +0,0 @@ -// See LICENSE for license details. - -#include "pk.h" - - -#include "softint.h" -#include - -#define noisy 0 - - -int emulate_int(trapframe_t* tf) -{ - if(noisy) - printk("Int emulation at pc %lx, insn %x\n",tf->epc,(uint32_t)tf->insn); - - #define RS1 ((tf->insn >> 15) & 0x1F) - #define RS2 ((tf->insn >> 20) & 0x1F) - #define RD ((tf->insn >> 7) & 0x1F) - -// #define XRS1 (tf->gpr[RS1]) -// #define XRS2 (tf->gpr[RS2]) - #define XRD (tf->gpr[RD]) - - unsigned long xrs1 = tf->gpr[RS1]; - unsigned long xrs2 = tf->gpr[RS2]; - - #define IS_INSN(x) ((tf->insn & MASK_ ## x) == MATCH_ ## x) - - if(IS_INSN(DIV)) - { - if(noisy) - printk("emulating div\n"); - - int num_negative = 0; - - if ((signed long) xrs1 < 0) - { - xrs1 = -xrs1; - num_negative++; - } - - if ((signed long) xrs2 < 0) - { - xrs2 = -xrs2; - num_negative++; - } - - unsigned long res = softint_udivrem(xrs1, xrs2, 0); - if (num_negative == 1) - XRD = -res; - else - XRD = res; - } - else if(IS_INSN(DIVU)) - { - if(noisy) - printk("emulating divu\n"); - XRD = softint_udivrem( xrs1, xrs2, 0); - } - else if(IS_INSN(MUL)) - { - if(noisy) - printk("emulating mul\n"); - XRD = softint_mul(xrs1, xrs2); - } - else if(IS_INSN(REM)) - { - if(noisy) - printk("emulating rem\n"); - - if ((signed long) xrs1 < 0) {xrs1 = -xrs1;} - if ((signed long) xrs2 < 0) {xrs2 = -xrs2;} - - XRD = softint_udivrem(xrs1, xrs2, 1); - } - else if(IS_INSN(REMU)) - { - if(noisy) - printk("emulating remu\n"); - XRD = softint_udivrem(xrs1, xrs2, 1); - } - else - return -1; - - return 0; -} - - diff --git a/pk/mcall.h b/pk/mcall.h new file mode 100644 index 0000000..9992891 --- /dev/null +++ b/pk/mcall.h @@ -0,0 +1,15 @@ +#ifndef _PK_MCALL_H +#define _PK_MCALL_H + +#define MCALL_HART_ID 0 +#define MCALL_CONSOLE_PUTCHAR 1 +#define MCALL_SEND_DEVICE_REQUEST 2 +#define MCALL_RECEIVE_DEVICE_RESPONSE 3 + +#ifndef __ASSEMBLER__ + +extern uintptr_t do_mcall(uintptr_t which, ...); + +#endif + +#endif diff --git a/pk/mentry.S b/pk/mentry.S new file mode 100644 index 0000000..4e64b7e --- /dev/null +++ b/pk/mentry.S @@ -0,0 +1,250 @@ +// See LICENSE for license details. + +#include "mtrap.h" + +#define HANDLE_TRAP_IN_MACHINE_MODE 0 \ + | (0 << (31- 0)) /* IF misaligned */ \ + | (0 << (31- 1)) /* IF fault */ \ + | (1 << (31- 2)) /* illegal instruction */ \ + | (1 << (31- 3)) /* reserved */ \ + | (0 << (31- 4)) /* system call */ \ + | (1 << (31- 5)) /* hypervisor call */ \ + | (1 << (31- 6)) /* machine call */ \ + | (0 << (31- 7)) /* breakpoint */ \ + | (1 << (31- 8)) /* load misaligned */ \ + | (0 << (31- 9)) /* load fault */ \ + | (1 << (31-10)) /* store misaligned */ \ + | (0 << (31-11)) /* store fault */ + + .section .text.init,"ax",@progbits + .globl mentry +mentry: + # Entry point from user mode. + .align 6 + csrrw sp, mscratch, sp + STORE a0, 10*REGBYTES(sp) + STORE a1, 11*REGBYTES(sp) + + csrr a0, mcause + bltz a0, .Linterrupt + + li a1, HANDLE_TRAP_IN_MACHINE_MODE + SLL32 a1, a1, a0 + bltz a1, .Lhandle_trap_in_machine_mode + + # Redirect the trap to the supervisor. +.Lmrts: + LOAD a0, 10*REGBYTES(sp) + LOAD a1, 11*REGBYTES(sp) + csrrw sp, mscratch, sp + mrts + + .align 6 + # Entry point from supervisor mode. + csrrw sp, mscratch, sp + STORE a0, 10*REGBYTES(sp) + STORE a1, 11*REGBYTES(sp) + + csrr a0, mcause + bltz a0, .Linterrupt + + li a1, HANDLE_TRAP_IN_MACHINE_MODE + SLL32 a1, a1, a0 + bltz a1, .Lhandle_trap_in_machine_mode + +.Linterrupt_in_supervisor: + # For now, direct all interrupts to supervisor mode. + + # Detect double faults. + csrr a0, mstatus + SLL32 a0, a0, 31 - CONST_CTZ(MSTATUS_PRV2) + bltz a0, .Lsupervisor_double_fault + +.Lreturn_from_supervisor_double_fault: + # Redirect the trap to the supervisor. + LOAD a0, 10*REGBYTES(sp) + LOAD a1, 11*REGBYTES(sp) + csrrw sp, mscratch, sp + mrts + + .align 6 + # Entry point from hypervisor mode. Not implemented. + j bad_trap + + .align 6 + csrw mscratch, sp + addi sp, sp, -INTEGER_CONTEXT_SIZE + STORE a0,10*REGBYTES(sp) + STORE a1,11*REGBYTES(sp) + + csrr a0, mcause + li a1, CAUSE_MCALL + beq a0, a1, .Lhandle_trap_in_machine_mode + li a1, CAUSE_FAULT_LOAD + beq a0, a1, .Lhandle_trap_in_machine_mode + li a1, CAUSE_FAULT_STORE + beq a0, a1, .Lhandle_trap_in_machine_mode + + # Uh oh... + j bad_trap + + .align 6 + # Entry point for power-on reset. + # TODO per-hart stacks + la sp, _end + RISCV_PGSIZE + 1 + li t0, -RISCV_PGSIZE + and sp, sp, t0 + j machine_init + + # XXX depend on sbi_base to force its linkage + la x0, sbi_base + +.Linterrupt: + sll a0, a0, 1 # discard MSB + +#if IRQ_TIMER != 0 +#error +#endif + # Send timer interrupts to the OS. + beqz a0, .Lmrts + + # See if this is an IPI; register a supervisor SW interrupt if so. + li a1, IRQ_IPI * 2 + bne a0, a1, 1f + csrc mstatus, MSTATUS_MSIP + csrs mstatus, MSTATUS_SSIP + j .Lmrts +1: + + # See if this is an HTIF interrupt; if so, handle it in machine mode. + li a1, IRQ_HOST * 2 + bne a0, a1, 1f + li a0, 12 + j .Lhandle_trap_in_machine_mode +1: + + # We don't know how to handle this interrupt. We're hosed. + j bad_trap + +.Lsupervisor_double_fault: + # Return to supervisor trap entry with interrupts disabled. + # Set PRV2=U, IE2=1, PRV1=S (it already is), and IE1=0. + li a0, MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_IE1 + csrc mstatus, a0 + j .Lreturn_from_supervisor_double_fault + +.Lhandle_trap_in_machine_mode: + # Preserve the registers. Compute the address of the trap handler. + STORE ra, 1*REGBYTES(sp) + csrr ra, mscratch # ra <- user sp + STORE gp, 3*REGBYTES(sp) + STORE tp, 4*REGBYTES(sp) + STORE t0, 5*REGBYTES(sp) +1:auipc t0, %pcrel_hi(trap_table) # t0 <- %hi(trap_table) + STORE t1, 6*REGBYTES(sp) + sll t1, a0, 2 # t1 <- mcause << 2 + STORE t2, 7*REGBYTES(sp) + add t0, t0, t1 # t0 <- %hi(trap_table)[mcause] + STORE s0, 8*REGBYTES(sp) + lw t0, %pcrel_lo(1b)(t0) # t0 <- handlers[mcause] + STORE s1, 9*REGBYTES(sp) + mv a1, sp # a1 <- regs + STORE a2,12*REGBYTES(sp) + STORE a3,13*REGBYTES(sp) + STORE a4,14*REGBYTES(sp) + STORE a5,15*REGBYTES(sp) + STORE a6,16*REGBYTES(sp) + STORE a7,17*REGBYTES(sp) + STORE s2,18*REGBYTES(sp) + STORE s3,19*REGBYTES(sp) + STORE s4,20*REGBYTES(sp) + STORE s5,21*REGBYTES(sp) + STORE s6,22*REGBYTES(sp) + STORE s7,23*REGBYTES(sp) + STORE s8,24*REGBYTES(sp) + STORE s9,25*REGBYTES(sp) + STORE s10,26*REGBYTES(sp) + STORE s11,27*REGBYTES(sp) + STORE t3,28*REGBYTES(sp) + STORE t4,29*REGBYTES(sp) + STORE t5,30*REGBYTES(sp) + STORE t6,31*REGBYTES(sp) + STORE ra, 2*REGBYTES(sp) # sp + +#ifndef __riscv_hard_float + lw tp, (sp) # Move the emulated FCSR from x0's save slot into tp. +#endif + STORE x0, (sp) # Zero x0's save slot. + + # Invoke the handler. + jalr t0 + +#ifndef __riscv_hard_float + sw tp, (sp) # Move the emulated FCSR from tp into x0's save slot. +#endif + + # Restore all of the registers. + LOAD ra, 1*REGBYTES(sp) + LOAD gp, 3*REGBYTES(sp) + LOAD tp, 4*REGBYTES(sp) + LOAD t0, 5*REGBYTES(sp) + LOAD t1, 6*REGBYTES(sp) + LOAD t2, 7*REGBYTES(sp) + LOAD s0, 8*REGBYTES(sp) + LOAD s1, 9*REGBYTES(sp) + LOAD a1,11*REGBYTES(sp) + LOAD a2,12*REGBYTES(sp) + LOAD a3,13*REGBYTES(sp) + LOAD a4,14*REGBYTES(sp) + LOAD a5,15*REGBYTES(sp) + LOAD a6,16*REGBYTES(sp) + LOAD a7,17*REGBYTES(sp) + LOAD s2,18*REGBYTES(sp) + LOAD s3,19*REGBYTES(sp) + LOAD s4,20*REGBYTES(sp) + LOAD s5,21*REGBYTES(sp) + LOAD s6,22*REGBYTES(sp) + LOAD s7,23*REGBYTES(sp) + LOAD s8,24*REGBYTES(sp) + LOAD s9,25*REGBYTES(sp) + LOAD s10,26*REGBYTES(sp) + LOAD s11,27*REGBYTES(sp) + LOAD t3,28*REGBYTES(sp) + LOAD t4,29*REGBYTES(sp) + LOAD t5,30*REGBYTES(sp) + LOAD t6,31*REGBYTES(sp) + + bnez a0, 1f + + # Go back whence we came. + LOAD a0, 10*REGBYTES(sp) + csrw mscratch, sp + LOAD sp, 2*REGBYTES(sp) + mret + +1:# Redirect the trap to the supervisor. + LOAD a0, 10*REGBYTES(sp) + csrw mscratch, sp + LOAD sp, 2*REGBYTES(sp) + mrts + + .data + .align 6 +trap_table: + .word bad_trap + .word bad_trap + .word illegal_insn_trap + .word bad_trap + .word bad_trap + .word bad_trap + .word mcall_trap + .word bad_trap + .word misaligned_load_trap + .word machine_page_fault + .word misaligned_store_trap + .word machine_page_fault + .word htif_interrupt + .word bad_trap + .word bad_trap + .word bad_trap + .word bad_trap diff --git a/pk/minit.c b/pk/minit.c new file mode 100644 index 0000000..a8a449f --- /dev/null +++ b/pk/minit.c @@ -0,0 +1,62 @@ +#include "vm.h" +#include "mtrap.h" + +uintptr_t mem_size; +uint32_t num_harts; + +static void mstatus_init() +{ + uintptr_t ms = read_csr(mstatus); + ms = INSERT_FIELD(ms, MSTATUS_SA, UA_RV64); + ms = INSERT_FIELD(ms, MSTATUS_UA, UA_RV64); + ms = INSERT_FIELD(ms, MSTATUS_PRV1, PRV_S); + ms = INSERT_FIELD(ms, MSTATUS_IE1, 0); + ms = INSERT_FIELD(ms, MSTATUS_PRV2, PRV_U); + ms = INSERT_FIELD(ms, MSTATUS_IE2, 1); + ms = INSERT_FIELD(ms, MSTATUS_MPRV, PRV_M); + ms = INSERT_FIELD(ms, MSTATUS_VM, VM_SV43); + ms = INSERT_FIELD(ms, MSTATUS_FS, 3); + ms = INSERT_FIELD(ms, MSTATUS_XS, 3); + write_csr(mstatus, ms); + ms = read_csr(mstatus); + + if (EXTRACT_FIELD(ms, MSTATUS_PRV1) != PRV_S) { + ms = INSERT_FIELD(ms, MSTATUS_PRV1, PRV_U); + ms = INSERT_FIELD(ms, MSTATUS_IE1, 1); + write_csr(mstatus, ms); + + panic("supervisor support is required"); + } + + if (EXTRACT_FIELD(ms, MSTATUS_VM) != VM_SV43) + have_vm = 0; +} + +static void memory_init() +{ + if (mem_size == 0) + panic("could not determine memory capacity"); +} + +static void hart_init() +{ + if (num_harts == 0) + panic("could not determine number of harts"); + + if (num_harts != 1) + panic("TODO: SMP support"); +} + +void machine_init() +{ + file_init(); + + struct mainvars arg_buffer; + struct mainvars *args = parse_args(&arg_buffer); + + mstatus_init(); + memory_init(); + hart_init(); + vm_init(); + boot_loader(args); +} diff --git a/pk/mtrap.c b/pk/mtrap.c new file mode 100644 index 0000000..27936fe --- /dev/null +++ b/pk/mtrap.c @@ -0,0 +1,222 @@ +#include "mtrap.h" +#include "frontend.h" +#include "mcall.h" +#include "vm.h" +#include + +uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs) +{ + asm (".pushsection .rodata\n" + "illegal_insn_trap_table:\n" + " .word truly_illegal_insn\n" + " .word emulate_float_load\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_float_store\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_mul_div\n" + " .word truly_illegal_insn\n" + " .word emulate_mul_div32\n" + " .word truly_illegal_insn\n" + " .word emulate_fmadd\n" + " .word emulate_fmsub\n" + " .word emulate_fnmsub\n" + " .word emulate_fnmadd\n" + " .word emulate_fp\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word emulate_system\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .word truly_illegal_insn\n" + " .popsection"); + + uintptr_t mstatus = read_csr(mstatus); + uintptr_t mepc = read_csr(mepc); + + insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); + + if (fetch.error || (fetch.insn & 3) != 3) + return -1; + + extern int32_t illegal_insn_trap_table[]; + int32_t* pf = (void*)illegal_insn_trap_table + (fetch.insn & 0x7c); + emulation_func f = (emulation_func)(uintptr_t)*pf; + return f(mcause, regs, fetch.insn, mstatus, mepc); +} + +void __attribute__((noreturn)) bad_trap() +{ + panic("machine mode: unhandlable trap %d @ %p", read_csr(mcause), read_csr(mepc)); +} + +uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs) +{ + uintptr_t fromhost = swap_csr(fromhost, 0); + if (!fromhost) + return 0; + + uintptr_t dev = FROMHOST_DEV(fromhost); + uintptr_t cmd = FROMHOST_CMD(fromhost); + uintptr_t data = FROMHOST_DATA(fromhost); + + sbi_device_message* m = MAILBOX()->device_request_queue_head; + sbi_device_message* prev = NULL; + for (size_t i = 0, n = MAILBOX()->device_request_queue_size; i < n; i++) { + if (!supervisor_paddr_valid(m, sizeof(*m)) + && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M) + panic("htif: page fault"); + + sbi_device_message* next = (void*)m->sbi_private_data; + if (m->dev == dev && m->cmd == cmd) { + m->data = data; + + // dequeue from request queue + if (prev) + prev->sbi_private_data = (uintptr_t)next; + else + MAILBOX()->device_request_queue_head = next; + MAILBOX()->device_request_queue_size = n-1; + m->sbi_private_data = 0; + + // enqueue to response queue + if (MAILBOX()->device_response_queue_tail) + MAILBOX()->device_response_queue_tail->sbi_private_data = (uintptr_t)m; + else + MAILBOX()->device_response_queue_head = m; + MAILBOX()->device_response_queue_tail = m; + + // signal software interrupt + set_csr(mstatus, MSTATUS_SSIP); + return 0; + } + + prev = m; + m = (void*)atomic_read(&m->sbi_private_data); + } + + panic("htif: no record"); +} + +static uintptr_t mcall_console_putchar(uint8_t ch) +{ + while (swap_csr(tohost, TOHOST_CMD(1, 1, ch)) != 0); + while (1) { + uintptr_t fromhost = read_csr(fromhost); + if (FROMHOST_DEV(fromhost) != 1 || FROMHOST_CMD(fromhost) != 1) { + if (fromhost) + htif_interrupt(0, 0); + continue; + } + write_csr(fromhost, 0); + break; + } + return 0; +} + +#define printm(str, ...) ({ \ + char buf[1024], *p = buf; sprintk(buf, str, __VA_ARGS__); \ + while (*p) mcall_console_putchar(*p++); }) + +static uintptr_t mcall_dev_req(sbi_device_message *m) +{ + //printm("req %d %p\n", MAILBOX()->device_request_queue_size, m); +#ifndef __riscv64 + return -ENOSYS; // TODO: RV32 HTIF? +#else + if (!supervisor_paddr_valid(m, sizeof(*m)) + && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M) + return -EFAULT; + + if ((m->dev > 0xFFU) | (m->cmd > 0xFFU) | (m->data > 0x0000FFFFFFFFFFFFU)) + return -EINVAL; + + while (swap_csr(tohost, TOHOST_CMD(m->dev, m->cmd, m->data)) != 0) + ; + + m->sbi_private_data = (uintptr_t)MAILBOX()->device_request_queue_head; + MAILBOX()->device_request_queue_head = m; + MAILBOX()->device_request_queue_size++; + + return 0; +#endif +} + +static uintptr_t mcall_dev_resp() +{ + htif_interrupt(0, 0); + + sbi_device_message* m = MAILBOX()->device_response_queue_head; + if (m) { + //printm("resp %p\n", m); + sbi_device_message* next = (void*)atomic_read(&m->sbi_private_data); + MAILBOX()->device_response_queue_head = next; + if (!next) + MAILBOX()->device_response_queue_tail = 0; + } + return (uintptr_t)m; +} + +uintptr_t mcall_trap(uintptr_t mcause, uintptr_t* regs) +{ + if (EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) < PRV_S) + return -1; + + uintptr_t n = regs[10], arg0 = regs[11], retval; + switch (n) + { + case MCALL_HART_ID: + retval = 0; // TODO + break; + case MCALL_CONSOLE_PUTCHAR: + retval = mcall_console_putchar(arg0); + break; + case MCALL_SEND_DEVICE_REQUEST: + retval = mcall_dev_req((sbi_device_message*)arg0); + break; + case MCALL_RECEIVE_DEVICE_RESPONSE: + retval = mcall_dev_resp(); + break; + default: + retval = -ENOSYS; + break; + } + regs[10] = retval; + write_csr(mepc, read_csr(mepc) + 4); + return 0; +} + +uintptr_t machine_page_fault(uintptr_t mcause, uintptr_t* regs) +{ + // See if this trap occurred when emulating an instruction on behalf of + // a lower privilege level. + extern int32_t unprivileged_access_ranges[]; + extern int32_t unprivileged_access_ranges_end[]; + uintptr_t mepc = read_csr(mepc); + + int32_t* p = unprivileged_access_ranges; + do { + if (mepc >= p[0] && mepc < p[1]) { + // Yes. Skip to the end of the unprivileged access region. + // Mark t0 zero so the emulation routine knows this occurred. + regs[5] = 0; + write_csr(mepc, p[1]); + return 0; + } + p += 2; + } while (p < unprivileged_access_ranges_end); + + // No. We're boned. + bad_trap(); +} diff --git a/pk/mtrap.h b/pk/mtrap.h new file mode 100644 index 0000000..f5b0cfa --- /dev/null +++ b/pk/mtrap.h @@ -0,0 +1,232 @@ +#ifndef _PK_MTRAP_H +#define _PK_MTRAP_H + +#include "pk.h" +#include "bits.h" +#include "encoding.h" + +#ifndef __ASSEMBLER__ + +#include "sbi.h" + +#define GET_MACRO(_1,_2,_3,_4,NAME,...) NAME + +#define unpriv_mem_access(a, b, c, d, ...) GET_MACRO(__VA_ARGS__, unpriv_mem_access3, unpriv_mem_access2, unpriv_mem_access1, unpriv_mem_access0)(a, b, c, d, __VA_ARGS__) +#define unpriv_mem_access0(a, b, c, d, e) ({ uintptr_t z = 0, z1 = 0, z2 = 0; unpriv_mem_access_base(a, b, c, d, e, z, z1, z2); }) +#define unpriv_mem_access1(a, b, c, d, e, f) ({ uintptr_t z = 0, z1 = 0; unpriv_mem_access_base(a, b, c, d, e, f, z, z1); }) +#define unpriv_mem_access2(a, b, c, d, e, f, g) ({ uintptr_t z = 0; unpriv_mem_access_base(a, b, c, d, e, f, g, z); }) +#define unpriv_mem_access3(a, b, c, d, e, f, g, h) unpriv_mem_access_base(a, b, c, d, e, f, g, h) +#define unpriv_mem_access_base(mstatus, mepc, code, o0, o1, i0, i1, i2) ({ \ + register uintptr_t result asm("t0"); \ + uintptr_t unused1, unused2 __attribute__((unused)); \ + uintptr_t scratch = ~(mstatus) & MSTATUS_PRV1; \ + scratch <<= CONST_CTZ32(MSTATUS_MPRV) - CONST_CTZ32(MSTATUS_PRV1); \ + asm volatile ("csrrc %[result], mstatus, %[scratch]\n" \ + "98: " code "\n" \ + "99: csrs mstatus, %[scratch]\n" \ + ".pushsection .unpriv,\"a\",@progbits\n" \ + ".word 98b; .word 99b\n" \ + ".popsection" \ + : [o0] "=&r"(o0), [o1] "=&r"(o1), \ + [result] "+&r"(result) \ + : [i0] "rJ"(i0), [i1] "rJ"(i1), [i2] "rJ"(i2), \ + [scratch] "r"(scratch), [mepc] "r"(mepc)); \ + unlikely(!result); }) + +#define restore_mstatus(mstatus, mepc) ({ \ + uintptr_t scratch; \ + uintptr_t mask = MSTATUS_PRV1 | MSTATUS_IE1 | MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_PRV3 | MSTATUS_IE3; \ + asm volatile("csrc mstatus, %[mask];" \ + "csrw mepc, %[mepc];" \ + "and %[scratch], %[mask], %[mstatus];" \ + "csrs mstatus, %[scratch]" \ + : [scratch] "=r"(scratch) \ + : [mstatus] "r"(mstatus), [mepc] "r"(mepc), \ + [mask] "r"(mask)); }) + +#define unpriv_load_1(ptr, dest, mstatus, mepc, type, insn) ({ \ + type value, dummy; void* addr = (ptr); \ + uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %[value], (%[addr])", value, dummy, addr); \ + (dest) = (typeof(dest))(uintptr_t)value; \ + res; }) +#define unpriv_load(ptr, dest) ({ \ + uintptr_t res; \ + uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \ + if (sizeof(*ptr) == 1) res = unpriv_load_1(ptr, dest, mstatus, mepc, int8_t, "lb"); \ + else if (sizeof(*ptr) == 2) res = unpriv_load_1(ptr, dest, mstatus, mepc, int16_t, "lh"); \ + else if (sizeof(*ptr) == 4) res = unpriv_load_1(ptr, dest, mstatus, mepc, int32_t, "lw"); \ + else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_load_1(ptr, dest, mstatus, mepc, int64_t, "ld"); \ + else __builtin_trap(); \ + if (res) restore_mstatus(mstatus, mepc); \ + res; }) + +#define unpriv_store_1(ptr, src, mstatus, mepc, type, insn) ({ \ + type dummy1, dummy2, value = (type)(uintptr_t)(src); void* addr = (ptr); \ + uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %z[value], (%[addr])", dummy1, dummy2, addr, value); \ + res; }) +#define unpriv_store(ptr, src) ({ \ + uintptr_t res; \ + uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \ + if (sizeof(*ptr) == 1) res = unpriv_store_1(ptr, src, mstatus, mepc, int8_t, "sb"); \ + else if (sizeof(*ptr) == 2) res = unpriv_store_1(ptr, src, mstatus, mepc, int16_t, "sh"); \ + else if (sizeof(*ptr) == 4) res = unpriv_store_1(ptr, src, mstatus, mepc, int32_t, "sw"); \ + else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_store_1(ptr, src, mstatus, mepc, int64_t, "sd"); \ + else __builtin_trap(); \ + if (res) restore_mstatus(mstatus, mepc); \ + res; }) + +typedef uint32_t insn_t; +typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, uintptr_t); +#define DECLARE_EMULATION_FUNC(name) uintptr_t name(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc) + +#define GET_REG(insn, pos, regs) ({ \ + int mask = (1 << (5+LOG_REGBYTES)) - (1 << LOG_REGBYTES); \ + (uintptr_t*)((uintptr_t)regs + (((insn) >> ((pos) - LOG_REGBYTES)) & mask)); \ +}) +#define GET_RS1(insn, regs) (*GET_REG(insn, 15, regs)) +#define GET_RS2(insn, regs) (*GET_REG(insn, 20, regs)) +#define SET_RD(insn, regs, val) (*GET_REG(insn, 7, regs) = (val)) +#define IMM_I(insn) ((int32_t)(insn) >> 20) +#define IMM_S(insn) (((int32_t)(insn) >> 25 << 5) | (int32_t)(((insn) >> 7) & 0x1f)) +#define MASK_FUNCT3 0x7000 + +#define GET_PRECISION(insn) (((insn) >> 25) & 3) +#define GET_RM(insn) (((insn) >> 12) & 7) +#define PRECISION_S 0 +#define PRECISION_D 1 + +#ifdef __riscv_hard_float +# define GET_F32_REG(insn, pos, regs) ({ \ + register int32_t value asm("a0") = ((insn) >> ((pos)-3)) & 0xf8; \ + uintptr_t tmp; \ + asm ("1: auipc %0, %%pcrel_hi(get_f32_reg); add %0, %0, %1; jalr t0, %0, %%pcrel_lo(1b)" : "=&r"(tmp), "+&r"(value) :: "t0"); \ + value; }) +# define SET_F32_REG(insn, pos, regs, val) ({ \ + register uint32_t value asm("a0") = (val); \ + uintptr_t offset = ((insn) >> ((pos)-3)) & 0xf8; \ + uintptr_t tmp; \ + asm volatile ("1: auipc %0, %%pcrel_hi(put_f32_reg); add %0, %0, %2; jalr t0, %0, %%pcrel_lo(1b)" : "=&r"(tmp) : "r"(value), "r"(offset) : "t0"); }) +# define GET_F64_REG(insn, pos, regs) ({ \ + register uintptr_t value asm("a0") = ((insn) >> ((pos)-3)) & 0xf8; \ + uintptr_t tmp; \ + asm ("1: auipc %0, %%pcrel_hi(get_f64_reg); add %0, %0, %1; jalr t0, %0, %%pcrel_lo(1b)" : "=&r"(tmp), "+&r"(value) :: "t0"); \ + sizeof(uintptr_t) == 4 ? *(int64_t*)value : (int64_t)value; }) +# define SET_F64_REG(insn, pos, regs, val) ({ \ + uint64_t __val = (val); \ + register uintptr_t value asm("a0") = sizeof(uintptr_t) == 4 ? (uintptr_t)&__val : (uintptr_t)__val; \ + uintptr_t offset = ((insn) >> ((pos)-3)) & 0xf8; \ + uintptr_t tmp; \ + asm volatile ("1: auipc %0, %%pcrel_hi(put_f64_reg); add %0, %0, %2; jalr t0, %0, %%pcrel_lo(1b)" : "=&r"(tmp) : "r"(value), "r"(offset) : "t0"); }) +# define GET_FCSR() read_csr(fcsr) +# define SET_FCSR(value) write_csr(fcsr, (value)) +# define GET_FRM() read_csr(frm) +# define SET_FRM(value) write_csr(frm, (value)) +# define GET_FFLAGS() read_csr(fflags) +# define SET_FFLAGS(value) write_csr(fflags, (value)) + +# define SETUP_STATIC_ROUNDING(insn) ({ \ + register long tp asm("tp") = read_csr(frm); \ + if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) ; \ + else if (GET_RM(insn) > 4) return -1; \ + else tp = GET_RM(insn); \ + asm volatile ("":"+r"(tp)); }) +# define softfloat_raiseFlags(which) set_csr(fflags, which) +# define softfloat_roundingMode ({ register int tp asm("tp"); tp; }) +#else +# define GET_F64_REG(insn, pos, regs) (((int64_t*)(&(regs)[32]))[((insn) >> (pos)) & 0x1f]) +# define SET_F64_REG(insn, pos, regs, val) (GET_F64_REG(insn, pos, regs) = (val)) +# define GET_F32_REG(insn, pos, regs) (*(int32_t*)GET_F64_REG(insn, pos, regs)) +# define SET_F32_REG(insn, pos, regs, val) (GET_F32_REG(insn, pos, regs) = (val)) +# define GET_FCSR() ({ register int tp asm("tp"); tp & 0xFF; }) +# define SET_FCSR(value) ({ asm volatile("add tp, x0, %0" :: "rI"((value) & 0xFF)); }) +# define GET_FRM() (GET_FCSR() >> 5) +# define SET_FRM(value) SET_FCSR(GET_FFLAGS() | ((value) << 5)) +# define GET_FFLAGS() (GET_FCSR() & 0x1F) +# define SET_FFLAGS(value) SET_FCSR((GET_FRM() << 5) | ((value) & 0x1F)) + +# define SETUP_STATIC_ROUNDING(insn) ({ \ + register int tp asm("tp"); tp &= 0xFF; \ + if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) tp |= tp << 8; \ + else if (GET_RM(insn) > 4) return -1; \ + else tp |= GET_RM(insn) << 13; \ + asm volatile ("":"+r"(tp)); }) +# define softfloat_raiseFlags(which) ({ asm volatile ("or tp, tp, %0" :: "rI"(which)); }) +# define softfloat_roundingMode ({ register int tp asm("tp"); tp >> 13; }) +#endif + +#define GET_F32_RS1(insn, regs) (GET_F32_REG(insn, 15, regs)) +#define GET_F32_RS2(insn, regs) (GET_F32_REG(insn, 20, regs)) +#define GET_F32_RS3(insn, regs) (GET_F32_REG(insn, 27, regs)) +#define GET_F64_RS1(insn, regs) (GET_F64_REG(insn, 15, regs)) +#define GET_F64_RS2(insn, regs) (GET_F64_REG(insn, 20, regs)) +#define GET_F64_RS3(insn, regs) (GET_F64_REG(insn, 27, regs)) +#define SET_F32_RD(insn, regs, val) (SET_F32_REG(insn, 15, regs, val), SET_FS_DIRTY()) +#define SET_F64_RD(insn, regs, val) (SET_F64_REG(insn, 15, regs, val), SET_FS_DIRTY()) +#define SET_FS_DIRTY() set_csr(mstatus, MSTATUS_FS) + +typedef struct { + uintptr_t error; + insn_t insn; +} insn_fetch_t; + +static insn_fetch_t __attribute__((always_inline)) + get_insn(uintptr_t mcause, uintptr_t mstatus, uintptr_t mepc) +{ + insn_fetch_t fetch; + insn_t insn; + +#ifdef __rvc + int rvc_mask = 3, insn_hi; + fetch.error = unpriv_mem_access(mstatus, mepc, + "mv %[insn], %[rvc_mask];" + "lhu %[insn], 0(%[mepc]);" + "and %[insn_hi], %[insn], %[rvc_mask];" + "bne %[insn_hi], %[rvc_mask], 1f;" + "lh %[insn_hi], 2(%[mepc]);" + "sll %[insn_hi], %[insn_hi], 16;" + "or %[insn], %[insn], %[insn_hi];" + "1:", + insn, insn_hi, rvc_mask); +#else + fetch.error = unpriv_mem_access(mstatus, mepc, + "lw %[insn], 0(%[mepc])", + insn, unused1); +#endif + fetch.insn = insn; + + if (unlikely(fetch.error)) { + // we've messed up mstatus, mepc, and mcause, so restore them all + restore_mstatus(mstatus, mepc); + write_csr(mcause, mcause); + } + + return fetch; +} + +typedef struct { + sbi_device_message* device_request_queue_head; + size_t device_request_queue_size; + sbi_device_message* device_response_queue_head; + sbi_device_message* device_response_queue_tail; +} mailbox_t; + +#define MACHINE_STACK_TOP() ({ \ + register uintptr_t sp asm ("sp"); \ + (void*)((sp + RISCV_PGSIZE) & -RISCV_PGSIZE); }) +#define MAILBOX() ((mailbox_t*)(MACHINE_STACK_TOP() - MAILBOX_SIZE)) + +#endif // !__ASSEMBLER__ + +#define MACHINE_STACK_SIZE RISCV_PGSIZE +#define MENTRY_FRAME_SIZE (INTEGER_CONTEXT_SIZE + SOFT_FLOAT_CONTEXT_SIZE \ + + MAILBOX_SIZE) + +#ifdef __riscv_hard_float +# define SOFT_FLOAT_CONTEXT_SIZE 0 +#else +# define SOFT_FLOAT_CONTEXT_SIZE (8 * 32) +#endif +#define MAILBOX_SIZE 64 +#define INTEGER_CONTEXT_SIZE (32 * REGBYTES) + +#endif diff --git a/pk/pk.S b/pk/pk.S index f657343..e69de29 100644 --- a/pk/pk.S +++ b/pk/pk.S @@ -1,15 +0,0 @@ -// See LICENSE for license details. - -#include "encoding.h" - -.section .text,"ax",@progbits -.globl _start -_start: - la gp, _gp - la sp, stack_top - csrw mscratch, sp - - li t0, MSTATUS_FS | MSTATUS_XS - csrs mstatus, t0 - - call boot diff --git a/pk/pk.h b/pk/pk.h index 011bff3..06d68ee 100644 --- a/pk/pk.h +++ b/pk/pk.h @@ -19,6 +19,11 @@ typedef struct long insn; } trapframe_t; +struct mainvars { + uint64_t argc; + uint64_t argv[127]; // this space is shared with the arg strings themselves +}; + #define panic(s,...) do { do_panic(s"\n", ##__VA_ARGS__); } while(0) #define kassert(cond) do { if(!(cond)) kassert_fail(""#cond); } while(0) void do_panic(const char* s, ...) __attribute__((noreturn)); @@ -26,21 +31,24 @@ void kassert_fail(const char* s) __attribute__((noreturn)); #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define CLAMP(a, lo, hi) MIN(MAX(a, lo), hi) -#define ROUNDUP(a, b) ((((a)-1)/(b)+1)*(b)) -#define ROUNDDOWN(a, b) ((a)/(b)*(b)) + +#define likely(x) __builtin_expect((x), 1) +#define unlikely(x) __builtin_expect((x), 0) + +#define EXTRACT_FIELD(val, which) (((val) & (which)) / ((which) & ~((which)-1))) +#define INSERT_FIELD(val, which, fieldval) (((val) & ~(which)) | ((fieldval) * ((which) & ~((which)-1)))) #ifdef __cplusplus extern "C" { #endif +extern uintptr_t mem_size; extern int have_vm; -extern uint32_t mem_mb; -int emulate_fp(trapframe_t*); -void fp_init(); - -int emulate_int(trapframe_t*); +extern uint32_t num_harts; +struct mainvars* parse_args(struct mainvars*); void printk(const char* s, ...); +void sprintk(char* out, const char* s, ...); void init_tf(trapframe_t*, long pc, long sp, int user64); void pop_tf(trapframe_t*) __attribute__((noreturn)); void dump_tf(trapframe_t*); @@ -50,21 +58,25 @@ void handle_misaligned_load(trapframe_t*); void handle_misaligned_store(trapframe_t*); void handle_fault_load(trapframe_t*); void handle_fault_store(trapframe_t*); -void boot(); +uintptr_t boot_loader(struct mainvars*); typedef struct { int elf64; int phent; int phnum; - size_t user_min; + int is_supervisor; + size_t phdr; + size_t phdr_size; + size_t first_free_paddr; + size_t first_user_vaddr; + size_t first_vaddr_after_user; + size_t bias; size_t entry; size_t brk_min; size_t brk; size_t brk_max; size_t mmap_max; size_t stack_bottom; - size_t phdr; - size_t phdr_top; size_t stack_top; size_t t0; } elf_info; @@ -89,6 +101,9 @@ extern char* uarch_counter_names[NUM_COUNTERS]; } #endif -#endif +#endif // !__ASSEMBLER__ + +#define ROUNDUP(a, b) ((((a)-1)/(b)+1)*(b)) +#define ROUNDDOWN(a, b) ((a)/(b)*(b)) #endif diff --git a/pk/pk.ld b/pk/pk.ld index 9167f3f..6f20a56 100644 --- a/pk/pk.ld +++ b/pk/pk.ld @@ -1,6 +1,6 @@ OUTPUT_ARCH( "riscv" ) -ENTRY( _start ) +ENTRY( mentry ) SECTIONS { @@ -10,7 +10,7 @@ SECTIONS /*--------------------------------------------------------------------*/ /* Begining of code and text segment */ - . = 0x00002000; + . = 0x0; _ftext = .; PROVIDE( eprol = . ); @@ -34,6 +34,10 @@ SECTIONS *(.rodata) *(.rodata.*) *(.gnu.linkonce.r.*) + + unprivileged_access_ranges = .; + *(.unpriv) + unprivileged_access_ranges_end = .; } /* End of code and read-only segment */ @@ -46,7 +50,6 @@ SECTIONS /* Start of initialized data segment */ . = ALIGN(16); - PROVIDE( _gp = . + 0x800 ); _fdata = .; /* data: Writable data */ @@ -54,7 +57,9 @@ SECTIONS { *(.data) *(.data.*) + *(.srodata*) *(.gnu.linkonce.d.*) + *(.comment) } /* End of initialized data segment */ @@ -80,10 +85,16 @@ SECTIONS { *(.bss) *(.bss.*) + *(.sbss*) *(.gnu.linkonce.b.*) *(COMMON) } + .sbi : + { + *(.sbi) + } + /* End of uninitialized data segment (used by syscalls.c for heap) */ PROVIDE( end = . ); _end = .; diff --git a/pk/pk.mk.in b/pk/pk.mk.in index 9c105ce..72f4412 100644 --- a/pk/pk.mk.in +++ b/pk/pk.mk.in @@ -1,12 +1,9 @@ pk_subproject_deps = \ - softfloat_riscv \ softfloat \ - softint \ pk_hdrs = \ - pk.h \ + mtrap.h \ encoding.h \ - fp.h \ atomic.h \ file.h \ frontend.h \ @@ -14,21 +11,26 @@ pk_hdrs = \ vm.h \ pk_c_srcs = \ + mtrap.c \ + minit.c \ + emulation.c \ + sbi_impl.c \ init.c \ file.c \ syscall.c \ handlers.c \ frontend.c \ - fp.c \ - int.c \ elf.c \ console.c \ vm.c \ string.c \ pk_asm_srcs = \ + mentry.S \ entry.S \ fp_asm.S \ + sbi_entry.S \ + sbi.S \ pk_test_srcs = diff --git a/pk/sbi.S b/pk/sbi.S new file mode 100644 index 0000000..a8d5066 --- /dev/null +++ b/pk/sbi.S @@ -0,0 +1,7 @@ +.globl sbi_hart_id; sbi_hart_id = -2048 +.globl sbi_num_harts; sbi_num_harts = -2032 +.globl sbi_query_memory; sbi_query_memory = -2016 +.globl sbi_console_putchar; sbi_console_putchar = -2000 +.globl sbi_send_device_request; sbi_send_device_request = -1984 +.globl sbi_receive_device_response; sbi_receive_device_response = -1968 +.globl sbi_send_ipi; sbi_send_ipi = -1952 diff --git a/pk/sbi.h b/pk/sbi.h new file mode 100644 index 0000000..b9e60b4 --- /dev/null +++ b/pk/sbi.h @@ -0,0 +1,27 @@ +#ifndef _ASM_RISCV_SBI_H +#define _ASM_RISCV_SBI_H + +typedef struct { + unsigned long base; + unsigned long size; + unsigned long node_id; +} memory_block_info; + +unsigned long sbi_query_memory(unsigned long id, memory_block_info *p); + +unsigned long sbi_hart_id(void); +unsigned long sbi_num_harts(void); +void sbi_send_ipi(uintptr_t hart_id); +void sbi_console_putchar(unsigned char ch); + +typedef struct { + unsigned long dev; + unsigned long cmd; + unsigned long data; + unsigned long sbi_private_data; +} sbi_device_message; + +unsigned long sbi_send_device_request(uintptr_t req); +uintptr_t sbi_receive_device_response(void); + +#endif diff --git a/pk/sbi_entry.S b/pk/sbi_entry.S new file mode 100644 index 0000000..33e998a --- /dev/null +++ b/pk/sbi_entry.S @@ -0,0 +1,61 @@ +#include "encoding.h" +#include "mcall.h" + + .section .sbi,"ax",@progbits + .align RISCV_PGSHIFT + .globl sbi_base +sbi_base: + + # TODO: figure out something better to do with this space. It's not + # protected from the OS, so beware. + .skip RISCV_PGSIZE - 2048 + + # hart_id + .align 4 + li a0, MCALL_HART_ID + mcall + ret + + # num_harts + .align 4 + lw a0, num_harts + ret + + # query_memory + .align 4 + j __sbi_query_memory + + # console_putchar + .align 4 + mv a1, a0 + li a0, MCALL_CONSOLE_PUTCHAR + mcall + ret + + # send_device_request + .align 4 + mv a1, a0 + li a0, MCALL_SEND_DEVICE_REQUEST + mcall + ret + + # receive_device_response + .align 4 + mv a1, a0 + li a0, MCALL_RECEIVE_DEVICE_RESPONSE + mcall + ret + + # send ipi + .align 4 + csrw send_ipi, a0 + ret + + # end of SBI trampolines + + .globl do_mcall +do_mcall: + mcall + ret + + .align RISCV_PGSHIFT diff --git a/pk/sbi_impl.c b/pk/sbi_impl.c new file mode 100644 index 0000000..03a56bc --- /dev/null +++ b/pk/sbi_impl.c @@ -0,0 +1,23 @@ +#include "pk.h" +#include "vm.h" +#include "frontend.h" +#include "sbi.h" +#include "mcall.h" +#include + +#define sbi_printk(str, ...) ({ \ + char buf[1024]; /* XXX */ \ + sprintk(buf, str, __VA_ARGS__); \ + for (size_t i = 0; buf[i]; i++) \ + do_mcall(MCALL_CONSOLE_PUTCHAR, buf[i]); }) + +uintptr_t __sbi_query_memory(uintptr_t id, memory_block_info *p) +{ + if (id == 0) { + p->base = current.first_free_paddr; + p->size = mem_size - p->base; + return 0; + } + + return -1; +} diff --git a/pk/string.c b/pk/string.c index 40b62e9..b1b9abc 100644 --- a/pk/string.c +++ b/pk/string.c @@ -1,5 +1,6 @@ #include #include +#include void* memcpy(void* dest, const void* src, size_t len) { @@ -51,3 +52,24 @@ char* strcpy(char* dest, const char* src) ; return dest; } + +long atol(const char* str) +{ + long res = 0; + int sign = 0; + + while (*str == ' ') + str++; + + if (*str == '-' || *str == '+') { + sign = *str == '-'; + str++; + } + + while (*str) { + res *= 10; + res += *str++ - '0'; + } + + return sign ? -res : res; +} diff --git a/pk/syscall.c b/pk/syscall.c index 1b097b7..2e14d22 100644 --- a/pk/syscall.c +++ b/pk/syscall.c @@ -413,7 +413,7 @@ static int sys_stub_nosys() return -ENOSYS; } -long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, long n) +long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, unsigned long n) { const static void* syscall_table[] = { [SYS_exit] = sys_exit, @@ -421,17 +421,11 @@ long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, long n) [SYS_read] = sys_read, [SYS_pread] = sys_pread, [SYS_write] = sys_write, - [SYS_open] = sys_open, [SYS_openat] = sys_openat, [SYS_close] = sys_close, [SYS_fstat] = sys_fstat, [SYS_lseek] = sys_lseek, - [SYS_stat] = sys_stat, - [SYS_lstat] = sys_lstat, [SYS_fstatat] = sys_fstatat, - [SYS_link] = sys_link, - [SYS_unlink] = sys_unlink, - [SYS_mkdir] = sys_mkdir, [SYS_linkat] = sys_linkat, [SYS_unlinkat] = sys_unlinkat, [SYS_mkdirat] = sys_mkdirat, @@ -448,11 +442,9 @@ long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, long n) [SYS_mremap] = sys_mremap, [SYS_mprotect] = sys_mprotect, [SYS_rt_sigaction] = sys_rt_sigaction, - [SYS_time] = sys_time, [SYS_gettimeofday] = sys_gettimeofday, [SYS_times] = sys_times, [SYS_writev] = sys_writev, - [SYS_access] = sys_access, [SYS_faccessat] = sys_faccessat, [SYS_fcntl] = sys_fcntl, [SYS_getdents] = sys_getdents, @@ -462,9 +454,26 @@ long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, long n) [SYS_ioctl] = sys_stub_nosys, }; - if(n >= ARRAY_SIZE(syscall_table) || !syscall_table[n]) + const static void* old_syscall_table[] = { + [-OLD_SYSCALL_THRESHOLD + SYS_open] = sys_open, + [-OLD_SYSCALL_THRESHOLD + SYS_link] = sys_link, + [-OLD_SYSCALL_THRESHOLD + SYS_unlink] = sys_unlink, + [-OLD_SYSCALL_THRESHOLD + SYS_mkdir] = sys_mkdir, + [-OLD_SYSCALL_THRESHOLD + SYS_access] = sys_access, + [-OLD_SYSCALL_THRESHOLD + SYS_stat] = sys_stat, + [-OLD_SYSCALL_THRESHOLD + SYS_lstat] = sys_lstat, + [-OLD_SYSCALL_THRESHOLD + SYS_time] = sys_time, + }; + + syscall_t f = 0; + + if (n < ARRAY_SIZE(syscall_table)) + f = syscall_table[n]; + else if (n - OLD_SYSCALL_THRESHOLD < ARRAY_SIZE(old_syscall_table)) + f = old_syscall_table[n - OLD_SYSCALL_THRESHOLD]; + + if (!f) panic("bad syscall #%ld!",n); - long r = ((syscall_t)syscall_table[n])(a0, a1, a2, a3, a4, a5, n); - return r; + return f(a0, a1, a2, a3, a4, a5, n); } diff --git a/pk/syscall.h b/pk/syscall.h index 6632bc9..98441f0 100644 --- a/pk/syscall.h +++ b/pk/syscall.h @@ -9,24 +9,17 @@ #define SYS_kill 129 #define SYS_read 63 #define SYS_write 64 -#define SYS_open 1024 #define SYS_openat 56 #define SYS_close 57 #define SYS_lseek 62 #define SYS_brk 214 -#define SYS_link 1025 -#define SYS_unlink 1026 -#define SYS_mkdir 1030 #define SYS_linkat 37 #define SYS_unlinkat 35 #define SYS_mkdirat 34 #define SYS_chdir 49 #define SYS_getcwd 17 -#define SYS_stat 1038 #define SYS_fstat 80 -#define SYS_lstat 1039 #define SYS_fstatat 79 -#define SYS_access 1033 #define SYS_faccessat 48 #define SYS_pread 67 #define SYS_pwrite 68 @@ -39,7 +32,6 @@ #define SYS_munmap 215 #define SYS_mremap 216 #define SYS_mprotect 226 -#define SYS_time 1062 #define SYS_getmainvars 2011 #define SYS_rt_sigaction 134 #define SYS_writev 66 @@ -52,6 +44,16 @@ #define SYS_rt_sigprocmask 135 #define SYS_ioctl 29 +#define OLD_SYSCALL_THRESHOLD 1024 +#define SYS_open 1024 +#define SYS_link 1025 +#define SYS_unlink 1026 +#define SYS_mkdir 1030 +#define SYS_access 1033 +#define SYS_stat 1038 +#define SYS_lstat 1039 +#define SYS_time 1062 + #define IS_ERR_VALUE(x) ((unsigned long)(x) >= (unsigned long)-4096) #define ERR_PTR(x) ((void*)(long)(x)) #define PTR_ERR(x) ((long)(x)) @@ -59,6 +61,6 @@ #define AT_FDCWD -100 void sys_exit(int code) __attribute__((noreturn)); -long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, long n); +long do_syscall(long a0, long a1, long a2, long a3, long a4, long a5, unsigned long n); #endif diff --git a/pk/vm.c b/pk/vm.c index 290a12c..c54417f 100644 --- a/pk/vm.c +++ b/pk/vm.c @@ -10,7 +10,7 @@ typedef struct { size_t length; file_t* file; size_t offset; - size_t refcnt; + unsigned refcnt; int prot; } vmr_t; @@ -26,20 +26,21 @@ static size_t free_pages; static uintptr_t __page_alloc() { - if (next_free_page == free_pages) - return 0; + kassert(next_free_page != free_pages); uintptr_t addr = first_free_page + RISCV_PGSIZE * next_free_page++; memset((void*)addr, 0, RISCV_PGSIZE); return addr; } static vmr_t* __vmr_alloc(uintptr_t addr, size_t length, file_t* file, - size_t offset, size_t refcnt, int prot) + size_t offset, unsigned refcnt, int prot) { for (vmr_t* v = vmrs; v < vmrs + MAX_VMR; v++) { if (v->refcnt == 0) { + if (file) + file_incref(file); v->addr = addr; v->length = length; v->file = file; @@ -52,7 +53,7 @@ static vmr_t* __vmr_alloc(uintptr_t addr, size_t length, file_t* file, return NULL; } -static void __vmr_decref(vmr_t* v, size_t dec) +static void __vmr_decref(vmr_t* v, unsigned dec) { if ((v->refcnt -= dec) == 0) { @@ -95,9 +96,18 @@ static pte_t pte_create(uintptr_t ppn, int kprot, int uprot) return super_pte_create(ppn, kprot, uprot, 0); } -static __attribute__((always_inline)) pte_t* __walk_internal(uintptr_t addr, int create) +static void __maybe_create_root_page_table() +{ + if (root_page_table) + return; + root_page_table = (void*)__page_alloc(); + if (have_vm) + write_csr(sptbr, root_page_table); +} +static pte_t* __walk_internal(uintptr_t addr, int create) { const size_t pte_per_page = RISCV_PGSIZE/sizeof(void*); + __maybe_create_root_page_table(); pte_t* t = root_page_table; for (unsigned i = RISCV_PGLEVELS-1; i > 0; i--) @@ -108,8 +118,6 @@ static __attribute__((always_inline)) pte_t* __walk_internal(uintptr_t addr, int if (!create) return 0; uintptr_t page = __page_alloc(); - if (page == 0) - return 0; t[idx] = ptd_create(ppn(page)); } else @@ -138,16 +146,15 @@ static int __va_avail(uintptr_t vaddr) static uintptr_t __vm_alloc(size_t npage) { uintptr_t start = current.brk, end = current.mmap_max - npage*RISCV_PGSIZE; - for (uintptr_t a = start; a <= end; a += RISCV_PGSIZE) + for (uintptr_t a = end; a >= start; a -= RISCV_PGSIZE) { if (!__va_avail(a)) continue; - uintptr_t first = a, last = a + (npage-1) * RISCV_PGSIZE; - for (a = last; a > first && __va_avail(a); a -= RISCV_PGSIZE) + uintptr_t last = a, first = a - (npage-1) * RISCV_PGSIZE; + for (a = first; a < last && __va_avail(a); a += RISCV_PGSIZE) ; - if (a > first) - continue; - return a; + if (a >= last) + return a; } return 0; } @@ -157,6 +164,13 @@ static void flush_tlb() asm volatile("sfence.vm"); } +int __valid_user_range(uintptr_t vaddr, size_t len) +{ + if (vaddr + len < vaddr) + return 0; + return vaddr >= current.first_free_paddr && vaddr + len <= current.mmap_max; +} + static int __handle_page_fault(uintptr_t vaddr, int prot) { uintptr_t vpn = vaddr >> RISCV_PGSHIFT; @@ -168,7 +182,7 @@ static int __handle_page_fault(uintptr_t vaddr, int prot) return -1; else if (!(*pte & PTE_V)) { - kassert(vaddr < current.stack_top && vaddr >= current.user_min); + kassert(__valid_user_range(vaddr, 1)); uintptr_t ppn = vpn; vmr_t* v = (vmr_t*)*pte; @@ -225,8 +239,7 @@ uintptr_t __do_mmap(uintptr_t addr, size_t length, int prot, int flags, file_t* size_t npage = (length-1)/RISCV_PGSIZE+1; if (flags & MAP_FIXED) { - if ((addr & (RISCV_PGSIZE-1)) || addr < current.user_min || - addr + length > current.stack_top || addr + length < addr) + if ((addr & (RISCV_PGSIZE-1)) || !__valid_user_range(addr, length)) return (uintptr_t)-1; } else if ((addr = __vm_alloc(npage)) == 0) @@ -247,19 +260,19 @@ uintptr_t __do_mmap(uintptr_t addr, size_t length, int prot, int flags, file_t* *pte = (pte_t)v; } - if (f) file_incref(f); - if (!have_vm || (flags & MAP_POPULATE)) for (uintptr_t a = addr; a < addr + length; a += RISCV_PGSIZE) kassert(__handle_page_fault(a, prot) == 0); + if (current.brk_min != 0 && addr < current.brk_max) + current.brk_max = ROUNDUP(addr + length, RISCV_PGSIZE); + return addr; } int do_munmap(uintptr_t addr, size_t length) { - if ((addr & (RISCV_PGSIZE-1)) || addr < current.user_min || - addr + length > current.stack_top || addr + length < addr) + if ((addr & (RISCV_PGSIZE-1)) || !__valid_user_range(addr, length)) return -EINVAL; spinlock_lock(&vm_lock); @@ -280,8 +293,6 @@ uintptr_t do_mmap(uintptr_t addr, size_t length, int prot, int flags, int fd, of spinlock_lock(&vm_lock); addr = __do_mmap(addr, length, prot, flags, f, offset); - if (addr < current.brk_max) - current.brk_max = addr; spinlock_unlock(&vm_lock); if (f) file_decref(f); @@ -318,29 +329,34 @@ uintptr_t do_brk(size_t addr) return addr; } +uintptr_t __do_mremap(uintptr_t addr, size_t old_size, size_t new_size, int flags) +{ + for (size_t i = 0; i < MAX_VMR; i++) + { + if (vmrs[i].refcnt && addr == vmrs[i].addr && old_size == vmrs[i].length) + { + size_t old_npage = (vmrs[i].length-1)/RISCV_PGSIZE+1; + size_t new_npage = (new_size-1)/RISCV_PGSIZE+1; + if (new_size < old_size) + __do_munmap(addr + new_size, old_size - new_size); + else if (new_size > old_size) + __do_mmap(addr + old_size, new_size - old_size, vmrs[i].prot, 0, + vmrs[i].file, vmrs[i].offset + new_size - old_size); + __vmr_decref(&vmrs[i], old_npage - new_npage); + return addr; + } + } + return -1; +} + uintptr_t do_mremap(uintptr_t addr, size_t old_size, size_t new_size, int flags) { - uintptr_t res = -1; if (((addr | old_size | new_size) & (RISCV_PGSIZE-1)) || (flags & MREMAP_FIXED)) return -EINVAL; spinlock_lock(&vm_lock); - for (size_t i = 0; i < MAX_VMR; i++) - { - if (vmrs[i].refcnt && addr == vmrs[i].addr && old_size == vmrs[i].length) - { - size_t old_npage = (vmrs[i].length-1)/RISCV_PGSIZE+1; - size_t new_npage = (new_size-1)/RISCV_PGSIZE+1; - if (new_size < old_size) - __do_munmap(addr + new_size, old_size - new_size); - else if (new_size > old_size) - __do_mmap(addr + old_size, new_size - old_size, vmrs[i].prot, 0, - vmrs[i].file, vmrs[i].offset + new_size - old_size); - __vmr_decref(&vmrs[i], old_npage - new_npage); - res = addr; - } - } + uintptr_t res = __do_mremap(addr, old_size, new_size, flags); spinlock_unlock(&vm_lock); return res; @@ -385,14 +401,15 @@ uintptr_t do_mprotect(uintptr_t addr, size_t length, int prot) return res; } -static void __map_kernel_range(uintptr_t paddr, size_t len, int prot) +void __map_kernel_range(uintptr_t vaddr, uintptr_t paddr, size_t len, int prot) { + uintptr_t n = ROUNDUP(len, RISCV_PGSIZE) / RISCV_PGSIZE; pte_t perms = pte_create(0, prot, 0); - for (uintptr_t a = paddr; a < paddr + len; a += RISCV_PGSIZE) + for (uintptr_t a = vaddr, i = 0; i < n; i++, a += RISCV_PGSIZE) { pte_t* pte = __walk_create(a); kassert(pte); - *pte = a | perms; + *pte = (a - vaddr + paddr) | perms; } } @@ -401,71 +418,88 @@ void populate_mapping(const void* start, size_t size, int prot) uintptr_t a0 = ROUNDDOWN((uintptr_t)start, RISCV_PGSIZE); for (uintptr_t a = a0; a < (uintptr_t)start+size; a += RISCV_PGSIZE) { - atomic_t* atom = (atomic_t*)(a & -sizeof(atomic_t)); if (prot & PROT_WRITE) - atomic_add(atom, 0); + atomic_add((int*)a, 0); else - atomic_read(atom); + atomic_read((int*)a); } } -void vm_init() +static uintptr_t sbi_top_paddr() { extern char _end; - current.user_min = ROUNDUP((uintptr_t)&_end, RISCV_PGSIZE); - current.brk_min = current.user_min; - current.brk = 0; + return ROUNDUP((uintptr_t)&_end, RISCV_PGSIZE); +} - uint32_t mem_mb = *(volatile uint32_t*)0; +#define first_free_paddr() (sbi_top_paddr() + RISCV_PGSIZE /* boot stack */) - if (mem_mb == 0) - { - current.stack_bottom = 0; - current.stack_top = 0; - current.brk_max = 0; - current.mmap_max = 0; - } - else - { - uintptr_t max_addr = (uintptr_t)mem_mb << 20; - size_t mem_pages = max_addr >> RISCV_PGSHIFT; - const size_t min_free_pages = 2*RISCV_PGLEVELS; - const size_t min_stack_pages = 8; - const size_t max_stack_pages = 1024; - kassert(mem_pages > min_free_pages + min_stack_pages); - free_pages = MAX(mem_pages >> (RISCV_PGLEVEL_BITS-1), min_free_pages); - size_t stack_pages = CLAMP(mem_pages/32, min_stack_pages, max_stack_pages); - first_free_page = max_addr - free_pages * RISCV_PGSIZE; - - uintptr_t root_page_table_paddr = __page_alloc(); - kassert(root_page_table_paddr); - root_page_table = (pte_t*)root_page_table_paddr; - - __map_kernel_range(0, current.user_min, PROT_READ|PROT_WRITE|PROT_EXEC); - - int vm_field = sizeof(long) == 4 ? VM_SV32 : VM_SV43; - if (have_vm) - { -#if 0 - write_csr(sptbr, root_page_table_paddr); - set_csr(mstatus, vm_field << __builtin_ctz(MSTATUS_VM)); -#endif - have_vm = (clear_csr(mstatus, MSTATUS_VM) & MSTATUS_VM) != VM_MBARE; - } +void vm_init() +{ + current.first_free_paddr = first_free_paddr(); - size_t stack_size = RISCV_PGSIZE * stack_pages; - current.stack_top = MIN(first_free_page, 0x80000000); // for RV32 sanity - uintptr_t stack_bot = current.stack_top - stack_size; + size_t mem_pages = mem_size >> RISCV_PGSHIFT; + free_pages = MAX(8, mem_pages >> (RISCV_PGLEVEL_BITS-1)); + first_free_page = mem_size - free_pages * RISCV_PGSIZE; + current.mmap_max = current.brk_max = first_free_page; +} - if (have_vm) - { - __map_kernel_range(first_free_page, free_pages * RISCV_PGSIZE, PROT_READ|PROT_WRITE); - kassert(__do_mmap(stack_bot, stack_size, -1, MAP_FIXED|MAP_PRIVATE|MAP_ANONYMOUS, 0, 0) == stack_bot); - set_csr(mstatus, vm_field); - } +void supervisor_vm_init() +{ + uintptr_t highest_va = -current.first_free_paddr; + mem_size = MIN(mem_size, highest_va - current.first_user_vaddr) & -SUPERPAGE_SIZE; + + pte_t* sbi_pt = (pte_t*)(current.first_vaddr_after_user + current.bias); + memset(sbi_pt, 0, RISCV_PGSIZE); + pte_t* middle_pt = (void*)sbi_pt + RISCV_PGSIZE; +#if RISCV_PGLEVELS == 2 + root_page_table = middle_pt; +#elif RISCV_PGLEVELS == 3 + kassert(current.first_user_vaddr >= -(SUPERPAGE_SIZE << RISCV_PGLEVEL_BITS)); + root_page_table = (void*)middle_pt + RISCV_PGSIZE; + memset(root_page_table, 0, RISCV_PGSIZE); + root_page_table[(1<> l2_shift) & ((1 << RISCV_PGLEVEL_BITS)-1); + middle_pt[l2_idx] = paddr | PTE_V | PTE_G | PTE_SR | PTE_SW | PTE_SX; } + current.first_vaddr_after_user += (void*)root_page_table + RISCV_PGSIZE - (void*)sbi_pt; + + // map SBI at top of vaddr space + uintptr_t num_sbi_pages = sbi_top_paddr() / RISCV_PGSIZE; + for (uintptr_t i = 0; i < num_sbi_pages; i++) { + uintptr_t idx = (1 << RISCV_PGLEVEL_BITS) - num_sbi_pages + i; + sbi_pt[idx] = (i * RISCV_PGSIZE) | PTE_V | PTE_G | PTE_SR | PTE_SX; + } + pte_t* sbi_pte = middle_pt + ((1 << RISCV_PGLEVEL_BITS)-1); + kassert(!*sbi_pte); + *sbi_pte = (uintptr_t)sbi_pt | PTE_T | PTE_V; + + // disable our allocator + kassert(next_free_page == 0); + free_pages = 0; + + flush_tlb(); +} + +void pk_vm_init() +{ + __map_kernel_range(0, 0, current.first_free_paddr, PROT_READ|PROT_WRITE|PROT_EXEC); + __map_kernel_range(first_free_page, first_free_page, free_pages * RISCV_PGSIZE, PROT_READ|PROT_WRITE); + + extern char trap_entry; + write_csr(stvec, &trap_entry); + write_csr(sscratch, __page_alloc() + RISCV_PGSIZE); + + size_t stack_size = RISCV_PGSIZE * CLAMP(mem_size/(RISCV_PGSIZE*32), 1, 256); + current.stack_bottom = __do_mmap(0, stack_size, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_PRIVATE|MAP_ANONYMOUS, 0, 0); + kassert(current.stack_bottom != (uintptr_t)-1); + current.stack_top = current.stack_bottom + stack_size; + kassert(current.stack_top == current.mmap_max); } diff --git a/pk/vm.h b/pk/vm.h index d46abec..273d71c 100644 --- a/pk/vm.h +++ b/pk/vm.h @@ -7,6 +7,8 @@ #include #include +#define SUPERPAGE_SIZE ((uintptr_t)(RISCV_PGSIZE << RISCV_PGLEVEL_BITS)) + #define PROT_READ 1 #define PROT_WRITE 2 #define PROT_EXEC 4 @@ -17,9 +19,18 @@ #define MAP_POPULATE 0x8000 #define MREMAP_FIXED 0x2 +#define supervisor_paddr_valid(start, length) \ + ((uintptr_t)(start) >= current.first_user_vaddr + current.bias \ + && (uintptr_t)(start) + (length) < mem_size \ + && (uintptr_t)(start) + (length) >= (uintptr_t)(start)) + void vm_init(); +void supervisor_vm_init(); +void pk_vm_init(); int handle_page_fault(uintptr_t vaddr, int prot); void populate_mapping(const void* start, size_t size, int prot); +void __map_kernel_range(uintptr_t va, uintptr_t pa, size_t len, int prot); +int __valid_user_range(uintptr_t vaddr, size_t len); uintptr_t __do_mmap(uintptr_t addr, size_t length, int prot, int flags, file_t* file, off_t offset); uintptr_t do_mmap(uintptr_t addr, size_t length, int prot, int flags, int fd, off_t offset); int do_munmap(uintptr_t addr, size_t length); -- cgit v1.1