diff options
-rw-r--r-- | mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h | 300 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h | 7 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 72 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 226 |
4 files changed, 595 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h new file mode 100644 index 0000000..6454076 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -0,0 +1,300 @@ +//===-- OpenMPClauseOperands.h ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the structures defining MLIR operands associated with each +// OpenMP clause, and structures grouping the appropriate operands for each +// construct. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ +#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc" + +namespace mlir { +namespace omp { + +//===----------------------------------------------------------------------===// +// Mixin structures defining MLIR operands associated with each OpenMP clause. +//===----------------------------------------------------------------------===// + +struct AlignedClauseOps { + llvm::SmallVector<Value> alignedVars; + llvm::SmallVector<Attribute> alignmentAttrs; +}; + +struct AllocateClauseOps { + llvm::SmallVector<Value> allocatorVars, allocateVars; +}; + +struct CollapseClauseOps { + llvm::SmallVector<Value> loopLBVar, loopUBVar, loopStepVar; +}; + +struct CopyprivateClauseOps { + llvm::SmallVector<Value> copyprivateVars; + llvm::SmallVector<Attribute> copyprivateFuncs; +}; + +struct DependClauseOps { + llvm::SmallVector<Attribute> dependTypeAttrs; + llvm::SmallVector<Value> dependVars; +}; + +struct DeviceClauseOps { + Value deviceVar; +}; + +struct DeviceTypeClauseOps { + // The default capture type. + DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any; +}; + +struct DistScheduleClauseOps { + UnitAttr distScheduleStaticAttr; + Value distScheduleChunkSizeVar; +}; + +struct DoacrossClauseOps { + llvm::SmallVector<Value> doacrossVectorVars; + ClauseDependAttr doacrossDependTypeAttr; + IntegerAttr doacrossNumLoopsAttr; +}; + +struct FinalClauseOps { + Value finalVar; +}; + +struct GrainsizeClauseOps { + Value grainsizeVar; +}; + +struct HintClauseOps { + IntegerAttr hintAttr; +}; + +struct IfClauseOps { + Value ifVar; +}; + +struct InReductionClauseOps { + llvm::SmallVector<Value> inReductionVars; + llvm::SmallVector<Attribute> inReductionDeclSymbols; +}; + +struct LinearClauseOps { + llvm::SmallVector<Value> linearVars, linearStepVars; +}; + +struct LoopRelatedOps { + UnitAttr loopInclusiveAttr; +}; + +struct MapClauseOps { + llvm::SmallVector<Value> mapVars; +}; + +struct MergeableClauseOps { + UnitAttr mergeableAttr; +}; + +struct NameClauseOps { + StringAttr nameAttr; +}; + +struct NogroupClauseOps { + UnitAttr nogroupAttr; +}; + +struct NontemporalClauseOps { + llvm::SmallVector<Value> nontemporalVars; +}; + +struct NowaitClauseOps { + UnitAttr nowaitAttr; +}; + +struct NumTasksClauseOps { + Value numTasksVar; +}; + +struct NumTeamsClauseOps { + Value numTeamsLowerVar, numTeamsUpperVar; +}; + +struct NumThreadsClauseOps { + Value numThreadsVar; +}; + +struct OrderClauseOps { + ClauseOrderKindAttr orderAttr; +}; + +struct OrderedClauseOps { + IntegerAttr orderedAttr; +}; + +struct ParallelizationLevelClauseOps { + UnitAttr parLevelSimdAttr; +}; + +struct PriorityClauseOps { + Value priorityVar; +}; + +struct PrivateClauseOps { + // SSA values that correspond to "original" values being privatized. + // They refer to the SSA value outside the OpenMP region from which a clone is + // created inside the region. + llvm::SmallVector<Value> privateVars; + // The list of symbols referring to delayed privatizer ops (i.e. `omp.private` + // ops). + llvm::SmallVector<Attribute> privatizers; +}; + +struct ProcBindClauseOps { + ClauseProcBindKindAttr procBindKindAttr; +}; + +struct ReductionClauseOps { + llvm::SmallVector<Value> reductionVars; + llvm::SmallVector<Attribute> reductionDeclSymbols; + UnitAttr reductionByRefAttr; +}; + +struct SafelenClauseOps { + IntegerAttr safelenAttr; +}; + +struct ScheduleClauseOps { + ClauseScheduleKindAttr scheduleValAttr; + ScheduleModifierAttr scheduleModAttr; + Value scheduleChunkVar; + UnitAttr scheduleSimdAttr; +}; + +struct SimdlenClauseOps { + IntegerAttr simdlenAttr; +}; + +struct TaskReductionClauseOps { + llvm::SmallVector<Value> taskReductionVars; + llvm::SmallVector<Attribute> taskReductionDeclSymbols; +}; + +struct ThreadLimitClauseOps { + Value threadLimitVar; +}; + +struct UntiedClauseOps { + UnitAttr untiedAttr; +}; + +struct UseDeviceClauseOps { + llvm::SmallVector<Value> useDevicePtrVars, useDeviceAddrVars; +}; + +//===----------------------------------------------------------------------===// +// Structures defining clause operands associated with each OpenMP leaf +// construct. +// +// These mirror the arguments expected by the corresponding OpenMP MLIR ops. +//===----------------------------------------------------------------------===// + +namespace detail { +template <typename... Mixins> +struct Clauses : public Mixins... {}; +} // namespace detail + +using CriticalClauseOps = detail::Clauses<HintClauseOps, NameClauseOps>; + +// TODO `indirect` clause. +using DeclareTargetClauseOps = detail::Clauses<DeviceTypeClauseOps>; + +using DistributeClauseOps = + detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps, + PrivateClauseOps>; + +// TODO `filter` clause. +using MaskedClauseOps = detail::Clauses<>; + +using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>; + +using OrderedRegionClauseOps = detail::Clauses<ParallelizationLevelClauseOps>; + +using ParallelClauseOps = + detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps, + PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>; + +using SectionsClauseOps = detail::Clauses<AllocateClauseOps, NowaitClauseOps, + PrivateClauseOps, ReductionClauseOps>; + +// TODO `linear` clause. +using SimdLoopClauseOps = + detail::Clauses<AlignedClauseOps, CollapseClauseOps, IfClauseOps, + LoopRelatedOps, NontemporalClauseOps, OrderClauseOps, + PrivateClauseOps, ReductionClauseOps, SafelenClauseOps, + SimdlenClauseOps>; + +using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps, + NowaitClauseOps, PrivateClauseOps>; + +// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators` +// clauses. +using TargetClauseOps = + detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps, + IfClauseOps, InReductionClauseOps, MapClauseOps, + NowaitClauseOps, PrivateClauseOps, ReductionClauseOps, + ThreadLimitClauseOps>; + +using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps, + MapClauseOps, UseDeviceClauseOps>; + +using TargetEnterExitUpdateDataClauseOps = + detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps, + NowaitClauseOps>; + +// TODO `affinity`, `detach` clauses. +using TaskClauseOps = + detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps, + IfClauseOps, InReductionClauseOps, MergeableClauseOps, + PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>; + +using TaskgroupClauseOps = + detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>; + +using TaskloopClauseOps = + detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps, + GrainsizeClauseOps, IfClauseOps, InReductionClauseOps, + LoopRelatedOps, MergeableClauseOps, NogroupClauseOps, + NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps, + ReductionClauseOps, UntiedClauseOps>; + +using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>; + +using TeamsClauseOps = + detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps, + PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>; + +using WsloopClauseOps = + detail::Clauses<AllocateClauseOps, CollapseClauseOps, LinearClauseOps, + LoopRelatedOps, NowaitClauseOps, OrderClauseOps, + OrderedClauseOps, PrivateClauseOps, ReductionClauseOps, + ScheduleClauseOps>; + +} // namespace omp +} // namespace mlir + +#endif // MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h index 23509c5..c656bdc 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -26,11 +26,10 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" -#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" -#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc" +#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h" + +#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 4574518..a38a82f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -287,7 +287,8 @@ def ParallelOp : OpenMP_Op<"parallel", [ let regions = (region AnyRegion:$region); let builders = [ - OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)> + OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, + OpBuilder<(ins CArg<"const ParallelClauseOps &">:$clauses)> ]; let extraClassDeclaration = [{ /// Returns the number of reduction variables. @@ -362,6 +363,10 @@ def TeamsOp : OpenMP_Op<"teams", [ let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TeamsClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to` @@ -451,6 +456,10 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments, let regions = (region SizedRegion<1>:$region); + let builders = [ + OpBuilder<(ins CArg<"const SectionsClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `reduction` `(` custom<ReductionVarList>( @@ -495,6 +504,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const SingleClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`allocate` `(` custom<AllocateAndAllocator>( @@ -601,6 +614,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound, "ValueRange":$step, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>, + OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)> ]; let regions = (region AnyRegion:$region); @@ -698,6 +712,11 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments, ); let regions = (region AnyRegion:$region); + + let builders = [ + OpBuilder<(ins CArg<"const SimdLoopClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`aligned` `(` custom<AlignedClause>($aligned_vars, type($aligned_vars), @@ -781,6 +800,10 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const DistributeClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`dist_schedule_static` $dist_schedule_static |`chunk_size` `(` $chunk_size `:` type($chunk_size) `)` @@ -883,6 +906,9 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments, Variadic<AnyType>:$allocate_vars, Variadic<AnyType>:$allocators_vars); let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskClauseOps &">:$clauses)> + ]; let assemblyFormat = [{ oilist(`if` `(` $if_expr `)` |`final` `(` $final_expr `)` @@ -1037,6 +1063,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskloopClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `)` |`final` `(` $final_expr `)` @@ -1106,6 +1136,10 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskgroupClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`task_reduction` `(` custom<ReductionVarList>( @@ -1432,6 +1466,10 @@ def TargetDataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TargetDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1486,6 +1524,10 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data", UnitAttr:$nowait, Variadic<AnyType>:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1540,6 +1582,10 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data", UnitAttr:$nowait, Variadic<AnyType>:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1596,6 +1642,10 @@ def TargetUpdateOp: OpenMP_Op<"target_update", [AttrSizedOperandSegments, UnitAttr:$nowait, Variadic<OpenMP_PointerLikeType>:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1649,6 +1699,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TargetClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `if` `(` $if_expr `)` | `device` `(` $device `:` type($device) `)` @@ -1693,6 +1747,10 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> { let arguments = (ins SymbolNameAttr:$sym_name, DefaultValuedAttr<I64Attr, "0">:$hint_val); + let builders = [ + OpBuilder<(ins CArg<"const CriticalClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ $sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`) attr-dict @@ -1773,6 +1831,10 @@ def OrderedOp : OpenMP_Op<"ordered"> { ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$num_loops_val, Variadic<AnyType>:$depend_vec_vars); + let builders = [ + OpBuilder<(ins CArg<"const OrderedOpClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ ( `depend_type` `` $depend_type_val^ )? ( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )? @@ -1797,6 +1859,10 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region"> { let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const OrderedRegionClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}]; let hasVerifier = 1; } @@ -1812,6 +1878,10 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> { of the current task. }]; + let builders = [ + OpBuilder<(ins CArg<"const TaskwaitClauseOps &">:$clauses)> + ]; + let assemblyFormat = "attr-dict"; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index a043431..5436553 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -41,6 +41,11 @@ using namespace mlir; using namespace mlir::omp; +static ArrayAttr makeArrayAttr(MLIRContext *context, + llvm::ArrayRef<Attribute> attrs) { + return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs); +} + namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, @@ -1161,6 +1166,17 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) { return success(); } +//===----------------------------------------------------------------------===// +// TargetDataOp +//===----------------------------------------------------------------------===// + +void TargetDataOp::build(OpBuilder &builder, OperationState &state, + const TargetDataClauseOps &clauses) { + TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar, + clauses.useDevicePtrVars, clauses.useDeviceAddrVars, + clauses.mapVars); +} + LogicalResult TargetDataOp::verify() { if (getMapOperands().empty() && getUseDevicePtr().empty() && getUseDeviceAddr().empty()) { @@ -1170,6 +1186,20 @@ LogicalResult TargetDataOp::verify() { return verifyMapClause(*this, getMapOperands()); } +//===----------------------------------------------------------------------===// +// TargetEnterDataOp +//===----------------------------------------------------------------------===// + +void TargetEnterDataOp::build( + OpBuilder &builder, OperationState &state, + const TargetEnterExitUpdateDataClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), + clauses.dependVars, clauses.nowaitAttr, + clauses.mapVars); +} + LogicalResult TargetEnterDataOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1177,6 +1207,20 @@ LogicalResult TargetEnterDataOp::verify() { : verifyMapClause(*this, getMapOperands()); } +//===----------------------------------------------------------------------===// +// TargetExitDataOp +//===----------------------------------------------------------------------===// + +void TargetExitDataOp::build( + OpBuilder &builder, OperationState &state, + const TargetEnterExitUpdateDataClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), + clauses.dependVars, clauses.nowaitAttr, + clauses.mapVars); +} + LogicalResult TargetExitDataOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1184,6 +1228,19 @@ LogicalResult TargetExitDataOp::verify() { : verifyMapClause(*this, getMapOperands()); } +//===----------------------------------------------------------------------===// +// TargetUpdateOp +//===----------------------------------------------------------------------===// + +void TargetUpdateOp::build(OpBuilder &builder, OperationState &state, + const TargetEnterExitUpdateDataClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), + clauses.dependVars, clauses.nowaitAttr, + clauses.mapVars); +} + LogicalResult TargetUpdateOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1191,6 +1248,22 @@ LogicalResult TargetUpdateOp::verify() { : verifyMapClause(*this, getMapOperands()); } +//===----------------------------------------------------------------------===// +// TargetOp +//===----------------------------------------------------------------------===// + +void TargetOp::build(OpBuilder &builder, OperationState &state, + const TargetClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars, + // inReductionDeclSymbols, privateVars, privatizers, reductionVars, + // reductionByRefAttr, reductionDeclSymbols. + TargetOp::build(builder, state, clauses.ifVar, clauses.deviceVar, + clauses.threadLimitVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), + clauses.dependVars, clauses.nowaitAttr, clauses.mapVars); +} + LogicalResult TargetOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1213,6 +1286,17 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, state.addAttributes(attributes); } +void ParallelOp::build(OpBuilder &builder, OperationState &state, + const ParallelClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + ParallelOp::build( + builder, state, clauses.ifVar, clauses.numThreadsVar, + clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), + clauses.procBindKindAttr, clauses.privateVars, + makeArrayAttr(ctx, clauses.privatizers), clauses.reductionByRefAttr); +} + template <typename OpType> static LogicalResult verifyPrivateVarList(OpType &op) { auto privateVars = op.getPrivateVars(); @@ -1280,6 +1364,17 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) { return true; } +void TeamsOp::build(OpBuilder &builder, OperationState &state, + const TeamsClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + TeamsOp::build(builder, state, clauses.numTeamsLowerVar, + clauses.numTeamsUpperVar, clauses.ifVar, + clauses.threadLimitVar, clauses.allocateVars, + clauses.allocatorVars, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols)); +} + LogicalResult TeamsOp::verify() { // Check parent region // TODO If nested inside of a target region, also check that it does not @@ -1312,9 +1407,19 @@ LogicalResult TeamsOp::verify() { } //===----------------------------------------------------------------------===// -// Verifier for SectionsOp +// SectionsOp //===----------------------------------------------------------------------===// +void SectionsOp::build(OpBuilder &builder, OperationState &state, + const SectionsClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + SectionsOp::build(builder, state, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), + clauses.allocateVars, clauses.allocatorVars, + clauses.nowaitAttr); +} + LogicalResult SectionsOp::verify() { if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( @@ -1334,6 +1439,20 @@ LogicalResult SectionsOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// SingleOp +//===----------------------------------------------------------------------===// + +void SingleOp::build(OpBuilder &builder, OperationState &state, + const SingleClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, privatizers. + SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.copyprivateVars, + makeArrayAttr(ctx, clauses.copyprivateFuncs), + clauses.nowaitAttr); +} + LogicalResult SingleOp::verify() { // Check for allocate clause restrictions if (getAllocateVars().size() != getAllocatorsVars().size()) @@ -1481,9 +1600,21 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, } //===----------------------------------------------------------------------===// -// Verifier for Simd construct [2.9.3.1] +// Simd construct [2.9.3.1] //===----------------------------------------------------------------------===// +void SimdLoopOp::build(OpBuilder &builder, OperationState &state, + const SimdLoopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars, + // privatizers, reductionDeclSymbols. + SimdLoopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs), + clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr, + clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr); +} + LogicalResult SimdLoopOp::verify() { if (this->getLowerBound().empty()) { return emitOpError() << "empty lowerbound for simd loop operation"; @@ -1504,9 +1635,17 @@ LogicalResult SimdLoopOp::verify() { } //===----------------------------------------------------------------------===// -// Verifier for Distribute construct [2.9.4.1] +// Distribute construct [2.9.4.1] //===----------------------------------------------------------------------===// +void DistributeOp::build(OpBuilder &builder, OperationState &state, + const DistributeClauseOps &clauses) { + // TODO Store clauses in op: privateVars, privatizers. + DistributeOp::build(builder, state, clauses.distScheduleStaticAttr, + clauses.distScheduleChunkSizeVar, clauses.allocateVars, + clauses.allocatorVars, clauses.orderAttr); +} + LogicalResult DistributeOp::verify() { if (this->getChunkSize() && !this->getDistScheduleStatic()) return emitOpError() << "chunk size set without " @@ -1630,6 +1769,19 @@ LogicalResult ReductionOp::verify() { //===----------------------------------------------------------------------===// // TaskOp //===----------------------------------------------------------------------===// + +void TaskOp::build(OpBuilder &builder, OperationState &state, + const TaskClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, privatizers. + TaskOp::build( + builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr, + clauses.mergeableAttr, clauses.inReductionVars, + makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars, + clauses.allocateVars, clauses.allocatorVars); +} + LogicalResult TaskOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1642,6 +1794,15 @@ LogicalResult TaskOp::verify() { //===----------------------------------------------------------------------===// // TaskgroupOp //===----------------------------------------------------------------------===// + +void TaskgroupOp::build(OpBuilder &builder, OperationState &state, + const TaskgroupClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + TaskgroupOp::build(builder, state, clauses.taskReductionVars, + makeArrayAttr(ctx, clauses.taskReductionDeclSymbols), + clauses.allocateVars, clauses.allocatorVars); +} + LogicalResult TaskgroupOp::verify() { return verifyReductionVarList(*this, getTaskReductions(), getTaskReductionVars()); @@ -1650,6 +1811,21 @@ LogicalResult TaskgroupOp::verify() { //===----------------------------------------------------------------------===// // TaskloopOp //===----------------------------------------------------------------------===// + +void TaskloopOp::build(OpBuilder &builder, OperationState &state, + const TaskloopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + TaskloopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar, + clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars, + makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar, + clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar, + clauses.numTasksVar, clauses.nogroupAttr); +} + SmallVector<Value> TaskloopOp::getAllReductionVars() { SmallVector<Value> allReductionNvars(getInReductionVars().begin(), getInReductionVars().end()); @@ -1703,14 +1879,33 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, state.addAttributes(attributes); } +void WsloopOp::build(OpBuilder &builder, OperationState &state, + const WsloopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: allocateVars, allocatorVars, privateVars, + // privatizers. + WsloopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.linearVars, clauses.linearStepVars, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr, + clauses.scheduleChunkVar, clauses.scheduleModAttr, + clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr, + clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr); +} + LogicalResult WsloopOp::verify() { return verifyReductionVarList(*this, getReductions(), getReductionVars()); } //===----------------------------------------------------------------------===// -// Verifier for critical construct (2.17.1) +// Critical construct (2.17.1) //===----------------------------------------------------------------------===// +void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state, + const CriticalClauseOps &clauses) { + CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr); +} + LogicalResult CriticalDeclareOp::verify() { return verifySynchronizationHint(*this, getHintVal()); } @@ -1730,9 +1925,15 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } //===----------------------------------------------------------------------===// -// Verifier for ordered construct +// Ordered construct //===----------------------------------------------------------------------===// +void OrderedOp::build(OpBuilder &builder, OperationState &state, + const OrderedOpClauseOps &clauses) { + OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr, + clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars); +} + LogicalResult OrderedOp::verify() { auto container = (*this)->getParentOfType<WsloopOp>(); if (!container || !container.getOrderedValAttr() || @@ -1749,6 +1950,11 @@ LogicalResult OrderedOp::verify() { return success(); } +void OrderedRegionOp::build(OpBuilder &builder, OperationState &state, + const OrderedRegionClauseOps &clauses) { + OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr); +} + LogicalResult OrderedRegionOp::verify() { // TODO: The code generation for ordered simd directive is not supported yet. if (getSimd()) @@ -1766,6 +1972,16 @@ LogicalResult OrderedRegionOp::verify() { } //===----------------------------------------------------------------------===// +// TaskwaitOp +//===----------------------------------------------------------------------===// + +void TaskwaitOp::build(OpBuilder &builder, OperationState &state, + const TaskwaitClauseOps &clauses) { + // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr. + TaskwaitOp::build(builder, state); +} + +//===----------------------------------------------------------------------===// // Verifier for AtomicReadOp //===----------------------------------------------------------------------===// |