aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp68
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp21
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp137
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp11
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp13
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp87
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt13
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp26
-rw-r--r--mlir/lib/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp14
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp28
-rw-r--r--mlir/lib/Dialect/OpenMP/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp170
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt14
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp447
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp104
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp16
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp95
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp37
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp377
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp48
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp22
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp25
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp6
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp31
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp25
35 files changed, 1586 insertions, 361 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db..585b6da 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
- Type sourceAType = getSourceA().getType();
- Type sourceBType = getSourceB().getType();
- Type destType = getDestC().getType();
- VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
- VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
- VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ SmallVector<int64_t, 3> dimensions;
+ if (parser.parseDimensionList(dimensions, false, false))
+ return failure();
+ if (dimensions.size() != 3)
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 3 dimensions in MNK dimension list";
- Type sourceAElemType = sourceVectorAType.getElementType();
- Type sourceBElemType = sourceVectorBType.getElementType();
- Type destElemType = destVectorType.getElementType();
+ m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+ n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+ k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+ return success();
+}
- if (sourceVectorAType.getNumElements() !=
- sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sourceAElemType = sourceAType.getElementType();
+ Type sourceBElemType = sourceBType.getElementType();
+ if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have different lengths: ")
- << sourceVectorAType << " vs. " << sourceVectorBType;
+ << sourceAType << " vs. " << sourceBType;
}
- bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
- bool isSrcFloat =
- isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
- sourceAElemType);
-
- if (isDestFloat && !isSrcFloat) {
- return emitOpError("Expected float sources with float destination");
- }
+ bool isDestFloat = destType.getElementType().isFloat();
+ bool isSrcFloat = sourceAElemType.isFloat();
- if (!isDestFloat && isSrcFloat) {
- return emitOpError("Expected int sources with int destination");
- }
+ if (isDestFloat && !isSrcFloat)
+ return emitOpError("expected float sources with float destination");
+ if (!isDestFloat && isSrcFloat)
+ return emitOpError("expected int sources with int destination");
- if (sourceAElemType != sourceBElemType &&
- !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
- isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+ if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}
+
+ if (!sourceAElemType.isInteger(4) && getK() != 16) {
+ return emitOpError("K dimension must be 16 for source element type ")
+ << sourceAElemType;
+ }
return success();
}
@@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
+ if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
+ if (auto destVector = dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() {
return emitOpError("expected both non-small-float source operand types "
"to match exactly");
}
- // Normalize the wider integer types the compiler expects to i8
+ // Normalize the wider integer types the compiler expects to i8.
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
index 316721b..60ae78b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
@@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> {
const std::function<unsigned(AffineForOp)> getUnrollFactor;
LoopUnroll() : getUnrollFactor(nullptr) {}
- LoopUnroll(const LoopUnroll &other)
-
- = default;
+ LoopUnroll(const LoopUnroll &other) = default;
explicit LoopUnroll(
std::optional<unsigned> unrollFactor = std::nullopt,
- bool unrollUpToFactor = false, bool unrollFull = false,
+ bool unrollUpToFactor = false,
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
: getUnrollFactor(getUnrollFactor) {
if (unrollFactor)
this->unrollFactor = *unrollFactor;
this->unrollUpToFactor = unrollUpToFactor;
- this->unrollFull = unrollFull;
}
void runOnOperation() override;
@@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f,
}
void LoopUnroll::runOnOperation() {
+ if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
+ emitError(UnknownLoc::get(&getContext()),
+ "Invalid option: 'unroll-factor' should be greater than 0 or "
+ "equal to -1");
+ return signalPassFailure();
+ }
FunctionOpInterface func = getOperation();
if (func.isExternal())
return;
- if (unrollFull && unrollFullThreshold.hasValue()) {
+ if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) {
// Store short loops as we walk.
SmallVector<AffineForOp, 4> loops;
@@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
return loopUnrollByFactor(forOp, getUnrollFactor(forOp),
/*annotateFn=*/nullptr, cleanUpUnroll);
// Unroll completely if full loop unroll was specified.
- if (unrollFull)
+ if (unrollFactor.getValue() == -1)
return loopUnrollFull(forOp);
// Otherwise, unroll by the given unroll factor.
if (unrollUpToFactor)
@@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::affine::createLoopUnrollPass(
- int unrollFactor, bool unrollUpToFactor, bool unrollFull,
+ int unrollFactor, bool unrollUpToFactor,
const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
return std::make_unique<LoopUnroll>(
unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
- unrollUpToFactor, unrollFull, getUnrollFactor);
+ unrollUpToFactor, getUnrollFactor);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index a6159ee..f0ddb50 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -14,13 +14,6 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
-namespace mlir {
-namespace bufferization {
-#define GEN_PASS_DEF_TENSORCOPYINSERTION
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
-} // namespace bufferization
-} // namespace mlir
-
using namespace mlir;
using namespace mlir::bufferization;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2a8c330..f0de4db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff095..37a45d4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
MLIRPass
MLIRTransforms
MLIRNVVMDialect
+ MLIROpenMPDialect
)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9a8a63e..794dda9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
if (!ShapedType::isDynamic(dim))
continue;
- Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
- auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ Value cst =
+ arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
+ auto dimOp =
+ tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
preservedOps.insert(dimOp);
dynamicDims.push_back(dimOp);
}
- auto allocation = rewriter.create<bufferization::AllocTensorOp>(
- tensor.getLoc(), type, dynamicDims);
+ auto allocation = bufferization::AllocTensorOp::create(
+ rewriter, tensor.getLoc(), type, dynamicDims);
// Set memory space if provided.
if (getMemorySpaceAttr())
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
@@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
// Only insert a materialization (typically bufferizes to a copy) when the
// value may be read from.
if (needsMaterialization) {
- auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
- tensor.getLoc(), tensor, allocated);
+ auto copy = bufferization::MaterializeInDestinationOp::create(
+ rewriter, tensor.getLoc(), tensor, allocated);
preservedOps.insert(copy);
promoted.push_back(copy.getResult());
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a..5e10ba3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,33 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = index::CmpOp::create(
+ builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate =
+ arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +138,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b7..c551fba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
@@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 11400de..a15bf89 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -59,6 +59,17 @@ struct DimOpInterface
}
};
+struct ExpandShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+ memref::ExpandShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto expandOp = cast<memref::ExpandShapeOp>(op);
+ assert(value == expandOp.getResult() && "invalid value");
+ cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ }
+};
+
struct GetGlobalOpInterface
: public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
+ *ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a..bd02516 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
-/// Replace `base, offset, sizes, strides =
-/// extract_strided_metadata(
-/// cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-/// ? dstTy.srcOffset
-/// : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-/// !srcSize.isDynamic()
-/// ? srcSize
-// : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-/// !srcStrides.isDynamic()
-/// ? srcStrides
-/// : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult
- matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- PatternRewriter &rewriter) const override {
- Value source = extractStridedMetadataOp.getSource();
- auto castOp = source.getDefiningOp<memref::CastOp>();
- if (!castOp)
- return failure();
-
- Location loc = extractStridedMetadataOp.getLoc();
- // Check if the source is suitable for extract_strided_metadata.
- SmallVector<Type> inferredReturnTypes;
- if (failed(extractStridedMetadataOp.inferReturnTypes(
- rewriter.getContext(), loc, {castOp.getSource()},
- /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
- inferredReturnTypes)))
- return rewriter.notifyMatchFailure(castOp,
- "cast source's type is incompatible");
-
- auto memrefType = cast<MemRefType>(source.getType());
- unsigned rank = memrefType.getRank();
- SmallVector<OpFoldResult> results;
- results.resize_for_overwrite(rank * 2 + 2);
-
- auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
- rewriter, loc, castOp.getSource());
-
- // Register the base_buffer.
- results[0] = newExtractStridedMetadata.getBaseBuffer();
-
- auto getConstantOrValue = [&rewriter](int64_t constant,
- OpFoldResult ofr) -> OpFoldResult {
- return ShapedType::isStatic(constant)
- ? OpFoldResult(rewriter.getIndexAttr(constant))
- : ofr;
- };
-
- auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
- assert(sourceStrides.size() == rank && "unexpected number of strides");
-
- // Register the new offset.
- results[1] =
- getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + rank;
- ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
- SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
- SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
- for (unsigned i = 0; i < rank; ++i) {
- results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
- results[strideStartIdx + i] =
- getConstantOrValue(sourceStrides[i], strides[i]);
- }
- rewriter.replaceOp(extractStridedMetadataOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, results));
- return success();
- }
-};
-
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..f305068
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIROpenACCAnalysis
+ OpenACCSupport.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+ MLIRSupport
+)
+
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
new file mode 100644
index 0000000..f6b4534
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -0,0 +1,26 @@
+//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the OpenACCSupport analysis interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+
+namespace mlir {
+namespace acc {
+
+std::string OpenACCSupport::getVariableName(Value v) {
+ if (impl)
+ return impl->getVariableName(v);
+ return acc::getVariableName(v);
+}
+
+} // namespace acc
+} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 7117520..e8a916e 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Utils)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ca0100..ca46629 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() {
}
//===----------------------------------------------------------------------===//
+// FirstprivateMapInitialOp
+//===----------------------------------------------------------------------===//
+LogicalResult acc::FirstprivateMapInitialOp::verify() {
+ if (getDataClause() != acc::DataClause::acc_firstprivate)
+ return emitError("data clause associated with firstprivate operation must "
+ "match its intent");
+ if (failed(checkVarAndVarType(*this)))
+ return failure();
+ if (failed(checkNoModifier(*this)))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
LogicalResult acc::ReductionOp::verify() {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1223325..89adda82 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/TypeSwitch.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) {
pointerLikeTy.getElementType());
return typeCategory;
}
+
+std::string mlir::acc::getVariableName(mlir::Value v) {
+ Value current = v;
+
+ // Walk through view operations until a name is found or can't go further
+ while (Operation *definingOp = current.getDefiningOp()) {
+ // Check for `acc.var_name` attribute
+ if (auto varNameAttr =
+ definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName()))
+ return varNameAttr.getName().str();
+
+ // If it is a data entry operation, get name via getVarName
+ if (isa<ACC_DATA_ENTRY_OPS>(definingOp))
+ if (auto name = acc::getVarName(definingOp))
+ return name->str();
+
+ // If it's a view operation, continue to the source
+ if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) {
+ current = viewOp.getViewSource();
+ continue;
+ }
+
+ break;
+ }
+
+ return "";
+}
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d34..f3c02da 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fd4cabbad..1b069c6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -32,7 +32,6 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/InterleavedRange.h"
#include <cstddef>
#include <iterator>
@@ -1737,10 +1736,10 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
// Parser, printer and verifier for Target
//===----------------------------------------------------------------------===//
-// Helper function to get bitwise AND of `value` and 'flag'
-static uint64_t mapTypeToBitFlag(uint64_t value,
- llvm::omp::OpenMPOffloadMappingFlags flag) {
- return value & llvm::to_underlying(flag);
+// Helper function to get bitwise AND of `value` and 'flag' then return it as a
+// boolean
+static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
+ return (value & flag) == flag;
}
/// Parses a map_entries map type from a string format back into its numeric
@@ -1748,10 +1747,9 @@ static uint64_t mapTypeToBitFlag(uint64_t value,
///
/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
-static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
- llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
-
+static ParseResult parseMapClause(OpAsmParser &parser,
+ ClauseMapFlagsAttr &mapType) {
+ ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
// This simply verifies the correct keyword is read in, the
// keyword itself is stored inside of the operation
auto parseTypeAndMod = [&]() -> ParseResult {
@@ -1760,35 +1758,64 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
return failure();
if (mapTypeMod == "always")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+ mapTypeBits |= ClauseMapFlags::always;
if (mapTypeMod == "implicit")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+ mapTypeBits |= ClauseMapFlags::implicit;
if (mapTypeMod == "ompx_hold")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
+ mapTypeBits |= ClauseMapFlags::ompx_hold;
if (mapTypeMod == "close")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
+ mapTypeBits |= ClauseMapFlags::close;
if (mapTypeMod == "present")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
+ mapTypeBits |= ClauseMapFlags::present;
if (mapTypeMod == "to")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+ mapTypeBits |= ClauseMapFlags::to;
if (mapTypeMod == "from")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ mapTypeBits |= ClauseMapFlags::from;
if (mapTypeMod == "tofrom")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
if (mapTypeMod == "delete")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+ mapTypeBits |= ClauseMapFlags::del;
+
+ if (mapTypeMod == "storage")
+ mapTypeBits |= ClauseMapFlags::storage;
if (mapTypeMod == "return_param")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapTypeBits |= ClauseMapFlags::return_param;
+
+ if (mapTypeMod == "private")
+ mapTypeBits |= ClauseMapFlags::priv;
+
+ if (mapTypeMod == "literal")
+ mapTypeBits |= ClauseMapFlags::literal;
+
+ if (mapTypeMod == "attach")
+ mapTypeBits |= ClauseMapFlags::attach;
+
+ if (mapTypeMod == "attach_always")
+ mapTypeBits |= ClauseMapFlags::attach_always;
+
+ if (mapTypeMod == "attach_none")
+ mapTypeBits |= ClauseMapFlags::attach_none;
+
+ if (mapTypeMod == "attach_auto")
+ mapTypeBits |= ClauseMapFlags::attach_auto;
+
+ if (mapTypeMod == "ref_ptr")
+ mapTypeBits |= ClauseMapFlags::ref_ptr;
+
+ if (mapTypeMod == "ref_ptee")
+ mapTypeBits |= ClauseMapFlags::ref_ptee;
+
+ if (mapTypeMod == "ref_ptr_ptee")
+ mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
return success();
};
@@ -1796,9 +1823,8 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (parser.parseCommaSeparatedList(parseTypeAndMod))
return failure();
- mapType = parser.getBuilder().getIntegerAttr(
- parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
- llvm::to_underlying(mapTypeBits));
+ mapType =
+ parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
return success();
}
@@ -1806,60 +1832,62 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
/// Prints a map_entries map type from its numeric value out into its string
/// format.
static void printMapClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr mapType) {
- uint64_t mapTypeBits = mapType.getUInt();
-
- bool emitAllocRelease = true;
+ ClauseMapFlagsAttr mapType) {
llvm::SmallVector<std::string, 4> mapTypeStrs;
+ ClauseMapFlags mapFlags = mapType.getValue();
// handling of always, close, present placed at the beginning of the string
// to aid readability
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
mapTypeStrs.push_back("always");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
mapTypeStrs.push_back("implicit");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
mapTypeStrs.push_back("ompx_hold");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
mapTypeStrs.push_back("close");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
mapTypeStrs.push_back("present");
// special handling of to/from/tofrom/delete and release/alloc, release +
// alloc are the abscense of one of the other flags, whereas tofrom requires
// both the to and from flag to be set.
- bool to = mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
- bool from = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- if (to && from) {
- emitAllocRelease = false;
+ bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
+ bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
+
+ if (to && from)
mapTypeStrs.push_back("tofrom");
- } else if (from) {
- emitAllocRelease = false;
+ else if (from)
mapTypeStrs.push_back("from");
- } else if (to) {
- emitAllocRelease = false;
+ else if (to)
mapTypeStrs.push_back("to");
- }
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
- emitAllocRelease = false;
+
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
mapTypeStrs.push_back("delete");
- }
- if (mapTypeToBitFlag(
- mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
- emitAllocRelease = false;
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
mapTypeStrs.push_back("return_param");
- }
- if (emitAllocRelease)
- mapTypeStrs.push_back("exit_release_or_enter_alloc");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
+ mapTypeStrs.push_back("storage");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
+ mapTypeStrs.push_back("private");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
+ mapTypeStrs.push_back("literal");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
+ mapTypeStrs.push_back("attach");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
+ mapTypeStrs.push_back("attach_always");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none))
+ mapTypeStrs.push_back("attach_none");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
+ mapTypeStrs.push_back("attach_auto");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
+ mapTypeStrs.push_back("ref_ptr");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
+ mapTypeStrs.push_back("ref_ptee");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
+ mapTypeStrs.push_back("ref_ptr_ptee");
+ if (mapFlags == ClauseMapFlags::none)
+ mapTypeStrs.push_back("none");
for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
p << mapTypeStrs[i];
@@ -1963,21 +1991,15 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
return emitError(op->getLoc(), "missing map operation");
if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
- uint64_t mapTypeBits = mapInfoOp.getMapType();
-
- bool to = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
- bool from = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- bool del = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
-
- bool always = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
- bool close = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
- bool implicit = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+ mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
+
+ bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
+ bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
+ bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
+
+ bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
+ bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
+ bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
return emitError(op->getLoc(),
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..b9b8eda
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPOffloadPrivatizationPrepare.cpp
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIROpenMPDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
new file mode 100644
index 0000000..c117d9b
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -0,0 +1,447 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+ : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+
+ // In this pass, we make host-allocated privatized variables persist for
+ // deferred target tasks by copying them to the heap. Once the target task
+ // is done, this heap memory is freed. Since all of this happens on the host
+ // we can skip device modules.
+ auto offloadModuleInterface =
+ dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+ return;
+
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return;
+ IRRewriter rewriter(&getContext());
+ OperandRange privateVars = targetOp.getPrivateVars();
+ SmallVector<mlir::Value> newPrivVars;
+ Value fakeDependVar;
+ omp::TaskOp cleanupTaskOp;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ Value privVar = std::get<0>(privVarSymPair);
+ Attribute privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+ auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // For deferred target tasks (!$omp target nowait), we need to keep
+ // a copy of the original, i.e. host variable being privatized so
+ // that it is available when the target task is eventually executed.
+ // We do this by first allocating as much heap memory as is needed by
+ // the original variable. Then, we use the init and copy regions of the
+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
+ // allocated copy.
+ // After the target task is done, we need to use the dealloc region
+ // of the privatizer to clean up everything. We also need to free
+ // the heap memory we allocated. But due to the deferred nature
+ // of the target task, we cannot simply deallocate right after the
+ // omp.target operation else we may end up freeing memory before
+ // its eventual use by the target task. So, we create a dummy
+ // dependence between the target task and new omp.task. In the omp.task,
+ // we do all the cleanup. So, we end up with the following structure
+ //
+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
+ // ...
+ // omp.terminator
+ // }
+ // omp.task depend(in: fakeDependVar) {
+ // /*cleanup_code*/
+ // omp.terminator
+ // }
+ // fakeDependVar is the address of the first heap-allocated copy of the
+ // host variable being privatized.
+
+ bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+ // should have mapped the pointer to the boxchar so use that as varPtr.
+ Value varPtr = mapInfoOp.getVarPtr();
+ Type varType = mapInfoOp.getVarType();
+ bool isPrivatizedByValue =
+ !isa<LLVM::LLVMPointerType>(privVar.getType());
+
+ assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+ Value heapMem =
+ allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+ if (!heapMem)
+ targetOp.emitError(
+ "Unable to allocate heap memory when trying to move "
+ "a private variable out of the stack and into the "
+ "heap for use by a deferred target task");
+
+ if (needsCleanupTask && !fakeDependVar)
+ fakeDependVar = heapMem;
+
+ // The types of private vars should match before and after the
+ // transformation. In particular, if the type is a pointer,
+ // simply record the newly allocated malloc location as the
+ // new private variable. If, however, the type is not a pointer
+ // then, we need to load the value from the newly allocated
+ // location. We'll insert that load later after we have updated
+ // the malloc'd location with the contents of the original
+ // variable.
+ if (!isPrivatizedByValue)
+ newPrivVars.push_back(heapMem);
+
+ // We now need to copy the original private variable into the newly
+ // allocated location in the heap.
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ DenseSet<Operation *> users;
+ if (varPtrDefiningOp) {
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+ } else {
+ auto blockArg = cast<BlockArgument>(varPtr);
+ users.insert(blockArg.user_begin(), blockArg.user_end());
+ }
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOp);
+ for (auto member : mapInfoOp.getMembers()) {
+ omp::MapInfoOp memberMap =
+ cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (usesVarPtr(memberMap))
+ chainOfOps.push_back(memberMap);
+ if (memberMap.getVarPtrPtr()) {
+ Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
+ if (defOp && usesVarPtr(defOp))
+ chainOfOps.push_back(defOp);
+ }
+ }
+
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ if (l == r)
+ return false;
+ return dom.properlyDominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+
+ Operation *firstOp = chainOfOps.front();
+ Location loc = firstOp->getLoc();
+
+ // Create a llvm.func for 'region' that is marked always_inline and call
+ // it.
+ auto createAlwaysInlineFuncAndCallIt =
+ [&](Region &region, llvm::StringRef funcName,
+ llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
+ assert(!region.empty() && "region cannot be empty");
+ LLVM::LLVMFuncOp func = createFuncOpForRegion(
+ loc, mod, region, funcName, rewriter, returnsValue);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
+ return call.getResult();
+ };
+
+ Value moldArg, newArg;
+ if (isPrivatizedByValue) {
+ moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
+ newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
+ } else {
+ moldArg = varPtr;
+ newArg = heapMem;
+ }
+
+ Value initializedVal;
+ if (!privatizer.getInitRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getInitRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+ {moldArg, newArg}, /*returnsValue=*/true);
+ else
+ initializedVal = newArg;
+
+ if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getCopyRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+ {moldArg, initializedVal}, /*returnsValue=*/true);
+
+ if (isPrivatizedByValue)
+ (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
+
+ // clone origOp, replace all uses of varPtr with heapMem and
+ // erase origOp.
+ auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ rewriter.modifyOpInPlace(clonedOp, [&]() {
+ clonedOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ rewriter.eraseOp(origOp);
+ return clonedOp;
+ };
+
+ // Now that we have set up the heap-allocated copy of the private
+ // variable, rewrite all the uses of the original variable with
+ // the heap-allocated variable.
+ rewriter.setInsertionPoint(targetOp);
+ mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
+ rewriter.setInsertionPoint(mapInfoOp);
+
+ // Fix any members that may use varPtr to now use heapMem
+ for (auto member : mapInfoOp.getMembers()) {
+ auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (!usesVarPtr(memberMapInfoOp))
+ continue;
+ memberMapInfoOp =
+ cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
+ rewriter.setInsertionPoint(memberMapInfoOp);
+
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+ rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+ }
+ }
+
+ // If the type of the private variable is not a pointer,
+ // which is typically the case with !fir.boxchar types, then
+ // we need to ensure that the new private variable is also
+ // not a pointer. Insert a load from heapMem right before
+ // targetOp.
+ if (isPrivatizedByValue) {
+ rewriter.setInsertionPoint(targetOp);
+ auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
+ varType, heapMem);
+ newPrivVars.push_back(newPrivVar);
+ }
+
+ // Deallocate
+ if (needsCleanupTask) {
+ if (!cleanupTaskOp) {
+ assert(fakeDependVar &&
+ "Need a valid value to set up a dependency");
+ rewriter.setInsertionPointAfter(targetOp);
+ omp::TaskOperands taskOperands;
+ auto inDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
+ taskOperands.dependKinds.push_back(inDepend);
+ taskOperands.dependVars.push_back(fakeDependVar);
+ cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
+ Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
+ rewriter.setInsertionPointToEnd(taskBlock);
+ omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
+ }
+ rewriter.setInsertionPointToStart(
+ &*cleanupTaskOp.getRegion().getBlocks().begin());
+ (void)createAlwaysInlineFuncAndCallIt(
+ privatizer.getDeallocRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
+ .str(),
+ {initializedVal}, /*returnsValue=*/false);
+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ LLVM::lookupOrCreateFreeFn(rewriter, mod);
+ assert(llvm::succeeded(freeFunc) &&
+ "Could not find free in the module");
+ (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
+ ValueRange{heapMem});
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+ if (fakeDependVar) {
+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
+ SmallVector<Attribute> newDependKinds;
+ if (!targetOp.getDependVars().empty()) {
+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
+ assert(dependKinds && "bad depend clause in omp::TargetOp");
+ llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
+ }
+ newDependKinds.push_back(outDepend);
+ ArrayAttr newDependKindsAttr =
+ ArrayAttr::get(rewriter.getContext(), newDependKinds);
+ targetOp.getDependVarsMutable().append(fakeDependVar);
+ targetOp.setDependKindsAttr(newDependKindsAttr);
+ }
+ rewriter.setInsertionPoint(targetOp);
+ targetOp.getPrivateVarsMutable().clear();
+ targetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
+ llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
+ ModuleOp mod, IRRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ BlockArgument blockArg;
+ if (!definingOp) {
+ blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
+ rewriter.setInsertionPointToStart(blockArg.getParentBlock());
+ } else {
+ rewriter.setInsertionPoint(definingOp);
+ }
+ Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+
+ const DataLayout &dl = DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+
+ Value sizeBytes = LLVM::ConstantOp::create(
+ rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ // Create a function for srcRegion and attribute it to be always_inline.
+ // The big assumption here is that srcRegion is one of init, copy or dealloc
+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
+ // to either be the same as the types of the two arguments of the region (for
+ // init and copy regions) or void as would be the case for dealloc regions.
+ LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
+ Region &srcRegion,
+ llvm::StringRef funcName,
+ IRRewriter &rewriter,
+ bool returnsValue = false) {
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+ Region clonedRegion;
+ IRMapping mapper;
+ srcRegion.cloneInto(&clonedRegion, mapper);
+
+ SmallVector<Type> paramTypes;
+ llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
+ Type resultType = returnsValue
+ ? srcRegion.getArgument(0).getType()
+ : LLVM::LLVMVoidType::get(rewriter.getContext());
+ LLVM::LLVMFunctionType funcType =
+ LLVM::LLVMFunctionType::get(resultType, paramTypes);
+
+ LLVM::LLVMFuncOp func =
+ LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
+ func.setAlwaysInline(true);
+ rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
+ func.getRegion().end());
+ for (auto &block : func.getRegion().getBlocks()) {
+ if (isa<omp::YieldOp>(block.getTerminator())) {
+ omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
+ yieldOp.getOperands());
+ }
+ }
+ return func;
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a9da6c2..744a595 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
+// Pattern to eliminate ExecuteRegionOp results which forward external
+// values from the region. In case there are multiple yield operations,
+// all of them must have the same operands in order for the pattern to be
+// applicable.
+struct ExecuteRegionForwardingEliminator
+ : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumResults() == 0)
+ return failure();
+
+ SmallVector<Operation *> yieldOps;
+ for (Block &block : op.getRegion()) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yield.getOperation());
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Check if all yield operations have the same operands.
+ auto yieldOpsOperands = yieldOps[0]->getOperands();
+ for (auto *yieldOp : yieldOps) {
+ if (yieldOp->getOperands() != yieldOpsOperands)
+ return failure();
+ }
+
+ SmallVector<Value> externalValues;
+ SmallVector<Value> internalValues;
+ SmallVector<Value> opResultsToReplaceWithExternalValues;
+ SmallVector<Value> opResultsToKeep;
+ for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
+ if (isValueFromInsideRegion(yieldedValue, op)) {
+ internalValues.push_back(yieldedValue);
+ opResultsToKeep.push_back(op.getResult(index));
+ } else {
+ externalValues.push_back(yieldedValue);
+ opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
+ }
+ }
+ // No yielded external values - nothing to do.
+ if (externalValues.empty())
+ return failure();
+
+ // There are yielded external values - create a new execute_region returning
+ // just the internal values.
+ SmallVector<Type> resultTypes;
+ for (Value value : internalValues)
+ resultTypes.push_back(value.getType());
+ auto newOp =
+ ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
+ newOp->setAttrs(op->getAttrs());
+
+ // Move old op's region to the new operation.
+ rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().end());
+
+ // Replace all yield operations with a new yield operation with updated
+ // results. scf.execute_region must have at least one yield operation.
+ for (auto *yieldOp : yieldOps) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
+ ValueRange(internalValues));
+ }
+
+ // Replace the old operation with the external values directly.
+ rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
+ externalValues);
+ // Replace the old operation's remaining results with the new operation's
+ // results.
+ rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ bool isValueFromInsideRegion(Value value,
+ ExecuteRegionOp executeRegionOp) const {
+ // Check if the value is defined within the execute_region
+ if (Operation *defOp = value.getDefiningOp())
+ return &executeRegionOp.getRegion() == defOp->getParentRegion();
+
+ // If it's a block argument, check if it's from within the region
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
+ return &executeRegionOp.getRegion() == blockArg.getParentRegion();
+
+ return false; // Value is from outside the region
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
+ ExecuteRegionForwardingEliminator>(context);
}
void ExecuteRegionOp::getSuccessorRegions(
@@ -2490,8 +2584,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+ constantTrue = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
@@ -2500,8 +2594,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+ constantFalse = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2..335ca1a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 73e0f3d..f53d272 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
- assert(loop2InsLvl.size() == ins.size());
+ assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
+
+ // For ranked types the rank must match.
+ // Simply return true for UnrankedTensorType
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
+ return !shapedType.hasRank() ||
+ (m.getNumResults() == shapedType.getRank());
+ }
+ // Non-shaped (scalar) types behave like rank-0.
+ return m.getNumResults() == 0;
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 1cba1bb..eb47e85 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,97 @@
namespace mlir {
namespace tosa {
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
+TosaSpecificationVersion getMinVersion(const Profile &profile) {
+ switch (profile) {
+ case Profile::pro_int:
+ case Profile::pro_fp:
+ return TosaSpecificationVersion(1, 0);
+ case Profile::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA profile");
+}
+
+TosaSpecificationVersion getMinVersion(const Extension &extension) {
+ switch (extension) {
+ case Extension::int16:
+ case Extension::int4:
+ case Extension::bf16:
+ case Extension::fp8e4m3:
+ case Extension::fp8e5m2:
+ case Extension::fft:
+ case Extension::variable:
+ case Extension::controlflow:
+ case Extension::doubleround:
+ case Extension::inexactround:
+ case Extension::dynamic:
+ return TosaSpecificationVersion(1, 0);
+ case Extension::mxfp:
+ case Extension::int64:
+ return TosaSpecificationVersion(1, 1);
+ case Extension::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA extension");
+}
+
+TosaSpecificationVersion getMinVersion(const Level &level) {
+ switch (level) {
+ case Level::eightK:
+ case Level::none:
+ return TosaSpecificationVersion(1, 0);
+ }
+ llvm_unreachable("Unknown TOSA level");
+}
+
+FailureOr<TargetEnv>
+TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
+ Location targetEnvAttrLoc) {
+ if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
+ return failure();
+
+ return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions());
+}
+
+LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
+ Location targetAttrLoc) {
+ TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
+
+ const auto isCompatibleWithTargetVersion =
+ [&](const auto &targetEnum, Location targetAttrLoc,
+ StringRef enumName) -> LogicalResult {
+ const TosaSpecificationVersion minRequiredVersion =
+ getMinVersion(targetEnum);
+ if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
+ return emitError(targetAttrLoc, enumName)
+ << " '" << stringifyEnum(targetEnum)
+ << "' is not compatible with the target version "
+ << stringifyVersion(targetVersion)
+ << ", minimum required version is "
+ << stringifyVersion(minRequiredVersion);
+ return success();
+ };
+
+ for (const auto &profile : targetAttr.getProfiles())
+ if (failed(
+ isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
+ return failure();
+ for (const auto &extension : targetAttr.getExtensions())
+ if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
+ "extension")))
+ return failure();
+ if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
+ "level")))
+ return failure();
+
+ return success();
+}
+
TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +130,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
-llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
- return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
-}
-
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf8016..a85ff10a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -76,28 +76,6 @@ template <typename OpTy>
struct PoolPadFoldAdaptor;
template <>
-struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
- using OpTy = tosa::AvgPool2dOp;
- static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
- const llvm::ArrayRef<int64_t> kernel = op.getKernel();
- if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
- newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
- return false;
- return true;
- }
- static bool checkPadConstCompliance(OpTy op, Value padConst) {
- return checkMatchingPadConstAndZp(padConst, op.getInputZp());
- }
- static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
- Value padInput, ArrayRef<int64_t> newPad) {
- rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
- op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
- op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
- op.getAccType());
- }
-};
-
-template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
@@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
};
} // namespace
-void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
- PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
- context);
-}
-
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
@@ -1001,8 +972,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
!outputTy.hasStaticShape())
return {};
- if (inputTy.getDimSize(getAxis()) == 1)
- return DenseElementsAttr::get(outputTy, 0);
+ const Type outputElementTy = getElementTypeOrSelf(outputTy);
+ if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+ const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+ const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+ return DenseElementsAttr::get(outputTy, zero);
+ }
return {};
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 00f84bc..0aff67f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
}
}
+ // special handling: block_size accepts a *bare* BlockSizeMode enum
+ if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+ if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeBlockSize(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid block_size value: " << kw;
+ auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
// Default path: parse any normal attribute literal, including fully qualified
// enum keyword
Attribute attr;
@@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
} else if (auto nanPropagationModeAttr =
dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
parser << nanPropagationModeAttr.getValue();
+ } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+ parser << blockSizeAttr.getValue();
} else {
parser.printAttribute(attr);
}
@@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
// verify that inType and outType have same element types
template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
- auto inputType = llvm::dyn_cast<TensorType>(inType);
- auto outputType = llvm::dyn_cast<TensorType>(outType);
- if (!inputType) {
- op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+ StringRef aName = "input",
+ StringRef bName = "output") {
+ auto aTType = llvm::dyn_cast<TensorType>(aType);
+ auto bTType = llvm::dyn_cast<TensorType>(bType);
+ if (!aTType) {
+ op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
return failure();
}
- if (!outputType) {
- op.emitOpError("expect shaped tensor for output, got ") << outType;
+ if (!bTType) {
+ op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
return failure();
}
- auto inputElementType = inputType.getElementType();
- auto outputElementType = outputType.getElementType();
- auto inputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
- auto outputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
- if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
- (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
- inputElementType != outputElementType) {
+ auto aElementType = aTType.getElementType();
+ auto bElementType = bTType.getElementType();
+ auto aQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+ auto bQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+ if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+ (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+ aElementType != bElementType) {
// only check if both element types are int/index/float/UniformQuantized
// eg, not sure how to check quant::QuantizedType
// this happens in test_conv2d_q_grouped_convolution in
// tfl-to-tosa-pipeline.mlir
- op.emitOpError("expect input and output to have same element type, got ")
- << inputElementType << " and " << outputElementType;
+ op.emitOpError("expect ")
+ << aName << " and " << bName << " to have same element type, got "
+ << aElementType << " and " << bElementType;
return failure();
}
return success();
@@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() {
return success();
}
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MatmulTBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+ const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+ if (aDataShape.hasRank()) {
+ outShape[0] = aDataShape.getDimSize(0);
+ outShape[1] = aDataShape.getDimSize(1);
+ }
+
+ const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+ : outShape[0];
+ outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+ : outShape[1];
+ }
+
+ // If B batch size is 1, it is broadcast across A's batch size
+ const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+ if (bDataShape.hasRank()) {
+ const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+ if (bDataBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+ outShape[2] = bDataShape.getDimSize(1);
+ }
+
+ const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+ if (bScaleBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+ outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+ : outShape[2];
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+ // Verify same input data types
+ const Type aDataType = getAData().getType();
+ const Type bDataType = getBData().getType();
+ if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+ "B_data")))
+ return failure();
+
+ auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) -> LogicalResult {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+ };
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t D = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+ if (aDataShape.hasRank()) {
+ N = aDataShape.getDimSize(0);
+ H = aDataShape.getDimSize(1);
+ C = aDataShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+ "height")))
+ return failure();
+ multiplesOfC = aScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+ if (bDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+ "channels")))
+ return failure();
+ W = bDataShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+ "width")) ||
+ failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+ "b_scale", "C/block_size")))
+ return failure();
+ }
+
+ // Verify batch size is broadcast compatible
+ if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+ return emitOpError("expect B matrix batch size to be broadcast compatible "
+ "with A, got D=")
+ << D << " vs N=" << N;
+
+ // Verify C is a multiple of block size
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(C) && C % blockSize != 0)
+ return emitOpError("expect C to be a multiple of block size, got C=")
+ << C << ", block_size=" << blockSize;
+
+ // Verify multiplesOfC is C / block size
+ if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != C / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal C/block_size (")
+ << C << "/" << blockSize << ")"
+ << ", got " << multiplesOfC;
+
+ // Verify output shape
+ N = ShapedType::isDynamic(N) ? D : N;
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f072e3e..ab363ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -19,12 +19,19 @@ TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
+ const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
+ // micro-scaling formats
+ const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
+ const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
+ const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
+ const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
@@ -44,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
- Value output) {
- for (auto operand : operands)
+ ValueRange results) {
+ for (const auto &operand : operands)
addValue(operand);
- addValue(output);
+ for (const auto &result : results)
+ addValue(result);
return success();
}
@@ -170,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
}
template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getInputImag());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
addValue(op.getOnFalse());
@@ -239,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- return populateProfileInfo(op->getOperands(), op->getResult(0)); \
+ return populateProfileInfo(op->getOperands(), op->getResults()); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -250,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
- POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
- POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -269,7 +258,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
+ POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(FFT2d)
+ POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
+ POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
@@ -623,6 +617,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
+ } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
+ return {"fp6e2m3"};
+ } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
+ return {"fp6e3m2"};
+ } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
+ return {"fp4e2m1"};
+ } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
+ return {"fp8e8m0"};
}
llvm_unreachable("unknown type");
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 82f2f7e..4d0b61a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
+ CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+ CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
@@ -657,6 +659,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(TransposeConv2D);
CHECK_SIZES(FFT2d);
CHECK_SIZES(MatMul);
+ CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
@@ -1192,9 +1195,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
- Float8E5M2Type>(type);
- }
- if (auto intTy = dyn_cast<IntegerType>(type)) {
+ Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
+ Float6E3M2FNType, Float8E8M0FNUType>(type);
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
@@ -1203,6 +1206,7 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
case 16:
case 32:
case 48:
+ case 64:
return true;
}
} else if (allowUnsigned && intTy.isUnsigned()) {
@@ -1220,13 +1224,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
+ ModuleOp modOp = getOperation();
+ const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
+ const auto maybeTargetEnv =
+ tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
+ if (failed(maybeTargetEnv))
+ return signalPassFailure();
+ targetEnv = *maybeTargetEnv;
+
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
- targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
-
- getOperation().walk([&](Operation *op) {
+ modOp.walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 8f46ad6..ef49c86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final
for (int64_t i = 0; i < origNumElems; ++i)
promoteMask[i] = i;
- Value promotedInput = rewriter.create<vector::ShuffleOp>(
- shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
- promoteMask);
+ Value promotedInput =
+ vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType,
+ inputToPromote, inputToPromote, promoteMask);
// Create the final shuffle with the promoted inputs.
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7c019e7..8b5e950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -341,13 +341,18 @@ private:
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
+/// distributed dimensions. If the number of results is zero there is no
+/// distribution (i.e. original type is returned).
+/// Otherwise, The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
+ // If the map has zero results, return the original type.
+ if (map.getNumResults() == 0)
+ return originalType;
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1599ae9..24e9095 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
OpBuilder &builder) {
auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
- return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+ return ArithOp::create(builder, loc, aVal, bVal).getResult();
}
// a helper utility to perform division operation on OpFoldResult and int64_t.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784a..2c37140 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(operandOrResult);
if (layout && layout.isForSubgroup()) {
- if (!layout.getEffectiveInstDataAsInt().empty())
- return layout.getEffectiveInstDataAsInt();
+ if (!layout.getEffectiveInstDataAsInt().empty()) {
+ SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+ // Remove leading unit dimensions from inst_data
+ // For example, if the inst_data is [1, 1, 32]
+ // it will pass [32] as the unroll/blocking size.
+ // Skip it for xegpu nd ops since it will be 2D
+ // TODO: For vectors ops, experiment with the
+ // upstream vector remove leading unit dims patterns,
+ // populateCastAwayVectorLeadingOneDimPatterns.
+ Operation *definingOp = value.getDefiningOp();
+ bool skipLeadingUnitDimRemoval =
+ definingOp &&
+ (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
+ xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
+ if (!skipLeadingUnitDimRemoval) {
+ auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
+ instData.erase(instData.begin(), it);
+ }
+ return instData;
+ }
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() {
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
-
encoding = newEncoding;
}
}
@@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
tdescTy.getLayoutAttr().dropInstData());
} else {
- newTy = type.clone(tileShape, elemTy);
+ newTy = VectorType::get(tileShape, elemTy);
}
if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 26770b3..d09dc19 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
- // If no layout is specified, assume the inner most dimension is distributed
- // for now.
+ // If no layout is specified, that means no distribution.
if (!layout)
- return AffineMap::getMultiDimMapWithTargets(
- vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ return AffineMap::getMultiDimMapWithTargets(vecRank, {},
+ val.getContext());
+ // Expecting vector and layout rank to match.
+ assert(layout.getRank() == vecRank &&
+ "Expecting vector and layout rank to match");
+ // A dimension is distributed only if layout suggests there are
+ // multiple lanes assigned for this dimension and the shape can be evenly
+ // distributed to those lanes.
SmallVector<unsigned int> distributedDims;
for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
- if (v > 1)
+ if (v > 1 && vecType.getShape()[i] % v == 0)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
@@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled =
- builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index aafa1b7..e6e71cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ protected:
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
Location loc, PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of destTy.");
auto shape = vecTy.getShape();
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
}
@@ -93,8 +91,6 @@ protected:
ArrayRef<int64_t> blockSize, Location loc,
PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of src.");
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
blockSize);
}
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
Type elemTy = valueTy.getElementType();
- VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+ VectorType newValueTy = VectorType::get(*targetShape, elemTy);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 31a967d..9fc5ad9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
baseTileValues);
- auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+ auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
// Get subgroup id
Value sgId =
@@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
SmallVector<Value, 2> strideConsts;
strideConsts.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ arith::ConstantIndexOp::create(rewriter, loc, colStride));
if (rows > 1)
strideConsts.insert(
strideConsts.begin(),
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+ arith::ConstantIndexOp::create(rewriter, loc, rowStride));
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < strideConsts.size(); ++i) {
- Value mul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
- mulOffset = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), mulOffset, mul);
+ Value mul =
+ arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
+ offsets[i], strideConsts[i]);
+ mulOffset = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
}
// Broadcast to baseConstVec size
- auto bcastOffset = rewriter.create<vector::BroadcastOp>(
- loc, baseConstVec.getType(), mulOffset);
+ auto bcastOffset = vector::BroadcastOp::create(
+ rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
@@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp
SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
- auto newShapeCast =
- rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ newResultType, src);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
@@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp
SmallVector<Value> newReductions;
for (auto sgSrc : adaptor.getSource()) {
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
- op.getReductionDims());
+ auto newOp = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
+ adaptor.getAcc()[0], op.getReductionDims());
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a43..b4605cd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
if (!computeShapeRatio(srcShape, shape))
return {value};
+ int64_t srcShapeRank = srcShape.size();
+ int64_t targetShapeRank = shape.size();
+
+ SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+ int64_t rankDiff = srcShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+ 1);
+ std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
SmallVector<Value> result;
- for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
SmallVector<int64_t> staticStrides(offsets.size(), 1);
- result.push_back(vector::ExtractStridedSliceOp::create(
- builder, loc, value, offsets, shape, staticStrides));
+ Value slice = vector::ExtractStridedSliceOp::create(
+ builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+ // Reshape to remove leading unit dims if needed
+ if (srcShapeRank > targetShapeRank) {
+ auto targetTy = VectorType::get(shape, vecTy.getElementType());
+ slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
+ }
+ result.push_back(slice);
}
return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
for (auto [src, offsets] :
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
- SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ SmallVector<int64_t> staticStrides(tileShape.size(), 1);
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
offsets, staticStrides);
}