aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp137
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp11
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp87
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt13
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp26
-rw-r--r--mlir/lib/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp14
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp28
-rw-r--r--mlir/lib/Dialect/OpenMP/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt14
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp445
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp8
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp31
22 files changed, 785 insertions, 142 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2a8c330..f0de4db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff095..37a45d4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
MLIRPass
MLIRTransforms
MLIRNVVMDialect
+ MLIROpenMPDialect
)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9a8a63e..794dda9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
if (!ShapedType::isDynamic(dim))
continue;
- Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
- auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ Value cst =
+ arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
+ auto dimOp =
+ tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
preservedOps.insert(dimOp);
dynamicDims.push_back(dimOp);
}
- auto allocation = rewriter.create<bufferization::AllocTensorOp>(
- tensor.getLoc(), type, dynamicDims);
+ auto allocation = bufferization::AllocTensorOp::create(
+ rewriter, tensor.getLoc(), type, dynamicDims);
// Set memory space if provided.
if (getMemorySpaceAttr())
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
@@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
// Only insert a materialization (typically bufferizes to a copy) when the
// value may be read from.
if (needsMaterialization) {
- auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
- tensor.getLoc(), tensor, allocated);
+ auto copy = bufferization::MaterializeInDestinationOp::create(
+ rewriter, tensor.getLoc(), tensor, allocated);
preservedOps.insert(copy);
promoted.push_back(copy.getResult());
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a..5e10ba3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,33 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = index::CmpOp::create(
+ builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate =
+ arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +138,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b7..c551fba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
@@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a..bd02516 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
-/// Replace `base, offset, sizes, strides =
-/// extract_strided_metadata(
-/// cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-/// ? dstTy.srcOffset
-/// : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-/// !srcSize.isDynamic()
-/// ? srcSize
-// : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-/// !srcStrides.isDynamic()
-/// ? srcStrides
-/// : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult
- matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- PatternRewriter &rewriter) const override {
- Value source = extractStridedMetadataOp.getSource();
- auto castOp = source.getDefiningOp<memref::CastOp>();
- if (!castOp)
- return failure();
-
- Location loc = extractStridedMetadataOp.getLoc();
- // Check if the source is suitable for extract_strided_metadata.
- SmallVector<Type> inferredReturnTypes;
- if (failed(extractStridedMetadataOp.inferReturnTypes(
- rewriter.getContext(), loc, {castOp.getSource()},
- /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
- inferredReturnTypes)))
- return rewriter.notifyMatchFailure(castOp,
- "cast source's type is incompatible");
-
- auto memrefType = cast<MemRefType>(source.getType());
- unsigned rank = memrefType.getRank();
- SmallVector<OpFoldResult> results;
- results.resize_for_overwrite(rank * 2 + 2);
-
- auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
- rewriter, loc, castOp.getSource());
-
- // Register the base_buffer.
- results[0] = newExtractStridedMetadata.getBaseBuffer();
-
- auto getConstantOrValue = [&rewriter](int64_t constant,
- OpFoldResult ofr) -> OpFoldResult {
- return ShapedType::isStatic(constant)
- ? OpFoldResult(rewriter.getIndexAttr(constant))
- : ofr;
- };
-
- auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
- assert(sourceStrides.size() == rank && "unexpected number of strides");
-
- // Register the new offset.
- results[1] =
- getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + rank;
- ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
- SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
- SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
- for (unsigned i = 0; i < rank; ++i) {
- results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
- results[strideStartIdx + i] =
- getConstantOrValue(sourceStrides[i], strides[i]);
- }
- rewriter.replaceOp(extractStridedMetadataOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, results));
- return success();
- }
-};
-
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..f305068
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIROpenACCAnalysis
+ OpenACCSupport.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+ MLIRSupport
+)
+
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
new file mode 100644
index 0000000..f6b4534
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -0,0 +1,26 @@
+//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the OpenACCSupport analysis interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+
+namespace mlir {
+namespace acc {
+
+std::string OpenACCSupport::getVariableName(Value v) {
+ if (impl)
+ return impl->getVariableName(v);
+ return acc::getVariableName(v);
+}
+
+} // namespace acc
+} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 7117520..e8a916e 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Utils)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ca0100..ca46629 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() {
}
//===----------------------------------------------------------------------===//
+// FirstprivateMapInitialOp
+//===----------------------------------------------------------------------===//
+LogicalResult acc::FirstprivateMapInitialOp::verify() {
+ if (getDataClause() != acc::DataClause::acc_firstprivate)
+ return emitError("data clause associated with firstprivate operation must "
+ "match its intent");
+ if (failed(checkVarAndVarType(*this)))
+ return failure();
+ if (failed(checkNoModifier(*this)))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
LogicalResult acc::ReductionOp::verify() {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1223325..89adda82 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/TypeSwitch.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) {
pointerLikeTy.getElementType());
return typeCategory;
}
+
+std::string mlir::acc::getVariableName(mlir::Value v) {
+ Value current = v;
+
+ // Walk through view operations until a name is found or can't go further
+ while (Operation *definingOp = current.getDefiningOp()) {
+ // Check for `acc.var_name` attribute
+ if (auto varNameAttr =
+ definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName()))
+ return varNameAttr.getName().str();
+
+ // If it is a data entry operation, get name via getVarName
+ if (isa<ACC_DATA_ENTRY_OPS>(definingOp))
+ if (auto name = acc::getVarName(definingOp))
+ return name->str();
+
+ // If it's a view operation, continue to the source
+ if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) {
+ current = viewOp.getViewSource();
+ continue;
+ }
+
+ break;
+ }
+
+ return "";
+}
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d34..f3c02da 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..b9b8eda
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPOffloadPrivatizationPrepare.cpp
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIROpenMPDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
new file mode 100644
index 0000000..a9125ec
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -0,0 +1,445 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+ : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+
+ // In this pass, we make host-allocated privatized variables persist for
+ // deferred target tasks by copying them to the heap. Once the target task
+ // is done, this heap memory is freed. Since all of this happens on the host
+ // we can skip device modules.
+ auto offloadModuleInterface =
+ dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+ return;
+
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return;
+ IRRewriter rewriter(&getContext());
+ OperandRange privateVars = targetOp.getPrivateVars();
+ SmallVector<mlir::Value> newPrivVars;
+ Value fakeDependVar;
+ omp::TaskOp cleanupTaskOp;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ Value privVar = std::get<0>(privVarSymPair);
+ Attribute privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+ auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // For deferred target tasks (!$omp target nowait), we need to keep
+ // a copy of the original, i.e. host variable being privatized so
+ // that it is available when the target task is eventually executed.
+ // We do this by first allocating as much heap memory as is needed by
+ // the original variable. Then, we use the init and copy regions of the
+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
+ // allocated copy.
+ // After the target task is done, we need to use the dealloc region
+ // of the privatizer to clean up everything. We also need to free
+ // the heap memory we allocated. But due to the deferred nature
+ // of the target task, we cannot simply deallocate right after the
+ // omp.target operation else we may end up freeing memory before
+ // its eventual use by the target task. So, we create a dummy
+ // dependence between the target task and new omp.task. In the omp.task,
+ // we do all the cleanup. So, we end up with the following structure
+ //
+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
+ // ...
+ // omp.terminator
+ // }
+ // omp.task depend(in: fakeDependVar) {
+ // /*cleanup_code*/
+ // omp.terminator
+ // }
+ // fakeDependVar is the address of the first heap-allocated copy of the
+ // host variable being privatized.
+
+ bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+ // should have mapped the pointer to the boxchar so use that as varPtr.
+ Value varPtr = mapInfoOp.getVarPtr();
+ Type varType = mapInfoOp.getVarType();
+ bool isPrivatizedByValue =
+ !isa<LLVM::LLVMPointerType>(privVar.getType());
+
+ assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+ Value heapMem =
+ allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+ if (!heapMem)
+ targetOp.emitError(
+ "Unable to allocate heap memory when trying to move "
+ "a private variable out of the stack and into the "
+ "heap for use by a deferred target task");
+
+ if (needsCleanupTask && !fakeDependVar)
+ fakeDependVar = heapMem;
+
+ // The types of private vars should match before and after the
+ // transformation. In particular, if the type is a pointer,
+ // simply record the newly allocated malloc location as the
+ // new private variable. If, however, the type is not a pointer
+ // then, we need to load the value from the newly allocated
+ // location. We'll insert that load later after we have updated
+ // the malloc'd location with the contents of the original
+ // variable.
+ if (!isPrivatizedByValue)
+ newPrivVars.push_back(heapMem);
+
+ // We now need to copy the original private variable into the newly
+ // allocated location in the heap.
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ DenseSet<Operation *> users;
+ if (varPtrDefiningOp) {
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+ } else {
+ auto blockArg = cast<BlockArgument>(varPtr);
+ users.insert(blockArg.user_begin(), blockArg.user_end());
+ }
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOp);
+ for (auto member : mapInfoOp.getMembers()) {
+ omp::MapInfoOp memberMap =
+ cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (usesVarPtr(memberMap))
+ chainOfOps.push_back(memberMap);
+ if (memberMap.getVarPtrPtr()) {
+ Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
+ if (defOp && usesVarPtr(defOp))
+ chainOfOps.push_back(defOp);
+ }
+ }
+
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ return dom.dominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+
+ Operation *firstOp = chainOfOps.front();
+ Location loc = firstOp->getLoc();
+
+ // Create a llvm.func for 'region' that is marked always_inline and call
+ // it.
+ auto createAlwaysInlineFuncAndCallIt =
+ [&](Region &region, llvm::StringRef funcName,
+ llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
+ assert(!region.empty() && "region cannot be empty");
+ LLVM::LLVMFuncOp func = createFuncOpForRegion(
+ loc, mod, region, funcName, rewriter, returnsValue);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
+ return call.getResult();
+ };
+
+ Value moldArg, newArg;
+ if (isPrivatizedByValue) {
+ moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
+ newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
+ } else {
+ moldArg = varPtr;
+ newArg = heapMem;
+ }
+
+ Value initializedVal;
+ if (!privatizer.getInitRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getInitRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+ {moldArg, newArg}, /*returnsValue=*/true);
+ else
+ initializedVal = newArg;
+
+ if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getCopyRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+ {moldArg, initializedVal}, /*returnsValue=*/true);
+
+ if (isPrivatizedByValue)
+ (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
+
+ // clone origOp, replace all uses of varPtr with heapMem and
+ // erase origOp.
+ auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ rewriter.modifyOpInPlace(clonedOp, [&]() {
+ clonedOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ rewriter.eraseOp(origOp);
+ return clonedOp;
+ };
+
+ // Now that we have set up the heap-allocated copy of the private
+ // variable, rewrite all the uses of the original variable with
+ // the heap-allocated variable.
+ rewriter.setInsertionPoint(targetOp);
+ mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
+ rewriter.setInsertionPoint(mapInfoOp);
+
+ // Fix any members that may use varPtr to now use heapMem
+ for (auto member : mapInfoOp.getMembers()) {
+ auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (!usesVarPtr(memberMapInfoOp))
+ continue;
+ memberMapInfoOp =
+ cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
+ rewriter.setInsertionPoint(memberMapInfoOp);
+
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+ rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+ }
+ }
+
+ // If the type of the private variable is not a pointer,
+ // which is typically the case with !fir.boxchar types, then
+ // we need to ensure that the new private variable is also
+ // not a pointer. Insert a load from heapMem right before
+ // targetOp.
+ if (isPrivatizedByValue) {
+ rewriter.setInsertionPoint(targetOp);
+ auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
+ varType, heapMem);
+ newPrivVars.push_back(newPrivVar);
+ }
+
+ // Deallocate
+ if (needsCleanupTask) {
+ if (!cleanupTaskOp) {
+ assert(fakeDependVar &&
+ "Need a valid value to set up a dependency");
+ rewriter.setInsertionPointAfter(targetOp);
+ omp::TaskOperands taskOperands;
+ auto inDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
+ taskOperands.dependKinds.push_back(inDepend);
+ taskOperands.dependVars.push_back(fakeDependVar);
+ cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
+ Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
+ rewriter.setInsertionPointToEnd(taskBlock);
+ omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
+ }
+ rewriter.setInsertionPointToStart(
+ &*cleanupTaskOp.getRegion().getBlocks().begin());
+ (void)createAlwaysInlineFuncAndCallIt(
+ privatizer.getDeallocRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
+ .str(),
+ {initializedVal}, /*returnsValue=*/false);
+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ LLVM::lookupOrCreateFreeFn(rewriter, mod);
+ assert(llvm::succeeded(freeFunc) &&
+ "Could not find free in the module");
+ (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
+ ValueRange{heapMem});
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+ if (fakeDependVar) {
+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
+ SmallVector<Attribute> newDependKinds;
+ if (!targetOp.getDependVars().empty()) {
+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
+ assert(dependKinds && "bad depend clause in omp::TargetOp");
+ llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
+ }
+ newDependKinds.push_back(outDepend);
+ ArrayAttr newDependKindsAttr =
+ ArrayAttr::get(rewriter.getContext(), newDependKinds);
+ targetOp.getDependVarsMutable().append(fakeDependVar);
+ targetOp.setDependKindsAttr(newDependKindsAttr);
+ }
+ rewriter.setInsertionPoint(targetOp);
+ targetOp.getPrivateVarsMutable().clear();
+ targetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
+ llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
+ ModuleOp mod, IRRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ BlockArgument blockArg;
+ if (!definingOp) {
+ blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
+ rewriter.setInsertionPointToStart(blockArg.getParentBlock());
+ } else {
+ rewriter.setInsertionPoint(definingOp);
+ }
+ Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+
+ const DataLayout &dl = DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+
+ Value sizeBytes = LLVM::ConstantOp::create(
+ rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ // Create a function for srcRegion and attribute it to be always_inline.
+ // The big assumption here is that srcRegion is one of init, copy or dealloc
+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
+ // to either be the same as the types of the two arguments of the region (for
+ // init and copy regions) or void as would be the case for dealloc regions.
+ LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
+ Region &srcRegion,
+ llvm::StringRef funcName,
+ IRRewriter &rewriter,
+ bool returnsValue = false) {
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+ Region clonedRegion;
+ IRMapping mapper;
+ srcRegion.cloneInto(&clonedRegion, mapper);
+
+ SmallVector<Type> paramTypes;
+ llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
+ Type resultType = returnsValue
+ ? srcRegion.getArgument(0).getType()
+ : LLVM::LLVMVoidType::get(rewriter.getContext());
+ LLVM::LLVMFunctionType funcType =
+ LLVM::LLVMFunctionType::get(resultType, paramTypes);
+
+ LLVM::LLVMFuncOp func =
+ LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
+ func.setAlwaysInline(true);
+ rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
+ func.getRegion().end());
+ for (auto &block : func.getRegion().getBlocks()) {
+ if (isa<omp::YieldOp>(block.getTerminator())) {
+ omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
+ yieldOp.getOperands());
+ }
+ }
+ return func;
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a9da6c2..9bd13f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2490,8 +2490,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+ constantTrue = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
@@ -2500,8 +2500,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+ constantFalse = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2..335ca1a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf8016..99b7cda 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
!outputTy.hasStaticShape())
return {};
- if (inputTy.getDimSize(getAxis()) == 1)
- return DenseElementsAttr::get(outputTy, 0);
+ const Type outputElementTy = getElementTypeOrSelf(outputTy);
+ if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+ const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+ const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+ return DenseElementsAttr::get(outputTy, zero);
+ }
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 8f46ad6..ef49c86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final
for (int64_t i = 0; i < origNumElems; ++i)
promoteMask[i] = i;
- Value promotedInput = rewriter.create<vector::ShuffleOp>(
- shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
- promoteMask);
+ Value promotedInput =
+ vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType,
+ inputToPromote, inputToPromote, promoteMask);
// Create the final shuffle with the promoted inputs.
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7c019e7..8b5e950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -341,13 +341,18 @@ private:
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
+/// distributed dimensions. If the number of results is zero there is no
+/// distribution (i.e. original type is returned).
+/// Otherwise, The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
+ // If the map has zero results, return the original type.
+ if (map.getNumResults() == 0)
+ return originalType;
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1599ae9..24e9095 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
OpBuilder &builder) {
auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
- return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+ return ArithOp::create(builder, loc, aVal, bVal).getResult();
}
// a helper utility to perform division operation on OpFoldResult and int64_t.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 26770b3..d09dc19 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
- // If no layout is specified, assume the inner most dimension is distributed
- // for now.
+ // If no layout is specified, that means no distribution.
if (!layout)
- return AffineMap::getMultiDimMapWithTargets(
- vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ return AffineMap::getMultiDimMapWithTargets(vecRank, {},
+ val.getContext());
+ // Expecting vector and layout rank to match.
+ assert(layout.getRank() == vecRank &&
+ "Expecting vector and layout rank to match");
+ // A dimension is distributed only if layout suggests there are
+ // multiple lanes assigned for this dimension and the shape can be evenly
+ // distributed to those lanes.
SmallVector<unsigned int> distributedDims;
for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
- if (v > 1)
+ if (v > 1 && vecType.getShape()[i] % v == 0)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
@@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled =
- builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 31a967d..9fc5ad9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
baseTileValues);
- auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+ auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
// Get subgroup id
Value sgId =
@@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
SmallVector<Value, 2> strideConsts;
strideConsts.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ arith::ConstantIndexOp::create(rewriter, loc, colStride));
if (rows > 1)
strideConsts.insert(
strideConsts.begin(),
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+ arith::ConstantIndexOp::create(rewriter, loc, rowStride));
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < strideConsts.size(); ++i) {
- Value mul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
- mulOffset = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), mulOffset, mul);
+ Value mul =
+ arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
+ offsets[i], strideConsts[i]);
+ mulOffset = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
}
// Broadcast to baseConstVec size
- auto bcastOffset = rewriter.create<vector::BroadcastOp>(
- loc, baseConstVec.getType(), mulOffset);
+ auto bcastOffset = vector::BroadcastOp::create(
+ rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
@@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp
SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
- auto newShapeCast =
- rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ newResultType, src);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
@@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp
SmallVector<Value> newReductions;
for (auto sgSrc : adaptor.getSource()) {
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
- op.getReductionDims());
+ auto newOp = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
+ adaptor.getAcc()[0], op.getReductionDims());
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),