aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorTom Eccles <tom.eccles@arm.com>2024-06-27 12:06:22 +0100
committerGitHub <noreply@github.com>2024-06-27 12:06:22 +0100
commitd4e9ba59d6a2e334c983fa79f43b167d0583772b (patch)
tree0dd94e49839b92b0b83126985a1c47425ae92897 /mlir/lib
parent2a948d11c0540004dc906d948bac58398bafe928 (diff)
downloadllvm-d4e9ba59d6a2e334c983fa79f43b167d0583772b.zip
llvm-d4e9ba59d6a2e334c983fa79f43b167d0583772b.tar.gz
llvm-d4e9ba59d6a2e334c983fa79f43b167d0583772b.tar.bz2
[mlir][OpenMP] Standardise representation of reduction clause (#96215)
Now all operations with a reduction clause have an array of bools controlling whether each reduction variable should be passed by reference or value. This was already supported for Wsloop and Parallel. The new operations modified here currently have no flang lowering or translation to LLVMIR and so further changes are not needed. It isn't possible to check the verifier in mlir/test/Dialect/OpenMP/invalid.mlir because there is no way of parsing an operation to have an incorrect number of byref attributes. The verifier exists to pick up buggy operation builders or in-place operation modification.
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp71
1 files changed, 49 insertions, 22 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fbad80a..c0be9e9 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
}
+static DenseBoolArrayAttr
+makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
+ return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
+}
+
namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -499,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
return success();
})))
return failure();
- isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
+ isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
@@ -591,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateVarTypes.size(), false);
DenseBoolArrayAttr isByRef =
- DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
+ makeDenseBoolArrayAttr(op->getContext(), isByRefVec);
printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes, isByRef,
@@ -607,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
static ParseResult
parseReductionVarList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
- SmallVectorImpl<Type> &types,
+ SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
ArrayAttr &redcuctionSymbols) {
SmallVector<SymbolRefAttr> reductionVec;
+ SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
+ ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
+ isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
+ isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
@@ -628,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
+ std::optional<DenseBoolArrayAttr> isByRef,
std::optional<ArrayAttr> reductions) {
- for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
+ auto getByRef = [&](unsigned i) -> const char * {
+ if (!isByRef || !*isByRef)
+ return "";
+ assert(isByRef->empty() || i < isByRef->size());
+ if (!isByRef->empty() && (*isByRef)[i])
+ return "byref ";
+ return "";
+ };
+
+ for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
- p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
+ p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
<< reductionVars[i].getType();
}
}
@@ -641,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
OperandRange reductionVars,
- std::optional<ArrayRef<bool>> byRef = std::nullopt) {
+ std::optional<ArrayRef<bool>> byRef) {
if (!reductionVars.empty()) {
if (!reductions || reductions->size() != reductionVars.size())
return op->emitOpError()
<< "expected as many reduction symbol references "
"as reduction variables";
- if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
- assert(byRef);
- else
- assert(!byRef); // TODO: support byref reductions on other operations
if (byRef && byRef->size() != reductionVars.size())
return op->emitError() << "expected as many reduction variable by "
"reference attributes as reduction variables";
@@ -1492,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
clauses.allocateVars, clauses.allocatorVars,
clauses.reductionVars,
- DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.procBindKindAttr, clauses.privateVars,
makeArrayAttr(ctx, clauses.privatizers));
@@ -1590,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
clauses.numTeamsUpperVar, clauses.ifVar,
clauses.threadLimitVar, clauses.allocateVars,
clauses.allocatorVars, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols));
}
@@ -1621,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- return verifyReductionVarList(*this, getReductions(), getReductionVars());
+ return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1633,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
SectionsOp::build(builder, state, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars,
clauses.nowaitAttr);
@@ -1643,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- return verifyReductionVarList(*this, getReductions(), getReductionVars());
+ return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref());
}
LogicalResult SectionsOp::verifyRegions() {
@@ -1733,7 +1752,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
// privatizers.
WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
clauses.reductionVars,
- DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.scheduleValAttr, clauses.scheduleChunkVar,
clauses.scheduleModAttr, clauses.scheduleSimdAttr,
@@ -1934,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
TaskOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.allocateVars, clauses.allocatorVars);
@@ -1945,7 +1965,8 @@ LogicalResult TaskOp::verify() {
return failed(verifyDependVars)
? verifyDependVars
: verifyReductionVarList(*this, getInReductions(),
- getInReductionVars());
+ getInReductionVars(),
+ getInReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1955,14 +1976,17 @@ LogicalResult TaskOp::verify() {
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
- TaskgroupOp::build(builder, state, clauses.taskReductionVars,
- makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
- clauses.allocateVars, clauses.allocatorVars);
+ TaskgroupOp::build(
+ builder, state, clauses.taskReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
+ makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars);
}
LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
- getTaskReductionVars());
+ getTaskReductionVars(),
+ getTaskReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1976,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
TaskloopOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
clauses.numTasksVar, clauses.nogroupAttr);
@@ -1994,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
- if (failed(
- verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
+ if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref())) ||
failed(verifyReductionVarList(*this, getInReductions(),
- getInReductionVars())))
+ getInReductionVars(),
+ getInReductionVarsByref())))
return failure();
if (!getReductionVars().empty() && getNogroup())