diff options
author | Andrew Waterman <waterman@cs.berkeley.edu> | 2016-02-05 18:14:42 -0800 |
---|---|---|
committer | Andrew Waterman <waterman@cs.berkeley.edu> | 2016-02-19 13:01:11 -0800 |
commit | ad7a60abeac14c47cfd8d96b6ca2cd07adb833ca (patch) | |
tree | 2cf73a67263ca320590b17db182aa4c887dd22c2 /pk | |
parent | 1d78c4a12e30ca6d51d6863e96661f80534cac74 (diff) | |
download | pk-ad7a60abeac14c47cfd8d96b6ca2cd07adb833ca.zip pk-ad7a60abeac14c47cfd8d96b6ca2cd07adb833ca.tar.gz pk-ad7a60abeac14c47cfd8d96b6ca2cd07adb833ca.tar.bz2 |
WIP on priv spec v1.9
Diffstat (limited to 'pk')
-rw-r--r-- | pk/bbl.c | 19 | ||||
-rw-r--r-- | pk/emulation.c | 466 | ||||
-rw-r--r-- | pk/encoding.h | 99 | ||||
-rw-r--r-- | pk/entry.S | 4 | ||||
-rw-r--r-- | pk/handlers.c | 2 | ||||
-rw-r--r-- | pk/init.c | 10 | ||||
-rw-r--r-- | pk/mentry.S | 126 | ||||
-rw-r--r-- | pk/minit.c | 35 | ||||
-rw-r--r-- | pk/mtrap.c | 65 | ||||
-rw-r--r-- | pk/mtrap.h | 124 | ||||
-rw-r--r-- | pk/pk.c | 2 | ||||
-rw-r--r-- | pk/pk.h | 1 |
12 files changed, 346 insertions, 607 deletions
@@ -2,13 +2,14 @@ #include "vm.h" #include "config.h" -volatile int elf_loaded; +static volatile int elf_loaded; static void enter_entry_point() { - write_csr(mepc, current.entry); - asm volatile("eret"); - __builtin_unreachable(); + prepare_supervisor_mode(); + write_csr(mepc, current.entry); + asm volatile("eret"); + __builtin_unreachable(); } void run_loaded_program(struct mainvars* args) @@ -16,13 +17,13 @@ void run_loaded_program(struct mainvars* args) if (!current.is_supervisor) panic("bbl can't run user binaries; try using pk instead"); - supervisor_vm_init(); + supervisor_vm_init(); #ifdef PK_ENABLE_LOGO - print_logo(); + print_logo(); #endif - mb(); - elf_loaded = 1; - enter_entry_point(); + mb(); + elf_loaded = 1; + enter_entry_point(); } void boot_other_hart() diff --git a/pk/emulation.c b/pk/emulation.c index 7e8000e..bd5b6a8 100644 --- a/pk/emulation.c +++ b/pk/emulation.c @@ -2,175 +2,112 @@ #include "softfloat.h" #include <limits.h> -DECLARE_EMULATION_FUNC(truly_illegal_insn) +void redirect_trap(uintptr_t epc, uintptr_t mstatus) { - return -1; + write_csr(sepc, epc); + write_csr(scause, read_csr(mcause)); + write_csr(mepc, read_csr(stvec)); + + uintptr_t prev_priv = EXTRACT_FIELD(mstatus, MSTATUS_MPP); + uintptr_t prev_ie = EXTRACT_FIELD(mstatus, MSTATUS_MPIE); + kassert(prev_priv <= PRV_S); + mstatus = INSERT_FIELD(mstatus, MSTATUS_SPP, prev_priv); + mstatus = INSERT_FIELD(mstatus, MSTATUS_SPIE, prev_ie); + mstatus = INSERT_FIELD(mstatus, MSTATUS_MPP, PRV_S); + mstatus = INSERT_FIELD(mstatus, MSTATUS_MPIE, 0); + write_csr(mstatus, mstatus); + leave(); +} + +void __attribute__((noinline)) truly_illegal_insn(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn) +{ + redirect_trap(mepc, mstatus); } -uintptr_t misaligned_load_trap(uintptr_t mcause, uintptr_t* regs) +void misaligned_load_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc) { uintptr_t mstatus = read_csr(mstatus); - uintptr_t mepc = read_csr(mepc); - insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); - if (fetch.error) - return -1; - - uintptr_t val, res, tmp; - uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_I(fetch.insn); - - #define DO_LOAD(type_lo, type_hi, insn_lo, insn_hi) ({ \ - type_lo val_lo; \ - type_hi val_hi; \ - uintptr_t addr_lo = addr & -sizeof(type_hi); \ - uintptr_t addr_hi = addr_lo + sizeof(type_hi); \ - uintptr_t masked_addr = sizeof(type_hi) < 4 ? addr % sizeof(type_hi) : addr; \ - res = unpriv_mem_access(mstatus, mepc, \ - insn_lo " %[val_lo], (%[addr_lo]);" \ - insn_hi " %[val_hi], (%[addr_hi])", \ - val_lo, val_hi, addr_lo, addr_hi); \ - val_lo >>= masked_addr * 8; \ - val_hi <<= (sizeof(type_hi) - masked_addr) * 8; \ - val = (type_hi)(val_lo | val_hi); \ - }) - - if ((fetch.insn & MASK_LW) == MATCH_LW) - DO_LOAD(uint32_t, int32_t, "lw", "lw"); -#ifdef __riscv64 - else if ((fetch.insn & MASK_LD) == MATCH_LD) - DO_LOAD(uint64_t, uint64_t, "ld", "ld"); - else if ((fetch.insn & MASK_LWU) == MATCH_LWU) - DO_LOAD(uint32_t, uint32_t, "lwu", "lwu"); -#endif - else if ((fetch.insn & MASK_FLD) == MATCH_FLD) { + insn_t insn = get_insn(mepc); + + int shift = 0, fp = 0, len; + if ((insn & MASK_LW) == MATCH_LW) + len = 4, shift = 8*(sizeof(uintptr_t) - len); #ifdef __riscv64 - DO_LOAD(uint64_t, uint64_t, "ld", "ld"); - if (res == 0) { - SET_F64_RD(fetch.insn, regs, val); - goto success; - } -#else - DO_LOAD(uint32_t, int32_t, "lw", "lw"); - if (res == 0) { - uint64_t double_val = val; - addr += 4; - DO_LOAD(uint32_t, int32_t, "lw", "lw"); - double_val |= (uint64_t)val << 32; - if (res == 0) { - SET_F64_RD(fetch.insn, regs, val); - goto success; - } - } + else if ((insn & MASK_LD) == MATCH_LD) + len = 8, shift = 8*(sizeof(uintptr_t) - len); + else if ((insn & MASK_LWU) == MATCH_LWU) + fp = 0, len = 4, shift = 0; #endif - } else if ((fetch.insn & MASK_FLW) == MATCH_FLW) { - DO_LOAD(uint32_t, uint32_t, "lw", "lw"); - if (res == 0) { - SET_F32_RD(fetch.insn, regs, val); - goto success; - } - } else if ((fetch.insn & MASK_LH) == MATCH_LH) { - // equivalent to DO_LOAD(uint32_t, int16_t, "lhu", "lh") - res = unpriv_mem_access(mstatus, mepc, - "lbu %[val], 0(%[addr]);" - "lb %[tmp], 1(%[addr])", - val, tmp, addr, mstatus/*X*/); - val |= tmp << 8; - } else if ((fetch.insn & MASK_LHU) == MATCH_LHU) { - // equivalent to DO_LOAD(uint32_t, uint16_t, "lhu", "lhu") - res = unpriv_mem_access(mstatus, mepc, - "lbu %[val], 0(%[addr]);" - "lbu %[tmp], 1(%[addr])", - val, tmp, addr, mstatus/*X*/); - val |= tmp << 8; - } else { - return -1; - } - - if (res) { - restore_mstatus(mstatus, mepc); - return -1; - } + else if ((insn & MASK_FLD) == MATCH_FLD) + fp = 1, len = 8; + else if ((insn & MASK_FLW) == MATCH_FLW) + fp = 1, len = 4; + else if ((insn & MASK_LH) == MATCH_LH) + len = 2, shift = 8*(sizeof(uintptr_t) - len); + else if ((insn & MASK_LHU) == MATCH_LHU) + len = 2; + else + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); - SET_RD(fetch.insn, regs, val); + uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn); + uintptr_t val = 0, tmp, tmp2; + unpriv_mem_access("add %[tmp2], %[addr], %[len];" + "1: slli %[val], %[val], 8;" + "lbu %[tmp], -1(%[tmp2]);" + "addi %[tmp2], %[tmp2], -1;" + "or %[val], %[val], %[tmp];" + "bne %[addr], %[tmp2], 1b;", + val, tmp, tmp2, addr, len); + + if (shift) + val = (intptr_t)val << shift >> shift; + + if (!fp) + SET_RD(insn, regs, val); + else if (len == 8) + SET_F64_RD(insn, regs, val); + else + SET_F32_RD(insn, regs, val); -success: write_csr(mepc, mepc + 4); - return 0; } -uintptr_t misaligned_store_trap(uintptr_t mcause, uintptr_t* regs) +void misaligned_store_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc) { uintptr_t mstatus = read_csr(mstatus); - uintptr_t mepc = read_csr(mepc); - insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); - if (fetch.error) - return -1; - - uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_S(fetch.insn); - uintptr_t val = GET_RS2(fetch.insn, regs), error; - - if ((fetch.insn & MASK_SW) == MATCH_SW) { -SW: - error = unpriv_mem_access(mstatus, mepc, - "sb %[val], 0(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 1(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 2(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 3(%[addr]);", - unused1, unused2, val, addr); -#ifdef __riscv64 - } else if ((fetch.insn & MASK_SD) == MATCH_SD) { -SD: - error = unpriv_mem_access(mstatus, mepc, - "sb %[val], 0(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 1(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 2(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 3(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 4(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 5(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 6(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 7(%[addr]);", - unused1, unused2, val, addr); -#endif - } else if ((fetch.insn & MASK_SH) == MATCH_SH) { - error = unpriv_mem_access(mstatus, mepc, - "sb %[val], 0(%[addr]);" - "srl %[val], %[val], 8; sb %[val], 1(%[addr]);", - unused1, unused2, val, addr); - } else if ((fetch.insn & MASK_FSD) == MATCH_FSD) { -#ifdef __riscv64 - val = GET_F64_RS2(fetch.insn, regs); - goto SD; -#else - uint64_t double_val = GET_F64_RS2(fetch.insn, regs); - uint32_t val_lo = double_val, val_hi = double_val >> 32; - error = unpriv_mem_access(mstatus, mepc, - "sb %[val_lo], 0(%[addr]);" - "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 1(%[addr]);" - "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 2(%[addr]);" - "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 3(%[addr]);" - "sb %[val_hi], 4(%[addr]);" - "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 5(%[addr]);" - "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 6(%[addr]);" - "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 7(%[addr]);", - unused1, unused2, val_lo, val_hi, addr); -#endif - } else if ((fetch.insn & MASK_FSW) == MATCH_FSW) { - val = GET_F32_RS2(fetch.insn, regs); - goto SW; - } else - return -1; - - if (error) { - restore_mstatus(mstatus, mepc); - return -1; - } + insn_t insn = get_insn(mepc); + + uintptr_t val = GET_RS2(insn, regs), error; + int len; + if ((insn & MASK_SW) == MATCH_SW) + len = 4; + else if ((insn & MASK_SD) == MATCH_SD) + len = 8; + else if ((insn & MASK_FSD) == MATCH_FSD) + len = 8, val = GET_F64_RS2(insn, regs); + else if ((insn & MASK_FSW) == MATCH_FSW) + len = 4, val = GET_F32_RS2(insn, regs); + else if ((insn & MASK_SH) == MATCH_SH) + len = 2; + else + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); + + uintptr_t addr = GET_RS1(insn, regs) + IMM_S(insn); + uintptr_t tmp, tmp2, addr_end = addr + len; + unpriv_mem_access("mv %[tmp], %[val];" + "mv %[tmp2], %[addr];" + "1: sb %[tmp], 0(%[tmp2]);" + "srli %[tmp], %[tmp], 8;" + "addi %[tmp2], %[tmp2], 1;" + "bne %[tmp2], %[addr_end], 1b", + tmp, tmp2, unused1, val, addr, addr_end); write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_float_load) { - uintptr_t val_lo, val_hi, error; + uintptr_t val_lo, val_hi; uint64_t val; uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn); @@ -178,51 +115,37 @@ DECLARE_EMULATION_FUNC(emulate_float_load) { case MATCH_FLW & MASK_FUNCT3: if (addr % 4 != 0) - return misaligned_load_trap(mcause, regs); - - error = unpriv_mem_access(mstatus, mepc, - "lw %[val_lo], (%[addr])", - val_lo, val_hi/*X*/, addr, mstatus/*X*/); + return misaligned_load_trap(regs, mcause, mepc); - if (error == 0) { - SET_F32_RD(insn, regs, val_lo); - goto success; - } + unpriv_mem_access("lw %[val_lo], (%[addr])", + val_lo, unused1, unused2, addr, mstatus/*X*/); + SET_F32_RD(insn, regs, val_lo); break; case MATCH_FLD & MASK_FUNCT3: if (addr % sizeof(uintptr_t) != 0) - return misaligned_load_trap(mcause, regs); + return misaligned_load_trap(regs, mcause, mepc); + #ifdef __riscv64 - error = unpriv_mem_access(mstatus, mepc, - "ld %[val], (%[addr])", - val, val_hi/*X*/, addr, mstatus/*X*/); + unpriv_mem_access("ld %[val], (%[addr])", + val, val_hi/*X*/, unused1, addr, mstatus/*X*/); #else - error = unpriv_mem_access(mstatus, mepc, - "lw %[val_lo], (%[addr]);" - "lw %[val_hi], 4(%[addr])", - val_lo, val_hi, addr, mstatus/*X*/); + unpriv_mem_access("lw %[val_lo], (%[addr]);" + "lw %[val_hi], 4(%[addr])", + val_lo, val_hi, unused1, addr, mstatus/*X*/); val = val_lo | ((uint64_t)val_hi << 32); #endif - - if (error == 0) { - SET_F64_RD(insn, regs, val); - goto success; - } + SET_F64_RD(insn, regs, val); break; - } - restore_mstatus(mstatus, mepc); - return -1; - -success: - write_csr(mepc, mepc + 4); - return 0; + default: + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); + } } DECLARE_EMULATION_FUNC(emulate_float_store) { - uintptr_t val_lo, val_hi, error; + uintptr_t val_lo, val_hi; uint64_t val; uintptr_t addr = GET_RS1(insn, regs) + IMM_S(insn); @@ -230,44 +153,33 @@ DECLARE_EMULATION_FUNC(emulate_float_store) { case MATCH_FSW & MASK_FUNCT3: if (addr % 4 != 0) - return misaligned_store_trap(mcause, regs); + return misaligned_store_trap(regs, mcause, mepc); val_lo = GET_F32_RS2(insn, regs); - error = unpriv_mem_access(mstatus, mepc, - "sw %[val_lo], (%[addr])", - unused1, unused2, val_lo, addr); + unpriv_mem_access("sw %[val_lo], (%[addr])", + unused1, unused2, unused3, val_lo, addr); break; case MATCH_FSD & MASK_FUNCT3: if (addr % sizeof(uintptr_t) != 0) - return misaligned_store_trap(mcause, regs); + return misaligned_store_trap(regs, mcause, mepc); val = GET_F64_RS2(insn, regs); #ifdef __riscv64 - error = unpriv_mem_access(mstatus, mepc, - "sd %[val], (%[addr])", - unused1, unused2, val, addr); + unpriv_mem_access("sd %[val], (%[addr])", + unused1, unused2, unused3, val, addr); #else val_lo = val; val_hi = val >> 32; - error = unpriv_mem_access(mstatus, mepc, - "sw %[val_lo], (%[addr]);" - "sw %[val_hi], 4(%[addr])", - unused1, unused2, val_lo, val_hi, addr); + unpriv_mem_access("sw %[val_lo], (%[addr]);" + "sw %[val_hi], 4(%[addr])", + unused1, unused2, unused3, val_lo, val_hi, addr); #endif break; default: - error = 1; - } - - if (error) { - restore_mstatus(mstatus, mepc); - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } #ifdef __riscv64 @@ -298,17 +210,15 @@ DECLARE_EMULATION_FUNC(emulate_mul_div) else if ((insn & MASK_MULHSU) == MATCH_MULHSU) val = ((double_int)(intptr_t)rs1 * (double_int)rs2) >> (8 * sizeof(rs1)); else - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); SET_RD(insn, regs, val); - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_mul_div32) { #ifndef __riscv64 - return truly_illegal_insn(mcause, regs, insn, mstatus, mepc); + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); #endif uint32_t rs1 = GET_RS1(insn, regs), rs2 = GET_RS2(insn, regs); @@ -326,14 +236,12 @@ DECLARE_EMULATION_FUNC(emulate_mul_div32) else if ((insn & MASK_REMU) == MATCH_REMU) val = rs1 % rs2; else - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); SET_RD(insn, regs, val); - write_csr(mepc, mepc + 4); - return 0; } -static inline int emulate_read_csr(int num, uintptr_t* result, uintptr_t mstatus) +static inline int emulate_read_csr(int num, uintptr_t mstatus, uintptr_t* result) { switch (num) { @@ -353,15 +261,14 @@ static inline int emulate_read_csr(int num, uintptr_t* result, uintptr_t mstatus return -1; } -static inline int emulate_write_csr(int num, uintptr_t value, uintptr_t mstatus) +static inline void emulate_write_csr(int num, uintptr_t value, uintptr_t mstatus) { switch (num) { - case CSR_FRM: SET_FRM(value); return 0; - case CSR_FFLAGS: SET_FFLAGS(value); return 0; - case CSR_FCSR: SET_FCSR(value); return 0; + case CSR_FRM: SET_FRM(value); return; + case CSR_FFLAGS: SET_FFLAGS(value); return; + case CSR_FCSR: SET_FCSR(value); return; } - return -1; } DECLARE_EMULATION_FUNC(emulate_system) @@ -371,28 +278,26 @@ DECLARE_EMULATION_FUNC(emulate_system) int csr_num = (uint32_t)insn >> 20; uintptr_t csr_val, new_csr_val; - if (emulate_read_csr(csr_num, &csr_val, mstatus) != 0) - return -1; + if (emulate_read_csr(csr_num, mstatus, &csr_val)) + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); int do_write = rs1_num; switch (GET_RM(insn)) { - case 0: return -1; + case 0: return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); case 1: new_csr_val = rs1_val; do_write = 1; break; case 2: new_csr_val = csr_val | rs1_val; break; case 3: new_csr_val = csr_val & ~rs1_val; break; - case 4: return -1; + case 4: return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); case 5: new_csr_val = rs1_num; do_write = 1; break; case 6: new_csr_val = csr_val | rs1_num; break; case 7: new_csr_val = csr_val & ~rs1_num; break; } - if (do_write && emulate_write_csr(csr_num, new_csr_val, mstatus) != 0) - return -1; + if (do_write) + emulate_write_csr(csr_num, new_csr_val, mstatus); SET_RD(insn, regs, csr_val); - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fp) @@ -435,17 +340,17 @@ DECLARE_EMULATION_FUNC(emulate_fp) // if FPU is disabled, punt back to the OS if (unlikely((mstatus & MSTATUS_FS) == 0)) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); extern int32_t fp_emulation_table[]; int32_t* pf = (void*)fp_emulation_table + ((insn >> 25) & 0x7c); emulation_func f = (emulation_func)(uintptr_t)*pf; SETUP_STATIC_ROUNDING(insn); - return f(mcause, regs, insn, mstatus, mepc); + return f(regs, mcause, mepc, mstatus, insn); } -uintptr_t emulate_any_fadd(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc, uintptr_t neg_b) +void emulate_any_fadd(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn, int32_t neg_b) { if (GET_PRECISION(insn) == PRECISION_S) { uint32_t rs1 = GET_F32_RS1(insn, regs); @@ -456,21 +361,18 @@ uintptr_t emulate_any_fadd(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintp uint64_t rs2 = GET_F64_RS2(insn, regs) ^ ((uint64_t)neg_b << 32); SET_F64_RD(insn, regs, f64_add(rs1, rs2)); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fadd) { - return emulate_any_fadd(mcause, regs, insn, mstatus, mepc, 0); + return emulate_any_fadd(regs, mcause, mepc, mstatus, insn, 0); } DECLARE_EMULATION_FUNC(emulate_fsub) { - return emulate_any_fadd(mcause, regs, insn, mstatus, mepc, INT32_MIN); + return emulate_any_fadd(regs, mcause, mepc, mstatus, insn, INT32_MIN); } DECLARE_EMULATION_FUNC(emulate_fmul) @@ -484,11 +386,8 @@ DECLARE_EMULATION_FUNC(emulate_fmul) uint64_t rs2 = GET_F64_RS2(insn, regs); SET_F64_RD(insn, regs, f64_mul(rs1, rs2)); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fdiv) @@ -502,35 +401,29 @@ DECLARE_EMULATION_FUNC(emulate_fdiv) uint64_t rs2 = GET_F64_RS2(insn, regs); SET_F64_RD(insn, regs, f64_div(rs1, rs2)); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fsqrt) { if ((insn >> 20) & 0x1f) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); if (GET_PRECISION(insn) == PRECISION_S) { SET_F32_RD(insn, regs, f32_sqrt(GET_F32_RS1(insn, regs))); } else if (GET_PRECISION(insn) == PRECISION_D) { SET_F64_RD(insn, regs, f64_sqrt(GET_F64_RS1(insn, regs))); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fsgnj) { int rm = GET_RM(insn); if (rm >= 3) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); #define DO_FSGNJ(rs1, rs2, rm) ({ \ typeof(rs1) rs1_sign = (rs1) >> (8*sizeof(rs1)-1); \ @@ -548,18 +441,15 @@ DECLARE_EMULATION_FUNC(emulate_fsgnj) uint64_t rs2 = GET_F64_RS2(insn, regs); SET_F64_RD(insn, regs, DO_FSGNJ(rs1, rs2, rm)); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fmin) { int rm = GET_RM(insn); if (rm >= 2) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); if (GET_PRECISION(insn) == PRECISION_S) { uint32_t rs1 = GET_F32_RS1(insn, regs); @@ -576,11 +466,8 @@ DECLARE_EMULATION_FUNC(emulate_fmin) int use_rs1 = f64_lt_quiet(arg1, arg2) || isNaNF64UI(rs2); SET_F64_RD(insn, regs, use_rs1 ? rs1 : rs2); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fcvt_ff) @@ -588,24 +475,21 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_ff) int rs2_num = (insn >> 20) & 0x1f; if (GET_PRECISION(insn) == PRECISION_S) { if (rs2_num != 1) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); SET_F32_RD(insn, regs, f64_to_f32(GET_F64_RS1(insn, regs))); } else if (GET_PRECISION(insn) == PRECISION_D) { if (rs2_num != 0) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); SET_F64_RD(insn, regs, f32_to_f64(GET_F32_RS1(insn, regs))); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fcvt_fi) { if (GET_PRECISION(insn) != PRECISION_S && GET_PRECISION(insn) != PRECISION_D) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); int negative = 0; uint64_t uint_val = GET_RS1(insn, regs); @@ -627,7 +511,7 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_fi) break; #endif default: - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } uint64_t float64 = ui64_to_f64(uint_val); @@ -638,9 +522,6 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_fi) SET_F32_RD(insn, regs, f64_to_f32(float64)); else SET_F64_RD(insn, regs, float64); - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fcvt_if) @@ -648,10 +529,10 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if) int rs2_num = (insn >> 20) & 0x1f; #ifdef __riscv64 if (rs2_num >= 4) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); #else if (rs2_num >= 2) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); #endif int64_t float64; @@ -660,7 +541,7 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if) else if (GET_PRECISION(insn) == PRECISION_D) float64 = GET_F64_RS1(insn, regs); else - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); int negative = 0; if (float64 < 0) { @@ -716,16 +597,13 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if) SET_FS_DIRTY(); SET_RD(insn, regs, result); - - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fcmp) { int rm = GET_RM(insn); if (rm >= 3) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); uintptr_t result; if (GET_PRECISION(insn) == PRECISION_S) { @@ -745,11 +623,9 @@ DECLARE_EMULATION_FUNC(emulate_fcmp) result = f64_lt(rs1, rs2); goto success; } - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); success: SET_RD(insn, regs, result); - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fmv_if) @@ -762,11 +638,9 @@ DECLARE_EMULATION_FUNC(emulate_fmv_if) result = GET_F64_RS1(insn, regs); #endif else - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); SET_RD(insn, regs, result); - write_csr(mepc, mepc + 4); - return 0; } DECLARE_EMULATION_FUNC(emulate_fmv_fi) @@ -778,18 +652,16 @@ DECLARE_EMULATION_FUNC(emulate_fmv_fi) else if ((insn & MASK_FMV_D_X) == MATCH_FMV_D_X) SET_F64_RD(insn, regs, rs1); else - return -1; - - write_csr(mepc, mepc + 4); - return 0; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } -uintptr_t emulate_any_fmadd(int op, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc) +DECLARE_EMULATION_FUNC(emulate_fmadd) { // if FPU is disabled, punt back to the OS if (unlikely((mstatus & MSTATUS_FS) == 0)) - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); + int op = (insn >> 2) & 3; SETUP_STATIC_ROUNDING(insn); if (GET_PRECISION(insn) == PRECISION_S) { uint32_t rs1 = GET_F32_RS1(insn, regs); @@ -802,32 +674,6 @@ uintptr_t emulate_any_fmadd(int op, uintptr_t* regs, insn_t insn, uintptr_t msta uint64_t rs3 = GET_F64_RS3(insn, regs); SET_F64_RD(insn, regs, softfloat_mulAddF64(op, rs1, rs2, rs3)); } else { - return -1; + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); } - write_csr(mepc, mepc + 4); - return 0; -} - -DECLARE_EMULATION_FUNC(emulate_fmadd) -{ - int op = 0; - return emulate_any_fmadd(op, regs, insn, mstatus, mepc); -} - -DECLARE_EMULATION_FUNC(emulate_fmsub) -{ - int op = softfloat_mulAdd_subC; - return emulate_any_fmadd(op, regs, insn, mstatus, mepc); -} - -DECLARE_EMULATION_FUNC(emulate_fnmadd) -{ - int op = softfloat_mulAdd_subC | softfloat_mulAdd_subProd; - return emulate_any_fmadd(op, regs, insn, mstatus, mepc); -} - -DECLARE_EMULATION_FUNC(emulate_fnmsub) -{ - int op = softfloat_mulAdd_subProd; - return emulate_any_fmadd(op, regs, insn, mstatus, mepc); } diff --git a/pk/encoding.h b/pk/encoding.h index e9a495f..df04845 100644 --- a/pk/encoding.h +++ b/pk/encoding.h @@ -3,37 +3,41 @@ #ifndef RISCV_CSR_ENCODING_H #define RISCV_CSR_ENCODING_H -#define MSTATUS_IE 0x00000001 -#define MSTATUS_PRV 0x00000006 -#define MSTATUS_IE1 0x00000008 -#define MSTATUS_PRV1 0x00000030 -#define MSTATUS_IE2 0x00000040 -#define MSTATUS_PRV2 0x00000180 -#define MSTATUS_IE3 0x00000200 -#define MSTATUS_PRV3 0x00000C00 -#define MSTATUS_FS 0x00003000 -#define MSTATUS_XS 0x0000C000 -#define MSTATUS_MPRV 0x00010000 -#define MSTATUS_VM 0x003E0000 +#define MSTATUS_UIE 0x00000001 +#define MSTATUS_SIE 0x00000002 +#define MSTATUS_HIE 0x00000004 +#define MSTATUS_MIE 0x00000008 +#define MSTATUS_UPIE 0x00000010 +#define MSTATUS_SPIE 0x00000020 +#define MSTATUS_HPIE 0x00000040 +#define MSTATUS_MPIE 0x00000080 +#define MSTATUS_SPP 0x00000100 +#define MSTATUS_HPP 0x00000600 +#define MSTATUS_MPP 0x00001800 +#define MSTATUS_FS 0x00006000 +#define MSTATUS_XS 0x00018000 +#define MSTATUS_MPRV 0x00020000 +#define MSTATUS_VM 0x007C0000 #define MSTATUS32_SD 0x80000000 #define MSTATUS64_SD 0x8000000000000000 -#define SSTATUS_IE 0x00000001 -#define SSTATUS_PIE 0x00000008 -#define SSTATUS_PS 0x00000010 -#define SSTATUS_FS 0x00003000 -#define SSTATUS_XS 0x0000C000 -#define SSTATUS_MPRV 0x00010000 -#define SSTATUS_TIE 0x01000000 +#define SSTATUS_UIE 0x00000001 +#define SSTATUS_SIE 0x00000002 +#define SSTATUS_UPIE 0x00000010 +#define SSTATUS_SPIE 0x00000020 +#define SSTATUS_SPP 0x00000100 +#define SSTATUS_FS 0x00006000 +#define SSTATUS_XS 0x00018000 +#define SSTATUS_VM 0x007C0000 #define SSTATUS32_SD 0x80000000 #define SSTATUS64_SD 0x8000000000000000 -#define MIP_SSIP 0x00000002 -#define MIP_HSIP 0x00000004 -#define MIP_MSIP 0x00000008 -#define MIP_STIP 0x00000020 -#define MIP_HTIP 0x00000040 -#define MIP_MTIP 0x00000080 +#define MIP_SSIP (1 << IRQ_S_SOFT) +#define MIP_HSIP (1 << IRQ_H_SOFT) +#define MIP_MSIP (1 << IRQ_M_SOFT) +#define MIP_STIP (1 << IRQ_S_TIMER) +#define MIP_HTIP (1 << IRQ_H_TIMER) +#define MIP_MTIP (1 << IRQ_M_TIMER) #define SIP_SSIP MIP_SSIP #define SIP_STIP MIP_STIP @@ -50,14 +54,14 @@ #define VM_SV39 9 #define VM_SV48 10 -#define UA_RV32 0 -#define UA_RV64 4 -#define UA_RV128 8 - -#define IRQ_SOFT 0 -#define IRQ_TIMER 1 -#define IRQ_HOST 2 -#define IRQ_COP 3 +#define IRQ_S_SOFT 1 +#define IRQ_H_SOFT 2 +#define IRQ_M_SOFT 3 +#define IRQ_S_TIMER 5 +#define IRQ_H_TIMER 6 +#define IRQ_M_TIMER 7 +#define IRQ_COP 8 +#define IRQ_HOST 9 #define IMPL_ROCKET 1 @@ -335,18 +339,12 @@ #define MASK_SCALL 0xffffffff #define MATCH_SBREAK 0x100073 #define MASK_SBREAK 0xffffffff -#define MATCH_SRET 0x10000073 +#define MATCH_SRET 0x10200073 #define MASK_SRET 0xffffffff -#define MATCH_SFENCE_VM 0x10100073 +#define MATCH_SFENCE_VM 0x10400073 #define MASK_SFENCE_VM 0xfff07fff -#define MATCH_WFI 0x10200073 +#define MATCH_WFI 0x10500073 #define MASK_WFI 0xffffffff -#define MATCH_MRTH 0x30600073 -#define MASK_MRTH 0xffffffff -#define MATCH_MRTS 0x30500073 -#define MASK_MRTS 0xffffffff -#define MATCH_HRTS 0x20500073 -#define MASK_HRTS 0xffffffff #define MATCH_CSRRW 0x1073 #define MASK_CSRRW 0x707f #define MATCH_CSRRS 0x2073 @@ -643,6 +641,8 @@ #define CSR_SIE 0x104 #define CSR_SSCRATCH 0x140 #define CSR_SEPC 0x141 +#define CSR_SCAUSE 0x142 +#define CSR_SBADADDR 0x143 #define CSR_SIP 0x144 #define CSR_SPTBR 0x180 #define CSR_SASID 0x181 @@ -650,12 +650,11 @@ #define CSR_TIMEW 0x901 #define CSR_INSTRETW 0x902 #define CSR_STIME 0xd01 -#define CSR_SCAUSE 0xd42 -#define CSR_SBADADDR 0xd43 #define CSR_STIMEW 0xa01 #define CSR_MSTATUS 0x300 #define CSR_MTVEC 0x301 -#define CSR_MTDELEG 0x302 +#define CSR_MEDELEG 0x302 +#define CSR_MIDELEG 0x303 #define CSR_MIE 0x304 #define CSR_MTIMECMP 0x321 #define CSR_MSCRATCH 0x340 @@ -787,9 +786,6 @@ DECLARE_INSN(sbreak, MATCH_SBREAK, MASK_SBREAK) DECLARE_INSN(sret, MATCH_SRET, MASK_SRET) DECLARE_INSN(sfence_vm, MATCH_SFENCE_VM, MASK_SFENCE_VM) DECLARE_INSN(wfi, MATCH_WFI, MASK_WFI) -DECLARE_INSN(mrth, MATCH_MRTH, MASK_MRTH) -DECLARE_INSN(mrts, MATCH_MRTS, MASK_MRTS) -DECLARE_INSN(hrts, MATCH_HRTS, MASK_HRTS) DECLARE_INSN(csrrw, MATCH_CSRRW, MASK_CSRRW) DECLARE_INSN(csrrs, MATCH_CSRRS, MASK_CSRRS) DECLARE_INSN(csrrc, MATCH_CSRRC, MASK_CSRRC) @@ -954,6 +950,8 @@ DECLARE_CSR(stvec, CSR_STVEC) DECLARE_CSR(sie, CSR_SIE) DECLARE_CSR(sscratch, CSR_SSCRATCH) DECLARE_CSR(sepc, CSR_SEPC) +DECLARE_CSR(scause, CSR_SCAUSE) +DECLARE_CSR(sbadaddr, CSR_SBADADDR) DECLARE_CSR(sip, CSR_SIP) DECLARE_CSR(sptbr, CSR_SPTBR) DECLARE_CSR(sasid, CSR_SASID) @@ -961,12 +959,11 @@ DECLARE_CSR(cyclew, CSR_CYCLEW) DECLARE_CSR(timew, CSR_TIMEW) DECLARE_CSR(instretw, CSR_INSTRETW) DECLARE_CSR(stime, CSR_STIME) -DECLARE_CSR(scause, CSR_SCAUSE) -DECLARE_CSR(sbadaddr, CSR_SBADADDR) DECLARE_CSR(stimew, CSR_STIMEW) DECLARE_CSR(mstatus, CSR_MSTATUS) DECLARE_CSR(mtvec, CSR_MTVEC) -DECLARE_CSR(mtdeleg, CSR_MTDELEG) +DECLARE_CSR(medeleg, CSR_MEDELEG) +DECLARE_CSR(mideleg, CSR_MIDELEG) DECLARE_CSR(mie, CSR_MIE) DECLARE_CSR(mtimecmp, CSR_MTIMECMP) DECLARE_CSR(mscratch, CSR_MSCRATCH) @@ -66,8 +66,8 @@ trap_entry: jal handle_trap mv a0,sp - # don't restore sstatus if trap came from kernel - andi s0,s0,SSTATUS_PS + # don't restore sscratch if trap came from kernel + andi s0,s0,SSTATUS_SPP bnez s0,start_user addi sp,sp,320 csrw sscratch,sp diff --git a/pk/handlers.c b/pk/handlers.c index 881cc17..34e39fe 100644 --- a/pk/handlers.c +++ b/pk/handlers.c @@ -47,7 +47,7 @@ void handle_misaligned_store(trapframe_t* tf) static void segfault(trapframe_t* tf, uintptr_t addr, const char* type) { dump_tf(tf); - const char* who = (tf->status & MSTATUS_PRV1) ? "Kernel" : "User"; + const char* who = (tf->status & SSTATUS_SPP) ? "Kernel" : "User"; panic("%s %s segfault @ %p", who, type, addr); } @@ -19,7 +19,7 @@ char* uarch_counter_names[NUM_COUNTERS]; void init_tf(trapframe_t* tf, long pc, long sp) { memset(tf, 0, sizeof(*tf)); - tf->status = read_csr(sstatus); + tf->status = (read_csr(sstatus) &~ SSTATUS_SPP &~ SSTATUS_SIE) | SSTATUS_SPIE; tf->gpr[2] = sp; tf->epc = pc; } @@ -68,3 +68,11 @@ void boot_loader(struct mainvars* args) run_loaded_program(args); } + +void prepare_supervisor_mode() +{ + uintptr_t mstatus = read_csr(mstatus); + mstatus = INSERT_FIELD(mstatus, MSTATUS_MPP, PRV_S); + mstatus = INSERT_FIELD(mstatus, MSTATUS_MPIE, 0); + write_csr(mstatus, mstatus); +} diff --git a/pk/mentry.S b/pk/mentry.S index 7f0fc61..baf35ec 100644 --- a/pk/mentry.S +++ b/pk/mentry.S @@ -24,30 +24,6 @@ trap_table: .word bad_trap .word bad_trap -#define HANDLE_USER_TRAP_IN_MACHINE_MODE 0 \ - | (0 << (31- 0)) /* IF misaligned */ \ - | (0 << (31- 1)) /* IF fault */ \ - | (1 << (31- 2)) /* illegal instruction */ \ - | (0 << (31- 3)) /* breakpoint */ \ - | (1 << (31- 4)) /* load misaligned */ \ - | (0 << (31- 5)) /* load fault */ \ - | (1 << (31- 6)) /* store misaligned */ \ - | (0 << (31- 7)) /* store fault */ \ - | (0 << (31- 8)) /* user environment call */ \ - | (0 << (31- 9)) /* super environment call */ \ - -#define HANDLE_SUPERVISOR_TRAP_IN_MACHINE_MODE 0 \ - | (0 << (31- 0)) /* IF misaligned */ \ - | (0 << (31- 1)) /* IF fault */ \ - | (1 << (31- 2)) /* illegal instruction */ \ - | (0 << (31- 3)) /* breakpoint */ \ - | (1 << (31- 4)) /* load misaligned */ \ - | (0 << (31- 5)) /* load fault */ \ - | (1 << (31- 6)) /* store misaligned */ \ - | (0 << (31- 7)) /* store fault */ \ - | (0 << (31- 8)) /* user environment call */ \ - | (1 << (31- 9)) /* super environment call */ \ - .option norvc .section .text.init,"ax",@progbits .globl mentry @@ -58,19 +34,9 @@ mentry: STORE a0, 10*REGBYTES(sp) STORE a1, 11*REGBYTES(sp) - csrr a0, mcause - bltz a0, .Linterrupt - - li a1, HANDLE_USER_TRAP_IN_MACHINE_MODE - SLL32 a1, a1, a0 - bltz a1, .Lhandle_trap_in_machine_mode - - # Redirect the trap to the supervisor. -.Lmrts: - LOAD a0, 10*REGBYTES(sp) - LOAD a1, 11*REGBYTES(sp) - csrrw sp, mscratch, sp - mrts + csrr a1, mcause + bltz a1, .Linterrupt + j .Lhandle_trap_in_machine_mode .align 6 # Entry point from supervisor mode (mtvec + 0x040) @@ -78,25 +44,9 @@ mentry: STORE a0, 10*REGBYTES(sp) STORE a1, 11*REGBYTES(sp) - csrr a0, mcause - bltz a0, .Linterrupt - - li a1, HANDLE_SUPERVISOR_TRAP_IN_MACHINE_MODE - SLL32 a1, a1, a0 - bltz a1, .Lhandle_trap_in_machine_mode - -.Linterrupt_in_supervisor: - # Detect double faults. - csrr a0, mstatus - SLL32 a0, a0, 31 - CONST_CTZ(MSTATUS_PRV2) - bltz a0, .Lsupervisor_double_fault - -.Lreturn_from_supervisor_double_fault: - # Redirect the trap to the supervisor. - LOAD a0, 10*REGBYTES(sp) - LOAD a1, 11*REGBYTES(sp) - csrrw sp, mscratch, sp - mrts + csrr a1, mcause + bltz a1, .Linterrupt + j .Lhandle_trap_in_machine_mode .align 6 # Entry point from hypervisor mode (mtvec + 0x080) @@ -110,16 +60,12 @@ mentry: addi sp, sp, -INTEGER_CONTEXT_SIZE STORE a0,10*REGBYTES(sp) STORE a1,11*REGBYTES(sp) - li a0, TRAP_FROM_MACHINE_MODE_VECTOR + li a1, TRAP_FROM_MACHINE_MODE_VECTOR j .Lhandle_trap_in_machine_mode -.Lsupervisor_double_fault: - # Return to supervisor trap entry with interrupts disabled. - # Set PRV2=U, IE2=1, PRV1=S (it already is), and IE1=0. - li a0, MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_IE1 - csrc mstatus, a0 - j .Lreturn_from_supervisor_double_fault - + nop + nop + nop nop nop nop @@ -198,26 +144,18 @@ mentry: j init_other_hart .Linterrupt: - sll a0, a0, 1 # discard MSB + sll a1, a1, 1 # discard MSB # See if this is a timer interrupt; post a supervisor interrupt if so. - li a1, IRQ_TIMER * 2 + li a0, IRQ_M_TIMER * 2 bne a0, a1, 1f li a0, MIP_MTIP - csrc mip, a0 - li a1, MIP_STIP csrc mie, a0 - csrs mip, a1 + li a0, MIP_STIP + csrs mip, a0 -.Linterrupt_supervisor: - # There are three cases: PRV1=U; PRV1=S and IE1=1; and PRV1=S and IE1=0. - # For cases 1-2, do an MRTS; for case 3, we can't, so ERET. - csrr a0, mstatus - li a1, (MSTATUS_PRV1 & ~(MSTATUS_PRV1<<1)) * PRV_S - and a0, a0, MSTATUS_PRV1 | MSTATUS_IE1 - bne a0, a1, .Lmrts - - # And then go back whence we came. +.Leret: + # Go back whence we came. LOAD a0, 10*REGBYTES(sp) LOAD a1, 11*REGBYTES(sp) csrrw sp, mscratch, sp @@ -225,19 +163,17 @@ mentry: 1: # See if this is an IPI; register a supervisor SW interrupt if so. -#if IRQ_SOFT != 0 -#error -#endif - bnez a0, 1f + li a0, IRQ_M_SOFT * 2 + bne a0, a1, 1f csrc mip, MIP_MSIP csrs mip, MIP_SSIP - j .Linterrupt_supervisor + j .Leret 1: # See if this is an HTIF interrupt; if so, handle it in machine mode. - li a1, IRQ_HOST * 2 + li a0, IRQ_HOST * 2 bne a0, a1, .Lbad_trap - li a0, HTIF_INTERRUPT_VECTOR + li a1, HTIF_INTERRUPT_VECTOR .Lhandle_trap_in_machine_mode: # Preserve the registers. Compute the address of the trap handler. @@ -248,14 +184,15 @@ mentry: STORE t0, 5*REGBYTES(sp) 1:auipc t0, %pcrel_hi(trap_table) # t0 <- %hi(trap_table) STORE t1, 6*REGBYTES(sp) - sll t1, a0, 2 # t1 <- mcause << 2 + sll t1, a1, 2 # t1 <- mcause << 2 STORE t2, 7*REGBYTES(sp) add t0, t0, t1 # t0 <- %hi(trap_table)[mcause] STORE s0, 8*REGBYTES(sp) lw t0, %pcrel_lo(1b)(t0) # t0 <- trap_table[mcause] STORE s1, 9*REGBYTES(sp) - mv a1, sp # a1 <- regs + mv a0, sp # a0 <- regs STORE a2,12*REGBYTES(sp) + csrr a2, mepc # a2 <- mepc STORE a3,13*REGBYTES(sp) STORE a4,14*REGBYTES(sp) STORE a5,15*REGBYTES(sp) @@ -289,6 +226,7 @@ mentry: sw tp, (sp) # Move the emulated FCSR from tp into x0's save slot. #endif +restore_regs: # Restore all of the registers. LOAD ra, 1*REGBYTES(sp) LOAD gp, 3*REGBYTES(sp) @@ -298,6 +236,7 @@ mentry: LOAD t2, 7*REGBYTES(sp) LOAD s0, 8*REGBYTES(sp) LOAD s1, 9*REGBYTES(sp) + LOAD a0,10*REGBYTES(sp) LOAD a1,11*REGBYTES(sp) LOAD a2,12*REGBYTES(sp) LOAD a3,13*REGBYTES(sp) @@ -319,18 +258,13 @@ mentry: LOAD t4,29*REGBYTES(sp) LOAD t5,30*REGBYTES(sp) LOAD t6,31*REGBYTES(sp) - - bnez a0, 1f - - # Go back whence we came. - LOAD a0, 10*REGBYTES(sp) LOAD sp, 2*REGBYTES(sp) eret -1:# Redirect the trap to the supervisor. - LOAD a0, 10*REGBYTES(sp) - LOAD sp, 2*REGBYTES(sp) - mrts +.globl leave +leave: + csrr sp, mscratch + j restore_regs .Lbad_trap: j bad_trap @@ -12,13 +12,8 @@ static void mstatus_init() panic("supervisor support is required"); uintptr_t ms = 0; - ms = INSERT_FIELD(ms, MSTATUS_PRV, PRV_M); - ms = INSERT_FIELD(ms, MSTATUS_PRV1, PRV_S); - ms = INSERT_FIELD(ms, MSTATUS_PRV2, PRV_U); - ms = INSERT_FIELD(ms, MSTATUS_IE2, 1); ms = INSERT_FIELD(ms, MSTATUS_VM, VM_CHOICE); - ms = INSERT_FIELD(ms, MSTATUS_FS, 3); - ms = INSERT_FIELD(ms, MSTATUS_XS, 3); + ms = INSERT_FIELD(ms, MSTATUS_FS, 1); write_csr(mstatus, ms); ms = read_csr(mstatus); @@ -27,7 +22,24 @@ static void mstatus_init() write_csr(mtimecmp, 0); clear_csr(mip, MIP_MSIP); - set_csr(mie, MIP_MSIP); + write_csr(mie, -1); +} + +static void delegate_traps() +{ + uintptr_t interrupts = MIP_SSIP | MIP_STIP; + uintptr_t exceptions = + (1U << CAUSE_MISALIGNED_FETCH) | + (1U << CAUSE_FAULT_FETCH | CAUSE_BREAKPOINT) | + (1U << CAUSE_FAULT_LOAD) | + (1U << CAUSE_FAULT_STORE) | + (1U << CAUSE_BREAKPOINT) | + (1U << CAUSE_USER_ECALL); + + write_csr(mideleg, interrupts); + write_csr(medeleg, exceptions); + kassert(read_csr(mideleg) == interrupts); + kassert(read_csr(medeleg) == exceptions); } static void memory_init() @@ -71,18 +83,19 @@ void hls_init(uint32_t id, uintptr_t* csrs) } } -static void init_hart() +static void hart_init() { mstatus_init(); fp_init(); + delegate_traps(); } void init_first_hart() { - init_hart(); + file_init(); + hart_init(); memset(HLS(), 0, sizeof(*HLS())); - file_init(); parse_device_tree(); struct mainvars arg_buffer; @@ -95,7 +108,7 @@ void init_first_hart() void init_other_hart() { - init_hart(); + hart_init(); // wait until virtual memory is enabled while (*(pte_t* volatile*)&root_page_table == NULL) @@ -4,7 +4,7 @@ #include "vm.h" #include <errno.h> -uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs) +void illegal_insn_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc) { asm (".pushsection .rodata\n" "illegal_insn_trap_table:\n" @@ -34,9 +34,9 @@ uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs) " .word truly_illegal_insn\n" #ifdef PK_ENABLE_FP_EMULATION " .word emulate_fmadd\n" - " .word emulate_fmsub\n" - " .word emulate_fnmsub\n" - " .word emulate_fnmadd\n" + " .word emulate_fmadd\n" + " .word emulate_fmadd\n" + " .word emulate_fmadd\n" " .word emulate_fp\n" #else " .word truly_illegal_insn\n" @@ -63,17 +63,17 @@ uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs) " .popsection"); uintptr_t mstatus = read_csr(mstatus); - uintptr_t mepc = read_csr(mepc); - insn_fetch_t fetch = get_insn(mcause, mstatus, mepc); + insn_t insn = get_insn(mepc); - if (fetch.error || (fetch.insn & 3) != 3) - return -1; + if ((insn & 3) != 3) + return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); + write_csr(mepc, mepc + 4); extern int32_t illegal_insn_trap_table[]; - int32_t* pf = (void*)illegal_insn_trap_table + (fetch.insn & 0x7c); + int32_t* pf = (void*)illegal_insn_trap_table + (insn & 0x7c); emulation_func f = (emulation_func)(uintptr_t)*pf; - return f(mcause, regs, fetch.insn, mstatus, mepc); + f(regs, mcause, mepc, mstatus, insn); } void __attribute__((noreturn)) bad_trap() @@ -81,11 +81,11 @@ void __attribute__((noreturn)) bad_trap() panic("machine mode: unhandlable trap %d @ %p", read_csr(mcause), read_csr(mepc)); } -uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs) +void htif_interrupt() { uintptr_t fromhost = swap_csr(mfromhost, 0); if (!fromhost) - return 0; + return; uintptr_t dev = FROMHOST_DEV(fromhost); uintptr_t cmd = FROMHOST_CMD(fromhost); @@ -95,7 +95,7 @@ uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs) sbi_device_message* prev = NULL; for (size_t i = 0, n = HLS()->device_request_queue_size; i < n; i++) { if (!supervisor_paddr_valid(m, sizeof(*m)) - && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M) + && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_MPP) != PRV_M) panic("htif: page fault"); sbi_device_message* next = (void*)m->sbi_private_data; @@ -119,7 +119,7 @@ uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs) // signal software interrupt set_csr(mip, MIP_SSIP); - return 0; + return; } prev = m; @@ -141,7 +141,7 @@ static uintptr_t mcall_console_putchar(uint8_t ch) uintptr_t fromhost = read_csr(mfromhost); if (FROMHOST_DEV(fromhost) != 1 || FROMHOST_CMD(fromhost) != 1) { if (fromhost) - htif_interrupt(0, 0); + htif_interrupt(); continue; } write_csr(mfromhost, 0); @@ -168,7 +168,7 @@ static uintptr_t mcall_dev_req(sbi_device_message *m) { //printm("req %d %p\n", HLS()->device_request_queue_size, m); if (!supervisor_paddr_valid(m, sizeof(*m)) - && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M) + && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_MPP) != PRV_M) return -EFAULT; if ((m->dev > 0xFFU) | (m->cmd > 0xFFU) | (m->data > 0x0000FFFFFFFFFFFFU)) @@ -186,7 +186,7 @@ static uintptr_t mcall_dev_req(sbi_device_message *m) static uintptr_t mcall_dev_resp() { - htif_interrupt(0, 0); + htif_interrupt(); sbi_device_message* m = HLS()->device_response_queue_head; if (m) { @@ -249,7 +249,7 @@ static uintptr_t mcall_set_timer(unsigned long long when) return 0; } -uintptr_t mcall_trap(uintptr_t mcause, uintptr_t* regs) +void mcall_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc) { uintptr_t n = regs[17], arg0 = regs[10], retval; switch (n) @@ -283,25 +283,22 @@ uintptr_t mcall_trap(uintptr_t mcause, uintptr_t* regs) break; } regs[10] = retval; - write_csr(mepc, read_csr(mepc) + 4); - return 0; + write_csr(mepc, mepc + 4); } -static uintptr_t machine_page_fault(uintptr_t mcause, uintptr_t* regs, uintptr_t mepc) +static void machine_page_fault(uintptr_t* regs, uintptr_t mepc) { // See if this trap occurred when emulating an instruction on behalf of // a lower privilege level. extern int32_t unprivileged_access_ranges[]; extern int32_t unprivileged_access_ranges_end[]; - int32_t* p = unprivileged_access_ranges; do { if (mepc >= p[0] && mepc < p[1]) { - // Yes. Skip to the end of the unprivileged access region. - // Mark t0 zero so the emulation routine knows this occurred. - regs[5] = 0; - write_csr(mepc, p[1]); - return 0; + // Yes. Redirect the trap to the supervisor. + write_csr(sbadaddr, read_csr(mbadaddr)); + redirect_trap(regs[14], regs[5]); + return; } p += 2; } while (p < unprivileged_access_ranges_end); @@ -310,15 +307,9 @@ static uintptr_t machine_page_fault(uintptr_t mcause, uintptr_t* regs, uintptr_t bad_trap(); } -static uintptr_t machine_illegal_instruction(uintptr_t mcause, uintptr_t* regs, uintptr_t mepc) -{ - bad_trap(); -} - -uintptr_t trap_from_machine_mode(uintptr_t dummy, uintptr_t* regs) +void trap_from_machine_mode(uintptr_t* regs, uintptr_t dummy, uintptr_t mepc) { uintptr_t mcause = read_csr(mcause); - uintptr_t mepc = read_csr(mepc); // restore mscratch, since we clobbered it. write_csr(mscratch, MACHINE_STACK_TOP() - MENTRY_FRAME_SIZE); @@ -326,11 +317,11 @@ uintptr_t trap_from_machine_mode(uintptr_t dummy, uintptr_t* regs) { case CAUSE_FAULT_LOAD: case CAUSE_FAULT_STORE: - return machine_page_fault(mcause, regs, mepc); + return machine_page_fault(regs, mepc); case CAUSE_ILLEGAL_INSTRUCTION: - return machine_illegal_instruction(mcause, regs, mepc); + return bad_trap(); case CAUSE_MACHINE_ECALL: - return mcall_trap(mcause, regs); + return mcall_trap(regs, dummy, mepc); default: bad_trap(); } @@ -11,72 +11,34 @@ #define GET_MACRO(_1,_2,_3,_4,NAME,...) NAME -#define unpriv_mem_access(a, b, c, d, ...) GET_MACRO(__VA_ARGS__, unpriv_mem_access3, unpriv_mem_access2, unpriv_mem_access1, unpriv_mem_access0)(a, b, c, d, __VA_ARGS__) -#define unpriv_mem_access0(a, b, c, d, e) ({ uintptr_t z = 0, z1 = 0, z2 = 0; unpriv_mem_access_base(a, b, c, d, e, z, z1, z2); }) -#define unpriv_mem_access1(a, b, c, d, e, f) ({ uintptr_t z = 0, z1 = 0; unpriv_mem_access_base(a, b, c, d, e, f, z, z1); }) -#define unpriv_mem_access2(a, b, c, d, e, f, g) ({ uintptr_t z = 0; unpriv_mem_access_base(a, b, c, d, e, f, g, z); }) -#define unpriv_mem_access3(a, b, c, d, e, f, g, h) unpriv_mem_access_base(a, b, c, d, e, f, g, h) -#define unpriv_mem_access_base(mstatus, mepc, code, o0, o1, i0, i1, i2) ({ \ - register uintptr_t result asm("t0"); \ - uintptr_t unused1, unused2 __attribute__((unused)); \ - uintptr_t scratch = MSTATUS_MPRV; \ - asm volatile ("csrrs %[result], mstatus, %[scratch]\n" \ +#define unpriv_mem_access(a, b, c, ...) GET_MACRO(__VA_ARGS__, unpriv_mem_access3, unpriv_mem_access2, unpriv_mem_access1, unpriv_mem_access0)(a, b, c, __VA_ARGS__) +#define unpriv_mem_access0(a, b, c, d) ({ uintptr_t z = 0, z1 = 0, z2 = 0; unpriv_mem_access_base(a, b, c, d, z, z1, z2); }) +#define unpriv_mem_access1(a, b, c, d, e) ({ uintptr_t z = 0, z1 = 0; unpriv_mem_access_base(a, b, c, d, e, z, z1); }) +#define unpriv_mem_access2(a, b, c, d, e, f) ({ uintptr_t z = 0; unpriv_mem_access_base(a, b, c, d, e, f, z); }) +#define unpriv_mem_access3(a, b, c, d, e, f, g) unpriv_mem_access_base(a, b, c, d, e, f, g) +#define unpriv_mem_access_base(code, o0, o1, o2, i0, i1, i2) ({ \ + register uintptr_t scratch asm ("t0") = MSTATUS_MPRV; \ + register uintptr_t __mepc asm ("a4") = mepc; \ + uintptr_t unused1, unused2, unused3 __attribute__((unused)); \ + asm volatile ("csrrs %[scratch], mstatus, %[scratch]\n" \ "98: " code "\n" \ - "99: csrc mstatus, %[scratch]\n" \ + "99: csrw mstatus, %[scratch]\n" \ ".pushsection .unpriv,\"a\",@progbits\n" \ ".word 98b; .word 99b\n" \ ".popsection" \ - : [o0] "=&r"(o0), [o1] "=&r"(o1), \ - [result] "+&r"(result) \ + : [o0] "=&r"(o0), [o1] "=&r"(o1), [o2] "=&r"(o2), \ + [scratch] "+&r"(scratch) \ : [i0] "rJ"(i0), [i1] "rJ"(i1), [i2] "rJ"(i2), \ - [scratch] "r"(scratch), [mepc] "r"(mepc)); \ - unlikely(!result); }) - -#define restore_mstatus(mstatus, mepc) ({ \ - uintptr_t scratch; \ - uintptr_t mask = MSTATUS_PRV1 | MSTATUS_IE1 | MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_PRV3 | MSTATUS_IE3; \ - asm volatile("csrc mstatus, %[mask];" \ - "csrw mepc, %[mepc];" \ - "and %[scratch], %[mask], %[mstatus];" \ - "csrs mstatus, %[scratch]" \ - : [scratch] "=r"(scratch) \ - : [mstatus] "r"(mstatus), [mepc] "r"(mepc), \ - [mask] "r"(mask)); }) - -#define unpriv_load_1(ptr, dest, mstatus, mepc, type, insn) ({ \ - type value, dummy; void* addr = (ptr); \ - uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %[value], (%[addr])", value, dummy, addr); \ - (dest) = (typeof(dest))(uintptr_t)value; \ - res; }) -#define unpriv_load(ptr, dest) ({ \ - uintptr_t res; \ - uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \ - if (sizeof(*ptr) == 1) res = unpriv_load_1(ptr, dest, mstatus, mepc, int8_t, "lb"); \ - else if (sizeof(*ptr) == 2) res = unpriv_load_1(ptr, dest, mstatus, mepc, int16_t, "lh"); \ - else if (sizeof(*ptr) == 4) res = unpriv_load_1(ptr, dest, mstatus, mepc, int32_t, "lw"); \ - else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_load_1(ptr, dest, mstatus, mepc, int64_t, "ld"); \ - else __builtin_trap(); \ - if (res) restore_mstatus(mstatus, mepc); \ - res; }) - -#define unpriv_store_1(ptr, src, mstatus, mepc, type, insn) ({ \ - type dummy1, dummy2, value = (type)(uintptr_t)(src); void* addr = (ptr); \ - uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %z[value], (%[addr])", dummy1, dummy2, addr, value); \ - res; }) -#define unpriv_store(ptr, src) ({ \ - uintptr_t res; \ - uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \ - if (sizeof(*ptr) == 1) res = unpriv_store_1(ptr, src, mstatus, mepc, int8_t, "sb"); \ - else if (sizeof(*ptr) == 2) res = unpriv_store_1(ptr, src, mstatus, mepc, int16_t, "sh"); \ - else if (sizeof(*ptr) == 4) res = unpriv_store_1(ptr, src, mstatus, mepc, int32_t, "sw"); \ - else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_store_1(ptr, src, mstatus, mepc, int64_t, "sd"); \ - else __builtin_trap(); \ - if (res) restore_mstatus(mstatus, mepc); \ - res; }) + "r"(__mepc)); \ +}) typedef uint32_t insn_t; -typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, uintptr_t); -#define DECLARE_EMULATION_FUNC(name) uintptr_t name(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc) +typedef void (*emulation_func)(uintptr_t*, uintptr_t, uintptr_t, uintptr_t, insn_t); +#define DECLARE_EMULATION_FUNC(name) void name(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn) + +void truly_illegal_insn(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn); +void redirect_trap(uintptr_t epc, uintptr_t mstatus); +void leave(); #define GET_REG(insn, pos, regs) ({ \ int mask = (1 << (5+LOG_REGBYTES)) - (1 << LOG_REGBYTES); \ @@ -127,7 +89,7 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui # define SETUP_STATIC_ROUNDING(insn) ({ \ register long tp asm("tp") = read_csr(frm); \ if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) ; \ - else if (GET_RM(insn) > 4) return -1; \ + else if (GET_RM(insn) > 4) return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); \ else tp = GET_RM(insn); \ asm volatile ("":"+r"(tp)); }) # define softfloat_raiseFlags(which) set_csr(fflags, which) @@ -147,7 +109,7 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui # define SETUP_STATIC_ROUNDING(insn) ({ \ register int tp asm("tp"); tp &= 0xFF; \ if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) tp |= tp << 8; \ - else if (GET_RM(insn) > 4) return -1; \ + else if (GET_RM(insn) > 4) return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); \ else tp |= GET_RM(insn) << 13; \ asm volatile ("":"+r"(tp)); }) # define softfloat_raiseFlags(which) ({ asm volatile ("or tp, tp, %0" :: "rI"(which)); }) @@ -164,42 +126,26 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui #define SET_F64_RD(insn, regs, val) (SET_F64_REG(insn, 7, regs, val), SET_FS_DIRTY()) #define SET_FS_DIRTY() set_csr(mstatus, MSTATUS_FS) -typedef struct { - uintptr_t error; - insn_t insn; -} insn_fetch_t; - -static insn_fetch_t __attribute__((always_inline)) - get_insn(uintptr_t mcause, uintptr_t mstatus, uintptr_t mepc) +static insn_t __attribute__((always_inline)) get_insn(uintptr_t mepc) { - insn_fetch_t fetch; insn_t insn; #ifdef __riscv_compressed int rvc_mask = 3, insn_hi; - fetch.error = unpriv_mem_access(mstatus, mepc, - "lhu %[insn], 0(%[mepc]);" - "and %[insn_hi], %[insn], %[rvc_mask];" - "bne %[insn_hi], %[rvc_mask], 1f;" - "lh %[insn_hi], 2(%[mepc]);" - "sll %[insn_hi], %[insn_hi], 16;" - "or %[insn], %[insn], %[insn_hi];" - "1:", - insn, insn_hi, rvc_mask); + unpriv_mem_access("lhu %[insn], 0(%[mepc]);" + "and %[insn_hi], %[insn], %[rvc_mask];" + "bne %[insn_hi], %[rvc_mask], 1f;" + "lh %[insn_hi], 2(%[mepc]);" + "sll %[insn_hi], %[insn_hi], 16;" + "or %[insn], %[insn], %[insn_hi];" + "1:", + insn, insn_hi, unused1, mepc, rvc_mask); #else - fetch.error = unpriv_mem_access(mstatus, mepc, - "lw %[insn], 0(%[mepc])", - insn, unused1); + unpriv_mem_access("lw %[insn], 0(%[mepc])", + insn, unused1, unused2, mepc); #endif - fetch.insn = insn; - - if (unlikely(fetch.error)) { - // we've messed up mstatus, mepc, and mcause, so restore them all - restore_mstatus(mstatus, mepc); - write_csr(mcause, mcause); - } - return fetch; + return insn; } static inline long __attribute__((pure)) cpuid() @@ -12,8 +12,10 @@ void run_loaded_program(struct mainvars* args) extern char trap_entry; write_csr(stvec, &trap_entry); write_csr(sscratch, 0); + clear_csr(sie, SIP_STIP | SIP_SSIP); // enter supervisor mode + prepare_supervisor_mode(); asm volatile("la t0, 1f; csrw mepc, t0; eret; 1:" ::: "t0"); // copy phdrs to user stack @@ -67,6 +67,7 @@ void handle_misaligned_load(trapframe_t*); void handle_misaligned_store(trapframe_t*); void handle_fault_load(trapframe_t*); void handle_fault_store(trapframe_t*); +void prepare_supervisor_mode(); void boot_loader(struct mainvars*); void run_loaded_program(struct mainvars*); void boot_other_hart(); |