From 1c7584bb501bb6d4cbc3b95cb22e008220fb537a Mon Sep 17 00:00:00 2001 From: Ben Marshall Date: Mon, 18 Oct 2021 10:16:43 +0100 Subject: scalar-crypto: Initial commit of 1.0.0-rc2 spec work. (#99) Merged scalar-crypto pull request #99 of 1.0.0-rc2 spec work from Ben Marshall. See https://github.com/riscv/sail-riscv/pull/99. --- Makefile | 5 +- c_emulator/riscv_platform.c | 4 + c_emulator/riscv_platform.h | 3 + c_emulator/riscv_platform_impl.c | 14 + c_emulator/riscv_platform_impl.h | 3 + handwritten_support/0.11/riscv_extras.lem | 4 + handwritten_support/riscv_extras.lem | 4 + model/prelude.sail | 22 ++ model/riscv_csr_map.sail | 4 +- model/riscv_insts_zicsr.sail | 6 + model/riscv_insts_zkn.sail | 408 ++++++++++++++++++++++++++++++ model/riscv_insts_zks.sail | 78 ++++++ model/riscv_sys_control.sail | 24 ++ model/riscv_sys_regs.sail | 63 +++++ model/riscv_types_kext.sail | 352 ++++++++++++++++++++++++++ ocaml_emulator/platform.ml | 5 + 16 files changed, 997 insertions(+), 2 deletions(-) create mode 100644 model/riscv_insts_zkn.sail create mode 100644 model/riscv_insts_zks.sail create mode 100644 model/riscv_types_kext.sail diff --git a/Makefile b/Makefile index 5d9a461..f7a4a3e 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,8 @@ SAIL_DEFAULT_INST += riscv_insts_fext.sail riscv_insts_cfext.sail ifeq ($(ARCH),RV64) SAIL_DEFAULT_INST += riscv_insts_dext.sail riscv_insts_cdext.sail endif +SAIL_DEFAULT_INST += riscv_insts_zkn.sail +SAIL_DEFAULT_INST += riscv_insts_zks.sail SAIL_SEQ_INST = $(SAIL_DEFAULT_INST) riscv_jalr_seq.sail SAIL_RMEM_INST = $(SAIL_DEFAULT_INST) riscv_jalr_rmem.sail riscv_insts_rmem.sail @@ -63,7 +65,8 @@ SAIL_ARCH_SRCS = $(PRELUDE) SAIL_ARCH_SRCS += riscv_types_common.sail riscv_types_ext.sail riscv_types.sail SAIL_ARCH_SRCS += riscv_vmem_types.sail $(SAIL_REGS_SRCS) $(SAIL_SYS_SRCS) riscv_platform.sail SAIL_ARCH_SRCS += riscv_mem.sail $(SAIL_VM_SRCS) -SAIL_ARCH_RVFI_SRCS = $(PRELUDE) rvfi_dii.sail riscv_types_common.sail riscv_types_ext.sail riscv_types.sail riscv_vmem_types.sail $(SAIL_REGS_SRCS) $(SAIL_SYS_SRCS) riscv_platform.sail riscv_mem.sail $(SAIL_VM_SRCS) +SAIL_ARCH_RVFI_SRCS = $(PRELUDE) rvfi_dii.sail riscv_types_common.sail riscv_types_ext.sail riscv_types.sail riscv_vmem_types.sail $(SAIL_REGS_SRCS) $(SAIL_SYS_SRCS) riscv_platform.sail riscv_mem.sail $(SAIL_VM_SRCS) riscv_types_kext.sail +SAIL_ARCH_SRCS += riscv_types_kext.sail # Shared/common code for the cryptography extension. SAIL_STEP_SRCS = riscv_step_common.sail riscv_step_ext.sail riscv_decode_ext.sail riscv_fetch.sail riscv_step.sail RVFI_STEP_SRCS = riscv_step_common.sail riscv_step_rvfi.sail riscv_decode_ext.sail riscv_fetch_rvfi.sail riscv_step.sail diff --git a/c_emulator/riscv_platform.c b/c_emulator/riscv_platform.c index 6529355..a528bee 100644 --- a/c_emulator/riscv_platform.c +++ b/c_emulator/riscv_platform.c @@ -45,6 +45,10 @@ mach_bits plat_rom_base(unit u) mach_bits plat_rom_size(unit u) { return rv_rom_size; } +// Provides entropy for the scalar cryptography extension. +mach_bits plat_get_16_random_bits() +{ return rv_16_random_bits(); } + mach_bits plat_clint_base(unit u) { return rv_clint_base; } diff --git a/c_emulator/riscv_platform.h b/c_emulator/riscv_platform.h index 464f6d0..f2a1c70 100644 --- a/c_emulator/riscv_platform.h +++ b/c_emulator/riscv_platform.h @@ -18,6 +18,9 @@ bool within_phys_mem(mach_bits, sail_int); mach_bits plat_rom_base(unit); mach_bits plat_rom_size(unit); +// Provides entropy for the scalar cryptography extension. +mach_bits plat_get_16_random_bits(); + mach_bits plat_clint_base(unit); mach_bits plat_clint_size(unit); diff --git a/c_emulator/riscv_platform_impl.c b/c_emulator/riscv_platform_impl.c index e43ba27..946b2ba 100644 --- a/c_emulator/riscv_platform_impl.c +++ b/c_emulator/riscv_platform_impl.c @@ -19,6 +19,20 @@ uint64_t rv_ram_size = UINT64_C(0x4000000); uint64_t rv_rom_base = UINT64_C(0x1000); uint64_t rv_rom_size = UINT64_C(0x100); +// Provides entropy for the scalar cryptography extension. +uint64_t rv_16_random_bits(void) { + // This function can be changed to support deterministic sequences of + // pseudo-random bytes. This is useful for testing. + const char *name = "/dev/urandom"; + FILE *f = fopen(name, "rb"); + uint16_t val; + if (fread(&val, 2, 1, f) != 1) { + fprintf(stderr, "Unable to read 2 bytes from %s\n", name); + } + fclose(f); + return (uint64_t)val; +} + uint64_t rv_clint_base = UINT64_C(0x2000000); uint64_t rv_clint_size = UINT64_C(0xc0000); diff --git a/c_emulator/riscv_platform_impl.h b/c_emulator/riscv_platform_impl.h index 0e1dadd..094cf3e 100644 --- a/c_emulator/riscv_platform_impl.h +++ b/c_emulator/riscv_platform_impl.h @@ -22,6 +22,9 @@ extern uint64_t rv_ram_size; extern uint64_t rv_rom_base; extern uint64_t rv_rom_size; +// Provides entropy for the scalar cryptography extension. +extern uint64_t rv_16_random_bits(void); + extern uint64_t rv_clint_base; extern uint64_t rv_clint_size; diff --git a/handwritten_support/0.11/riscv_extras.lem b/handwritten_support/0.11/riscv_extras.lem index db93001..2182940 100644 --- a/handwritten_support/0.11/riscv_extras.lem +++ b/handwritten_support/0.11/riscv_extras.lem @@ -143,6 +143,10 @@ val plat_term_read : forall 'a. Size 'a => unit -> bitvector 'a let plat_term_read () = wordFromInteger 0 declare ocaml target_rep function plat_term_read = `Platform.term_read` +val plat_get_16_random_bits : forall 'a. Size 'a => unit -> bitvector 'a +let plat_get_16_random_bits () = wordFromInteger 0 +declare ocaml target_rep function plat_get_16_random_bits = `Platform.get_16_random_bits` + val shift_bits_right : forall 'a 'b. Size 'a, Size 'b => bitvector 'a -> bitvector 'b -> bitvector 'a let shift_bits_right v m = shiftr v (uint m) val shift_bits_left : forall 'a 'b. Size 'a, Size 'b => bitvector 'a -> bitvector 'b -> bitvector 'a diff --git a/handwritten_support/riscv_extras.lem b/handwritten_support/riscv_extras.lem index 2431bb3..b0737e5 100644 --- a/handwritten_support/riscv_extras.lem +++ b/handwritten_support/riscv_extras.lem @@ -211,6 +211,10 @@ val plat_term_read : forall 'a. Size 'a => unit -> bitvector 'a let plat_term_read () = wordFromInteger 0 declare ocaml target_rep function plat_term_read = `Platform.term_read` +val plat_get_16_random_bits : forall 'a. Size 'a => unit -> bitvector 'a +let plat_get_16_random_bits () = wordFromInteger 0 +declare ocaml target_rep function plat_get_16_random_bits = `Platform.get_16_random_bits` + val shift_bits_right : forall 'a 'b. Size 'a, Size 'b => bitvector 'a -> bitvector 'b -> bitvector 'a let shift_bits_right v m = shiftr v (uint m) val shift_bits_left : forall 'a 'b. Size 'a, Size 'b => bitvector 'a -> bitvector 'b -> bitvector 'a diff --git a/model/prelude.sail b/model/prelude.sail index 17c2fa1..21c6793 100644 --- a/model/prelude.sail +++ b/model/prelude.sail @@ -248,6 +248,28 @@ function shift_right_arith32 (v : bits(32), shift : bits(5)) -> bits(32) = let v64 : bits(64) = EXTS(v) in (v64 >> shift)[31..0] +infix 7 >>> +infix 7 <<< + +val rotate_bits_right : forall 'n 'm, 'm >= 0. (bits('n), bits('m)) -> bits('n) +function rotate_bits_right (v, n) = + (v >> n) | (v << (to_bits(length(n), length(v)) - n)) + +val rotate_bits_left : forall 'n 'm, 'm >= 0. (bits('n), bits('m)) -> bits('n) +function rotate_bits_left (v, n) = + (v << n) | (v >> (to_bits(length(n), length(v)) - n)) + +val rotater : forall 'm 'n, 'm >= 'n >= 0. (bits('m), atom('n)) -> bits('m) +function rotater (v, n) = + (v >> n) | (v << (length(v) - n)) + +val rotatel : forall 'm 'n, 'm >= 'n >= 0. (bits('m), atom('n)) -> bits('m) +function rotatel (v, n) = + (v << n) | (v >> (length(v) - n)) + +overload operator >>> = {rotate_bits_right, rotater} +overload operator <<< = {rotate_bits_left, rotatel} + /* helpers for mappings */ val spc : unit <-> string diff --git a/model/riscv_csr_map.sail b/model/riscv_csr_map.sail index 656f80c..87211ea 100644 --- a/model/riscv_csr_map.sail +++ b/model/riscv_csr_map.sail @@ -86,6 +86,8 @@ mapping clause csr_name_map = 0x044 <-> "uip" mapping clause csr_name_map = 0x001 <-> "fflags" mapping clause csr_name_map = 0x002 <-> "frm" mapping clause csr_name_map = 0x003 <-> "fcsr" +/* user entropy source */ +mapping clause csr_name_map = 0x015 <-> "seed" /* counter/timers */ mapping clause csr_name_map = 0xC00 <-> "cycle" mapping clause csr_name_map = 0xC01 <-> "time" @@ -185,4 +187,4 @@ scattered function ext_read_CSR /* returns new value (after legalisation) if the CSR is defined */ val ext_write_CSR : (csreg, xlenbits) -> option(xlenbits) effect {rreg, wreg} -scattered function ext_write_CSR \ No newline at end of file +scattered function ext_write_CSR diff --git a/model/riscv_insts_zicsr.sail b/model/riscv_insts_zicsr.sail index fc2cda8..518396f 100644 --- a/model/riscv_insts_zicsr.sail +++ b/model/riscv_insts_zicsr.sail @@ -158,6 +158,9 @@ function readCSR csr : csreg -> xlenbits = { (0xC81, 32) => mtime[63 .. 32], (0xC82, 32) => minstret[63 .. 32], + /* user mode: Zkr */ + (0x015, _) => read_seed_csr(), + _ => /* check extensions */ match ext_read_CSR(csr) { Some(res) => res, @@ -235,6 +238,9 @@ function writeCSR (csr : csreg, value : xlenbits) -> unit = { (0x144, _) => { mip = legalize_sip(mip, mideleg, value); Some(mip.bits()) }, (0x180, _) => { satp = legalize_satp(cur_Architecture(), satp, value); Some(satp) }, + /* user mode: seed (entropy source). writes are ignored */ + (0x015, _) => write_seed_csr(), + _ => ext_write_CSR(csr, value) }; match res { diff --git a/model/riscv_insts_zkn.sail b/model/riscv_insts_zkn.sail new file mode 100644 index 0000000..c684ec2 --- /dev/null +++ b/model/riscv_insts_zkn.sail @@ -0,0 +1,408 @@ +/* + * Scalar Cryptography Extension - Scalar SHA256 instructions (RV32/RV64) + * ---------------------------------------------------------------------- + */ + +union clause ast = SHA256SIG0 : (regidx, regidx) +union clause ast = SHA256SIG1 : (regidx, regidx) +union clause ast = SHA256SUM0 : (regidx, regidx) +union clause ast = SHA256SUM1 : (regidx, regidx) + +mapping clause encdec = SHA256SUM0 (rs1, rd) if haveZknh() + <-> 0b00 @ 0b01000 @ 0b00000 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA256SUM1 (rs1, rd) if haveZknh() + <-> 0b00 @ 0b01000 @ 0b00001 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA256SIG0 (rs1, rd) if haveZknh() + <-> 0b00 @ 0b01000 @ 0b00010 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA256SIG1 (rs1, rd) if haveZknh() + <-> 0b00 @ 0b01000 @ 0b00011 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause assembly = SHA256SIG0 (rs1, rd) + <-> "sha256sig0" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA256SIG1 (rs1, rd) + <-> "sha256sig1" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA256SUM0 (rs1, rd) + <-> "sha256sum0" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA256SUM1 (rs1, rd) + <-> "sha256sum1" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +function clause execute (SHA256SIG0(rs1, rd)) = { + let inb : bits(32) = X(rs1)[31..0]; + let result : bits(32) = (inb >>> 7) ^ (inb >>> 18) ^ (inb >> 3); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +function clause execute (SHA256SIG1(rs1, rd)) = { + let inb : bits(32) = X(rs1)[31..0]; + let result : bits(32) = (inb >>> 17) ^ (inb >>> 19) ^ (inb >> 10); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +function clause execute (SHA256SUM0(rs1, rd)) = { + let inb : bits(32) = X(rs1)[31..0]; + let result : bits(32) = (inb >>> 2) ^ (inb >>> 13) ^ (inb >>> 22); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +function clause execute (SHA256SUM1(rs1, rd)) = { + let inb : bits(32) = X(rs1)[31..0]; + let result : bits(32) = (inb >>> 6) ^ (inb >>> 11) ^ (inb >>> 25); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar 32-bit AES instructions (encrypt) + * ---------------------------------------------------------------------- + */ + +union clause ast = AES32ESMI : (bits(2), regidx, regidx, regidx) + +mapping clause encdec = AES32ESMI (bs, rs2, rs1, rd) if haveZkne() & sizeof(xlen) == 32 + <-> bs @ 0b10011 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = AES32ESMI (bs, rs2, rs1, rd) <-> + "aes32esmi" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +function clause execute (AES32ESMI (bs, rs2, rs1, rd)) = { + let shamt : bits( 5) = bs @ 0b000; /* shamt = bs*8 */ + let si : bits( 8) = (X(rs2) >> shamt)[7..0]; /* SBox Input */ + let so : bits( 8) = aes_sbox_fwd(si); + let mixed : bits(32) = aes_mixcolumn_byte_fwd(so); + let result : bits(32) = X(rs1)[31..0] ^ (mixed <<< shamt); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +union clause ast = AES32ESI : (bits(2), regidx, regidx, regidx) + +mapping clause encdec = AES32ESI (bs, rs2, rs1, rd) if haveZkne() & sizeof(xlen) == 32 + <-> bs @ 0b10001 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = AES32ESI (bs, rs2, rs1, rd) <-> + "aes32esi" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +function clause execute (AES32ESI (bs, rs2, rs1, rd)) = { + let shamt : bits( 5) = bs @ 0b000; /* shamt = bs*8 */ + let si : bits( 8) = (X(rs2) >> shamt)[7..0]; /* SBox Input */ + let so : bits(32) = 0x000000 @ aes_sbox_fwd(si); + let result : bits(32) = X(rs1)[31..0] ^ (so <<< shamt); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar 32-bit AES instructions (decrypt) + * ---------------------------------------------------------------------- + */ + +union clause ast = AES32DSMI : (bits(2), regidx, regidx, regidx) + +mapping clause encdec = AES32DSMI (bs, rs2, rs1, rd) if haveZknd() & sizeof(xlen) == 32 + <-> bs @ 0b10111 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = AES32DSMI (bs, rs2, rs1, rd) <-> + "aes32dsmi" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +function clause execute (AES32DSMI (bs, rs2, rs1, rd)) = { + let shamt : bits( 5) = bs @ 0b000; /* shamt = bs*8 */ + let si : bits( 8) = (X(rs2) >> shamt)[7..0]; /* SBox Input */ + let so : bits( 8) = aes_sbox_inv(si); + let mixed : bits(32) = aes_mixcolumn_byte_inv(so); + let result : bits(32) = X(rs1)[31..0] ^ (mixed <<< shamt); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +union clause ast = AES32DSI : (bits(2), regidx, regidx, regidx) + +mapping clause encdec = AES32DSI (bs, rs2, rs1, rd) if haveZknd() & sizeof(xlen) == 32 + <-> bs @ 0b10101 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = AES32DSI (bs, rs2, rs1, rd) <-> + "aes32dsi" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +function clause execute (AES32DSI (bs, rs2, rs1, rd)) = { + let shamt : bits( 5) = bs @ 0b000; /* shamt = bs*8 */ + let si : bits( 8) = (X(rs2) >> shamt)[7..0]; /* SBox Input */ + let so : bits(32) = 0x000000 @ aes_sbox_inv(si); + let result : bits(32) = X(rs1)[31..0] ^ (so <<< shamt); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar 32-bit SHA512 instructions + * ---------------------------------------------------------------------- + */ + +union clause ast = SHA512SIG0L : (regidx, regidx, regidx) +union clause ast = SHA512SIG0H : (regidx, regidx, regidx) +union clause ast = SHA512SIG1L : (regidx, regidx, regidx) +union clause ast = SHA512SIG1H : (regidx, regidx, regidx) +union clause ast = SHA512SUM0R : (regidx, regidx, regidx) +union clause ast = SHA512SUM1R : (regidx, regidx, regidx) + +mapping clause encdec = SHA512SUM0R (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01000 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SHA512SUM1R (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01001 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SHA512SIG0L (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01010 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SHA512SIG0H (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01110 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SHA512SIG1L (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01011 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SHA512SIG1H (rs2, rs1, rd) if haveZknh() & sizeof(xlen)==32 + <-> 0b01 @ 0b01111 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = SHA512SIG0L (rs2, rs1, rd) + <-> "sha512sig0l" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = SHA512SIG0H (rs2, rs1, rd) + <-> "sha512sig0h" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = SHA512SIG1L (rs2, rs1, rd) + <-> "sha512sig1l" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = SHA512SIG1H (rs2, rs1, rd) + <-> "sha512sig1h" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = SHA512SUM0R (rs2, rs1, rd) + <-> "sha512sum0r" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = SHA512SUM1R (rs2, rs1, rd) + <-> "sha512sum1r" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +function clause execute (SHA512SIG0H(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) >> 1) ^ (X(rs1) >> 7) ^ (X(rs1) >> 8) ^ + (X(rs2) << 31) ^ (X(rs2) << 24) ); + RETIRE_SUCCESS +} + +function clause execute (SHA512SIG0L(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) >> 1) ^ (X(rs1) >> 7) ^ (X(rs1) >> 8) ^ + (X(rs2) << 31) ^ (X(rs2) << 25) ^ (X(rs2) << 24) ); + RETIRE_SUCCESS +} + +function clause execute (SHA512SIG1H(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) << 3) ^ (X(rs1) >> 6) ^ (X(rs1) >> 19) ^ + (X(rs2) >> 29) ^ (X(rs2) << 13) ); + RETIRE_SUCCESS +} + +function clause execute (SHA512SIG1L(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) << 3) ^ (X(rs1) >> 6) ^ (X(rs1) >> 19) ^ + (X(rs2) >> 29) ^ (X(rs2) << 26) ^ (X(rs2) << 13) ); + RETIRE_SUCCESS +} + +function clause execute (SHA512SUM0R(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) << 25) ^ (X(rs1) << 30) ^ (X(rs1) >> 28) ^ + (X(rs2) >> 7) ^ (X(rs2) >> 2) ^ (X(rs2) << 4) ); + RETIRE_SUCCESS +} + +function clause execute (SHA512SUM1R(rs2, rs1, rd)) = { + X(rd) = EXTS((X(rs1) << 23) ^ (X(rs1) >> 14) ^ (X(rs1) >> 18) ^ + (X(rs2) >> 9) ^ (X(rs2) << 18) ^ (X(rs2) << 14) ); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar 64-bit AES instructions + * ---------------------------------------------------------------------- + */ + +union clause ast = AES64KS1I : (bits(4), regidx, regidx) +union clause ast = AES64KS2 : (regidx, regidx, regidx) +union clause ast = AES64IM : (regidx, regidx) +union clause ast = AES64ESM : (regidx, regidx, regidx) +union clause ast = AES64ES : (regidx, regidx, regidx) +union clause ast = AES64DSM : (regidx, regidx, regidx) +union clause ast = AES64DS : (regidx, regidx, regidx) + +mapping clause encdec = AES64KS1I (rcon, rs1, rd) if (haveZkne() | haveZknd()) & (sizeof(xlen) == 64) & (rcon <_u 0xB) + <-> 0b00 @ 0b11000 @ 0b1 @ rcon @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = AES64IM (rs1, rd) if haveZknd() & sizeof(xlen) == 64 + <-> 0b00 @ 0b11000 @ 0b00000 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = AES64KS2 (rs2, rs1, rd) if (haveZkne() | haveZknd()) & sizeof(xlen) == 64 + <-> 0b01 @ 0b11111 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = AES64ESM (rs2, rs1, rd) if haveZkne() & sizeof(xlen) == 64 + <-> 0b00 @ 0b11011 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = AES64ES (rs2, rs1, rd) if haveZkne() & sizeof(xlen) == 64 + <-> 0b00 @ 0b11001 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = AES64DSM (rs2, rs1, rd) if haveZknd() & sizeof(xlen) == 64 + <-> 0b00 @ 0b11111 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = AES64DS (rs2, rs1, rd) if haveZknd() & sizeof(xlen) == 64 + <-> 0b00 @ 0b11101 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = AES64KS1I (rcon, rs1, rd) + <-> "aes64ks1i" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ hex_bits_4(rcon) + +mapping clause assembly = AES64KS2 (rs2, rs1, rd) + <-> "aes64ks2" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = AES64IM (rs1, rd) + <-> "aes64im" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = AES64ESM (rs2, rs1, rd) + <-> "aes64esm" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = AES64ES (rs2, rs1, rd) + <-> "aes64es" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = AES64DSM (rs2, rs1, rd) + <-> "aes64dsm" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +mapping clause assembly = AES64DS (rs2, rs1, rd) + <-> "aes64ds" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) + +function clause execute (AES64KS1I(rcon, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let rs1_hi : bits(32) = X(rs1)[63..32]; + let rc : bits(32) = aes_decode_rcon(rcon); + let rotated : bits(32) = if (rcon == 0xA) then rs1_hi else (rs1_hi >>> 8); + let post_sb : bits(32) = aes_subword_fwd(rotated); + X(rd) = (post_sb ^ rc) @ (post_sb ^ rc); + RETIRE_SUCCESS +} + +function clause execute (AES64KS2(rs2, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let w0 : bits(32) = X(rs1)[63..32] ^ X(rs2)[31..0]; + let w1 : bits(32) = X(rs1)[63..32] ^ X(rs2)[31..0] ^ X(rs2)[63..32]; + X(rd) = w1 @ w0; + RETIRE_SUCCESS +} + +function clause execute (AES64IM(rs1, rd)) = { + assert(sizeof(xlen) == 64); + let w0 : bits(32) = aes_mixcolumn_inv(X(rs1)[31.. 0]); + let w1 : bits(32) = aes_mixcolumn_inv(X(rs1)[63..32]); + X(rd) = w1 @ w0; + RETIRE_SUCCESS +} + +function clause execute (AES64ESM(rs2, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let sr : bits(64) = aes_rv64_shiftrows_fwd(X(rs2), X(rs1)); + let wd : bits(64) = sr[63..0]; + let sb : bits(64) = aes_apply_fwd_sbox_to_each_byte(wd); + X(rd) = aes_mixcolumn_fwd(sb[63..32]) @ aes_mixcolumn_fwd(sb[31..0]); + RETIRE_SUCCESS +} + +function clause execute (AES64ES(rs2, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let sr : bits(64) = aes_rv64_shiftrows_fwd(X(rs2), X(rs1)); + let wd : bits(64) = sr[63..0]; + X(rd) = aes_apply_fwd_sbox_to_each_byte(wd); + RETIRE_SUCCESS +} + +function clause execute (AES64DSM(rs2, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let sr : bits(64) = aes_rv64_shiftrows_inv(X(rs2), X(rs1)); + let wd : bits(64) = sr[63..0]; + let sb : bits(64) = aes_apply_inv_sbox_to_each_byte(wd); + X(rd) = aes_mixcolumn_inv(sb[63..32]) @ aes_mixcolumn_inv(sb[31..0]); + RETIRE_SUCCESS +} + +function clause execute (AES64DS(rs2, rs1, rd)) = { + assert(sizeof(xlen) == 64); + let sr : bits(64) = aes_rv64_shiftrows_inv(X(rs2), X(rs1)); + let wd : bits(64) = sr[63..0]; + X(rd) = aes_apply_inv_sbox_to_each_byte(wd); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar 64-bit SHA512 instructions + * ---------------------------------------------------------------------- + */ + +union clause ast = SHA512SIG0 : (regidx, regidx) +union clause ast = SHA512SIG1 : (regidx, regidx) +union clause ast = SHA512SUM0 : (regidx, regidx) +union clause ast = SHA512SUM1 : (regidx, regidx) + +mapping clause encdec = SHA512SUM0 (rs1, rd) if haveZknh() & sizeof(xlen)==64 + <-> 0b00 @ 0b01000 @ 0b00100 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA512SUM1 (rs1, rd) if haveZknh() & sizeof(xlen)==64 + <-> 0b00 @ 0b01000 @ 0b00101 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA512SIG0 (rs1, rd) if haveZknh() & sizeof(xlen)==64 + <-> 0b00 @ 0b01000 @ 0b00110 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SHA512SIG1 (rs1, rd) if haveZknh() & sizeof(xlen)==64 + <-> 0b00 @ 0b01000 @ 0b00111 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause assembly = SHA512SIG0 (rs1, rd) + <-> "sha512sig0" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA512SIG1 (rs1, rd) + <-> "sha512sig1" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA512SUM0 (rs1, rd) + <-> "sha512sum0" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SHA512SUM1 (rs1, rd) + <-> "sha512sum1" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +/* Execute clauses for the 64-bit SHA512 instructions. */ + +function clause execute (SHA512SIG0(rs1, rd)) = { + assert(sizeof(xlen) == 64); + let input : bits(64) = X(rs1); + let result : bits(64) = (input >>> 1) ^ (input >>> 8) ^ (input >> 7); + X(rd) = result; + RETIRE_SUCCESS +} + +function clause execute (SHA512SIG1(rs1, rd)) = { + assert(sizeof(xlen) == 64); + let input : bits(64) = X(rs1); + let result : bits(64) = (input >>> 19) ^ (input >>> 61) ^ (input >> 6); + X(rd) = result; + RETIRE_SUCCESS +} + +function clause execute (SHA512SUM0(rs1, rd)) = { + assert(sizeof(xlen) == 64); + let input : bits(64) = X(rs1); + let result : bits(64) = (input >>> 28) ^ (input >>> 34) ^ (input >>> 39); + X(rd) = result; + RETIRE_SUCCESS +} + +function clause execute (SHA512SUM1(rs1, rd)) = { + assert(sizeof(xlen) == 64); + let input : bits(64) = X(rs1); + let result : bits(64) = (input >>> 14) ^ (input >>> 18) ^ (input >>> 41); + X(rd) = result; + RETIRE_SUCCESS +} diff --git a/model/riscv_insts_zks.sail b/model/riscv_insts_zks.sail new file mode 100644 index 0000000..153e1be --- /dev/null +++ b/model/riscv_insts_zks.sail @@ -0,0 +1,78 @@ +/* + * Scalar Cryptography Extension - Scalar SM3 instructions + * ---------------------------------------------------------------------- + */ + +union clause ast = SM3P0 : (regidx, regidx) +union clause ast = SM3P1 : (regidx, regidx) + +mapping clause encdec = SM3P0 (rs1, rd) if haveZksh() + <-> 0b00 @ 0b01000 @ 0b01000 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause encdec = SM3P1 (rs1, rd) if haveZksh() + <-> 0b00 @ 0b01000 @ 0b01001 @ rs1 @ 0b001 @ rd @ 0b0010011 + +mapping clause assembly = SM3P0 (rs1, rd) <-> + "sm3p0" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +mapping clause assembly = SM3P1 (rs1, rd) <-> + "sm3p1" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) + +function clause execute (SM3P0(rs1, rd)) = { + let r1 : bits(32) = X(rs1)[31..0]; + let result : bits(32) = r1 ^ (r1 <<< 9) ^ (r1 <<< 17); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +function clause execute (SM3P1(rs1, rd)) = { + let r1 : bits(32) = X(rs1)[31..0]; + let result : bits(32) = r1 ^ (r1 <<< 15) ^ (r1 <<< 23); + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +/* + * Scalar Cryptography Extension - Scalar SM4 instructions + * ---------------------------------------------------------------------- + */ + +union clause ast = SM4ED : (bits(2), regidx, regidx, regidx) +union clause ast = SM4KS : (bits(2), regidx, regidx, regidx) + +mapping clause encdec = SM4ED (bs, rs2, rs1, rd) if haveZksed() + <-> bs @ 0b11000 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause encdec = SM4KS (bs, rs2, rs1, rd) if haveZksed() + <-> bs @ 0b11010 @ rs2 @ rs1 @ 0b000 @ rd @ 0b0110011 + +mapping clause assembly = SM4ED (bs, rs2, rs1, rd) <-> + "sm4ed" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +mapping clause assembly = SM4KS (bs, rs2, rs1, rd) <-> + "sm4ks" ^ spc() ^ reg_name(rd) ^ sep() ^ reg_name(rs1) ^ sep() ^ reg_name(rs2) ^ sep() ^ hex_bits_2(bs) + +function clause execute (SM4ED (bs, rs2, rs1, rd)) = { + let shamt : bits(5) = bs @ 0b000; /* shamt = bs*8 */ + let sb_in : bits(8) = (X(rs2)[31..0] >> shamt)[7..0]; + let x : bits(32) = 0x000000 @ sm4_sbox(sb_in); + let y : bits(32) = x ^ (x << 8) ^ ( x << 2) ^ + (x << 18) ^ ((x & 0x0000003F) << 26) ^ + ((x & 0x000000C0) << 10); + let z : bits(32) = (y <<< shamt); + let result : bits(32) = z ^ X(rs1)[31..0]; + X(rd) = EXTS(result); + RETIRE_SUCCESS +} + +function clause execute (SM4KS (bs, rs2, rs1, rd)) = { + let shamt : bits(5) = (bs @ 0b000); /* shamt = bs*8 */ + let sb_in : bits(8) = (X(rs2)[31..0] >> shamt)[7..0]; + let x : bits(32) = 0x000000 @ sm4_sbox(sb_in); + let y : bits(32) = x ^ ((x & 0x00000007) << 29) ^ ((x & 0x000000FE) << 7) ^ + ((x & 0x00000001) << 23) ^ ((x & 0x000000F8) << 13) ; + let z : bits(32) = (y <<< shamt); + let result : bits(32) = z ^ X(rs1)[31..0]; + X(rd) = EXTS(result); + RETIRE_SUCCESS +} diff --git a/model/riscv_sys_control.sail b/model/riscv_sys_control.sail index 6a13fdf..a3859ff 100644 --- a/model/riscv_sys_control.sail +++ b/model/riscv_sys_control.sail @@ -157,6 +157,9 @@ function is_CSR_defined (csr : csreg, p : Privilege) -> bool = 0xC81 => haveUsrMode() & (sizeof(xlen) == 32), // timeh 0xC82 => haveUsrMode() & (sizeof(xlen) == 32), // instreth + /* user mode: Zkr */ + 0x015 => haveZkr(), + /* check extensions */ _ => ext_is_CSR_defined(csr, p) } @@ -185,11 +188,32 @@ function check_Counteren(csr : csreg, p : Privilege) -> bool = else true } + +/* Seed may only be accessed if we are doing a write, and access has been + * allowed in the current priv mode + */ +function check_seed_CSR (csr : csreg, p : Privilege, isWrite : bool) -> bool = { + if ~(csr == 0x015) then { + true + } else if ~(isWrite) then { + /* Read-only access to the seed CSR is not allowed */ + false + } else { + match (p) { + Machine => true, + Supervisor => false, /* TODO: base this on mseccfg */ + User => false, /* TODO: base this on mseccfg */ + _ => false + } + } +} + function check_CSR(csr : csreg, p : Privilege, isWrite : bool) -> bool = is_CSR_defined(csr, p) & check_CSR_access(csrAccess(csr), csrPriv(csr), p, isWrite) & check_TVM_SATP(csr, p) & check_Counteren(csr, p) + & check_seed_CSR(csr, p, isWrite) /* Reservation handling for LR/SC. * diff --git a/model/riscv_sys_regs.sail b/model/riscv_sys_regs.sail index 2083f03..696d154 100644 --- a/model/riscv_sys_regs.sail +++ b/model/riscv_sys_regs.sail @@ -180,6 +180,16 @@ function haveUsrMode() -> bool = misa.U() == 0b1 function haveNExt() -> bool = misa.N() == 0b1 /* see below for F and D extension tests */ +/* Cryptography extension support. Note these will need updating once */ +/* Sail can be dynamically configured with different extension support */ +/* and have dynamic changes of XLEN via S/UXL */ +function haveZkr() -> bool = true +function haveZksh() -> bool = true +function haveZksed() -> bool = true +function haveZknh() -> bool = true +function haveZkne() -> bool = true +function haveZknd() -> bool = true + bitfield Mstatush : bits(32) = { MBE : 5, SBE : 4 @@ -736,3 +746,56 @@ function legalize_satp32(a : Architecture, o : bits(32), v : bits(32)) -> bits(3 /* disabled trigger/debug module */ register tselect : xlenbits + +/* + * The seed CSR (entropy source) + * ------------------------------------------------------------ + */ + +/* Valid return states for reading the seed CSR. */ +enum seed_opst = { + BIST, // Built-in-self-test. No randomness sampled. + ES16, // Entropy-sample-16. Valid 16-bits of randomness sampled. + WAIT, // Device still gathering entropy. + DEAD // Fatal device compromise. No randomness sampled. +} + +/* Mapping of status codes and their actual encodings. */ +mapping opst_code : seed_opst <-> bits(2) = { + BIST <-> 0b00, + WAIT <-> 0b01, + ES16 <-> 0b10, + DEAD <-> 0b11 +} + +/* + * Entropy Source - Platform access to random bits. + * WARNING: This function currently lacks a proper side-effect annotation. + * If you are using theorem prover tool flows, you + * may need to modify or stub out this function for now. + * NOTE: This would be better placed in riscv_platform.sail, but that file + * appears _after_ this one in the compile order meaning the valspec + * for this function is unavailable when it's first encountered in + * read_seed_csr. Hence it appears here. + */ +val get_16_random_bits = { + ocaml: "Platform.get_16_random_bits", + interpreter: "Platform.get_16_random_bits", + c: "plat_get_16_random_bits", + lem: "plat_get_16_random_bits" +} : unit -> bits(16) + +/* Entropy source spec requires an Illegal opcode exception be raised if the + * seed register is read without also being written. This function is only + * called once we know the CSR is being written, and all other access control + * checks have been done. + */ +function read_seed_csr() -> xlenbits = { + let reserved_bits : bits(6) = 0b000000; + let custom_bits : bits(8) = 0x00; + let seed : bits(16) = get_16_random_bits(); + EXTZ(opst_code(ES16) @ reserved_bits @ custom_bits @ seed) +} + +/* Writes to the seed CSR are ignored */ +function write_seed_csr () -> option(xlenbits) = None() diff --git a/model/riscv_types_kext.sail b/model/riscv_types_kext.sail new file mode 100644 index 0000000..9c0db6d --- /dev/null +++ b/model/riscv_types_kext.sail @@ -0,0 +1,352 @@ +/* + * This file contains types, mappings and functions used across the + * cryptography extension instructions. + * + * This file must be included in the model build whatever the value of XLEN. + */ + +/* + * Cryptography extension shared / utility functions + * ---------------------------------------------------------------------- + */ + +/* Auxiliary function for performing GF multiplicaiton */ +val xt2 : bits(8) -> bits(8) +function xt2(x) = { + (x << 1) ^ (if bit_to_bool(x[7]) then 0x1b else 0x00) +} + +val xt3 : bits(8) -> bits(8) +function xt3(x) = x ^ xt2(x) + +/* Multiply 8-bit field element by 4-bit value for AES MixCols step */ +val gfmul : (bits(8), bits(4)) -> bits(8) +function gfmul( x, y) = { + (if bit_to_bool(y[0]) then x else 0x00) ^ + (if bit_to_bool(y[1]) then xt2( x) else 0x00) ^ + (if bit_to_bool(y[2]) then xt2(xt2( x)) else 0x00) ^ + (if bit_to_bool(y[3]) then xt2(xt2(xt2(x))) else 0x00) +} + +/* 8-bit to 32-bit partial AES Mix Colum - forwards */ +val aes_mixcolumn_byte_fwd : bits(8) -> bits(32) +function aes_mixcolumn_byte_fwd(so) = { + gfmul(so, 0x3) @ so @ so @ gfmul(so, 0x2) +} + +/* 8-bit to 32-bit partial AES Mix Colum - inverse*/ +val aes_mixcolumn_byte_inv : bits(8) -> bits(32) +function aes_mixcolumn_byte_inv(so) = { + gfmul(so, 0xb) @ gfmul(so, 0xd) @ gfmul(so, 0x9) @ gfmul(so, 0xe) +} + +/* 32-bit to 32-bit AES forward MixColumn */ +val aes_mixcolumn_fwd : bits(32) -> bits(32) +function aes_mixcolumn_fwd(x) = { + let s0 : bits (8) = x[ 7.. 0]; + let s1 : bits (8) = x[15.. 8]; + let s2 : bits (8) = x[23..16]; + let s3 : bits (8) = x[31..24]; + let b0 : bits (8) = xt2(s0) ^ xt3(s1) ^ (s2) ^ (s3); + let b1 : bits (8) = (s0) ^ xt2(s1) ^ xt3(s2) ^ (s3); + let b2 : bits (8) = (s0) ^ (s1) ^ xt2(s2) ^ xt3(s3); + let b3 : bits (8) = xt3(s0) ^ (s1) ^ (s2) ^ xt2(s3); + b3 @ b2 @ b1 @ b0 /* Return value */ +} + +/* 32-bit to 32-bit AES inverse MixColumn */ +val aes_mixcolumn_inv : bits(32) -> bits(32) +function aes_mixcolumn_inv(x) = { + let s0 : bits (8) = x[ 7.. 0]; + let s1 : bits (8) = x[15.. 8]; + let s2 : bits (8) = x[23..16]; + let s3 : bits (8) = x[31..24]; + let b0 : bits (8) = gfmul(s0, 0xE) ^ gfmul(s1, 0xB) ^ gfmul(s2, 0xD) ^ gfmul(s3, 0x9); + let b1 : bits (8) = gfmul(s0, 0x9) ^ gfmul(s1, 0xE) ^ gfmul(s2, 0xB) ^ gfmul(s3, 0xD); + let b2 : bits (8) = gfmul(s0, 0xD) ^ gfmul(s1, 0x9) ^ gfmul(s2, 0xE) ^ gfmul(s3, 0xB); + let b3 : bits (8) = gfmul(s0, 0xB) ^ gfmul(s1, 0xD) ^ gfmul(s2, 0x9) ^ gfmul(s3, 0xE); + b3 @ b2 @ b1 @ b0 /* Return value */ +} + +val aes_decode_rcon : bits(4) -> bits(32) +function aes_decode_rcon(r) = { + match r { + 0x0 => 0x00000001, + 0x1 => 0x00000002, + 0x2 => 0x00000004, + 0x3 => 0x00000008, + 0x4 => 0x00000010, + 0x5 => 0x00000020, + 0x6 => 0x00000040, + 0x7 => 0x00000080, + 0x8 => 0x0000001b, + 0x9 => 0x00000036, + 0xA => 0x00000000, + 0xB => 0x00000000, + 0xC => 0x00000000, + 0xD => 0x00000000, + 0xE => 0x00000000, + 0xF => 0x00000000 + } +} + +/* SM4 SBox - only one sbox for forwards and inverse */ +let sm4_sbox_table : list(bits(8)) = [| +0xD6, 0x90, 0xE9, 0xFE, 0xCC, 0xE1, 0x3D, 0xB7, 0x16, 0xB6, 0x14, 0xC2, 0x28, +0xFB, 0x2C, 0x05, 0x2B, 0x67, 0x9A, 0x76, 0x2A, 0xBE, 0x04, 0xC3, 0xAA, 0x44, +0x13, 0x26, 0x49, 0x86, 0x06, 0x99, 0x9C, 0x42, 0x50, 0xF4, 0x91, 0xEF, 0x98, +0x7A, 0x33, 0x54, 0x0B, 0x43, 0xED, 0xCF, 0xAC, 0x62, 0xE4, 0xB3, 0x1C, 0xA9, +0xC9, 0x08, 0xE8, 0x95, 0x80, 0xDF, 0x94, 0xFA, 0x75, 0x8F, 0x3F, 0xA6, 0x47, +0x07, 0xA7, 0xFC, 0xF3, 0x73, 0x17, 0xBA, 0x83, 0x59, 0x3C, 0x19, 0xE6, 0x85, +0x4F, 0xA8, 0x68, 0x6B, 0x81, 0xB2, 0x71, 0x64, 0xDA, 0x8B, 0xF8, 0xEB, 0x0F, +0x4B, 0x70, 0x56, 0x9D, 0x35, 0x1E, 0x24, 0x0E, 0x5E, 0x63, 0x58, 0xD1, 0xA2, +0x25, 0x22, 0x7C, 0x3B, 0x01, 0x21, 0x78, 0x87, 0xD4, 0x00, 0x46, 0x57, 0x9F, +0xD3, 0x27, 0x52, 0x4C, 0x36, 0x02, 0xE7, 0xA0, 0xC4, 0xC8, 0x9E, 0xEA, 0xBF, +0x8A, 0xD2, 0x40, 0xC7, 0x38, 0xB5, 0xA3, 0xF7, 0xF2, 0xCE, 0xF9, 0x61, 0x15, +0xA1, 0xE0, 0xAE, 0x5D, 0xA4, 0x9B, 0x34, 0x1A, 0x55, 0xAD, 0x93, 0x32, 0x30, +0xF5, 0x8C, 0xB1, 0xE3, 0x1D, 0xF6, 0xE2, 0x2E, 0x82, 0x66, 0xCA, 0x60, 0xC0, +0x29, 0x23, 0xAB, 0x0D, 0x53, 0x4E, 0x6F, 0xD5, 0xDB, 0x37, 0x45, 0xDE, 0xFD, +0x8E, 0x2F, 0x03, 0xFF, 0x6A, 0x72, 0x6D, 0x6C, 0x5B, 0x51, 0x8D, 0x1B, 0xAF, +0x92, 0xBB, 0xDD, 0xBC, 0x7F, 0x11, 0xD9, 0x5C, 0x41, 0x1F, 0x10, 0x5A, 0xD8, +0x0A, 0xC1, 0x31, 0x88, 0xA5, 0xCD, 0x7B, 0xBD, 0x2D, 0x74, 0xD0, 0x12, 0xB8, +0xE5, 0xB4, 0xB0, 0x89, 0x69, 0x97, 0x4A, 0x0C, 0x96, 0x77, 0x7E, 0x65, 0xB9, +0xF1, 0x09, 0xC5, 0x6E, 0xC6, 0x84, 0x18, 0xF0, 0x7D, 0xEC, 0x3A, 0xDC, 0x4D, +0x20, 0x79, 0xEE, 0x5F, 0x3E, 0xD7, 0xCB, 0x39, 0x48 +|] + +let aes_sbox_fwd_table : list(bits(8)) = [| +0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, +0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, +0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, +0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, +0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, +0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, +0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, +0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, +0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, +0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, +0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, +0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, +0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, +0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, +0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, +0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, +0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, +0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, +0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, +0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 +|] + +let aes_sbox_inv_table : list(bits(8)) = [| +0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, +0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, +0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, +0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, +0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 0x72, +0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, +0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, +0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, +0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, +0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, +0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, +0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, +0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, +0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, +0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8, +0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, +0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, +0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, +0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, +0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d +|] + +/* Lookup function - takes an index and a list, and retrieves the + * x'th element of that list. + */ +val sbox_lookup : (bits(8), list(bits(8))) -> bits(8) +function sbox_lookup(x, table) = { + match (x, table) { + (0x00, t0::tn) => t0, + ( y, t0::tn) => sbox_lookup(x - 0x01, tn) + } +} + +/* Easy function to perform a forward AES SBox operation on 1 byte. */ +val aes_sbox_fwd : bits(8) -> bits(8) +function aes_sbox_fwd(x) = sbox_lookup(x, aes_sbox_fwd_table) + +/* Easy function to perform an inverse AES SBox operation on 1 byte. */ +val aes_sbox_inv : bits(8) -> bits(8) +function aes_sbox_inv(x) = sbox_lookup(x, aes_sbox_inv_table) + +/* AES SubWord function used in the key expansion + * - Applies the forward sbox to each byte in the input word. + */ +val aes_subword_fwd : bits(32) -> bits(32) +function aes_subword_fwd(x) = { + aes_sbox_fwd(x[31..24]) @ + aes_sbox_fwd(x[23..16]) @ + aes_sbox_fwd(x[15.. 8]) @ + aes_sbox_fwd(x[ 7.. 0]) +} + +/* AES Inverse SubWord function. + * - Applies the inverse sbox to each byte in the input word. + */ +val aes_subword_inv : bits(32) -> bits(32) +function aes_subword_inv(x) = { + aes_sbox_inv(x[31..24]) @ + aes_sbox_inv(x[23..16]) @ + aes_sbox_inv(x[15.. 8]) @ + aes_sbox_inv(x[ 7.. 0]) +} + +/* Easy function to perform an SM4 SBox operation on 1 byte. */ +val sm4_sbox : bits(8) -> bits(8) +function sm4_sbox(x) = sbox_lookup(x, sm4_sbox_table) + +val aes_get_column : (bits(128), nat) -> bits(32) +function aes_get_column(state,c) = (state >> (to_bits(7, 32 * c)))[31..0] + +/* 64-bit to 64-bit function which applies the AES forward sbox to each byte + * in a 64-bit word. + */ +val aes_apply_fwd_sbox_to_each_byte : bits(64) -> bits(64) +function aes_apply_fwd_sbox_to_each_byte(x) = { + aes_sbox_fwd(x[63..56]) @ + aes_sbox_fwd(x[55..48]) @ + aes_sbox_fwd(x[47..40]) @ + aes_sbox_fwd(x[39..32]) @ + aes_sbox_fwd(x[31..24]) @ + aes_sbox_fwd(x[23..16]) @ + aes_sbox_fwd(x[15.. 8]) @ + aes_sbox_fwd(x[ 7.. 0]) +} + +/* 64-bit to 64-bit function which applies the AES inverse sbox to each byte + * in a 64-bit word. + */ +val aes_apply_inv_sbox_to_each_byte : bits(64) -> bits(64) +function aes_apply_inv_sbox_to_each_byte(x) = { + aes_sbox_inv(x[63..56]) @ + aes_sbox_inv(x[55..48]) @ + aes_sbox_inv(x[47..40]) @ + aes_sbox_inv(x[39..32]) @ + aes_sbox_inv(x[31..24]) @ + aes_sbox_inv(x[23..16]) @ + aes_sbox_inv(x[15.. 8]) @ + aes_sbox_inv(x[ 7.. 0]) +} + +/* + * AES full-round transformation functions. + */ + +val getbyte : (bits(64), int) -> bits(8) +function getbyte(x, i) = (x >> to_bits(6, i * 8))[7..0] + +val aes_rv64_shiftrows_fwd : (bits(64), bits(64)) -> bits(64) +function aes_rv64_shiftrows_fwd(rs2, rs1) = { + getbyte(rs1, 3) @ + getbyte(rs2, 6) @ + getbyte(rs2, 1) @ + getbyte(rs1, 4) @ + getbyte(rs2, 7) @ + getbyte(rs2, 2) @ + getbyte(rs1, 5) @ + getbyte(rs1, 0) +} + +val aes_rv64_shiftrows_inv : (bits(64), bits(64)) -> bits(64) +function aes_rv64_shiftrows_inv(rs2, rs1) = { + getbyte(rs2, 3) @ + getbyte(rs2, 6) @ + getbyte(rs1, 1) @ + getbyte(rs1, 4) @ + getbyte(rs1, 7) @ + getbyte(rs2, 2) @ + getbyte(rs2, 5) @ + getbyte(rs1, 0) +} + +/* 128-bit to 128-bit implementation of the forward AES ShiftRows transform. + * Byte 0 of state is input column 0, bits 7..0. + * Byte 5 of state is input column 1, bits 15..8. + */ +val aes_shift_rows_fwd : bits(128) -> bits(128) +function aes_shift_rows_fwd(x) = { + let ic3 : bits(32) = aes_get_column(x, 3); + let ic2 : bits(32) = aes_get_column(x, 2); + let ic1 : bits(32) = aes_get_column(x, 1); + let ic0 : bits(32) = aes_get_column(x, 0); + let oc0 : bits(32) = ic0[31..24] @ ic1[23..16] @ ic2[15.. 8] @ ic3[ 7.. 0]; + let oc1 : bits(32) = ic1[31..24] @ ic2[23..16] @ ic3[15.. 8] @ ic0[ 7.. 0]; + let oc2 : bits(32) = ic2[31..24] @ ic3[23..16] @ ic0[15.. 8] @ ic1[ 7.. 0]; + let oc3 : bits(32) = ic3[31..24] @ ic0[23..16] @ ic1[15.. 8] @ ic2[ 7.. 0]; + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} + +/* 128-bit to 128-bit implementation of the inverse AES ShiftRows transform. + * Byte 0 of state is input column 0, bits 7..0. + * Byte 5 of state is input column 1, bits 15..8. + */ +val aes_shift_rows_inv : bits(128) -> bits(128) +function aes_shift_rows_inv(x) = { + let ic3 : bits(32) = aes_get_column(x, 3); /* In column 3 */ + let ic2 : bits(32) = aes_get_column(x, 2); + let ic1 : bits(32) = aes_get_column(x, 1); + let ic0 : bits(32) = aes_get_column(x, 0); + let oc0 : bits(32) = ic0[31..24] @ ic3[23..16] @ ic2[15.. 8] @ ic1[ 7.. 0]; + let oc1 : bits(32) = ic1[31..24] @ ic0[23..16] @ ic3[15.. 8] @ ic2[ 7.. 0]; + let oc2 : bits(32) = ic2[31..24] @ ic1[23..16] @ ic0[15.. 8] @ ic3[ 7.. 0]; + let oc3 : bits(32) = ic3[31..24] @ ic2[23..16] @ ic1[15.. 8] @ ic0[ 7.. 0]; + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} + +/* Applies the forward sub-bytes step of AES to a 128-bit vector + * representation of its state. + */ +val aes_subbytes_fwd : bits(128) -> bits(128) +function aes_subbytes_fwd(x) = { + let oc0 : bits(32) = aes_subword_fwd(aes_get_column(x, 0)); + let oc1 : bits(32) = aes_subword_fwd(aes_get_column(x, 1)); + let oc2 : bits(32) = aes_subword_fwd(aes_get_column(x, 2)); + let oc3 : bits(32) = aes_subword_fwd(aes_get_column(x, 3)); + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} + +/* Applies the inverse sub-bytes step of AES to a 128-bit vector + * representation of its state. + */ +val aes_subbytes_inv : bits(128) -> bits(128) +function aes_subbytes_inv(x) = { + let oc0 : bits(32) = aes_subword_inv(aes_get_column(x, 0)); + let oc1 : bits(32) = aes_subword_inv(aes_get_column(x, 1)); + let oc2 : bits(32) = aes_subword_inv(aes_get_column(x, 2)); + let oc3 : bits(32) = aes_subword_inv(aes_get_column(x, 3)); + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} + +/* Applies the forward MixColumns step of AES to a 128-bit vector + * representation of its state. + */ +val aes_mixcolumns_fwd : bits(128) -> bits(128) +function aes_mixcolumns_fwd(x) = { + let oc0 : bits(32) = aes_mixcolumn_fwd(aes_get_column(x, 0)); + let oc1 : bits(32) = aes_mixcolumn_fwd(aes_get_column(x, 1)); + let oc2 : bits(32) = aes_mixcolumn_fwd(aes_get_column(x, 2)); + let oc3 : bits(32) = aes_mixcolumn_fwd(aes_get_column(x, 3)); + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} + +/* Applies the inverse MixColumns step of AES to a 128-bit vector + * representation of its state. + */ +val aes_mixcolumns_inv : bits(128) -> bits(128) +function aes_mixcolumns_inv(x) = { + let oc0 : bits(32) = aes_mixcolumn_inv(aes_get_column(x, 0)); + let oc1 : bits(32) = aes_mixcolumn_inv(aes_get_column(x, 1)); + let oc2 : bits(32) = aes_mixcolumn_inv(aes_get_column(x, 2)); + let oc3 : bits(32) = aes_mixcolumn_inv(aes_get_column(x, 3)); + (oc3 @ oc2 @ oc1 @ oc0) /* Return value */ +} diff --git a/ocaml_emulator/platform.ml b/ocaml_emulator/platform.ml index 4664f73..81e33da 100644 --- a/ocaml_emulator/platform.ml +++ b/ocaml_emulator/platform.ml @@ -97,6 +97,11 @@ let insns_per_tick () = Big_int.of_int P.insns_per_tick let htif_tohost () = arch_bits_of_int64 (Big_int.to_int64 (Elf.elf_tohost ())) +(* Entropy Source - get random bits *) + +(* This function can be changed to support deterministic sequences of + pseudo-random bytes. This is useful for testing. *) +let get_16_random_bits () = arch_bits_of_int (Random.int 0xFFFF) (* load reservation *) -- cgit v1.1