aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pk/bbl.c19
-rw-r--r--pk/emulation.c466
-rw-r--r--pk/encoding.h99
-rw-r--r--pk/entry.S4
-rw-r--r--pk/handlers.c2
-rw-r--r--pk/init.c10
-rw-r--r--pk/mentry.S126
-rw-r--r--pk/minit.c35
-rw-r--r--pk/mtrap.c65
-rw-r--r--pk/mtrap.h124
-rw-r--r--pk/pk.c2
-rw-r--r--pk/pk.h1
12 files changed, 346 insertions, 607 deletions
diff --git a/pk/bbl.c b/pk/bbl.c
index 68ad712..dc6e1ea 100644
--- a/pk/bbl.c
+++ b/pk/bbl.c
@@ -2,13 +2,14 @@
#include "vm.h"
#include "config.h"
-volatile int elf_loaded;
+static volatile int elf_loaded;
static void enter_entry_point()
{
- write_csr(mepc, current.entry);
- asm volatile("eret");
- __builtin_unreachable();
+ prepare_supervisor_mode();
+ write_csr(mepc, current.entry);
+ asm volatile("eret");
+ __builtin_unreachable();
}
void run_loaded_program(struct mainvars* args)
@@ -16,13 +17,13 @@ void run_loaded_program(struct mainvars* args)
if (!current.is_supervisor)
panic("bbl can't run user binaries; try using pk instead");
- supervisor_vm_init();
+ supervisor_vm_init();
#ifdef PK_ENABLE_LOGO
- print_logo();
+ print_logo();
#endif
- mb();
- elf_loaded = 1;
- enter_entry_point();
+ mb();
+ elf_loaded = 1;
+ enter_entry_point();
}
void boot_other_hart()
diff --git a/pk/emulation.c b/pk/emulation.c
index 7e8000e..bd5b6a8 100644
--- a/pk/emulation.c
+++ b/pk/emulation.c
@@ -2,175 +2,112 @@
#include "softfloat.h"
#include <limits.h>
-DECLARE_EMULATION_FUNC(truly_illegal_insn)
+void redirect_trap(uintptr_t epc, uintptr_t mstatus)
{
- return -1;
+ write_csr(sepc, epc);
+ write_csr(scause, read_csr(mcause));
+ write_csr(mepc, read_csr(stvec));
+
+ uintptr_t prev_priv = EXTRACT_FIELD(mstatus, MSTATUS_MPP);
+ uintptr_t prev_ie = EXTRACT_FIELD(mstatus, MSTATUS_MPIE);
+ kassert(prev_priv <= PRV_S);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_SPP, prev_priv);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_SPIE, prev_ie);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_MPP, PRV_S);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_MPIE, 0);
+ write_csr(mstatus, mstatus);
+ leave();
+}
+
+void __attribute__((noinline)) truly_illegal_insn(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn)
+{
+ redirect_trap(mepc, mstatus);
}
-uintptr_t misaligned_load_trap(uintptr_t mcause, uintptr_t* regs)
+void misaligned_load_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc)
{
uintptr_t mstatus = read_csr(mstatus);
- uintptr_t mepc = read_csr(mepc);
- insn_fetch_t fetch = get_insn(mcause, mstatus, mepc);
- if (fetch.error)
- return -1;
-
- uintptr_t val, res, tmp;
- uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_I(fetch.insn);
-
- #define DO_LOAD(type_lo, type_hi, insn_lo, insn_hi) ({ \
- type_lo val_lo; \
- type_hi val_hi; \
- uintptr_t addr_lo = addr & -sizeof(type_hi); \
- uintptr_t addr_hi = addr_lo + sizeof(type_hi); \
- uintptr_t masked_addr = sizeof(type_hi) < 4 ? addr % sizeof(type_hi) : addr; \
- res = unpriv_mem_access(mstatus, mepc, \
- insn_lo " %[val_lo], (%[addr_lo]);" \
- insn_hi " %[val_hi], (%[addr_hi])", \
- val_lo, val_hi, addr_lo, addr_hi); \
- val_lo >>= masked_addr * 8; \
- val_hi <<= (sizeof(type_hi) - masked_addr) * 8; \
- val = (type_hi)(val_lo | val_hi); \
- })
-
- if ((fetch.insn & MASK_LW) == MATCH_LW)
- DO_LOAD(uint32_t, int32_t, "lw", "lw");
-#ifdef __riscv64
- else if ((fetch.insn & MASK_LD) == MATCH_LD)
- DO_LOAD(uint64_t, uint64_t, "ld", "ld");
- else if ((fetch.insn & MASK_LWU) == MATCH_LWU)
- DO_LOAD(uint32_t, uint32_t, "lwu", "lwu");
-#endif
- else if ((fetch.insn & MASK_FLD) == MATCH_FLD) {
+ insn_t insn = get_insn(mepc);
+
+ int shift = 0, fp = 0, len;
+ if ((insn & MASK_LW) == MATCH_LW)
+ len = 4, shift = 8*(sizeof(uintptr_t) - len);
#ifdef __riscv64
- DO_LOAD(uint64_t, uint64_t, "ld", "ld");
- if (res == 0) {
- SET_F64_RD(fetch.insn, regs, val);
- goto success;
- }
-#else
- DO_LOAD(uint32_t, int32_t, "lw", "lw");
- if (res == 0) {
- uint64_t double_val = val;
- addr += 4;
- DO_LOAD(uint32_t, int32_t, "lw", "lw");
- double_val |= (uint64_t)val << 32;
- if (res == 0) {
- SET_F64_RD(fetch.insn, regs, val);
- goto success;
- }
- }
+ else if ((insn & MASK_LD) == MATCH_LD)
+ len = 8, shift = 8*(sizeof(uintptr_t) - len);
+ else if ((insn & MASK_LWU) == MATCH_LWU)
+ fp = 0, len = 4, shift = 0;
#endif
- } else if ((fetch.insn & MASK_FLW) == MATCH_FLW) {
- DO_LOAD(uint32_t, uint32_t, "lw", "lw");
- if (res == 0) {
- SET_F32_RD(fetch.insn, regs, val);
- goto success;
- }
- } else if ((fetch.insn & MASK_LH) == MATCH_LH) {
- // equivalent to DO_LOAD(uint32_t, int16_t, "lhu", "lh")
- res = unpriv_mem_access(mstatus, mepc,
- "lbu %[val], 0(%[addr]);"
- "lb %[tmp], 1(%[addr])",
- val, tmp, addr, mstatus/*X*/);
- val |= tmp << 8;
- } else if ((fetch.insn & MASK_LHU) == MATCH_LHU) {
- // equivalent to DO_LOAD(uint32_t, uint16_t, "lhu", "lhu")
- res = unpriv_mem_access(mstatus, mepc,
- "lbu %[val], 0(%[addr]);"
- "lbu %[tmp], 1(%[addr])",
- val, tmp, addr, mstatus/*X*/);
- val |= tmp << 8;
- } else {
- return -1;
- }
-
- if (res) {
- restore_mstatus(mstatus, mepc);
- return -1;
- }
+ else if ((insn & MASK_FLD) == MATCH_FLD)
+ fp = 1, len = 8;
+ else if ((insn & MASK_FLW) == MATCH_FLW)
+ fp = 1, len = 4;
+ else if ((insn & MASK_LH) == MATCH_LH)
+ len = 2, shift = 8*(sizeof(uintptr_t) - len);
+ else if ((insn & MASK_LHU) == MATCH_LHU)
+ len = 2;
+ else
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
- SET_RD(fetch.insn, regs, val);
+ uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn);
+ uintptr_t val = 0, tmp, tmp2;
+ unpriv_mem_access("add %[tmp2], %[addr], %[len];"
+ "1: slli %[val], %[val], 8;"
+ "lbu %[tmp], -1(%[tmp2]);"
+ "addi %[tmp2], %[tmp2], -1;"
+ "or %[val], %[val], %[tmp];"
+ "bne %[addr], %[tmp2], 1b;",
+ val, tmp, tmp2, addr, len);
+
+ if (shift)
+ val = (intptr_t)val << shift >> shift;
+
+ if (!fp)
+ SET_RD(insn, regs, val);
+ else if (len == 8)
+ SET_F64_RD(insn, regs, val);
+ else
+ SET_F32_RD(insn, regs, val);
-success:
write_csr(mepc, mepc + 4);
- return 0;
}
-uintptr_t misaligned_store_trap(uintptr_t mcause, uintptr_t* regs)
+void misaligned_store_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc)
{
uintptr_t mstatus = read_csr(mstatus);
- uintptr_t mepc = read_csr(mepc);
- insn_fetch_t fetch = get_insn(mcause, mstatus, mepc);
- if (fetch.error)
- return -1;
-
- uintptr_t addr = GET_RS1(fetch.insn, regs) + IMM_S(fetch.insn);
- uintptr_t val = GET_RS2(fetch.insn, regs), error;
-
- if ((fetch.insn & MASK_SW) == MATCH_SW) {
-SW:
- error = unpriv_mem_access(mstatus, mepc,
- "sb %[val], 0(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 1(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 2(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 3(%[addr]);",
- unused1, unused2, val, addr);
-#ifdef __riscv64
- } else if ((fetch.insn & MASK_SD) == MATCH_SD) {
-SD:
- error = unpriv_mem_access(mstatus, mepc,
- "sb %[val], 0(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 1(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 2(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 3(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 4(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 5(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 6(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 7(%[addr]);",
- unused1, unused2, val, addr);
-#endif
- } else if ((fetch.insn & MASK_SH) == MATCH_SH) {
- error = unpriv_mem_access(mstatus, mepc,
- "sb %[val], 0(%[addr]);"
- "srl %[val], %[val], 8; sb %[val], 1(%[addr]);",
- unused1, unused2, val, addr);
- } else if ((fetch.insn & MASK_FSD) == MATCH_FSD) {
-#ifdef __riscv64
- val = GET_F64_RS2(fetch.insn, regs);
- goto SD;
-#else
- uint64_t double_val = GET_F64_RS2(fetch.insn, regs);
- uint32_t val_lo = double_val, val_hi = double_val >> 32;
- error = unpriv_mem_access(mstatus, mepc,
- "sb %[val_lo], 0(%[addr]);"
- "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 1(%[addr]);"
- "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 2(%[addr]);"
- "srl %[val_lo], %[val_lo], 8; sb %[val_lo], 3(%[addr]);"
- "sb %[val_hi], 4(%[addr]);"
- "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 5(%[addr]);"
- "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 6(%[addr]);"
- "srl %[val_hi], %[val_hi], 8; sb %[val_hi], 7(%[addr]);",
- unused1, unused2, val_lo, val_hi, addr);
-#endif
- } else if ((fetch.insn & MASK_FSW) == MATCH_FSW) {
- val = GET_F32_RS2(fetch.insn, regs);
- goto SW;
- } else
- return -1;
-
- if (error) {
- restore_mstatus(mstatus, mepc);
- return -1;
- }
+ insn_t insn = get_insn(mepc);
+
+ uintptr_t val = GET_RS2(insn, regs), error;
+ int len;
+ if ((insn & MASK_SW) == MATCH_SW)
+ len = 4;
+ else if ((insn & MASK_SD) == MATCH_SD)
+ len = 8;
+ else if ((insn & MASK_FSD) == MATCH_FSD)
+ len = 8, val = GET_F64_RS2(insn, regs);
+ else if ((insn & MASK_FSW) == MATCH_FSW)
+ len = 4, val = GET_F32_RS2(insn, regs);
+ else if ((insn & MASK_SH) == MATCH_SH)
+ len = 2;
+ else
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
+
+ uintptr_t addr = GET_RS1(insn, regs) + IMM_S(insn);
+ uintptr_t tmp, tmp2, addr_end = addr + len;
+ unpriv_mem_access("mv %[tmp], %[val];"
+ "mv %[tmp2], %[addr];"
+ "1: sb %[tmp], 0(%[tmp2]);"
+ "srli %[tmp], %[tmp], 8;"
+ "addi %[tmp2], %[tmp2], 1;"
+ "bne %[tmp2], %[addr_end], 1b",
+ tmp, tmp2, unused1, val, addr, addr_end);
write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_float_load)
{
- uintptr_t val_lo, val_hi, error;
+ uintptr_t val_lo, val_hi;
uint64_t val;
uintptr_t addr = GET_RS1(insn, regs) + IMM_I(insn);
@@ -178,51 +115,37 @@ DECLARE_EMULATION_FUNC(emulate_float_load)
{
case MATCH_FLW & MASK_FUNCT3:
if (addr % 4 != 0)
- return misaligned_load_trap(mcause, regs);
-
- error = unpriv_mem_access(mstatus, mepc,
- "lw %[val_lo], (%[addr])",
- val_lo, val_hi/*X*/, addr, mstatus/*X*/);
+ return misaligned_load_trap(regs, mcause, mepc);
- if (error == 0) {
- SET_F32_RD(insn, regs, val_lo);
- goto success;
- }
+ unpriv_mem_access("lw %[val_lo], (%[addr])",
+ val_lo, unused1, unused2, addr, mstatus/*X*/);
+ SET_F32_RD(insn, regs, val_lo);
break;
case MATCH_FLD & MASK_FUNCT3:
if (addr % sizeof(uintptr_t) != 0)
- return misaligned_load_trap(mcause, regs);
+ return misaligned_load_trap(regs, mcause, mepc);
+
#ifdef __riscv64
- error = unpriv_mem_access(mstatus, mepc,
- "ld %[val], (%[addr])",
- val, val_hi/*X*/, addr, mstatus/*X*/);
+ unpriv_mem_access("ld %[val], (%[addr])",
+ val, val_hi/*X*/, unused1, addr, mstatus/*X*/);
#else
- error = unpriv_mem_access(mstatus, mepc,
- "lw %[val_lo], (%[addr]);"
- "lw %[val_hi], 4(%[addr])",
- val_lo, val_hi, addr, mstatus/*X*/);
+ unpriv_mem_access("lw %[val_lo], (%[addr]);"
+ "lw %[val_hi], 4(%[addr])",
+ val_lo, val_hi, unused1, addr, mstatus/*X*/);
val = val_lo | ((uint64_t)val_hi << 32);
#endif
-
- if (error == 0) {
- SET_F64_RD(insn, regs, val);
- goto success;
- }
+ SET_F64_RD(insn, regs, val);
break;
- }
- restore_mstatus(mstatus, mepc);
- return -1;
-
-success:
- write_csr(mepc, mepc + 4);
- return 0;
+ default:
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
+ }
}
DECLARE_EMULATION_FUNC(emulate_float_store)
{
- uintptr_t val_lo, val_hi, error;
+ uintptr_t val_lo, val_hi;
uint64_t val;
uintptr_t addr = GET_RS1(insn, regs) + IMM_S(insn);
@@ -230,44 +153,33 @@ DECLARE_EMULATION_FUNC(emulate_float_store)
{
case MATCH_FSW & MASK_FUNCT3:
if (addr % 4 != 0)
- return misaligned_store_trap(mcause, regs);
+ return misaligned_store_trap(regs, mcause, mepc);
val_lo = GET_F32_RS2(insn, regs);
- error = unpriv_mem_access(mstatus, mepc,
- "sw %[val_lo], (%[addr])",
- unused1, unused2, val_lo, addr);
+ unpriv_mem_access("sw %[val_lo], (%[addr])",
+ unused1, unused2, unused3, val_lo, addr);
break;
case MATCH_FSD & MASK_FUNCT3:
if (addr % sizeof(uintptr_t) != 0)
- return misaligned_store_trap(mcause, regs);
+ return misaligned_store_trap(regs, mcause, mepc);
val = GET_F64_RS2(insn, regs);
#ifdef __riscv64
- error = unpriv_mem_access(mstatus, mepc,
- "sd %[val], (%[addr])",
- unused1, unused2, val, addr);
+ unpriv_mem_access("sd %[val], (%[addr])",
+ unused1, unused2, unused3, val, addr);
#else
val_lo = val;
val_hi = val >> 32;
- error = unpriv_mem_access(mstatus, mepc,
- "sw %[val_lo], (%[addr]);"
- "sw %[val_hi], 4(%[addr])",
- unused1, unused2, val_lo, val_hi, addr);
+ unpriv_mem_access("sw %[val_lo], (%[addr]);"
+ "sw %[val_hi], 4(%[addr])",
+ unused1, unused2, unused3, val_lo, val_hi, addr);
#endif
break;
default:
- error = 1;
- }
-
- if (error) {
- restore_mstatus(mstatus, mepc);
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
#ifdef __riscv64
@@ -298,17 +210,15 @@ DECLARE_EMULATION_FUNC(emulate_mul_div)
else if ((insn & MASK_MULHSU) == MATCH_MULHSU)
val = ((double_int)(intptr_t)rs1 * (double_int)rs2) >> (8 * sizeof(rs1));
else
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
SET_RD(insn, regs, val);
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_mul_div32)
{
#ifndef __riscv64
- return truly_illegal_insn(mcause, regs, insn, mstatus, mepc);
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
#endif
uint32_t rs1 = GET_RS1(insn, regs), rs2 = GET_RS2(insn, regs);
@@ -326,14 +236,12 @@ DECLARE_EMULATION_FUNC(emulate_mul_div32)
else if ((insn & MASK_REMU) == MATCH_REMU)
val = rs1 % rs2;
else
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
SET_RD(insn, regs, val);
- write_csr(mepc, mepc + 4);
- return 0;
}
-static inline int emulate_read_csr(int num, uintptr_t* result, uintptr_t mstatus)
+static inline int emulate_read_csr(int num, uintptr_t mstatus, uintptr_t* result)
{
switch (num)
{
@@ -353,15 +261,14 @@ static inline int emulate_read_csr(int num, uintptr_t* result, uintptr_t mstatus
return -1;
}
-static inline int emulate_write_csr(int num, uintptr_t value, uintptr_t mstatus)
+static inline void emulate_write_csr(int num, uintptr_t value, uintptr_t mstatus)
{
switch (num)
{
- case CSR_FRM: SET_FRM(value); return 0;
- case CSR_FFLAGS: SET_FFLAGS(value); return 0;
- case CSR_FCSR: SET_FCSR(value); return 0;
+ case CSR_FRM: SET_FRM(value); return;
+ case CSR_FFLAGS: SET_FFLAGS(value); return;
+ case CSR_FCSR: SET_FCSR(value); return;
}
- return -1;
}
DECLARE_EMULATION_FUNC(emulate_system)
@@ -371,28 +278,26 @@ DECLARE_EMULATION_FUNC(emulate_system)
int csr_num = (uint32_t)insn >> 20;
uintptr_t csr_val, new_csr_val;
- if (emulate_read_csr(csr_num, &csr_val, mstatus) != 0)
- return -1;
+ if (emulate_read_csr(csr_num, mstatus, &csr_val))
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
int do_write = rs1_num;
switch (GET_RM(insn))
{
- case 0: return -1;
+ case 0: return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
case 1: new_csr_val = rs1_val; do_write = 1; break;
case 2: new_csr_val = csr_val | rs1_val; break;
case 3: new_csr_val = csr_val & ~rs1_val; break;
- case 4: return -1;
+ case 4: return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
case 5: new_csr_val = rs1_num; do_write = 1; break;
case 6: new_csr_val = csr_val | rs1_num; break;
case 7: new_csr_val = csr_val & ~rs1_num; break;
}
- if (do_write && emulate_write_csr(csr_num, new_csr_val, mstatus) != 0)
- return -1;
+ if (do_write)
+ emulate_write_csr(csr_num, new_csr_val, mstatus);
SET_RD(insn, regs, csr_val);
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fp)
@@ -435,17 +340,17 @@ DECLARE_EMULATION_FUNC(emulate_fp)
// if FPU is disabled, punt back to the OS
if (unlikely((mstatus & MSTATUS_FS) == 0))
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
extern int32_t fp_emulation_table[];
int32_t* pf = (void*)fp_emulation_table + ((insn >> 25) & 0x7c);
emulation_func f = (emulation_func)(uintptr_t)*pf;
SETUP_STATIC_ROUNDING(insn);
- return f(mcause, regs, insn, mstatus, mepc);
+ return f(regs, mcause, mepc, mstatus, insn);
}
-uintptr_t emulate_any_fadd(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc, uintptr_t neg_b)
+void emulate_any_fadd(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn, int32_t neg_b)
{
if (GET_PRECISION(insn) == PRECISION_S) {
uint32_t rs1 = GET_F32_RS1(insn, regs);
@@ -456,21 +361,18 @@ uintptr_t emulate_any_fadd(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintp
uint64_t rs2 = GET_F64_RS2(insn, regs) ^ ((uint64_t)neg_b << 32);
SET_F64_RD(insn, regs, f64_add(rs1, rs2));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fadd)
{
- return emulate_any_fadd(mcause, regs, insn, mstatus, mepc, 0);
+ return emulate_any_fadd(regs, mcause, mepc, mstatus, insn, 0);
}
DECLARE_EMULATION_FUNC(emulate_fsub)
{
- return emulate_any_fadd(mcause, regs, insn, mstatus, mepc, INT32_MIN);
+ return emulate_any_fadd(regs, mcause, mepc, mstatus, insn, INT32_MIN);
}
DECLARE_EMULATION_FUNC(emulate_fmul)
@@ -484,11 +386,8 @@ DECLARE_EMULATION_FUNC(emulate_fmul)
uint64_t rs2 = GET_F64_RS2(insn, regs);
SET_F64_RD(insn, regs, f64_mul(rs1, rs2));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fdiv)
@@ -502,35 +401,29 @@ DECLARE_EMULATION_FUNC(emulate_fdiv)
uint64_t rs2 = GET_F64_RS2(insn, regs);
SET_F64_RD(insn, regs, f64_div(rs1, rs2));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fsqrt)
{
if ((insn >> 20) & 0x1f)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
if (GET_PRECISION(insn) == PRECISION_S) {
SET_F32_RD(insn, regs, f32_sqrt(GET_F32_RS1(insn, regs)));
} else if (GET_PRECISION(insn) == PRECISION_D) {
SET_F64_RD(insn, regs, f64_sqrt(GET_F64_RS1(insn, regs)));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fsgnj)
{
int rm = GET_RM(insn);
if (rm >= 3)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
#define DO_FSGNJ(rs1, rs2, rm) ({ \
typeof(rs1) rs1_sign = (rs1) >> (8*sizeof(rs1)-1); \
@@ -548,18 +441,15 @@ DECLARE_EMULATION_FUNC(emulate_fsgnj)
uint64_t rs2 = GET_F64_RS2(insn, regs);
SET_F64_RD(insn, regs, DO_FSGNJ(rs1, rs2, rm));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fmin)
{
int rm = GET_RM(insn);
if (rm >= 2)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
if (GET_PRECISION(insn) == PRECISION_S) {
uint32_t rs1 = GET_F32_RS1(insn, regs);
@@ -576,11 +466,8 @@ DECLARE_EMULATION_FUNC(emulate_fmin)
int use_rs1 = f64_lt_quiet(arg1, arg2) || isNaNF64UI(rs2);
SET_F64_RD(insn, regs, use_rs1 ? rs1 : rs2);
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fcvt_ff)
@@ -588,24 +475,21 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_ff)
int rs2_num = (insn >> 20) & 0x1f;
if (GET_PRECISION(insn) == PRECISION_S) {
if (rs2_num != 1)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
SET_F32_RD(insn, regs, f64_to_f32(GET_F64_RS1(insn, regs)));
} else if (GET_PRECISION(insn) == PRECISION_D) {
if (rs2_num != 0)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
SET_F64_RD(insn, regs, f32_to_f64(GET_F32_RS1(insn, regs)));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fcvt_fi)
{
if (GET_PRECISION(insn) != PRECISION_S && GET_PRECISION(insn) != PRECISION_D)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
int negative = 0;
uint64_t uint_val = GET_RS1(insn, regs);
@@ -627,7 +511,7 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_fi)
break;
#endif
default:
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
uint64_t float64 = ui64_to_f64(uint_val);
@@ -638,9 +522,6 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_fi)
SET_F32_RD(insn, regs, f64_to_f32(float64));
else
SET_F64_RD(insn, regs, float64);
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fcvt_if)
@@ -648,10 +529,10 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if)
int rs2_num = (insn >> 20) & 0x1f;
#ifdef __riscv64
if (rs2_num >= 4)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
#else
if (rs2_num >= 2)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
#endif
int64_t float64;
@@ -660,7 +541,7 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if)
else if (GET_PRECISION(insn) == PRECISION_D)
float64 = GET_F64_RS1(insn, regs);
else
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
int negative = 0;
if (float64 < 0) {
@@ -716,16 +597,13 @@ DECLARE_EMULATION_FUNC(emulate_fcvt_if)
SET_FS_DIRTY();
SET_RD(insn, regs, result);
-
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fcmp)
{
int rm = GET_RM(insn);
if (rm >= 3)
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
uintptr_t result;
if (GET_PRECISION(insn) == PRECISION_S) {
@@ -745,11 +623,9 @@ DECLARE_EMULATION_FUNC(emulate_fcmp)
result = f64_lt(rs1, rs2);
goto success;
}
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
success:
SET_RD(insn, regs, result);
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fmv_if)
@@ -762,11 +638,9 @@ DECLARE_EMULATION_FUNC(emulate_fmv_if)
result = GET_F64_RS1(insn, regs);
#endif
else
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
SET_RD(insn, regs, result);
- write_csr(mepc, mepc + 4);
- return 0;
}
DECLARE_EMULATION_FUNC(emulate_fmv_fi)
@@ -778,18 +652,16 @@ DECLARE_EMULATION_FUNC(emulate_fmv_fi)
else if ((insn & MASK_FMV_D_X) == MATCH_FMV_D_X)
SET_F64_RD(insn, regs, rs1);
else
- return -1;
-
- write_csr(mepc, mepc + 4);
- return 0;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
-uintptr_t emulate_any_fmadd(int op, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc)
+DECLARE_EMULATION_FUNC(emulate_fmadd)
{
// if FPU is disabled, punt back to the OS
if (unlikely((mstatus & MSTATUS_FS) == 0))
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
+ int op = (insn >> 2) & 3;
SETUP_STATIC_ROUNDING(insn);
if (GET_PRECISION(insn) == PRECISION_S) {
uint32_t rs1 = GET_F32_RS1(insn, regs);
@@ -802,32 +674,6 @@ uintptr_t emulate_any_fmadd(int op, uintptr_t* regs, insn_t insn, uintptr_t msta
uint64_t rs3 = GET_F64_RS3(insn, regs);
SET_F64_RD(insn, regs, softfloat_mulAddF64(op, rs1, rs2, rs3));
} else {
- return -1;
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
}
- write_csr(mepc, mepc + 4);
- return 0;
-}
-
-DECLARE_EMULATION_FUNC(emulate_fmadd)
-{
- int op = 0;
- return emulate_any_fmadd(op, regs, insn, mstatus, mepc);
-}
-
-DECLARE_EMULATION_FUNC(emulate_fmsub)
-{
- int op = softfloat_mulAdd_subC;
- return emulate_any_fmadd(op, regs, insn, mstatus, mepc);
-}
-
-DECLARE_EMULATION_FUNC(emulate_fnmadd)
-{
- int op = softfloat_mulAdd_subC | softfloat_mulAdd_subProd;
- return emulate_any_fmadd(op, regs, insn, mstatus, mepc);
-}
-
-DECLARE_EMULATION_FUNC(emulate_fnmsub)
-{
- int op = softfloat_mulAdd_subProd;
- return emulate_any_fmadd(op, regs, insn, mstatus, mepc);
}
diff --git a/pk/encoding.h b/pk/encoding.h
index e9a495f..df04845 100644
--- a/pk/encoding.h
+++ b/pk/encoding.h
@@ -3,37 +3,41 @@
#ifndef RISCV_CSR_ENCODING_H
#define RISCV_CSR_ENCODING_H
-#define MSTATUS_IE 0x00000001
-#define MSTATUS_PRV 0x00000006
-#define MSTATUS_IE1 0x00000008
-#define MSTATUS_PRV1 0x00000030
-#define MSTATUS_IE2 0x00000040
-#define MSTATUS_PRV2 0x00000180
-#define MSTATUS_IE3 0x00000200
-#define MSTATUS_PRV3 0x00000C00
-#define MSTATUS_FS 0x00003000
-#define MSTATUS_XS 0x0000C000
-#define MSTATUS_MPRV 0x00010000
-#define MSTATUS_VM 0x003E0000
+#define MSTATUS_UIE 0x00000001
+#define MSTATUS_SIE 0x00000002
+#define MSTATUS_HIE 0x00000004
+#define MSTATUS_MIE 0x00000008
+#define MSTATUS_UPIE 0x00000010
+#define MSTATUS_SPIE 0x00000020
+#define MSTATUS_HPIE 0x00000040
+#define MSTATUS_MPIE 0x00000080
+#define MSTATUS_SPP 0x00000100
+#define MSTATUS_HPP 0x00000600
+#define MSTATUS_MPP 0x00001800
+#define MSTATUS_FS 0x00006000
+#define MSTATUS_XS 0x00018000
+#define MSTATUS_MPRV 0x00020000
+#define MSTATUS_VM 0x007C0000
#define MSTATUS32_SD 0x80000000
#define MSTATUS64_SD 0x8000000000000000
-#define SSTATUS_IE 0x00000001
-#define SSTATUS_PIE 0x00000008
-#define SSTATUS_PS 0x00000010
-#define SSTATUS_FS 0x00003000
-#define SSTATUS_XS 0x0000C000
-#define SSTATUS_MPRV 0x00010000
-#define SSTATUS_TIE 0x01000000
+#define SSTATUS_UIE 0x00000001
+#define SSTATUS_SIE 0x00000002
+#define SSTATUS_UPIE 0x00000010
+#define SSTATUS_SPIE 0x00000020
+#define SSTATUS_SPP 0x00000100
+#define SSTATUS_FS 0x00006000
+#define SSTATUS_XS 0x00018000
+#define SSTATUS_VM 0x007C0000
#define SSTATUS32_SD 0x80000000
#define SSTATUS64_SD 0x8000000000000000
-#define MIP_SSIP 0x00000002
-#define MIP_HSIP 0x00000004
-#define MIP_MSIP 0x00000008
-#define MIP_STIP 0x00000020
-#define MIP_HTIP 0x00000040
-#define MIP_MTIP 0x00000080
+#define MIP_SSIP (1 << IRQ_S_SOFT)
+#define MIP_HSIP (1 << IRQ_H_SOFT)
+#define MIP_MSIP (1 << IRQ_M_SOFT)
+#define MIP_STIP (1 << IRQ_S_TIMER)
+#define MIP_HTIP (1 << IRQ_H_TIMER)
+#define MIP_MTIP (1 << IRQ_M_TIMER)
#define SIP_SSIP MIP_SSIP
#define SIP_STIP MIP_STIP
@@ -50,14 +54,14 @@
#define VM_SV39 9
#define VM_SV48 10
-#define UA_RV32 0
-#define UA_RV64 4
-#define UA_RV128 8
-
-#define IRQ_SOFT 0
-#define IRQ_TIMER 1
-#define IRQ_HOST 2
-#define IRQ_COP 3
+#define IRQ_S_SOFT 1
+#define IRQ_H_SOFT 2
+#define IRQ_M_SOFT 3
+#define IRQ_S_TIMER 5
+#define IRQ_H_TIMER 6
+#define IRQ_M_TIMER 7
+#define IRQ_COP 8
+#define IRQ_HOST 9
#define IMPL_ROCKET 1
@@ -335,18 +339,12 @@
#define MASK_SCALL 0xffffffff
#define MATCH_SBREAK 0x100073
#define MASK_SBREAK 0xffffffff
-#define MATCH_SRET 0x10000073
+#define MATCH_SRET 0x10200073
#define MASK_SRET 0xffffffff
-#define MATCH_SFENCE_VM 0x10100073
+#define MATCH_SFENCE_VM 0x10400073
#define MASK_SFENCE_VM 0xfff07fff
-#define MATCH_WFI 0x10200073
+#define MATCH_WFI 0x10500073
#define MASK_WFI 0xffffffff
-#define MATCH_MRTH 0x30600073
-#define MASK_MRTH 0xffffffff
-#define MATCH_MRTS 0x30500073
-#define MASK_MRTS 0xffffffff
-#define MATCH_HRTS 0x20500073
-#define MASK_HRTS 0xffffffff
#define MATCH_CSRRW 0x1073
#define MASK_CSRRW 0x707f
#define MATCH_CSRRS 0x2073
@@ -643,6 +641,8 @@
#define CSR_SIE 0x104
#define CSR_SSCRATCH 0x140
#define CSR_SEPC 0x141
+#define CSR_SCAUSE 0x142
+#define CSR_SBADADDR 0x143
#define CSR_SIP 0x144
#define CSR_SPTBR 0x180
#define CSR_SASID 0x181
@@ -650,12 +650,11 @@
#define CSR_TIMEW 0x901
#define CSR_INSTRETW 0x902
#define CSR_STIME 0xd01
-#define CSR_SCAUSE 0xd42
-#define CSR_SBADADDR 0xd43
#define CSR_STIMEW 0xa01
#define CSR_MSTATUS 0x300
#define CSR_MTVEC 0x301
-#define CSR_MTDELEG 0x302
+#define CSR_MEDELEG 0x302
+#define CSR_MIDELEG 0x303
#define CSR_MIE 0x304
#define CSR_MTIMECMP 0x321
#define CSR_MSCRATCH 0x340
@@ -787,9 +786,6 @@ DECLARE_INSN(sbreak, MATCH_SBREAK, MASK_SBREAK)
DECLARE_INSN(sret, MATCH_SRET, MASK_SRET)
DECLARE_INSN(sfence_vm, MATCH_SFENCE_VM, MASK_SFENCE_VM)
DECLARE_INSN(wfi, MATCH_WFI, MASK_WFI)
-DECLARE_INSN(mrth, MATCH_MRTH, MASK_MRTH)
-DECLARE_INSN(mrts, MATCH_MRTS, MASK_MRTS)
-DECLARE_INSN(hrts, MATCH_HRTS, MASK_HRTS)
DECLARE_INSN(csrrw, MATCH_CSRRW, MASK_CSRRW)
DECLARE_INSN(csrrs, MATCH_CSRRS, MASK_CSRRS)
DECLARE_INSN(csrrc, MATCH_CSRRC, MASK_CSRRC)
@@ -954,6 +950,8 @@ DECLARE_CSR(stvec, CSR_STVEC)
DECLARE_CSR(sie, CSR_SIE)
DECLARE_CSR(sscratch, CSR_SSCRATCH)
DECLARE_CSR(sepc, CSR_SEPC)
+DECLARE_CSR(scause, CSR_SCAUSE)
+DECLARE_CSR(sbadaddr, CSR_SBADADDR)
DECLARE_CSR(sip, CSR_SIP)
DECLARE_CSR(sptbr, CSR_SPTBR)
DECLARE_CSR(sasid, CSR_SASID)
@@ -961,12 +959,11 @@ DECLARE_CSR(cyclew, CSR_CYCLEW)
DECLARE_CSR(timew, CSR_TIMEW)
DECLARE_CSR(instretw, CSR_INSTRETW)
DECLARE_CSR(stime, CSR_STIME)
-DECLARE_CSR(scause, CSR_SCAUSE)
-DECLARE_CSR(sbadaddr, CSR_SBADADDR)
DECLARE_CSR(stimew, CSR_STIMEW)
DECLARE_CSR(mstatus, CSR_MSTATUS)
DECLARE_CSR(mtvec, CSR_MTVEC)
-DECLARE_CSR(mtdeleg, CSR_MTDELEG)
+DECLARE_CSR(medeleg, CSR_MEDELEG)
+DECLARE_CSR(mideleg, CSR_MIDELEG)
DECLARE_CSR(mie, CSR_MIE)
DECLARE_CSR(mtimecmp, CSR_MTIMECMP)
DECLARE_CSR(mscratch, CSR_MSCRATCH)
diff --git a/pk/entry.S b/pk/entry.S
index 70ea2f1..d5fe55f 100644
--- a/pk/entry.S
+++ b/pk/entry.S
@@ -66,8 +66,8 @@ trap_entry:
jal handle_trap
mv a0,sp
- # don't restore sstatus if trap came from kernel
- andi s0,s0,SSTATUS_PS
+ # don't restore sscratch if trap came from kernel
+ andi s0,s0,SSTATUS_SPP
bnez s0,start_user
addi sp,sp,320
csrw sscratch,sp
diff --git a/pk/handlers.c b/pk/handlers.c
index 881cc17..34e39fe 100644
--- a/pk/handlers.c
+++ b/pk/handlers.c
@@ -47,7 +47,7 @@ void handle_misaligned_store(trapframe_t* tf)
static void segfault(trapframe_t* tf, uintptr_t addr, const char* type)
{
dump_tf(tf);
- const char* who = (tf->status & MSTATUS_PRV1) ? "Kernel" : "User";
+ const char* who = (tf->status & SSTATUS_SPP) ? "Kernel" : "User";
panic("%s %s segfault @ %p", who, type, addr);
}
diff --git a/pk/init.c b/pk/init.c
index e9f195f..8010629 100644
--- a/pk/init.c
+++ b/pk/init.c
@@ -19,7 +19,7 @@ char* uarch_counter_names[NUM_COUNTERS];
void init_tf(trapframe_t* tf, long pc, long sp)
{
memset(tf, 0, sizeof(*tf));
- tf->status = read_csr(sstatus);
+ tf->status = (read_csr(sstatus) &~ SSTATUS_SPP &~ SSTATUS_SIE) | SSTATUS_SPIE;
tf->gpr[2] = sp;
tf->epc = pc;
}
@@ -68,3 +68,11 @@ void boot_loader(struct mainvars* args)
run_loaded_program(args);
}
+
+void prepare_supervisor_mode()
+{
+ uintptr_t mstatus = read_csr(mstatus);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_MPP, PRV_S);
+ mstatus = INSERT_FIELD(mstatus, MSTATUS_MPIE, 0);
+ write_csr(mstatus, mstatus);
+}
diff --git a/pk/mentry.S b/pk/mentry.S
index 7f0fc61..baf35ec 100644
--- a/pk/mentry.S
+++ b/pk/mentry.S
@@ -24,30 +24,6 @@ trap_table:
.word bad_trap
.word bad_trap
-#define HANDLE_USER_TRAP_IN_MACHINE_MODE 0 \
- | (0 << (31- 0)) /* IF misaligned */ \
- | (0 << (31- 1)) /* IF fault */ \
- | (1 << (31- 2)) /* illegal instruction */ \
- | (0 << (31- 3)) /* breakpoint */ \
- | (1 << (31- 4)) /* load misaligned */ \
- | (0 << (31- 5)) /* load fault */ \
- | (1 << (31- 6)) /* store misaligned */ \
- | (0 << (31- 7)) /* store fault */ \
- | (0 << (31- 8)) /* user environment call */ \
- | (0 << (31- 9)) /* super environment call */ \
-
-#define HANDLE_SUPERVISOR_TRAP_IN_MACHINE_MODE 0 \
- | (0 << (31- 0)) /* IF misaligned */ \
- | (0 << (31- 1)) /* IF fault */ \
- | (1 << (31- 2)) /* illegal instruction */ \
- | (0 << (31- 3)) /* breakpoint */ \
- | (1 << (31- 4)) /* load misaligned */ \
- | (0 << (31- 5)) /* load fault */ \
- | (1 << (31- 6)) /* store misaligned */ \
- | (0 << (31- 7)) /* store fault */ \
- | (0 << (31- 8)) /* user environment call */ \
- | (1 << (31- 9)) /* super environment call */ \
-
.option norvc
.section .text.init,"ax",@progbits
.globl mentry
@@ -58,19 +34,9 @@ mentry:
STORE a0, 10*REGBYTES(sp)
STORE a1, 11*REGBYTES(sp)
- csrr a0, mcause
- bltz a0, .Linterrupt
-
- li a1, HANDLE_USER_TRAP_IN_MACHINE_MODE
- SLL32 a1, a1, a0
- bltz a1, .Lhandle_trap_in_machine_mode
-
- # Redirect the trap to the supervisor.
-.Lmrts:
- LOAD a0, 10*REGBYTES(sp)
- LOAD a1, 11*REGBYTES(sp)
- csrrw sp, mscratch, sp
- mrts
+ csrr a1, mcause
+ bltz a1, .Linterrupt
+ j .Lhandle_trap_in_machine_mode
.align 6
# Entry point from supervisor mode (mtvec + 0x040)
@@ -78,25 +44,9 @@ mentry:
STORE a0, 10*REGBYTES(sp)
STORE a1, 11*REGBYTES(sp)
- csrr a0, mcause
- bltz a0, .Linterrupt
-
- li a1, HANDLE_SUPERVISOR_TRAP_IN_MACHINE_MODE
- SLL32 a1, a1, a0
- bltz a1, .Lhandle_trap_in_machine_mode
-
-.Linterrupt_in_supervisor:
- # Detect double faults.
- csrr a0, mstatus
- SLL32 a0, a0, 31 - CONST_CTZ(MSTATUS_PRV2)
- bltz a0, .Lsupervisor_double_fault
-
-.Lreturn_from_supervisor_double_fault:
- # Redirect the trap to the supervisor.
- LOAD a0, 10*REGBYTES(sp)
- LOAD a1, 11*REGBYTES(sp)
- csrrw sp, mscratch, sp
- mrts
+ csrr a1, mcause
+ bltz a1, .Linterrupt
+ j .Lhandle_trap_in_machine_mode
.align 6
# Entry point from hypervisor mode (mtvec + 0x080)
@@ -110,16 +60,12 @@ mentry:
addi sp, sp, -INTEGER_CONTEXT_SIZE
STORE a0,10*REGBYTES(sp)
STORE a1,11*REGBYTES(sp)
- li a0, TRAP_FROM_MACHINE_MODE_VECTOR
+ li a1, TRAP_FROM_MACHINE_MODE_VECTOR
j .Lhandle_trap_in_machine_mode
-.Lsupervisor_double_fault:
- # Return to supervisor trap entry with interrupts disabled.
- # Set PRV2=U, IE2=1, PRV1=S (it already is), and IE1=0.
- li a0, MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_IE1
- csrc mstatus, a0
- j .Lreturn_from_supervisor_double_fault
-
+ nop
+ nop
+ nop
nop
nop
nop
@@ -198,26 +144,18 @@ mentry:
j init_other_hart
.Linterrupt:
- sll a0, a0, 1 # discard MSB
+ sll a1, a1, 1 # discard MSB
# See if this is a timer interrupt; post a supervisor interrupt if so.
- li a1, IRQ_TIMER * 2
+ li a0, IRQ_M_TIMER * 2
bne a0, a1, 1f
li a0, MIP_MTIP
- csrc mip, a0
- li a1, MIP_STIP
csrc mie, a0
- csrs mip, a1
+ li a0, MIP_STIP
+ csrs mip, a0
-.Linterrupt_supervisor:
- # There are three cases: PRV1=U; PRV1=S and IE1=1; and PRV1=S and IE1=0.
- # For cases 1-2, do an MRTS; for case 3, we can't, so ERET.
- csrr a0, mstatus
- li a1, (MSTATUS_PRV1 & ~(MSTATUS_PRV1<<1)) * PRV_S
- and a0, a0, MSTATUS_PRV1 | MSTATUS_IE1
- bne a0, a1, .Lmrts
-
- # And then go back whence we came.
+.Leret:
+ # Go back whence we came.
LOAD a0, 10*REGBYTES(sp)
LOAD a1, 11*REGBYTES(sp)
csrrw sp, mscratch, sp
@@ -225,19 +163,17 @@ mentry:
1:
# See if this is an IPI; register a supervisor SW interrupt if so.
-#if IRQ_SOFT != 0
-#error
-#endif
- bnez a0, 1f
+ li a0, IRQ_M_SOFT * 2
+ bne a0, a1, 1f
csrc mip, MIP_MSIP
csrs mip, MIP_SSIP
- j .Linterrupt_supervisor
+ j .Leret
1:
# See if this is an HTIF interrupt; if so, handle it in machine mode.
- li a1, IRQ_HOST * 2
+ li a0, IRQ_HOST * 2
bne a0, a1, .Lbad_trap
- li a0, HTIF_INTERRUPT_VECTOR
+ li a1, HTIF_INTERRUPT_VECTOR
.Lhandle_trap_in_machine_mode:
# Preserve the registers. Compute the address of the trap handler.
@@ -248,14 +184,15 @@ mentry:
STORE t0, 5*REGBYTES(sp)
1:auipc t0, %pcrel_hi(trap_table) # t0 <- %hi(trap_table)
STORE t1, 6*REGBYTES(sp)
- sll t1, a0, 2 # t1 <- mcause << 2
+ sll t1, a1, 2 # t1 <- mcause << 2
STORE t2, 7*REGBYTES(sp)
add t0, t0, t1 # t0 <- %hi(trap_table)[mcause]
STORE s0, 8*REGBYTES(sp)
lw t0, %pcrel_lo(1b)(t0) # t0 <- trap_table[mcause]
STORE s1, 9*REGBYTES(sp)
- mv a1, sp # a1 <- regs
+ mv a0, sp # a0 <- regs
STORE a2,12*REGBYTES(sp)
+ csrr a2, mepc # a2 <- mepc
STORE a3,13*REGBYTES(sp)
STORE a4,14*REGBYTES(sp)
STORE a5,15*REGBYTES(sp)
@@ -289,6 +226,7 @@ mentry:
sw tp, (sp) # Move the emulated FCSR from tp into x0's save slot.
#endif
+restore_regs:
# Restore all of the registers.
LOAD ra, 1*REGBYTES(sp)
LOAD gp, 3*REGBYTES(sp)
@@ -298,6 +236,7 @@ mentry:
LOAD t2, 7*REGBYTES(sp)
LOAD s0, 8*REGBYTES(sp)
LOAD s1, 9*REGBYTES(sp)
+ LOAD a0,10*REGBYTES(sp)
LOAD a1,11*REGBYTES(sp)
LOAD a2,12*REGBYTES(sp)
LOAD a3,13*REGBYTES(sp)
@@ -319,18 +258,13 @@ mentry:
LOAD t4,29*REGBYTES(sp)
LOAD t5,30*REGBYTES(sp)
LOAD t6,31*REGBYTES(sp)
-
- bnez a0, 1f
-
- # Go back whence we came.
- LOAD a0, 10*REGBYTES(sp)
LOAD sp, 2*REGBYTES(sp)
eret
-1:# Redirect the trap to the supervisor.
- LOAD a0, 10*REGBYTES(sp)
- LOAD sp, 2*REGBYTES(sp)
- mrts
+.globl leave
+leave:
+ csrr sp, mscratch
+ j restore_regs
.Lbad_trap:
j bad_trap
diff --git a/pk/minit.c b/pk/minit.c
index 91e2c9c..80380f8 100644
--- a/pk/minit.c
+++ b/pk/minit.c
@@ -12,13 +12,8 @@ static void mstatus_init()
panic("supervisor support is required");
uintptr_t ms = 0;
- ms = INSERT_FIELD(ms, MSTATUS_PRV, PRV_M);
- ms = INSERT_FIELD(ms, MSTATUS_PRV1, PRV_S);
- ms = INSERT_FIELD(ms, MSTATUS_PRV2, PRV_U);
- ms = INSERT_FIELD(ms, MSTATUS_IE2, 1);
ms = INSERT_FIELD(ms, MSTATUS_VM, VM_CHOICE);
- ms = INSERT_FIELD(ms, MSTATUS_FS, 3);
- ms = INSERT_FIELD(ms, MSTATUS_XS, 3);
+ ms = INSERT_FIELD(ms, MSTATUS_FS, 1);
write_csr(mstatus, ms);
ms = read_csr(mstatus);
@@ -27,7 +22,24 @@ static void mstatus_init()
write_csr(mtimecmp, 0);
clear_csr(mip, MIP_MSIP);
- set_csr(mie, MIP_MSIP);
+ write_csr(mie, -1);
+}
+
+static void delegate_traps()
+{
+ uintptr_t interrupts = MIP_SSIP | MIP_STIP;
+ uintptr_t exceptions =
+ (1U << CAUSE_MISALIGNED_FETCH) |
+ (1U << CAUSE_FAULT_FETCH | CAUSE_BREAKPOINT) |
+ (1U << CAUSE_FAULT_LOAD) |
+ (1U << CAUSE_FAULT_STORE) |
+ (1U << CAUSE_BREAKPOINT) |
+ (1U << CAUSE_USER_ECALL);
+
+ write_csr(mideleg, interrupts);
+ write_csr(medeleg, exceptions);
+ kassert(read_csr(mideleg) == interrupts);
+ kassert(read_csr(medeleg) == exceptions);
}
static void memory_init()
@@ -71,18 +83,19 @@ void hls_init(uint32_t id, uintptr_t* csrs)
}
}
-static void init_hart()
+static void hart_init()
{
mstatus_init();
fp_init();
+ delegate_traps();
}
void init_first_hart()
{
- init_hart();
+ file_init();
+ hart_init();
memset(HLS(), 0, sizeof(*HLS()));
- file_init();
parse_device_tree();
struct mainvars arg_buffer;
@@ -95,7 +108,7 @@ void init_first_hart()
void init_other_hart()
{
- init_hart();
+ hart_init();
// wait until virtual memory is enabled
while (*(pte_t* volatile*)&root_page_table == NULL)
diff --git a/pk/mtrap.c b/pk/mtrap.c
index 478710e..7bc6e7c 100644
--- a/pk/mtrap.c
+++ b/pk/mtrap.c
@@ -4,7 +4,7 @@
#include "vm.h"
#include <errno.h>
-uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs)
+void illegal_insn_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc)
{
asm (".pushsection .rodata\n"
"illegal_insn_trap_table:\n"
@@ -34,9 +34,9 @@ uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs)
" .word truly_illegal_insn\n"
#ifdef PK_ENABLE_FP_EMULATION
" .word emulate_fmadd\n"
- " .word emulate_fmsub\n"
- " .word emulate_fnmsub\n"
- " .word emulate_fnmadd\n"
+ " .word emulate_fmadd\n"
+ " .word emulate_fmadd\n"
+ " .word emulate_fmadd\n"
" .word emulate_fp\n"
#else
" .word truly_illegal_insn\n"
@@ -63,17 +63,17 @@ uintptr_t illegal_insn_trap(uintptr_t mcause, uintptr_t* regs)
" .popsection");
uintptr_t mstatus = read_csr(mstatus);
- uintptr_t mepc = read_csr(mepc);
- insn_fetch_t fetch = get_insn(mcause, mstatus, mepc);
+ insn_t insn = get_insn(mepc);
- if (fetch.error || (fetch.insn & 3) != 3)
- return -1;
+ if ((insn & 3) != 3)
+ return truly_illegal_insn(regs, mcause, mepc, mstatus, insn);
+ write_csr(mepc, mepc + 4);
extern int32_t illegal_insn_trap_table[];
- int32_t* pf = (void*)illegal_insn_trap_table + (fetch.insn & 0x7c);
+ int32_t* pf = (void*)illegal_insn_trap_table + (insn & 0x7c);
emulation_func f = (emulation_func)(uintptr_t)*pf;
- return f(mcause, regs, fetch.insn, mstatus, mepc);
+ f(regs, mcause, mepc, mstatus, insn);
}
void __attribute__((noreturn)) bad_trap()
@@ -81,11 +81,11 @@ void __attribute__((noreturn)) bad_trap()
panic("machine mode: unhandlable trap %d @ %p", read_csr(mcause), read_csr(mepc));
}
-uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs)
+void htif_interrupt()
{
uintptr_t fromhost = swap_csr(mfromhost, 0);
if (!fromhost)
- return 0;
+ return;
uintptr_t dev = FROMHOST_DEV(fromhost);
uintptr_t cmd = FROMHOST_CMD(fromhost);
@@ -95,7 +95,7 @@ uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs)
sbi_device_message* prev = NULL;
for (size_t i = 0, n = HLS()->device_request_queue_size; i < n; i++) {
if (!supervisor_paddr_valid(m, sizeof(*m))
- && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M)
+ && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_MPP) != PRV_M)
panic("htif: page fault");
sbi_device_message* next = (void*)m->sbi_private_data;
@@ -119,7 +119,7 @@ uintptr_t htif_interrupt(uintptr_t mcause, uintptr_t* regs)
// signal software interrupt
set_csr(mip, MIP_SSIP);
- return 0;
+ return;
}
prev = m;
@@ -141,7 +141,7 @@ static uintptr_t mcall_console_putchar(uint8_t ch)
uintptr_t fromhost = read_csr(mfromhost);
if (FROMHOST_DEV(fromhost) != 1 || FROMHOST_CMD(fromhost) != 1) {
if (fromhost)
- htif_interrupt(0, 0);
+ htif_interrupt();
continue;
}
write_csr(mfromhost, 0);
@@ -168,7 +168,7 @@ static uintptr_t mcall_dev_req(sbi_device_message *m)
{
//printm("req %d %p\n", HLS()->device_request_queue_size, m);
if (!supervisor_paddr_valid(m, sizeof(*m))
- && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_PRV1) != PRV_M)
+ && EXTRACT_FIELD(read_csr(mstatus), MSTATUS_MPP) != PRV_M)
return -EFAULT;
if ((m->dev > 0xFFU) | (m->cmd > 0xFFU) | (m->data > 0x0000FFFFFFFFFFFFU))
@@ -186,7 +186,7 @@ static uintptr_t mcall_dev_req(sbi_device_message *m)
static uintptr_t mcall_dev_resp()
{
- htif_interrupt(0, 0);
+ htif_interrupt();
sbi_device_message* m = HLS()->device_response_queue_head;
if (m) {
@@ -249,7 +249,7 @@ static uintptr_t mcall_set_timer(unsigned long long when)
return 0;
}
-uintptr_t mcall_trap(uintptr_t mcause, uintptr_t* regs)
+void mcall_trap(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc)
{
uintptr_t n = regs[17], arg0 = regs[10], retval;
switch (n)
@@ -283,25 +283,22 @@ uintptr_t mcall_trap(uintptr_t mcause, uintptr_t* regs)
break;
}
regs[10] = retval;
- write_csr(mepc, read_csr(mepc) + 4);
- return 0;
+ write_csr(mepc, mepc + 4);
}
-static uintptr_t machine_page_fault(uintptr_t mcause, uintptr_t* regs, uintptr_t mepc)
+static void machine_page_fault(uintptr_t* regs, uintptr_t mepc)
{
// See if this trap occurred when emulating an instruction on behalf of
// a lower privilege level.
extern int32_t unprivileged_access_ranges[];
extern int32_t unprivileged_access_ranges_end[];
-
int32_t* p = unprivileged_access_ranges;
do {
if (mepc >= p[0] && mepc < p[1]) {
- // Yes. Skip to the end of the unprivileged access region.
- // Mark t0 zero so the emulation routine knows this occurred.
- regs[5] = 0;
- write_csr(mepc, p[1]);
- return 0;
+ // Yes. Redirect the trap to the supervisor.
+ write_csr(sbadaddr, read_csr(mbadaddr));
+ redirect_trap(regs[14], regs[5]);
+ return;
}
p += 2;
} while (p < unprivileged_access_ranges_end);
@@ -310,15 +307,9 @@ static uintptr_t machine_page_fault(uintptr_t mcause, uintptr_t* regs, uintptr_t
bad_trap();
}
-static uintptr_t machine_illegal_instruction(uintptr_t mcause, uintptr_t* regs, uintptr_t mepc)
-{
- bad_trap();
-}
-
-uintptr_t trap_from_machine_mode(uintptr_t dummy, uintptr_t* regs)
+void trap_from_machine_mode(uintptr_t* regs, uintptr_t dummy, uintptr_t mepc)
{
uintptr_t mcause = read_csr(mcause);
- uintptr_t mepc = read_csr(mepc);
// restore mscratch, since we clobbered it.
write_csr(mscratch, MACHINE_STACK_TOP() - MENTRY_FRAME_SIZE);
@@ -326,11 +317,11 @@ uintptr_t trap_from_machine_mode(uintptr_t dummy, uintptr_t* regs)
{
case CAUSE_FAULT_LOAD:
case CAUSE_FAULT_STORE:
- return machine_page_fault(mcause, regs, mepc);
+ return machine_page_fault(regs, mepc);
case CAUSE_ILLEGAL_INSTRUCTION:
- return machine_illegal_instruction(mcause, regs, mepc);
+ return bad_trap();
case CAUSE_MACHINE_ECALL:
- return mcall_trap(mcause, regs);
+ return mcall_trap(regs, dummy, mepc);
default:
bad_trap();
}
diff --git a/pk/mtrap.h b/pk/mtrap.h
index 61624d3..c4de413 100644
--- a/pk/mtrap.h
+++ b/pk/mtrap.h
@@ -11,72 +11,34 @@
#define GET_MACRO(_1,_2,_3,_4,NAME,...) NAME
-#define unpriv_mem_access(a, b, c, d, ...) GET_MACRO(__VA_ARGS__, unpriv_mem_access3, unpriv_mem_access2, unpriv_mem_access1, unpriv_mem_access0)(a, b, c, d, __VA_ARGS__)
-#define unpriv_mem_access0(a, b, c, d, e) ({ uintptr_t z = 0, z1 = 0, z2 = 0; unpriv_mem_access_base(a, b, c, d, e, z, z1, z2); })
-#define unpriv_mem_access1(a, b, c, d, e, f) ({ uintptr_t z = 0, z1 = 0; unpriv_mem_access_base(a, b, c, d, e, f, z, z1); })
-#define unpriv_mem_access2(a, b, c, d, e, f, g) ({ uintptr_t z = 0; unpriv_mem_access_base(a, b, c, d, e, f, g, z); })
-#define unpriv_mem_access3(a, b, c, d, e, f, g, h) unpriv_mem_access_base(a, b, c, d, e, f, g, h)
-#define unpriv_mem_access_base(mstatus, mepc, code, o0, o1, i0, i1, i2) ({ \
- register uintptr_t result asm("t0"); \
- uintptr_t unused1, unused2 __attribute__((unused)); \
- uintptr_t scratch = MSTATUS_MPRV; \
- asm volatile ("csrrs %[result], mstatus, %[scratch]\n" \
+#define unpriv_mem_access(a, b, c, ...) GET_MACRO(__VA_ARGS__, unpriv_mem_access3, unpriv_mem_access2, unpriv_mem_access1, unpriv_mem_access0)(a, b, c, __VA_ARGS__)
+#define unpriv_mem_access0(a, b, c, d) ({ uintptr_t z = 0, z1 = 0, z2 = 0; unpriv_mem_access_base(a, b, c, d, z, z1, z2); })
+#define unpriv_mem_access1(a, b, c, d, e) ({ uintptr_t z = 0, z1 = 0; unpriv_mem_access_base(a, b, c, d, e, z, z1); })
+#define unpriv_mem_access2(a, b, c, d, e, f) ({ uintptr_t z = 0; unpriv_mem_access_base(a, b, c, d, e, f, z); })
+#define unpriv_mem_access3(a, b, c, d, e, f, g) unpriv_mem_access_base(a, b, c, d, e, f, g)
+#define unpriv_mem_access_base(code, o0, o1, o2, i0, i1, i2) ({ \
+ register uintptr_t scratch asm ("t0") = MSTATUS_MPRV; \
+ register uintptr_t __mepc asm ("a4") = mepc; \
+ uintptr_t unused1, unused2, unused3 __attribute__((unused)); \
+ asm volatile ("csrrs %[scratch], mstatus, %[scratch]\n" \
"98: " code "\n" \
- "99: csrc mstatus, %[scratch]\n" \
+ "99: csrw mstatus, %[scratch]\n" \
".pushsection .unpriv,\"a\",@progbits\n" \
".word 98b; .word 99b\n" \
".popsection" \
- : [o0] "=&r"(o0), [o1] "=&r"(o1), \
- [result] "+&r"(result) \
+ : [o0] "=&r"(o0), [o1] "=&r"(o1), [o2] "=&r"(o2), \
+ [scratch] "+&r"(scratch) \
: [i0] "rJ"(i0), [i1] "rJ"(i1), [i2] "rJ"(i2), \
- [scratch] "r"(scratch), [mepc] "r"(mepc)); \
- unlikely(!result); })
-
-#define restore_mstatus(mstatus, mepc) ({ \
- uintptr_t scratch; \
- uintptr_t mask = MSTATUS_PRV1 | MSTATUS_IE1 | MSTATUS_PRV2 | MSTATUS_IE2 | MSTATUS_PRV3 | MSTATUS_IE3; \
- asm volatile("csrc mstatus, %[mask];" \
- "csrw mepc, %[mepc];" \
- "and %[scratch], %[mask], %[mstatus];" \
- "csrs mstatus, %[scratch]" \
- : [scratch] "=r"(scratch) \
- : [mstatus] "r"(mstatus), [mepc] "r"(mepc), \
- [mask] "r"(mask)); })
-
-#define unpriv_load_1(ptr, dest, mstatus, mepc, type, insn) ({ \
- type value, dummy; void* addr = (ptr); \
- uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %[value], (%[addr])", value, dummy, addr); \
- (dest) = (typeof(dest))(uintptr_t)value; \
- res; })
-#define unpriv_load(ptr, dest) ({ \
- uintptr_t res; \
- uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \
- if (sizeof(*ptr) == 1) res = unpriv_load_1(ptr, dest, mstatus, mepc, int8_t, "lb"); \
- else if (sizeof(*ptr) == 2) res = unpriv_load_1(ptr, dest, mstatus, mepc, int16_t, "lh"); \
- else if (sizeof(*ptr) == 4) res = unpriv_load_1(ptr, dest, mstatus, mepc, int32_t, "lw"); \
- else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_load_1(ptr, dest, mstatus, mepc, int64_t, "ld"); \
- else __builtin_trap(); \
- if (res) restore_mstatus(mstatus, mepc); \
- res; })
-
-#define unpriv_store_1(ptr, src, mstatus, mepc, type, insn) ({ \
- type dummy1, dummy2, value = (type)(uintptr_t)(src); void* addr = (ptr); \
- uintptr_t res = unpriv_mem_access(mstatus, mepc, insn " %z[value], (%[addr])", dummy1, dummy2, addr, value); \
- res; })
-#define unpriv_store(ptr, src) ({ \
- uintptr_t res; \
- uintptr_t mstatus = read_csr(mstatus), mepc = read_csr(mepc); \
- if (sizeof(*ptr) == 1) res = unpriv_store_1(ptr, src, mstatus, mepc, int8_t, "sb"); \
- else if (sizeof(*ptr) == 2) res = unpriv_store_1(ptr, src, mstatus, mepc, int16_t, "sh"); \
- else if (sizeof(*ptr) == 4) res = unpriv_store_1(ptr, src, mstatus, mepc, int32_t, "sw"); \
- else if (sizeof(uintptr_t) == 8 && sizeof(*ptr) == 8) res = unpriv_store_1(ptr, src, mstatus, mepc, int64_t, "sd"); \
- else __builtin_trap(); \
- if (res) restore_mstatus(mstatus, mepc); \
- res; })
+ "r"(__mepc)); \
+})
typedef uint32_t insn_t;
-typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, uintptr_t);
-#define DECLARE_EMULATION_FUNC(name) uintptr_t name(uintptr_t mcause, uintptr_t* regs, insn_t insn, uintptr_t mstatus, uintptr_t mepc)
+typedef void (*emulation_func)(uintptr_t*, uintptr_t, uintptr_t, uintptr_t, insn_t);
+#define DECLARE_EMULATION_FUNC(name) void name(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn)
+
+void truly_illegal_insn(uintptr_t* regs, uintptr_t mcause, uintptr_t mepc, uintptr_t mstatus, insn_t insn);
+void redirect_trap(uintptr_t epc, uintptr_t mstatus);
+void leave();
#define GET_REG(insn, pos, regs) ({ \
int mask = (1 << (5+LOG_REGBYTES)) - (1 << LOG_REGBYTES); \
@@ -127,7 +89,7 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui
# define SETUP_STATIC_ROUNDING(insn) ({ \
register long tp asm("tp") = read_csr(frm); \
if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) ; \
- else if (GET_RM(insn) > 4) return -1; \
+ else if (GET_RM(insn) > 4) return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); \
else tp = GET_RM(insn); \
asm volatile ("":"+r"(tp)); })
# define softfloat_raiseFlags(which) set_csr(fflags, which)
@@ -147,7 +109,7 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui
# define SETUP_STATIC_ROUNDING(insn) ({ \
register int tp asm("tp"); tp &= 0xFF; \
if (likely(((insn) & MASK_FUNCT3) == MASK_FUNCT3)) tp |= tp << 8; \
- else if (GET_RM(insn) > 4) return -1; \
+ else if (GET_RM(insn) > 4) return truly_illegal_insn(regs, mcause, mepc, mstatus, insn); \
else tp |= GET_RM(insn) << 13; \
asm volatile ("":"+r"(tp)); })
# define softfloat_raiseFlags(which) ({ asm volatile ("or tp, tp, %0" :: "rI"(which)); })
@@ -164,42 +126,26 @@ typedef uintptr_t (*emulation_func)(uintptr_t, uintptr_t*, insn_t, uintptr_t, ui
#define SET_F64_RD(insn, regs, val) (SET_F64_REG(insn, 7, regs, val), SET_FS_DIRTY())
#define SET_FS_DIRTY() set_csr(mstatus, MSTATUS_FS)
-typedef struct {
- uintptr_t error;
- insn_t insn;
-} insn_fetch_t;
-
-static insn_fetch_t __attribute__((always_inline))
- get_insn(uintptr_t mcause, uintptr_t mstatus, uintptr_t mepc)
+static insn_t __attribute__((always_inline)) get_insn(uintptr_t mepc)
{
- insn_fetch_t fetch;
insn_t insn;
#ifdef __riscv_compressed
int rvc_mask = 3, insn_hi;
- fetch.error = unpriv_mem_access(mstatus, mepc,
- "lhu %[insn], 0(%[mepc]);"
- "and %[insn_hi], %[insn], %[rvc_mask];"
- "bne %[insn_hi], %[rvc_mask], 1f;"
- "lh %[insn_hi], 2(%[mepc]);"
- "sll %[insn_hi], %[insn_hi], 16;"
- "or %[insn], %[insn], %[insn_hi];"
- "1:",
- insn, insn_hi, rvc_mask);
+ unpriv_mem_access("lhu %[insn], 0(%[mepc]);"
+ "and %[insn_hi], %[insn], %[rvc_mask];"
+ "bne %[insn_hi], %[rvc_mask], 1f;"
+ "lh %[insn_hi], 2(%[mepc]);"
+ "sll %[insn_hi], %[insn_hi], 16;"
+ "or %[insn], %[insn], %[insn_hi];"
+ "1:",
+ insn, insn_hi, unused1, mepc, rvc_mask);
#else
- fetch.error = unpriv_mem_access(mstatus, mepc,
- "lw %[insn], 0(%[mepc])",
- insn, unused1);
+ unpriv_mem_access("lw %[insn], 0(%[mepc])",
+ insn, unused1, unused2, mepc);
#endif
- fetch.insn = insn;
-
- if (unlikely(fetch.error)) {
- // we've messed up mstatus, mepc, and mcause, so restore them all
- restore_mstatus(mstatus, mepc);
- write_csr(mcause, mcause);
- }
- return fetch;
+ return insn;
}
static inline long __attribute__((pure)) cpuid()
diff --git a/pk/pk.c b/pk/pk.c
index a257bf5..d3f6b3e 100644
--- a/pk/pk.c
+++ b/pk/pk.c
@@ -12,8 +12,10 @@ void run_loaded_program(struct mainvars* args)
extern char trap_entry;
write_csr(stvec, &trap_entry);
write_csr(sscratch, 0);
+ clear_csr(sie, SIP_STIP | SIP_SSIP);
// enter supervisor mode
+ prepare_supervisor_mode();
asm volatile("la t0, 1f; csrw mepc, t0; eret; 1:" ::: "t0");
// copy phdrs to user stack
diff --git a/pk/pk.h b/pk/pk.h
index e023242..c76825b 100644
--- a/pk/pk.h
+++ b/pk/pk.h
@@ -67,6 +67,7 @@ void handle_misaligned_load(trapframe_t*);
void handle_misaligned_store(trapframe_t*);
void handle_fault_load(trapframe_t*);
void handle_fault_store(trapframe_t*);
+void prepare_supervisor_mode();
void boot_loader(struct mainvars*);
void run_loaded_program(struct mainvars*);
void boot_other_hart();