diff options
author | IIITM-Jay <jaydev.neuroscitech@gmail.com> | 2024-10-25 01:43:58 +0530 |
---|---|---|
committer | IIITM-Jay <jaydev.neuroscitech@gmail.com> | 2024-10-25 01:43:58 +0530 |
commit | d57a94cf8eb07917a2084b6d502b9225fc9ce210 (patch) | |
tree | ef59527e0beba91314303c5732354fd88e582dcd /shared_utils.py | |
parent | 0b7f6180f893ec7cddaa6fcf55c7a7e3969ccd17 (diff) | |
download | riscv-opcodes-d57a94cf8eb07917a2084b6d502b9225fc9ce210.zip riscv-opcodes-d57a94cf8eb07917a2084b6d502b9225fc9ce210.tar.gz riscv-opcodes-d57a94cf8eb07917a2084b6d502b9225fc9ce210.tar.bz2 |
clean up codes for refactoring parsing logic
Diffstat (limited to 'shared_utils.py')
-rw-r--r-- | shared_utils.py | 481 |
1 files changed, 279 insertions, 202 deletions
diff --git a/shared_utils.py b/shared_utils.py index 8c081e2..5c92515 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -5,6 +5,7 @@ import logging import os import pprint import re +from itertools import chain from constants import * @@ -15,290 +16,353 @@ pretty_printer = pprint.PrettyPrinter(indent=2) logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) -def process_enc_line(line, ext): - """ - This function processes each line of the encoding files (rv*). As part of - the processing, the function ensures that the encoding is legal through the - following checks:: - - - there is no over specification (same bits assigned different values) - - there is no under specification (some bits not assigned values) - - bit ranges are in the format hi..lo=val where hi > lo - - value assigned is representable in the bit range - - also checks that the mapping of arguments of an instruction exists in - arg_lut. - - If the above checks pass, then the function returns a tuple of the name and - a dictionary containing basic information of the instruction which includes: - - variables: list of arguments used by the instruction whose mapping - exists in the arg_lut dictionary - - encoding: this contains the 32-bit encoding of the instruction where - '-' is used to represent position of arguments and 1/0 is used to - reprsent the static encoding of the bits - - extension: this field contains the rv* filename from which this - instruction was included - - match: hex value representing the bits that need to match to detect - this instruction - - mask: hex value representin the bits that need to be masked to extract - the value required for matching. - """ - encoding = initialize_encoding() - name, remaining = parse_instruction_name(line) - - # Fixed ranges of the form hi..lo=val - process_fixed_ranges(remaining, encoding, line) - - # Single fixed values of the form <lsb>=<val> - remaining = process_single_fixed(remaining, encoding, line) +# Initialize encoding to 32-bit '-' values +def initialize_encoding(bits=32): + """Initialize encoding with '-' to represent don't care bits.""" + return ["-"] * bits - # Create match and mask strings - match, mask = create_match_and_mask(encoding) - - # Process instruction arguments - args = process_arguments(remaining, encoding, name) - - # Create and return the final instruction dictionary - instruction_dict = create_instruction_dict(encoding, args, ext, match, mask) - - return name, instruction_dict +# Validate bit range and value +def validate_bit_range(msb, lsb, entry_value, line): + """Validate the bit range and entry value.""" + if msb < lsb: + logging.error( + f'{line.split(" ")[0]:<10} has position {msb} less than position {lsb} in its encoding' + ) + raise SystemExit(1) -def initialize_encoding(): - """Initialize a 32-bit encoding with '-' representing 'don't care'.""" - return ["-"] * 32 + if entry_value >= (1 << (msb - lsb + 1)): + logging.error( + f'{line.split(" ")[0]:<10} has an illegal value {entry_value} assigned as per the bit width {msb - lsb}' + ) + raise SystemExit(1) -def parse_instruction_name(line): - """Extract the instruction name and remaining part of the line.""" +# Split the instruction line into name and remaining part +def parse_instruction_line(line): + """Parse the instruction name and the remaining encoding details.""" name, remaining = line.split(" ", 1) - name = name.replace(".", "_").lstrip() + name = name.replace(".", "_") # Replace dots for compatibility + remaining = remaining.lstrip() # Remove leading whitespace return name, remaining -def process_fixed_ranges(remaining, encoding, line): - """Process bit ranges of the form hi..lo=val, checking for errors and updating encoding.""" - for s2, s1, entry in fixed_ranges.findall(remaining): - msb, lsb = int(s2), int(s1) - validate_bit_range(msb, lsb, line) - validate_entry_value(msb, lsb, entry, line) - update_encoding(msb, lsb, entry, encoding, line) +# Verify Overlapping Bits +def check_overlapping_bits(encoding, ind, line): + """Check for overlapping bits in the encoding.""" + if encoding[31 - ind] != "-": + logging.error( + f'{line.split(" ")[0]:<10} has {ind} bit overlapping in its opcodes' + ) + raise SystemExit(1) -def validate_bit_range(msb, lsb, line): - """Ensure that msb > lsb and raise an error if not.""" - if msb < lsb: - log_and_exit(f"{get_instruction_name(line)} has msb < lsb in its encoding") +# Update encoding for fixed ranges +def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line): + """ + Update encoding bits for a given bit range. + Checks for overlapping bits and assigns the value accordingly. + """ + for ind in range(lsb, msb + 1): + check_overlapping_bits(encoding, ind, line) + bit = str((entry_value >> (ind - lsb)) & 1) + encoding[31 - ind] = bit -def validate_entry_value(msb, lsb, entry, line): - """Ensure that the value assigned to a bit range is legal for its width.""" - entry_value = int(entry, 0) - if entry_value >= (1 << (msb - lsb + 1)): - log_and_exit( - f"{get_instruction_name(line)} has an illegal value for the bit width {msb - lsb}" - ) +# Process fixed bit patterns +def process_fixed_ranges(remaining, encoding, line): + """Process fixed bit ranges in the encoding.""" + for s2, s1, entry in fixed_ranges.findall(remaining): + msb, lsb, entry_value = int(s2), int(s1), int(entry, 0) + # Validate bit range and entry value + validate_bit_range(msb, lsb, entry_value, line) + update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line) -def update_encoding(msb, lsb, entry, encoding, line): - """Update the encoding array for a given bit range.""" - entry_value = int(entry, 0) - for ind in range(lsb, msb + 1): - if encoding[31 - ind] != "-": - log_and_exit( - f"{get_instruction_name(line)} has overlapping bits in its opcodes" - ) - encoding[31 - ind] = str((entry_value >> (ind - lsb)) & 1) + return fixed_ranges.sub(" ", remaining) +# Process single bit assignments def process_single_fixed(remaining, encoding, line): - """Process single fixed values of the form <lsb>=<val>.""" - for lsb, value, _ in single_fixed.findall(remaining): + """Process single fixed assignments in the encoding.""" + for lsb, value, drop in single_fixed.findall(remaining): lsb = int(lsb, 0) value = int(value, 0) - if encoding[31 - lsb] != "-": - log_and_exit( - f"{get_instruction_name(line)} has overlapping bits in its opcodes" - ) - encoding[31 - lsb] = str(value) - return fixed_ranges.sub(" ", remaining) - -def create_match_and_mask(encoding): - """Generate match and mask strings from the encoding array.""" - match = "".join(encoding).replace("-", "0") - mask = "".join(encoding).replace("0", "1").replace("-", "0") - return match, mask + check_overlapping_bits(encoding, lsb, line) + encoding[31 - lsb] = str(value) -def process_arguments(remaining, encoding, name): - """Process instruction arguments and update the encoding with argument positions.""" - args = single_fixed.sub(" ", remaining).split() - encoding_args = encoding.copy() +# Main function to check argument look-up table +def check_arg_lut(args, encoding_args, name): + """Check if arguments are present in arg_lut.""" for arg in args: if arg not in arg_lut: - handle_missing_arg(arg, name) + arg = handle_arg_lut_mapping(arg, name) msb, lsb = arg_lut[arg] - update_arg_encoding(msb, lsb, arg, encoding_args, name) - return args, encoding_args + update_encoding_args(encoding_args, arg, msb, lsb) -def handle_missing_arg(arg, name): - """Handle missing argument mapping in arg_lut.""" - if "=" in arg: - existing_arg = arg.split("=")[0] +# Handle missing argument mappings +def handle_arg_lut_mapping(arg, name): + """Handle cases where an argument needs to be mapped to an existing one.""" + parts = arg.split("=") + if len(parts) == 2: + existing_arg, new_arg = parts if existing_arg in arg_lut: arg_lut[arg] = arg_lut[existing_arg] - return - log_and_exit(f"Variable {arg} in instruction {name} not mapped in arg_lut") + else: + logging.error( + f" Found field {existing_arg} in variable {arg} in instruction {name} " + f"whose mapping in arg_lut does not exist" + ) + raise SystemExit(1) + else: + logging.error( + f" Found variable {arg} in instruction {name} " + f"whose mapping in arg_lut does not exist" + ) + raise SystemExit(1) + return arg -def update_arg_encoding(msb, lsb, arg, encoding_args, name): - """Update the encoding array with the argument positions.""" +# Update encoding args with variables +def update_encoding_args(encoding_args, arg, msb, lsb): + """Update encoding arguments and ensure no overlapping.""" for ind in range(lsb, msb + 1): - if encoding_args[31 - ind] != "-": - log_and_exit(f"Variable {arg} overlaps in bit {ind} in instruction {name}") + check_overlapping_bits(encoding_args, ind, arg) encoding_args[31 - ind] = arg -def create_instruction_dict(encoding, args, ext, match, mask): - """Create the final dictionary for the instruction.""" - return { +# Compute match and mask +def convert_encoding_to_match_mask(encoding): + """Convert the encoding list to match and mask strings.""" + match = "".join(encoding).replace("-", "0") + mask = "".join(encoding).replace("0", "1").replace("-", "0") + return hex(int(match, 2)), hex(int(mask, 2)) + + +# Processing main function for a line in the encoding file +def process_enc_line(line, ext): + """ + This function processes each line of the encoding files (rv*). As part of + the processing, the function ensures that the encoding is legal through the + following checks:: + - there is no over specification (same bits assigned different values) + - there is no under specification (some bits not assigned values) + - bit ranges are in the format hi..lo=val where hi > lo + - value assigned is representable in the bit range + - also checks that the mapping of arguments of an instruction exists in + arg_lut. + If the above checks pass, then the function returns a tuple of the name and + a dictionary containing basic information of the instruction which includes: + - variables: list of arguments used by the instruction whose mapping + exists in the arg_lut dictionary + - encoding: this contains the 32-bit encoding of the instruction where + '-' is used to represent position of arguments and 1/0 is used to + reprsent the static encoding of the bits + - extension: this field contains the rv* filename from which this + instruction was included + - match: hex value representing the bits that need to match to detect + this instruction + - mask: hex value representin the bits that need to be masked to extract + the value required for matching. + """ + encoding = initialize_encoding() + + # Parse the instruction line + name, remaining = parse_instruction_line(line) + + # Process fixed ranges + remaining = process_fixed_ranges(remaining, encoding, line) + + # Process single fixed assignments + process_single_fixed(remaining, encoding, line) + + # Convert the list of encodings into a match and mask + match, mask = convert_encoding_to_match_mask(encoding) + + # Check arguments in arg_lut + args = single_fixed.sub(" ", remaining).split() + encoding_args = encoding.copy() + + check_arg_lut(args, encoding_args, name) + + # Return single_dict + return name, { "encoding": "".join(encoding), "variable_fields": args, "extension": [os.path.basename(ext)], - "match": hex(int(match, 2)), - "mask": hex(int(mask, 2)), + "match": match, + "mask": mask, } -def log_and_exit(message): - """Log an error message and exit the program.""" - logging.error(message) - raise SystemExit(1) +# Extract ISA Type +def extract_isa_type(ext_name): + """Extracts the ISA type from the extension name.""" + return ext_name.split("_")[0] -def get_instruction_name(line): - """Helper to extract the instruction name from a line.""" - return line.split(" ")[0] +# Verify the types for RV* +def is_rv_variant(type1, type2): + """Checks if the types are RV variants (rv32/rv64).""" + return (type2 == "rv" and type1 in {"rv32", "rv64"}) or ( + type1 == "rv" and type2 in {"rv32", "rv64"} + ) -def overlaps(x, y): - """ - Check if two bit strings overlap without conflicts. +# Check for same base ISA +def has_same_base_isa(type1, type2): + """Determines if the two ISA types share the same base.""" + return type1 == type2 or is_rv_variant(type1, type2) - Args: - x (str): First bit string. - y (str): Second bit string. - Returns: - bool: True if the bit strings overlap without conflicts, False otherwise. +# Compare the base ISA type of a given extension name against a list of extension names +def same_base_isa(ext_name, ext_name_list): + """Checks if the base ISA type of ext_name matches any in ext_name_list.""" + type1 = extract_isa_type(ext_name) + return any(has_same_base_isa(type1, extract_isa_type(ext)) for ext in ext_name_list) - In the context of RISC-V opcodes, this function ensures that the bit ranges - defined by two different bit strings do not conflict. - """ - # Minimum length of the two strings - min_len = min(len(x), len(y)) +# Pad two strings to equal length +def pad_to_equal_length(str1, str2, pad_char="-"): + """Pads two strings to equal length using the given padding character.""" + max_len = max(len(str1), len(str2)) + return str1.rjust(max_len, pad_char), str2.rjust(max_len, pad_char) - for char_x, char_y in zip(x[:min_len], y[:min_len]): - if char_x != "-" and char_y != "-" and char_x != char_y: - return False - return True +# Check compatibility for two characters +def has_no_conflict(char1, char2): + """Checks if two characters are compatible (either matching or don't-care).""" + return char1 == "-" or char2 == "-" or char1 == char2 -def overlap_allowed(a, x, y): - """ - Check if there is an overlap between keys and values in a dictionary. +# Conflict check between two encoded strings +def overlaps(x, y): + """Checks if two encoded strings overlap without conflict.""" + x, y = pad_to_equal_length(x, y) + return all(has_no_conflict(x[i], y[i]) for i in range(len(x))) - Args: - a (dict): The dictionary where keys are mapped to sets or lists of keys. - x (str): The first key to check. - y (str): The second key to check. - Returns: - bool: True if both (x, y) or (y, x) are present in the dictionary - as described, False otherwise. +# Check presence of keys in dictionary. +def is_in_nested_dict(a, key1, key2): + """Checks if key2 exists in the dictionary under key1.""" + return key1 in a and key2 in a[key1] - This function determines if `x` is a key in the dictionary `a` and - its corresponding value contains `y`, or if `y` is a key and its - corresponding value contains `x`. - """ - return x in a and y in a[x] or y in a and x in a[y] +# Overlap allowance +def overlap_allowed(a, x, y): + """Determines if overlap is allowed between x and y based on nested dictionary checks""" + return is_in_nested_dict(a, x, y) or is_in_nested_dict(a, y, x) -# Checks if overlap between two extensions is allowed +# Check overlap allowance between extensions def extension_overlap_allowed(x, y): + """Checks if overlap is allowed between two extensions using the overlapping_extensions dictionary.""" return overlap_allowed(overlapping_extensions, x, y) -# Checks if overlap between two instructions is allowed +# Check overlap allowance between instructions def instruction_overlap_allowed(x, y): + """Checks if overlap is allowed between two instructions using the overlapping_instructions dictionary.""" return overlap_allowed(overlapping_instructions, x, y) -# Checks if ext_name shares the same base ISA with any in ext_name_list -def same_base_isa(ext_name, ext_name_list): - type1 = ext_name.split("_")[0] - for ext_name1 in ext_name_list: - type2 = ext_name1.split("_")[0] - if ( - type1 == type2 - or (type2 == "rv" and type1 in ["rv32", "rv64"]) - or (type1 == "rv" and type2 in ["rv32", "rv64"]) - ): - return True - return False +# Check 'nf' field +def is_segmented_instruction(instruction): + """Checks if an instruction contains the 'nf' field.""" + return "nf" in instruction["variable_fields"] + +# Expand 'nf' fields +def update_with_expanded_instructions(updated_dict, key, value): + """Expands 'nf' fields in the instruction dictionary and updates it with new instructions.""" + for new_key, new_value in expand_nf_field(key, value): + updated_dict[new_key] = new_value -# Expands instructions with "nf" field in variable_fields, otherwise returns unchanged + +# Process instructions, expanding segmented ones and updating the dictionary def add_segmented_vls_insn(instr_dict): - updated_dict = {} - for k, v in instr_dict.items(): - if "nf" in v["variable_fields"]: - updated_dict.update(expand_nf_field(k, v)) - else: - updated_dict[k] = v - return updated_dict + """Processes instructions, expanding segmented ones and updating the dictionary.""" + # Use dictionary comprehension for efficiency + return dict( + chain.from_iterable( + ( + expand_nf_field(key, value) + if is_segmented_instruction(value) + else [(key, value)] + ) + for key, value in instr_dict.items() + ) + ) -# Expands nf field in instruction name and updates instruction details +# Expand the 'nf' field in the instruction dictionary def expand_nf_field(name, single_dict): + """Validate and prepare the instruction dictionary.""" + validate_nf_field(single_dict, name) + remove_nf_field(single_dict) + update_mask(single_dict) + + name_expand_index = name.find("e") + + # Pre compute the base match value and encoding prefix + base_match = int(single_dict["match"], 16) + encoding_prefix = single_dict["encoding"][3:] + + expanded_instructions = [ + create_expanded_instruction( + name, single_dict, nf, name_expand_index, base_match, encoding_prefix + ) + for nf in range(8) # Range of 0 to 7 + ] + + return expanded_instructions + + +# Validate the presence of 'nf' +def validate_nf_field(single_dict, name): + """Validates the presence of 'nf' in variable fields before expansion.""" if "nf" not in single_dict["variable_fields"]: logging.error(f"Cannot expand nf field for instruction {name}") raise SystemExit(1) - single_dict["variable_fields"].remove("nf") # Remove "nf" from variable fields - single_dict["mask"] = hex( - int(single_dict["mask"], 16) | (0b111 << 29) - ) # Update mask - name_expand_index = name.find("e") - expanded_instructions = [] - for nf in range(8): # Expand nf for values 0 to 7 - new_single_dict = copy.deepcopy(single_dict) - new_single_dict["match"] = hex(int(single_dict["match"], 16) | (nf << 29)) - new_single_dict["encoding"] = format(nf, "03b") + single_dict["encoding"][3:] - new_name = ( - name - if nf == 0 - else f"{name[:name_expand_index]}seg{nf+1}{name[name_expand_index:]}" - ) - expanded_instructions.append((new_name, new_single_dict)) - return expanded_instructions +# Remove 'nf' from variable fields +def remove_nf_field(single_dict): + """Removes 'nf' from variable fields in the instruction dictionary.""" + single_dict["variable_fields"].remove("nf") -# Extracts the extensions used in an instruction dictionary -def instr_dict_2_extensions(instr_dict): - return list({item["extension"][0] for item in instr_dict.values()}) +# Update the mask to include the 'nf' field +def update_mask(single_dict): + """Updates the mask to include the 'nf' field in the instruction dictionary.""" + single_dict["mask"] = hex(int(single_dict["mask"], 16) | 0b111 << 29) -# Returns signed interpretation of a value within a given width -def signed(value, width): - return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) +# Create an expanded instruction +def create_expanded_instruction( + name, single_dict, nf, name_expand_index, base_match, encoding_prefix +): + """Creates an expanded instruction based on 'nf' value.""" + new_single_dict = copy.deepcopy(single_dict) + + # Update match value in one step + new_single_dict["match"] = hex(base_match | (nf << 29)) + new_single_dict["encoding"] = format(nf, "03b") + encoding_prefix + # Construct new instruction name + new_name = ( + name + if nf == 0 + else f"{name[:name_expand_index]}seg{nf + 1}{name[name_expand_index:]}" + ) + return (new_name, new_single_dict) + + +# Return a list of relevant lines from the specified file def read_lines(file): """Reads lines from a file and returns non-blank, non-comment lines.""" with open(file) as fp: @@ -306,6 +370,7 @@ def read_lines(file): return [line for line in lines if line and not line.startswith("#")] +# Update the instruction dictionary def process_standard_instructions(lines, instr_dict, file_name): """Processes standard instructions from the given lines and updates the instruction dictionary.""" for line in lines: @@ -342,6 +407,7 @@ def process_standard_instructions(lines, instr_dict, file_name): instr_dict[name] = single_dict +# Incorporate pseudo instructions into the instruction dictionary based on given conditions def process_pseudo_instructions( lines, instr_dict, file_name, opcodes_dir, include_pseudo, include_pseudo_ops ): @@ -371,6 +437,7 @@ def process_pseudo_instructions( instr_dict[name]["extension"].extend(single_dict["extension"]) +# Integrate imported instructions into the instruction dictionary def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): """Processes imported instructions from the given lines and updates the instruction dictionary.""" for line in lines: @@ -396,6 +463,7 @@ def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): break +# Locate the path of the specified extension file, checking fallback directories def find_extension_file(ext, opcodes_dir): """Finds the extension file path, considering the unratified directory if necessary.""" ext_file = f"{opcodes_dir}/{ext}" @@ -406,6 +474,7 @@ def find_extension_file(ext, opcodes_dir): return ext_file +# Confirm the presence of an original instruction in the corresponding extension file. def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): """Validates if the original instruction exists in the dependent extension.""" found = False @@ -419,15 +488,14 @@ def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): ) +# Construct a dictionary of instructions filtered by specified criteria def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): """Creates a dictionary of instructions based on the provided file filters.""" """ This function return a dictionary containing all instructions associated with an extension defined by the file_filter input. - Allowed input extensions: needs to be rv* file name without the 'rv' prefix i.e. '_i', '32_i', etc. - Each node of the dictionary will correspond to an instruction which again is a dictionary. The dictionary contents of each instruction includes: - variables: list of arguments used by the instruction whose mapping @@ -441,7 +509,6 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): this instruction - mask: hex value representin the bits that need to be masked to extract the value required for matching. - In order to build this dictionary, the function does 2 passes over the same rv<file_filter> file: - First pass: extracts all standard instructions, skipping pseudo ops @@ -489,3 +556,13 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): process_imported_instructions(lines, instr_dict, file_name, opcodes_dir) return instr_dict + + +# Extracts the extensions used in an instruction dictionary +def instr_dict_2_extensions(instr_dict): + return list({item["extension"][0] for item in instr_dict.values()}) + + +# Returns signed interpretation of a value within a given width +def signed(value, width): + return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) |