aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp')
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp154
1 files changed, 154 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 32ebe06..5672942 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/InterleavedRange.h"
#include <cstddef>
#include <iterator>
#include <optional>
@@ -3385,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
Value result = getResult();
auto [newCli, gen, cons] = decodeCli(result);
+ // Structured binding `gen` cannot be captured in lambdas before C++20
+ OpOperand *generator = gen;
+
// Derive the CLI variable name from its generator:
// * "canonloop" for omp.canonical_loop
// * custom name for loop transformation generatees
@@ -3403,6 +3407,24 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
.Case([&](UnrollHeuristicOp op) -> std::string {
llvm_unreachable("heuristic unrolling does not generate a loop");
})
+ .Case([&](TileOp op) -> std::string {
+ auto [generateesFirst, generateesCount] =
+ op.getGenerateesODSOperandIndexAndLength();
+ unsigned firstGrid = generateesFirst;
+ unsigned firstIntratile = generateesFirst + generateesCount / 2;
+ unsigned end = generateesFirst + generateesCount;
+ unsigned opnum = generator->getOperandNumber();
+ // In the OpenMP apply and looprange clauses, indices are 1-based
+ if (firstGrid <= opnum && opnum < firstIntratile) {
+ unsigned gridnum = opnum - firstGrid + 1;
+ return ("grid" + Twine(gridnum)).str();
+ }
+ if (firstIntratile <= opnum && opnum < end) {
+ unsigned intratilenum = opnum - firstIntratile + 1;
+ return ("intratile" + Twine(intratilenum)).str();
+ }
+ llvm_unreachable("Unexpected generatee argument");
+ })
.Default([&](Operation *op) {
assert(false && "TODO: Custom name for this operation");
return "transformed";
@@ -3632,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
}
//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
+static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
+ OperandRange generatees,
+ OperandRange applyees) {
+ if (!generatees.empty())
+ p << '(' << llvm::interleaved(generatees) << ')';
+
+ if (!applyees.empty())
+ p << " <- (" << llvm::interleaved(applyees) << ')';
+}
+
+static ParseResult parseLoopTransformClis(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) {
+ if (parser.parseOptionalLess()) {
+ // Syntax 1: generatees present
+
+ if (parser.parseOperandList(generateesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ if (parser.parseLess())
+ return failure();
+ } else {
+ // Syntax 2: generatees omitted
+ }
+
+ // Parse `<-` (`<` has already been parsed)
+ if (parser.parseMinus())
+ return failure();
+
+ if (parser.parseOperandList(applyeesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ return success();
+}
+
+LogicalResult TileOp::verify() {
+ if (getApplyees().empty())
+ return emitOpError() << "must apply to at least one loop";
+
+ if (getSizes().size() != getApplyees().size())
+ return emitOpError() << "there must be one tile size for each applyee";
+
+ if (!getGeneratees().empty() &&
+ 2 * getSizes().size() != getGeneratees().size())
+ return emitOpError()
+ << "expecting two times the number of generatees than applyees";
+
+ DenseSet<Value> parentIVs;
+
+ Value parent = getApplyees().front();
+ for (auto &&applyee : llvm::drop_begin(getApplyees())) {
+ auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
+ auto [create, gen, cons] = decodeCli(applyee);
+
+ if (!parentGen)
+ return emitOpError() << "applyee CLI has no generator";
+
+ auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
+ if (!parentGen)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ parentIVs.insert(parentLoop.getInductionVar());
+
+ if (!gen)
+ return emitOpError() << "applyee CLI has no generator";
+ auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
+ if (!loop)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ // Canonical loop must be perfectly nested, i.e. the body of the parent must
+ // only contain the omp.canonical_loop of the nested loops, and
+ // omp.terminator
+ bool isPerfectlyNested = [&]() {
+ auto &parentBody = parentLoop.getRegion();
+ if (!parentBody.hasOneBlock())
+ return false;
+ auto &parentBlock = parentBody.getBlocks().front();
+
+ auto nestedLoopIt = parentBlock.begin();
+ if (nestedLoopIt == parentBlock.end() ||
+ (&*nestedLoopIt != loop.getOperation()))
+ return false;
+
+ auto termIt = std::next(nestedLoopIt);
+ if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
+ return false;
+
+ if (std::next(termIt) != parentBlock.end())
+ return false;
+
+ return true;
+ }();
+ if (!isPerfectlyNested)
+ return emitOpError() << "tiled loop nest must be perfectly nested";
+
+ if (parentIVs.contains(loop.getTripCount()))
+ return emitOpError() << "tiled loop nest must be rectangular";
+
+ parent = applyee;
+ }
+
+ // TODO: The tile sizes must be computed before the loop, but checking this
+ // requires dominance analysis. For instance:
+ //
+ // %canonloop = omp.new_cli
+ // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // // write to %x
+ // omp.terminator
+ // }
+ // %ts = llvm.load %x
+ // omp.tile <- (%canonloop) sizes(%ts : i32)
+
+ return success();
+}
+
+std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_applyees);
+}
+
+std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_generatees);
+}
+
+//===----------------------------------------------------------------------===//
// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//