#!/usr/bin/env python3 from __future__ import annotations from dataclasses import dataclass from instructions import * from typing import Any, Iterable, Callable, Optional, Tuple, List, Dict import argparse import fileinput import inspect import re import sys RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$") # Parse the SPIR-V instructions. Some instructions are ignored because # not required to simulate this module. # Instructions are to be implemented in instructions.py def parseInstruction(i): IGNORED = set( [ "OpCapability", "OpMemoryModel", "OpExecutionMode", "OpExtension", "OpSource", "OpTypeInt", "OpTypeStruct", "OpTypeFloat", "OpTypeBool", "OpTypeVoid", "OpTypeFunction", "OpTypePointer", "OpTypeArray", ] ) if i.opcode() in IGNORED: return None try: Type = getattr(sys.modules["instructions"], i.opcode()) except AttributeError: raise RuntimeError(f"Unsupported instruction {i}") if not inspect.isclass(Type): raise RuntimeError( f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?" ) return Type(i.line) # Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType. # The delimiter is the first instruction of the next piece. # This function returns no empty pieces: # - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second # with the delimiter and following instructions. # - if the first instruction is a delimiter, the first piece will begin with this delimiter. def splitInstructions( splitType: type, instructions: Iterable[Instruction] ) -> List[List[Instruction]]: blocks: List[List[Instruction]] = [[]] for instruction in instructions: if isinstance(instruction, splitType) and len(blocks[-1]) > 0: blocks.append([]) blocks[-1].append(instruction) return blocks # Defines a BasicBlock in the simulator. # Begins at an OpLabel, and ends with a control-flow instruction. class BasicBlock: def __init__(self, instructions) -> None: assert isinstance(instructions[0], OpLabel) # The name of the basic block, which is the register of the leading # OpLabel. self._name = instructions[0].output_register() # The list of instructions belonging to this block. self._instructions = instructions[1:] # Returns the name of this basic block. def name(self): return self._name # Returns the instruction at index in this basic block. def __getitem__(self, index: int) -> Instruction: return self._instructions[index] # Returns the number of instructions in this basic block, excluding the # leading OpLabel. def __len__(self): return len(self._instructions) def dump(self): print(f" {self._name}:") for instruction in self._instructions: print(f" {instruction}") # Defines a Function in the simulator. class Function: def __init__(self, instructions) -> None: assert isinstance(instructions[0], OpFunction) # The name of the function (name of the register returned by OpFunction). self._name: str = instructions[0].output_register() # The list of basic blocks that belongs to this function. self._basic_blocks: List[BasicBlock] = [] # The variables local to this function. self._variables: List[OpVariable] = [ x for x in instructions if isinstance(x, OpVariable) ] assert isinstance(instructions[-1], OpFunctionEnd) body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1]) for block in splitInstructions(OpLabel, body): self._basic_blocks.append(BasicBlock(block)) # Returns the name of this function. def name(self) -> str: return self._name # Returns the basic block at index in this function. def __getitem__(self, index: int) -> BasicBlock: return self._basic_blocks[index] # Returns the index of the basic block with the given name if found, # -1 otherwise. def get_bb_index(self, name) -> int: for i in range(len(self._basic_blocks)): if self._basic_blocks[i].name() == name: return i return -1 def dump(self): print(" Variables:") for var in self._variables: print(f" {var}") print(" Blocks:") for bb in self._basic_blocks: bb.dump() # Represents an instruction pointer in the simulator. @dataclass class InstructionPointer: # The current function the IP points to. function: Function # The basic block index in function IP points to. basic_block: int # The instruction in basic_block IP points to. instruction_index: int def __str__(self): bb = self.function[self.basic_block] i = bb[self.instruction_index] return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}" def __hash__(self): return hash((self.function.name(), self.basic_block, self.instruction_index)) # Returns the basic block IP points to. def bb(self) -> BasicBlock: return self.function[self.basic_block] # Returns the instruction IP points to. def instruction(self): return self.function[self.basic_block][self.instruction_index] # Increment IP by 1. This only works inside a basic-block boundary. # Incrementing IP when at the boundary of a basic block will fail. def __add__(self, value: int): bb = self.function[self.basic_block] assert len(bb) > self.instruction_index + value return InstructionPointer( self.function, self.basic_block, self.instruction_index + value ) # Defines a Lane in this simulator. class Lane: # The registers known by this lane. _registers: Dict[str, Any] # The current IP of this lane. _ip: Optional[InstructionPointer] # If this lane running. _running: bool # The wave this lane belongs to. _wave: Wave # The callstack of this lane. Each tuple represents 1 call. # The first element is the IP the function will return to. # The second element is the callback to call to store the return value # into the correct register. _callstack: List[Tuple[InstructionPointer, Callable[[Any], None]]] _previous_bb: Optional[BasicBlock] _current_bb: Optional[BasicBlock] def __init__(self, wave: Wave, tid: int) -> None: self._registers = dict() self._ip = None self._running = True self._wave = wave self._callstack = [] # The index of this lane in the wave. self._tid = tid # The last BB this lane was executing into. self._previous_bb = None # The current BB this lane is executing into. self._current_bb = None # Returns the lane/thread ID of this lane in its wave. def tid(self) -> int: return self._tid # Returns true is this lane if the first by index in the current active tangle. def is_first_active_lane(self) -> bool: return self._tid == self._wave.get_first_active_lane_index() # Broadcast value into the registers of all active lanes. def broadcast_register(self, register: str, value: Any) -> None: self._wave.broadcast_register(register, value) # Returns the IP this lane is currently at. def ip(self) -> InstructionPointer: assert self._ip is not None return self._ip # Returns true if this lane is running, false otherwise. # Running means not dead. An inactive lane is running. def running(self) -> bool: return self._running # Set the register at "name" to "value" in this lane. def set_register(self, name: str, value: Any) -> None: self._registers[name] = value # Get the value in register "name" in this lane. # If allow_undef is true, fetching an unknown register won't fail. def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]: if allow_undef and name not in self._registers: return None return self._registers[name] def set_ip(self, ip: InstructionPointer) -> None: if ip.bb() != self._current_bb: self._previous_bb = self._current_bb self._current_bb = ip.bb() self._ip = ip def get_previous_bb_name(self): return self._previous_bb.name() def handle_convergence_header(self, instruction): self._wave.handle_convergence_header(self, instruction) def do_call(self, ip, output_register): return_ip = None if self._ip is None else self._ip + 1 self._callstack.append( (return_ip, lambda value: self.set_register(output_register, value)) ) self.set_ip(ip) def do_return(self, value): ip, callback = self._callstack[-1] self._callstack.pop() callback(value) if len(self._callstack) == 0: self._running = False else: self.set_ip(ip) # Represents the SPIR-V module in the simulator. class Module: _functions: Dict[str, Function] _prolog: List[Instruction] _globals: List[Instruction] _name2reg: Dict[str, str] _reg2name: Dict[str, str] def __init__(self, instructions) -> None: chunks = splitInstructions(OpFunction, instructions) # The instructions located outside of all functions. self._prolog = chunks[0] # The functions in this module. self._functions = {} # Global variables in this module. self._globals = [ x for x in instructions if isinstance(x, OpVariable) or issubclass(type(x), OpConstant) ] # Helper dictionaries to get real names of registers, or registers by names. self._name2reg = {} self._reg2name = {} for instruction in instructions: if isinstance(instruction, OpName): name = instruction.name() reg = instruction.decoratedRegister() self._name2reg[name] = reg self._reg2name[reg] = name for chunk in chunks[1:]: function = Function(chunk) assert function.name() not in self._functions self._functions[function.name()] = function # Returns the register matching "name" if any, None otherwise. # This assumes names are unique. def getRegisterFromName(self, name): if name in self._name2reg: return self._name2reg[name] return None # Returns the name given to "register" if any, None otherwise. def getNameFromRegister(self, register): if register in self._reg2name: return self._reg2name[register] return None # Initialize the module before wave execution begins. # See Instruction::static_execution for more details. def initialize(self, lane): for instruction in self._globals: instruction.static_execution(lane) # Initialize builtins for instruction in self._prolog: if isinstance(instruction, OpDecorate): instruction.static_execution(lane) def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None: ip.instruction().runtime_execution(self, lane) # Returns the first valid IP for the function defined by the given register. # Calling this with a register not returned by OpFunction is illegal. def get_function_entry(self, register: str) -> InstructionPointer: if register not in self._functions: raise RuntimeError(f"Function defining {register} not found.") return InstructionPointer(self._functions[register], 0, 0) # Returns the first valid IP for the basic block defined by register. # Calling this with a register not returned by an OpLabel is illegal. def get_bb_entry(self, register: str) -> InstructionPointer: for name, function in self._functions.items(): index = function.get_bb_index(register) if index != -1: return InstructionPointer(function, index, 0) raise RuntimeError(f"Instruction defining {register} not found.") # Returns the list of function names in this module. # If an OpName exists for this function, returns the pretty name, else # returns the register name. def get_function_names(self): return [self.getNameFromRegister(reg) for reg, func in self._functions.items()] # Returns the global variables defined in this module. def variables(self) -> Iterable: return [x.output_register() for x in self._globals] def dump(self, function_name: Optional[str] = None): print("Module:") print(" globals:") for instruction in self._globals: print(f" {instruction}") if function_name is None: print(" functions:") for register, function in self._functions.items(): name = self.getNameFromRegister(register) print(f" Function {register} ({name})") function.dump() return register = self.getRegisterFromName(function_name) print(f" function {register} ({function_name}):") if register is not None: self._functions[register].dump() else: print(f" error: cannot find function.") # Defines a convergence requirement for the simulation: # A list of lanes impacted by a merge and possibly the associated # continue target. @dataclass class ConvergenceRequirement: mergeTarget: InstructionPointer continueTarget: Optional[InstructionPointer] impactedLanes: set[int] Task = Dict[InstructionPointer, List[Lane]] # Defines a Lane group/Wave in the simulator. class Wave: # The module this wave will execute. _module: Module # The lanes this wave will be composed of. _lanes: List[Lane] # The instructions scheduled for execution. _tasks: Task # The actual requirements to comply with when executing instructions. # E.g: the set of lanes required to merge before executing the merge block. _convergence_requirements: List[ConvergenceRequirement] # The indices of the active lanes for the current executing instruction. _active_lane_indices: set[int] def __init__(self, module, wave_size: int) -> None: assert wave_size > 0 self._module = module self._lanes = [] for i in range(wave_size): self._lanes.append(Lane(self, i)) self._tasks = {} self._convergence_requirements = [] # The indices of the active lanes for the current executing instruction. self._active_lane_indices = set() # Returns True if the given IP can be executed for the given list of lanes. def _is_task_candidate(self, ip: InstructionPointer, lanes: List[Lane]): merged_lanes: set[int] = set() for lane in self._lanes: if not lane.running(): merged_lanes.add(lane.tid()) for requirement in self._convergence_requirements: # This task is not executing a merge or continue target. # Adding all lanes at those points into the ignore list. if requirement.mergeTarget != ip and requirement.continueTarget != ip: for tid in requirement.impactedLanes: if self._lanes[tid].ip() == requirement.mergeTarget: merged_lanes.add(tid) if self._lanes[tid].ip() == requirement.continueTarget: merged_lanes.add(tid) continue # This task is executing the current requirement continue/merge # target. for tid in requirement.impactedLanes: lane = self._lanes[tid] if not lane.running(): continue if lane.tid() in merged_lanes: continue if ip == requirement.mergeTarget: if lane.ip() != requirement.mergeTarget: return False else: if ( lane.ip() != requirement.mergeTarget and lane.ip() != requirement.continueTarget ): return False return True # Returns the next task we can schedule. This must always return a task. # Calling this when all lanes are dead is invalid. def _get_next_runnable_task(self) -> Tuple[InstructionPointer, List[Lane]]: candidate = None for ip, lanes in self._tasks.items(): if len(lanes) == 0: continue if self._is_task_candidate(ip, lanes): candidate = ip break if candidate: lanes = self._tasks[candidate] del self._tasks[ip] return (candidate, lanes) raise RuntimeError("No task to execute. Deadlock?") # Handle an encountered merge instruction for the given lane. def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction): mergeTarget = self._module.get_bb_entry(instruction.merge_location()) for requirement in self._convergence_requirements: if requirement.mergeTarget == mergeTarget: requirement.impactedLanes.add(lane.tid()) return continueTarget = None if instruction.continue_location(): continueTarget = self._module.get_bb_entry(instruction.continue_location()) requirement = ConvergenceRequirement( mergeTarget, continueTarget, set([lane.tid()]) ) self._convergence_requirements.append(requirement) # Returns true if some instructions are scheduled for execution. def _has_tasks(self) -> bool: return len(self._tasks) > 0 # Returns the index of the first active lane right now. def get_first_active_lane_index(self) -> int: return min(self._active_lane_indices) # Broadcast the given value to all active lane registers. def broadcast_register(self, register: str, value: Any) -> None: for tid in self._active_lane_indices: self._lanes[tid].set_register(register, value) # Returns the entrypoint of the function associated with 'name'. # Calling this function with an invalid name is illegal. def _get_function_entry_from_name(self, name: str) -> InstructionPointer: register = self._module.getRegisterFromName(name) assert register is not None return self._module.get_function_entry(register) # Run the wave on the function 'function_name' until all lanes are dead. # If verbose is True, execution trace is printed. # Returns the value returned by the function for each lane. def run(self, function_name: str, verbose: bool = False) -> List[Any]: for t in self._lanes: self._module.initialize(t) entry_ip = self._get_function_entry_from_name(function_name) assert entry_ip is not None for t in self._lanes: t.do_call(entry_ip, "__shader_output__") self._tasks[self._lanes[0].ip()] = self._lanes while self._has_tasks(): ip, lanes = self._get_next_runnable_task() self._active_lane_indices = set([x.tid() for x in lanes]) if verbose: print( f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}" ) for lane in lanes: self._module.execute_one_instruction(lane, ip) if not lane.running(): continue if lane.ip() in self._tasks: self._tasks[lane.ip()].append(lane) else: self._tasks[lane.ip()] = [lane] if verbose and ip.instruction().has_output_register(): register = ip.instruction().output_register() print( f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}" ) output = [] for lane in self._lanes: output.append(lane.get_register("__shader_output__")) return output def dump_register(self, register: str) -> None: for lane in self._lanes: print( f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}" ) parser = argparse.ArgumentParser( description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "-i", "--input", help="Text SPIR-V to read from", required=False, default="-" ) parser.add_argument("-f", "--function", help="Function to execute") parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False) parser.add_argument( "-e", "--expects", help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.", ) parser.add_argument("-v", "--verbose", help="verbose", action="store_true") args = parser.parse_args() def load_instructions(filename: str): if filename is None: return [] if filename.strip() != "-": try: with open(filename, "r") as f: lines = f.read().split("\n") except Exception: # (FileNotFoundError, PermissionError): return [] else: lines = sys.stdin.readlines() # Remove leading/trailing whitespaces. lines = [x.strip() for x in lines] # Strip comments. lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)] instructions = [] for i in [Instruction(x) for x in lines]: out = parseInstruction(i) if out != None: instructions.append(out) return instructions def main(): if args.expects is None or not RE_EXPECTS.match(args.expects): print("Invalid format for --expects/-e flag.", file=sys.stderr) sys.exit(1) if args.function is None: print("Invalid format for --function/-f flag.", file=sys.stderr) sys.exit(1) try: int(args.wave) except ValueError: print("Invalid format for --wave/-w flag.", file=sys.stderr) sys.exit(1) expected_results = [int(x.strip()) for x in args.expects.split(",")] wave_size = int(args.wave) if len(expected_results) != wave_size: print("Wave size != expected result array size", file=sys.stderr) sys.exit(1) instructions = load_instructions(args.input) if len(instructions) == 0: print("Invalid input. Expected a text SPIR-V module.") sys.exit(1) module = Module(instructions) if args.verbose: module.dump() module.dump(args.function) function_names = module.get_function_names() if args.function not in function_names: print( f"'{args.function}' function not found. Known functions are:", file=sys.stderr, ) for name in function_names: print(f" - {name}", file=sys.stderr) sys.exit(1) wave = Wave(module, wave_size) results = wave.run(args.function, verbose=args.verbose) if expected_results != results: print("Expected != Observed", file=sys.stderr) print(f"{expected_results} != {results}", file=sys.stderr) sys.exit(1) sys.exit(0) main()