aboutsummaryrefslogtreecommitdiff
path: root/llvm/utils/spirv-sim/spirv-sim.py
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/utils/spirv-sim/spirv-sim.py')
-rwxr-xr-xllvm/utils/spirv-sim/spirv-sim.py658
1 files changed, 0 insertions, 658 deletions
diff --git a/llvm/utils/spirv-sim/spirv-sim.py b/llvm/utils/spirv-sim/spirv-sim.py
deleted file mode 100755
index 428b0ca..0000000
--- a/llvm/utils/spirv-sim/spirv-sim.py
+++ /dev/null
@@ -1,658 +0,0 @@
-#!/usr/bin/env python3
-
-from __future__ import annotations
-from dataclasses import dataclass
-from instructions import *
-from typing import Any, Iterable, Callable, Optional, Tuple, List, Dict
-import argparse
-import fileinput
-import inspect
-import re
-import sys
-
-RE_EXPECTS = re.compile(r"^([0-9]+,)*[0-9]+$")
-
-
-# Parse the SPIR-V instructions. Some instructions are ignored because
-# not required to simulate this module.
-# Instructions are to be implemented in instructions.py
-def parseInstruction(i):
- IGNORED = set(
- [
- "OpCapability",
- "OpMemoryModel",
- "OpExecutionMode",
- "OpExtension",
- "OpSource",
- "OpTypeInt",
- "OpTypeStruct",
- "OpTypeFloat",
- "OpTypeBool",
- "OpTypeVoid",
- "OpTypeFunction",
- "OpTypePointer",
- "OpTypeArray",
- ]
- )
- if i.opcode() in IGNORED:
- return None
-
- try:
- Type = getattr(sys.modules["instructions"], i.opcode())
- except AttributeError:
- raise RuntimeError(f"Unsupported instruction {i}")
- if not inspect.isclass(Type):
- raise RuntimeError(
- f"{i} instruction definition is not a class. Did you used 'def' instead of 'class'?"
- )
- return Type(i.line)
-
-
-# Split a list of instructions into pieces. Pieces are delimited by instructions of the type splitType.
-# The delimiter is the first instruction of the next piece.
-# This function returns no empty pieces:
-# - if 2 subsequent delimiters will mean 2 pieces. One with only the first delimiter, and the second
-# with the delimiter and following instructions.
-# - if the first instruction is a delimiter, the first piece will begin with this delimiter.
-def splitInstructions(
- splitType: type, instructions: Iterable[Instruction]
-) -> List[List[Instruction]]:
- blocks: List[List[Instruction]] = [[]]
- for instruction in instructions:
- if isinstance(instruction, splitType) and len(blocks[-1]) > 0:
- blocks.append([])
- blocks[-1].append(instruction)
- return blocks
-
-
-# Defines a BasicBlock in the simulator.
-# Begins at an OpLabel, and ends with a control-flow instruction.
-class BasicBlock:
- def __init__(self, instructions) -> None:
- assert isinstance(instructions[0], OpLabel)
- # The name of the basic block, which is the register of the leading
- # OpLabel.
- self._name = instructions[0].output_register()
- # The list of instructions belonging to this block.
- self._instructions = instructions[1:]
-
- # Returns the name of this basic block.
- def name(self):
- return self._name
-
- # Returns the instruction at index in this basic block.
- def __getitem__(self, index: int) -> Instruction:
- return self._instructions[index]
-
- # Returns the number of instructions in this basic block, excluding the
- # leading OpLabel.
- def __len__(self):
- return len(self._instructions)
-
- def dump(self):
- print(f" {self._name}:")
- for instruction in self._instructions:
- print(f" {instruction}")
-
-
-# Defines a Function in the simulator.
-class Function:
- def __init__(self, instructions) -> None:
- assert isinstance(instructions[0], OpFunction)
- # The name of the function (name of the register returned by OpFunction).
- self._name: str = instructions[0].output_register()
- # The list of basic blocks that belongs to this function.
- self._basic_blocks: List[BasicBlock] = []
- # The variables local to this function.
- self._variables: List[OpVariable] = [
- x for x in instructions if isinstance(x, OpVariable)
- ]
-
- assert isinstance(instructions[-1], OpFunctionEnd)
- body = filter(lambda x: not isinstance(x, OpVariable), instructions[1:-1])
- for block in splitInstructions(OpLabel, body):
- self._basic_blocks.append(BasicBlock(block))
-
- # Returns the name of this function.
- def name(self) -> str:
- return self._name
-
- # Returns the basic block at index in this function.
- def __getitem__(self, index: int) -> BasicBlock:
- return self._basic_blocks[index]
-
- # Returns the index of the basic block with the given name if found,
- # -1 otherwise.
- def get_bb_index(self, name) -> int:
- for i in range(len(self._basic_blocks)):
- if self._basic_blocks[i].name() == name:
- return i
- return -1
-
- def dump(self):
- print(" Variables:")
- for var in self._variables:
- print(f" {var}")
- print(" Blocks:")
- for bb in self._basic_blocks:
- bb.dump()
-
-
-# Represents an instruction pointer in the simulator.
-@dataclass
-class InstructionPointer:
- # The current function the IP points to.
- function: Function
- # The basic block index in function IP points to.
- basic_block: int
- # The instruction in basic_block IP points to.
- instruction_index: int
-
- def __str__(self):
- bb = self.function[self.basic_block]
- i = bb[self.instruction_index]
- return f"{bb.name()}:{self.instruction_index} in {self.function.name()} | {i}"
-
- def __hash__(self):
- return hash((self.function.name(), self.basic_block, self.instruction_index))
-
- # Returns the basic block IP points to.
- def bb(self) -> BasicBlock:
- return self.function[self.basic_block]
-
- # Returns the instruction IP points to.
- def instruction(self):
- return self.function[self.basic_block][self.instruction_index]
-
- # Increment IP by 1. This only works inside a basic-block boundary.
- # Incrementing IP when at the boundary of a basic block will fail.
- def __add__(self, value: int):
- bb = self.function[self.basic_block]
- assert len(bb) > self.instruction_index + value
- return InstructionPointer(
- self.function, self.basic_block, self.instruction_index + value
- )
-
-
-# Defines a Lane in this simulator.
-class Lane:
- # The registers known by this lane.
- _registers: Dict[str, Any]
- # The current IP of this lane.
- _ip: Optional[InstructionPointer]
- # If this lane running.
- _running: bool
- # The wave this lane belongs to.
- _wave: Wave
- # The callstack of this lane. Each tuple represents 1 call.
- # The first element is the IP the function will return to.
- # The second element is the callback to call to store the return value
- # into the correct register.
- _callstack: List[Tuple[InstructionPointer, Callable[[Any], None]]]
-
- _previous_bb: Optional[BasicBlock]
- _current_bb: Optional[BasicBlock]
-
- def __init__(self, wave: Wave, tid: int) -> None:
- self._registers = dict()
- self._ip = None
- self._running = True
- self._wave = wave
- self._callstack = []
-
- # The index of this lane in the wave.
- self._tid = tid
- # The last BB this lane was executing into.
- self._previous_bb = None
- # The current BB this lane is executing into.
- self._current_bb = None
-
- # Returns the lane/thread ID of this lane in its wave.
- def tid(self) -> int:
- return self._tid
-
- # Returns true is this lane if the first by index in the current active tangle.
- def is_first_active_lane(self) -> bool:
- return self._tid == self._wave.get_first_active_lane_index()
-
- # Broadcast value into the registers of all active lanes.
- def broadcast_register(self, register: str, value: Any) -> None:
- self._wave.broadcast_register(register, value)
-
- # Returns the IP this lane is currently at.
- def ip(self) -> InstructionPointer:
- assert self._ip is not None
- return self._ip
-
- # Returns true if this lane is running, false otherwise.
- # Running means not dead. An inactive lane is running.
- def running(self) -> bool:
- return self._running
-
- # Set the register at "name" to "value" in this lane.
- def set_register(self, name: str, value: Any) -> None:
- self._registers[name] = value
-
- # Get the value in register "name" in this lane.
- # If allow_undef is true, fetching an unknown register won't fail.
- def get_register(self, name: str, allow_undef: bool = False) -> Optional[Any]:
- if allow_undef and name not in self._registers:
- return None
- return self._registers[name]
-
- def set_ip(self, ip: InstructionPointer) -> None:
- if ip.bb() != self._current_bb:
- self._previous_bb = self._current_bb
- self._current_bb = ip.bb()
- self._ip = ip
-
- def get_previous_bb_name(self):
- return self._previous_bb.name()
-
- def handle_convergence_header(self, instruction):
- self._wave.handle_convergence_header(self, instruction)
-
- def do_call(self, ip, output_register):
- return_ip = None if self._ip is None else self._ip + 1
- self._callstack.append(
- (return_ip, lambda value: self.set_register(output_register, value))
- )
- self.set_ip(ip)
-
- def do_return(self, value):
- ip, callback = self._callstack[-1]
- self._callstack.pop()
-
- callback(value)
- if len(self._callstack) == 0:
- self._running = False
- else:
- self.set_ip(ip)
-
-
-# Represents the SPIR-V module in the simulator.
-class Module:
- _functions: Dict[str, Function]
- _prolog: List[Instruction]
- _globals: List[Instruction]
- _name2reg: Dict[str, str]
- _reg2name: Dict[str, str]
-
- def __init__(self, instructions) -> None:
- chunks = splitInstructions(OpFunction, instructions)
-
- # The instructions located outside of all functions.
- self._prolog = chunks[0]
- # The functions in this module.
- self._functions = {}
- # Global variables in this module.
- self._globals = [
- x
- for x in instructions
- if isinstance(x, OpVariable) or issubclass(type(x), OpConstant)
- ]
-
- # Helper dictionaries to get real names of registers, or registers by names.
- self._name2reg = {}
- self._reg2name = {}
- for instruction in instructions:
- if isinstance(instruction, OpName):
- name = instruction.name()
- reg = instruction.decoratedRegister()
- self._name2reg[name] = reg
- self._reg2name[reg] = name
-
- for chunk in chunks[1:]:
- function = Function(chunk)
- assert function.name() not in self._functions
- self._functions[function.name()] = function
-
- # Returns the register matching "name" if any, None otherwise.
- # This assumes names are unique.
- def getRegisterFromName(self, name):
- if name in self._name2reg:
- return self._name2reg[name]
- return None
-
- # Returns the name given to "register" if any, None otherwise.
- def getNameFromRegister(self, register):
- if register in self._reg2name:
- return self._reg2name[register]
- return None
-
- # Initialize the module before wave execution begins.
- # See Instruction::static_execution for more details.
- def initialize(self, lane):
- for instruction in self._globals:
- instruction.static_execution(lane)
-
- # Initialize builtins
- for instruction in self._prolog:
- if isinstance(instruction, OpDecorate):
- instruction.static_execution(lane)
-
- def execute_one_instruction(self, lane: Lane, ip: InstructionPointer) -> None:
- ip.instruction().runtime_execution(self, lane)
-
- # Returns the first valid IP for the function defined by the given register.
- # Calling this with a register not returned by OpFunction is illegal.
- def get_function_entry(self, register: str) -> InstructionPointer:
- if register not in self._functions:
- raise RuntimeError(f"Function defining {register} not found.")
- return InstructionPointer(self._functions[register], 0, 0)
-
- # Returns the first valid IP for the basic block defined by register.
- # Calling this with a register not returned by an OpLabel is illegal.
- def get_bb_entry(self, register: str) -> InstructionPointer:
- for name, function in self._functions.items():
- index = function.get_bb_index(register)
- if index != -1:
- return InstructionPointer(function, index, 0)
- raise RuntimeError(f"Instruction defining {register} not found.")
-
- # Returns the list of function names in this module.
- # If an OpName exists for this function, returns the pretty name, else
- # returns the register name.
- def get_function_names(self):
- return [self.getNameFromRegister(reg) for reg, func in self._functions.items()]
-
- # Returns the global variables defined in this module.
- def variables(self) -> Iterable:
- return [x.output_register() for x in self._globals]
-
- def dump(self, function_name: Optional[str] = None):
- print("Module:")
- print(" globals:")
- for instruction in self._globals:
- print(f" {instruction}")
-
- if function_name is None:
- print(" functions:")
- for register, function in self._functions.items():
- name = self.getNameFromRegister(register)
- print(f" Function {register} ({name})")
- function.dump()
- return
-
- register = self.getRegisterFromName(function_name)
- print(f" function {register} ({function_name}):")
- if register is not None:
- self._functions[register].dump()
- else:
- print(f" error: cannot find function.")
-
-
-# Defines a convergence requirement for the simulation:
-# A list of lanes impacted by a merge and possibly the associated
-# continue target.
-@dataclass
-class ConvergenceRequirement:
- mergeTarget: InstructionPointer
- continueTarget: Optional[InstructionPointer]
- impactedLanes: set[int]
-
-
-Task = Dict[InstructionPointer, List[Lane]]
-
-
-# Defines a Lane group/Wave in the simulator.
-class Wave:
- # The module this wave will execute.
- _module: Module
- # The lanes this wave will be composed of.
- _lanes: List[Lane]
- # The instructions scheduled for execution.
- _tasks: Task
- # The actual requirements to comply with when executing instructions.
- # E.g: the set of lanes required to merge before executing the merge block.
- _convergence_requirements: List[ConvergenceRequirement]
- # The indices of the active lanes for the current executing instruction.
- _active_lane_indices: set[int]
-
- def __init__(self, module, wave_size: int) -> None:
- assert wave_size > 0
- self._module = module
- self._lanes = []
-
- for i in range(wave_size):
- self._lanes.append(Lane(self, i))
-
- self._tasks = {}
- self._convergence_requirements = []
- # The indices of the active lanes for the current executing instruction.
- self._active_lane_indices = set()
-
- # Returns True if the given IP can be executed for the given list of lanes.
- def _is_task_candidate(self, ip: InstructionPointer, lanes: List[Lane]):
- merged_lanes: set[int] = set()
- for lane in self._lanes:
- if not lane.running():
- merged_lanes.add(lane.tid())
-
- for requirement in self._convergence_requirements:
- # This task is not executing a merge or continue target.
- # Adding all lanes at those points into the ignore list.
- if requirement.mergeTarget != ip and requirement.continueTarget != ip:
- for tid in requirement.impactedLanes:
- if self._lanes[tid].ip() == requirement.mergeTarget:
- merged_lanes.add(tid)
- if self._lanes[tid].ip() == requirement.continueTarget:
- merged_lanes.add(tid)
- continue
-
- # This task is executing the current requirement continue/merge
- # target.
- for tid in requirement.impactedLanes:
- lane = self._lanes[tid]
- if not lane.running():
- continue
-
- if lane.tid() in merged_lanes:
- continue
-
- if ip == requirement.mergeTarget:
- if lane.ip() != requirement.mergeTarget:
- return False
- else:
- if (
- lane.ip() != requirement.mergeTarget
- and lane.ip() != requirement.continueTarget
- ):
- return False
- return True
-
- # Returns the next task we can schedule. This must always return a task.
- # Calling this when all lanes are dead is invalid.
- def _get_next_runnable_task(self) -> Tuple[InstructionPointer, List[Lane]]:
- candidate = None
- for ip, lanes in self._tasks.items():
- if len(lanes) == 0:
- continue
- if self._is_task_candidate(ip, lanes):
- candidate = ip
- break
-
- if candidate:
- lanes = self._tasks[candidate]
- del self._tasks[ip]
- return (candidate, lanes)
- raise RuntimeError("No task to execute. Deadlock?")
-
- # Handle an encountered merge instruction for the given lane.
- def handle_convergence_header(self, lane: Lane, instruction: MergeInstruction):
- mergeTarget = self._module.get_bb_entry(instruction.merge_location())
- for requirement in self._convergence_requirements:
- if requirement.mergeTarget == mergeTarget:
- requirement.impactedLanes.add(lane.tid())
- return
-
- continueTarget = None
- if instruction.continue_location():
- continueTarget = self._module.get_bb_entry(instruction.continue_location())
- requirement = ConvergenceRequirement(
- mergeTarget, continueTarget, set([lane.tid()])
- )
- self._convergence_requirements.append(requirement)
-
- # Returns true if some instructions are scheduled for execution.
- def _has_tasks(self) -> bool:
- return len(self._tasks) > 0
-
- # Returns the index of the first active lane right now.
- def get_first_active_lane_index(self) -> int:
- return min(self._active_lane_indices)
-
- # Broadcast the given value to all active lane registers.
- def broadcast_register(self, register: str, value: Any) -> None:
- for tid in self._active_lane_indices:
- self._lanes[tid].set_register(register, value)
-
- # Returns the entrypoint of the function associated with 'name'.
- # Calling this function with an invalid name is illegal.
- def _get_function_entry_from_name(self, name: str) -> InstructionPointer:
- register = self._module.getRegisterFromName(name)
- assert register is not None
- return self._module.get_function_entry(register)
-
- # Run the wave on the function 'function_name' until all lanes are dead.
- # If verbose is True, execution trace is printed.
- # Returns the value returned by the function for each lane.
- def run(self, function_name: str, verbose: bool = False) -> List[Any]:
- for t in self._lanes:
- self._module.initialize(t)
-
- entry_ip = self._get_function_entry_from_name(function_name)
- assert entry_ip is not None
- for t in self._lanes:
- t.do_call(entry_ip, "__shader_output__")
-
- self._tasks[self._lanes[0].ip()] = self._lanes
- while self._has_tasks():
- ip, lanes = self._get_next_runnable_task()
- self._active_lane_indices = set([x.tid() for x in lanes])
- if verbose:
- print(
- f"Executing with lanes {self._active_lane_indices}: {ip.instruction()}"
- )
-
- for lane in lanes:
- self._module.execute_one_instruction(lane, ip)
- if not lane.running():
- continue
-
- if lane.ip() in self._tasks:
- self._tasks[lane.ip()].append(lane)
- else:
- self._tasks[lane.ip()] = [lane]
-
- if verbose and ip.instruction().has_output_register():
- register = ip.instruction().output_register()
- print(
- f" {register:3} = {[ x.get_register(register, allow_undef=True) for x in lanes ]}"
- )
-
- output = []
- for lane in self._lanes:
- output.append(lane.get_register("__shader_output__"))
- return output
-
- def dump_register(self, register: str) -> None:
- for lane in self._lanes:
- print(
- f" Lane {lane.tid():2} | {register:3} = {lane.get_register(register)}"
- )
-
-
-parser = argparse.ArgumentParser(
- description="simulator", formatter_class=argparse.ArgumentDefaultsHelpFormatter
-)
-parser.add_argument(
- "-i", "--input", help="Text SPIR-V to read from", required=False, default="-"
-)
-parser.add_argument("-f", "--function", help="Function to execute")
-parser.add_argument("-w", "--wave", help="Wave size", default=32, required=False)
-parser.add_argument(
- "-e",
- "--expects",
- help="Expected results per lanes, expects a list of values. Ex: '1, 2, 3'.",
-)
-parser.add_argument("-v", "--verbose", help="verbose", action="store_true")
-args = parser.parse_args()
-
-
-def load_instructions(filename: str):
- if filename is None:
- return []
-
- if filename.strip() != "-":
- try:
- with open(filename, "r") as f:
- lines = f.read().split("\n")
- except Exception: # (FileNotFoundError, PermissionError):
- return []
- else:
- lines = sys.stdin.readlines()
-
- # Remove leading/trailing whitespaces.
- lines = [x.strip() for x in lines]
- # Strip comments.
- lines = [x for x in filter(lambda x: len(x) != 0 and x[0] != ";", lines)]
-
- instructions = []
- for i in [Instruction(x) for x in lines]:
- out = parseInstruction(i)
- if out != None:
- instructions.append(out)
- return instructions
-
-
-def main():
- if args.expects is None or not RE_EXPECTS.match(args.expects):
- print("Invalid format for --expects/-e flag.", file=sys.stderr)
- sys.exit(1)
- if args.function is None:
- print("Invalid format for --function/-f flag.", file=sys.stderr)
- sys.exit(1)
- try:
- int(args.wave)
- except ValueError:
- print("Invalid format for --wave/-w flag.", file=sys.stderr)
- sys.exit(1)
-
- expected_results = [int(x.strip()) for x in args.expects.split(",")]
- wave_size = int(args.wave)
- if len(expected_results) != wave_size:
- print("Wave size != expected result array size", file=sys.stderr)
- sys.exit(1)
-
- instructions = load_instructions(args.input)
- if len(instructions) == 0:
- print("Invalid input. Expected a text SPIR-V module.")
- sys.exit(1)
-
- module = Module(instructions)
- if args.verbose:
- module.dump()
- module.dump(args.function)
-
- function_names = module.get_function_names()
- if args.function not in function_names:
- print(
- f"'{args.function}' function not found. Known functions are:",
- file=sys.stderr,
- )
- for name in function_names:
- print(f" - {name}", file=sys.stderr)
- sys.exit(1)
-
- wave = Wave(module, wave_size)
- results = wave.run(args.function, verbose=args.verbose)
-
- if expected_results != results:
- print("Expected != Observed", file=sys.stderr)
- print(f"{expected_results} != {results}", file=sys.stderr)
- sys.exit(1)
- sys.exit(0)
-
-
-main()