diff options
Diffstat (limited to 'mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp')
-rw-r--r-- | mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 170 |
1 files changed, 96 insertions, 74 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fd4cabbad..1b069c6 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -32,7 +32,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> @@ -1737,10 +1736,10 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { // Parser, printer and verifier for Target //===----------------------------------------------------------------------===// -// Helper function to get bitwise AND of `value` and 'flag' -static uint64_t mapTypeToBitFlag(uint64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) { - return value & llvm::to_underlying(flag); +// Helper function to get bitwise AND of `value` and 'flag' then return it as a +// boolean +static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) { + return (value & flag) == flag; } /// Parses a map_entries map type from a string format back into its numeric @@ -1748,10 +1747,9 @@ static uint64_t mapTypeToBitFlag(uint64_t value, /// /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `? /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` ) -static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - +static ParseResult parseMapClause(OpAsmParser &parser, + ClauseMapFlagsAttr &mapType) { + ClauseMapFlags mapTypeBits = ClauseMapFlags::none; // This simply verifies the correct keyword is read in, the // keyword itself is stored inside of the operation auto parseTypeAndMod = [&]() -> ParseResult { @@ -1760,35 +1758,64 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { return failure(); if (mapTypeMod == "always") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + mapTypeBits |= ClauseMapFlags::always; if (mapTypeMod == "implicit") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mapTypeBits |= ClauseMapFlags::implicit; if (mapTypeMod == "ompx_hold") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + mapTypeBits |= ClauseMapFlags::ompx_hold; if (mapTypeMod == "close") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + mapTypeBits |= ClauseMapFlags::close; if (mapTypeMod == "present") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= ClauseMapFlags::present; if (mapTypeMod == "to") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapTypeBits |= ClauseMapFlags::to; if (mapTypeMod == "from") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::from; if (mapTypeMod == "tofrom") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from; if (mapTypeMod == "delete") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + mapTypeBits |= ClauseMapFlags::del; + + if (mapTypeMod == "storage") + mapTypeBits |= ClauseMapFlags::storage; if (mapTypeMod == "return_param") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mapTypeBits |= ClauseMapFlags::return_param; + + if (mapTypeMod == "private") + mapTypeBits |= ClauseMapFlags::priv; + + if (mapTypeMod == "literal") + mapTypeBits |= ClauseMapFlags::literal; + + if (mapTypeMod == "attach") + mapTypeBits |= ClauseMapFlags::attach; + + if (mapTypeMod == "attach_always") + mapTypeBits |= ClauseMapFlags::attach_always; + + if (mapTypeMod == "attach_none") + mapTypeBits |= ClauseMapFlags::attach_none; + + if (mapTypeMod == "attach_auto") + mapTypeBits |= ClauseMapFlags::attach_auto; + + if (mapTypeMod == "ref_ptr") + mapTypeBits |= ClauseMapFlags::ref_ptr; + + if (mapTypeMod == "ref_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptee; + + if (mapTypeMod == "ref_ptr_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; return success(); }; @@ -1796,9 +1823,8 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { if (parser.parseCommaSeparatedList(parseTypeAndMod)) return failure(); - mapType = parser.getBuilder().getIntegerAttr( - parser.getBuilder().getIntegerType(64, /*isSigned=*/false), - llvm::to_underlying(mapTypeBits)); + mapType = + parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits); return success(); } @@ -1806,60 +1832,62 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { /// Prints a map_entries map type from its numeric value out into its string /// format. static void printMapClause(OpAsmPrinter &p, Operation *op, - IntegerAttr mapType) { - uint64_t mapTypeBits = mapType.getUInt(); - - bool emitAllocRelease = true; + ClauseMapFlagsAttr mapType) { llvm::SmallVector<std::string, 4> mapTypeStrs; + ClauseMapFlags mapFlags = mapType.getValue(); // handling of always, close, present placed at the beginning of the string // to aid readability - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::always)) mapTypeStrs.push_back("always"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit)) mapTypeStrs.push_back("implicit"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold)) mapTypeStrs.push_back("ompx_hold"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::close)) mapTypeStrs.push_back("close"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::present)) mapTypeStrs.push_back("present"); // special handling of to/from/tofrom/delete and release/alloc, release + // alloc are the abscense of one of the other flags, whereas tofrom requires // both the to and from flag to be set. - bool to = mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - if (to && from) { - emitAllocRelease = false; + bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to); + bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from); + + if (to && from) mapTypeStrs.push_back("tofrom"); - } else if (from) { - emitAllocRelease = false; + else if (from) mapTypeStrs.push_back("from"); - } else if (to) { - emitAllocRelease = false; + else if (to) mapTypeStrs.push_back("to"); - } - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { - emitAllocRelease = false; + + if (mapTypeToBool(mapFlags, ClauseMapFlags::del)) mapTypeStrs.push_back("delete"); - } - if (mapTypeToBitFlag( - mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { - emitAllocRelease = false; + if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param)) mapTypeStrs.push_back("return_param"); - } - if (emitAllocRelease) - mapTypeStrs.push_back("exit_release_or_enter_alloc"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::storage)) + mapTypeStrs.push_back("storage"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::priv)) + mapTypeStrs.push_back("private"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::literal)) + mapTypeStrs.push_back("literal"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach)) + mapTypeStrs.push_back("attach"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always)) + mapTypeStrs.push_back("attach_always"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none)) + mapTypeStrs.push_back("attach_none"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto)) + mapTypeStrs.push_back("attach_auto"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr)) + mapTypeStrs.push_back("ref_ptr"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee)) + mapTypeStrs.push_back("ref_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) + mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapFlags == ClauseMapFlags::none) + mapTypeStrs.push_back("none"); for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { p << mapTypeStrs[i]; @@ -1963,21 +1991,15 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { return emitError(op->getLoc(), "missing map operation"); if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) { - uint64_t mapTypeBits = mapInfoOp.getMapType(); - - bool to = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - bool always = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); - bool close = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - bool implicit = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); + mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType(); + + bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to); + bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from); + bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del); + + bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always); + bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close); + bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit); if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del) return emitError(op->getLoc(), |