aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp19
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp171
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp70
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt6
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp139
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp16
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp140
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp36
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp152
-rw-r--r--mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp83
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp26
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp48
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp141
15 files changed, 660 insertions, 391 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index c798adb..61166db 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -339,6 +339,25 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
+// ScaledExtPacked816Op
+//===----------------------------------------------------------------------===//
+LogicalResult ScaledExtPacked816Op::verify() {
+ int blockSize = getBlockSize();
+ assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+ int firstScaleByte = getFirstScaleByte();
+ if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError(
+ "blockSize of 16 can only have firstScaleByte be 0 or 1.");
+ }
+ if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError(
+ "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
LogicalResult WMMAOp::verify() {
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 749e2ba..e0a53cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2600,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
return success(folded);
}
+/// Returns constant trip count in trivial cases.
+static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+ int64_t step = forOp.getStepAsInt();
+ if (!forOp.hasConstantBounds() || step <= 0)
+ return std::nullopt;
+ int64_t lb = forOp.getConstantLowerBound();
+ int64_t ub = forOp.getConstantUpperBound();
+ return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
+/// Fold the empty loop.
+static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
+ if (!llvm::hasSingleElement(*forOp.getBody()))
+ return {};
+ if (forOp.getNumResults() == 0)
+ return {};
+ std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+ if (tripCount == 0) {
+ // The initial values of the iteration arguments would be the op's
+ // results.
+ return forOp.getInits();
+ }
+ SmallVector<Value, 4> replacements;
+ auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+ auto iterArgs = forOp.getRegionIterArgs();
+ bool hasValDefinedOutsideLoop = false;
+ bool iterArgsNotInOrder = false;
+ for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+ Value val = yieldOp.getOperand(i);
+ BlockArgument *iterArgIt = llvm::find(iterArgs, val);
+ // TODO: It should be possible to perform a replacement by computing the
+ // last value of the IV based on the bounds and the step.
+ if (val == forOp.getInductionVar())
+ return {};
+ if (iterArgIt == iterArgs.end()) {
+ // `val` is defined outside of the loop.
+ assert(forOp.isDefinedOutsideOfLoop(val) &&
+ "must be defined outside of the loop");
+ hasValDefinedOutsideLoop = true;
+ replacements.push_back(val);
+ } else {
+ unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+ if (pos != i)
+ iterArgsNotInOrder = true;
+ replacements.push_back(forOp.getInits()[pos]);
+ }
+ }
+ // Bail out when the trip count is unknown and the loop returns any value
+ // defined outside of the loop or any iterArg out of order.
+ if (!tripCount.has_value() &&
+ (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+ return {};
+ // Bail out when the loop iterates more than once and it returns any iterArg
+ // out of order.
+ if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
+ return {};
+ return llvm::to_vector_of<OpFoldResult>(replacements);
+}
+
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
@@ -2631,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
return success();
}
-namespace {
-/// Returns constant trip count in trivial cases.
-static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
- int64_t step = forOp.getStepAsInt();
- if (!forOp.hasConstantBounds() || step <= 0)
- return std::nullopt;
- int64_t lb = forOp.getConstantLowerBound();
- int64_t ub = forOp.getConstantUpperBound();
- return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+/// Returns true if the affine.for has zero iterations in trivial cases.
+static bool hasTrivialZeroTripCount(AffineForOp op) {
+ return getTrivialConstantTripCount(op) == 0;
}
-/// This is a pattern to fold trivially empty loop bodies.
-/// TODO: This should be moved into the folding hook.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- // Check that the body only contains a yield.
- if (!llvm::hasSingleElement(*forOp.getBody()))
- return failure();
- if (forOp.getNumResults() == 0)
- return success();
- std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
- if (tripCount == 0) {
- // The initial values of the iteration arguments would be the op's
- // results.
- rewriter.replaceOp(forOp, forOp.getInits());
- return success();
- }
- SmallVector<Value, 4> replacements;
- auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
- auto iterArgs = forOp.getRegionIterArgs();
- bool hasValDefinedOutsideLoop = false;
- bool iterArgsNotInOrder = false;
- for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
- Value val = yieldOp.getOperand(i);
- auto *iterArgIt = llvm::find(iterArgs, val);
- // TODO: It should be possible to perform a replacement by computing the
- // last value of the IV based on the bounds and the step.
- if (val == forOp.getInductionVar())
- return failure();
- if (iterArgIt == iterArgs.end()) {
- // `val` is defined outside of the loop.
- assert(forOp.isDefinedOutsideOfLoop(val) &&
- "must be defined outside of the loop");
- hasValDefinedOutsideLoop = true;
- replacements.push_back(val);
- } else {
- unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
- if (pos != i)
- iterArgsNotInOrder = true;
- replacements.push_back(forOp.getInits()[pos]);
- }
- }
- // Bail out when the trip count is unknown and the loop returns any value
- // defined outside of the loop or any iterArg out of order.
- if (!tripCount.has_value() &&
- (hasValDefinedOutsideLoop || iterArgsNotInOrder))
- return failure();
- // Bail out when the loop iterates more than once and it returns any iterArg
- // out of order.
- if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
- return failure();
- rewriter.replaceOp(forOp, replacements);
- return success();
+LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ bool folded = succeeded(foldLoopBounds(*this));
+ folded |= succeeded(canonicalizeLoopBounds(*this));
+ if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
+ // The initial values of the loop-carried variables (iter_args) are the
+ // results of the op. But this must be avoided for an affine.for op that
+ // does not return any results. Since ops that do not return results cannot
+ // be folded away, we would enter an infinite loop of folds on the same
+ // affine.for op.
+ results.assign(getInits().begin(), getInits().end());
+ folded = true;
}
-};
-} // namespace
-
-void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<AffineForEmptyLoopFolder>(context);
+ SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
+ if (!foldResults.empty()) {
+ results.assign(foldResults);
+ folded = true;
+ }
+ return success(folded);
}
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
@@ -2746,27 +2756,6 @@ void AffineForOp::getSuccessorRegions(
regions.push_back(RegionSuccessor(getResults()));
}
-/// Returns true if the affine.for has zero iterations in trivial cases.
-static bool hasTrivialZeroTripCount(AffineForOp op) {
- return getTrivialConstantTripCount(op) == 0;
-}
-
-LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- bool folded = succeeded(foldLoopBounds(*this));
- folded |= succeeded(canonicalizeLoopBounds(*this));
- if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
- // The initial values of the loop-carried variables (iter_args) are the
- // results of the op. But this must be avoided for an affine.for op that
- // does not return any results. Since ops that do not return results cannot
- // be folded away, we would enter an infinite loop of folds on the same
- // affine.for op.
- results.assign(getInits().begin(), getInits().end());
- folded = true;
- }
- return success(folded);
-}
-
AffineBound AffineForOp::getLowerBound() {
return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71..bc17990 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,37 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
+}
+
+/// Get the operands at the specified position for all returnOps.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
+ return returnOp.getOperand(pos);
+ });
+}
+
+/// Check if all given values are the same buffer as the block argument (modulo
+/// cast ops).
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
// Compute erased results.
- SmallVector<Value> newReturnValues;
- BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
+ size_t numReturnOps = returnOps.size();
+ size_t numReturnValues = funcOp.getFunctionType().getNumResults();
+ SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
+ BitVector erasedResultIndices(numReturnValues);
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0; i < numReturnValues; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index 70a9c77..ec68acf 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRGPUPipelines
GPUToNVVMPipeline.cpp
+ GPUToXeVMPipeline.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -11,12 +12,17 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRTransforms
MLIRLinalgTransforms
MLIRAffineToStandard
+ MLIRGPUToLLVMSPV
MLIRGPUToNVVMTransforms
MLIRIndexToLLVM
MLIRMathToLLVM
+ MLIRMathToXeVM
MLIRNVGPUToNVVM
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRVectorToSCF
+ MLIRXeGPUTransforms
+ MLIRXeGPUToXeVM
+ MLIRXeVMToLLVM
)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
new file mode 100644
index 0000000..1a1485b
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -0,0 +1,139 @@
+//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to XeVM as a generally
+// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have
+// XeGPU ops already embedded in gpu code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Pre-GPU common pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPreGPUCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes.
+ pm.addPass(createCSEPass());
+ pm.addPass(createConvertVectorToSCFPass());
+ {
+ GpuXeVMAttachTargetOptions xevmTargetOptions;
+ xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
+ xevmTargetOptions.triple = options.zebinTriple;
+ xevmTargetOptions.chip = options.zebinChip;
+ xevmTargetOptions.optLevel = options.optLevel;
+ xevmTargetOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
+ }
+ pm.addPass(createLowerAffinePass());
+ pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
+}
+
+//===----------------------------------------------------------------------===//
+// GPUModule-specific stuff.
+//===----------------------------------------------------------------------===//
+void buildGPUPassPipeline(OpPassManager &pm,
+ const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ if (options.xegpuOpLevel == "workgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ }
+ if (options.xegpuOpLevel == "subgroup" ||
+ options.xegpuOpLevel == "workgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+ }
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM());
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
+ {
+ ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
+ gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+ }
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Post-GPU pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPostGPUCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes.
+ pm.addPass(createSCFToControlFlowPass());
+ pm.addPass(memref::createExpandStridedMetadataPass());
+ {
+ GpuToLLVMConversionPassOptions gpuToLLVMOptions;
+ gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
+ gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
+ pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
+ }
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertToLLVMPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
+ // gpu-module-to-binary
+ {
+ GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
+ gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
+ gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
+ }
+}
+} // namespace
+
+void mlir::gpu::buildLowerToXeVMPassPipeline(
+ OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
+ // Pre-GPU common pipelines.
+ buildPreGPUCommonPassPipeline(pm, options);
+
+ // GPUModule-specific stuff.
+ buildGPUPassPipeline(pm, options);
+
+ // Post-GPU pipeline for both Host and GPU.
+ buildPostGPUCommonPassPipeline(pm, options);
+}
+
+void mlir::gpu::registerGPUToXeVMPipeline() {
+ PassPipelineRegistration<GPUToXeVMPipelineOptions>(
+ "gpu-lower-to-xevm-pipeline",
+ "The default GPU to XeVM lowering pipeline. It starts by lowering GPU "
+ "code to the "
+ "specified compilation target (default is fatbin) then lowers the host "
+ "code.",
+ buildLowerToXeVMPassPipeline);
+}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6192d79..9a8a63e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}
// Set options.
- TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());
- // Apply padding.
- SmallVector<tensor::PadOp> newPadOps;
- FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
- rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
- newPadOps);
- if (failed(maybePaddedOp)) {
+ auto maybePadOps = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+ if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
+ const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
// Set transform results.
- paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
- padOps.append(newPadOps.begin(), newPadOps.end());
+ paddedOps.push_back(paddedOp);
+ padOps.append(paddedOperands.begin(), paddedOperands.end());
+ rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}
results.set(cast<OpResult>(getPadded()), paddedOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d..3e787a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
- const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -109,7 +110,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
// "Full-rank" padding specification.
SmallVector<OpFoldResult> paddingSizes =
- getFullRankPaddingSizes(rewriter, indexingSizes, options);
+ getFullRankPaddingSizes(builder, indexingSizes, options);
// For each dimension in the operand's shape, iterate over indexingSizes and
// add the various term contributions.
@@ -147,28 +148,27 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
- bindDims(rewriter.getContext(), d0);
- bindSymbols(rewriter.getContext(), s0);
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, composedMap,
- {indexingSizes[paddingDim], paddingSize},
+ builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
} else {
// Otherwise just set to paddingSize.
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, projectedMap, paddingSize);
+ builder, loc, projectedMap, paddingSize);
}
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
// multiplier.
AffineExpr d0;
- bindDims(rewriter.getContext(), d0);
+ bindDims(builder.getContext(), d0);
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
- rewriter, loc, subtractMap, {paddingDimOfr});
+ builder, loc, subtractMap, {paddingDimOfr});
terms.push_back(maxAccessIdx);
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
@@ -177,19 +177,19 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
// If there are no terms, just return the dim.
if (terms.empty()) {
paddedShape[resultIndex] =
- createFoldedDimOp(rewriter, loc, v, resultIndex);
+ createFoldedDimOp(builder, loc, v, resultIndex);
continue;
}
// Sum individual terms' contributions.
SmallVector<AffineExpr> dims(terms.size());
- bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
+ bindDimsList(builder.getContext(), MutableArrayRef{dims});
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
// Add 1 to the maximum accessed index and get the final padded size.
- OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, sumExpr + 1, terms);
+ OpFoldResult paddedDimOfr =
+ affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
@@ -198,7 +198,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &builder, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -206,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
return failure();
// clang-format off
- assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
- return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
- r.stride == OpFoldResult(rewriter.getIndexAttr(1));
+ assert(llvm::all_of(iterationDomain, [&builder](Range r) {
+ return r.offset == OpFoldResult(builder.getIndexAttr(0)) &&
+ r.stride == OpFoldResult(builder.getIndexAttr(1));
}) && "expected 0-offset 1-stride loop ranges");
// clang-format on
SmallVector<OpFoldResult> loopUpperBounds;
@@ -218,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
- rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
+ builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
}
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &builder, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -232,15 +232,15 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(),
complexTy, complexAttr);
}
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
- paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(),
getElementTypeOrSelf(v.getType()));
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
paddingValue =
- arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
+ arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr);
}
assert(paddingValue && "failed to create value from padding attribute");
@@ -259,49 +259,48 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
- return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
+ return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
- RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+ OpBuilder &builder, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
const PadSizeComputationFunction &computePaddingSizeFun) {
- LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+ LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+ SmallVector<tensor::PadOp> padOps;
+ Location loc = toPad.getLoc();
- Location loc = opToPad.getLoc();
- PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
- SmallVector<Type> types(opToPad->getOperandTypes());
- llvm::append_range(types, opToPad->getResultTypes());
+ SmallVector<Type> types(toPad->getOperandTypes());
+ llvm::append_range(types, toPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
- rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ builder.getZeroAttr(getElementTypeOrSelf(t)));
}
}
- if (llvm::any_of(opToPad->getOperands(),
+ if (llvm::any_of(toPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
- return rewriter.notifyMatchFailure(opToPad,
- "expected operation on tensors");
+ LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+ return failure();
}
- OpBuilder::InsertionGuard g(rewriter);
- // Set IP after opToPad because we also take the dims of opToPad's output.
- rewriter.setInsertionPointAfter(opToPad);
+ OpBuilder::InsertionGuard g(builder);
+ // Set IP after toPad because we also take the dims of toPad's output.
+ builder.setInsertionPointAfter(toPad);
// 1. Get the loopUpperBounds from the TilingInterface.
- SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+ SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
// 2. For each operand.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad->getNumOperands());
- for (OpOperand &opOperand : opToPad->getOpOperands()) {
+ newOperands.reserve(toPad->getNumOperands());
+ for (OpOperand &opOperand : toPad->getOpOperands()) {
Value operand = opOperand.get();
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+ LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
@@ -311,30 +310,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
newOperands.push_back(operand);
continue;
}
+
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
- computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+ computePaddingSizeFun(builder, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
- return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+ LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+ return failure();
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
- return rewriter.notifyMatchFailure(opToPad,
- "--no padding value specified");
+ LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+ return failure();
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
- Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
- *maybePaddedShape, paddingValueAttr);
+ Value paddedOperand =
+ padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+ *maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
- // 2.d. Perform actual padding.
newOperands.push_back(paddedOperand);
if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
@@ -342,38 +342,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
+ if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+ LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+ return failure();
}
- assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ assert(reifiedResultShapes.size() == toPad->getNumResults() &&
"expected same number of results");
- // Clone `opToPad` to operate on the statically padded shapes.
+ // Clone `toPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
- // clone **should** properly notify the rewriter.
+ ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+ // clone **should** properly notify the builder.
TilingInterface paddedOp =
- clone(rewriter, opToPad, resultTensorTypes, newOperands);
+ clone(builder, toPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
- // Recover the slice out of the new static results. This keeps the original
- // opToPad around because it uses the dims of the original results.
+ // Recover the slice out of the new static results.
SmallVector<Value> paddedSubtensorResults;
- paddedSubtensorResults.reserve(opToPad->getNumResults());
+ paddedSubtensorResults.reserve(toPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
- rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
- rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
- return paddedOp;
+ return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults};
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 507597b..94947b7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,45 @@ public:
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned srcStaticCount = llvm::count_if(
+ llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+ SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 49b7162..6f815ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -121,7 +121,7 @@ struct EmulateWideIntPass final
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
- // Add common pattenrs to support contants, functions, etc.
+ // Add common patterns to support contants, functions, etc.
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 90cbbd8..dcfe2c7 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1030,12 +1030,12 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
//===----------------------------------------------------------------------===//
/// Create and populate an init region for privatization recipes.
-/// Returns the init block on success, or nullptr on failure.
+/// Returns success if the region is populated, failure otherwise.
/// Sets needsFree to indicate if the allocated memory requires deallocation.
-static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
- Type varType, StringRef varName,
- ValueRange bounds,
- bool &needsFree) {
+static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
+ Region &initRegion, Type varType,
+ StringRef varName, ValueRange bounds,
+ bool &needsFree) {
// Create init block with arguments: original value + bounds
SmallVector<Type> argTypes{varType};
SmallVector<Location> argLocs{loc};
@@ -1044,9 +1044,9 @@ static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
argLocs.push_back(loc);
}
- auto initBlock = std::make_unique<Block>();
+ Block *initBlock = builder.createBlock(&initRegion);
initBlock->addArguments(argTypes, argLocs);
- builder.setInsertionPointToStart(initBlock.get());
+ builder.setInsertionPointToStart(initBlock);
Value privatizedValue;
@@ -1060,7 +1060,7 @@ static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
privatizedValue = mappableTy.generatePrivateInit(
builder, loc, typedVar, varName, bounds, {}, needsFree);
if (!privatizedValue)
- return nullptr;
+ return failure();
} else {
assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
auto pointerLikeTy = cast<PointerLikeType>(varType);
@@ -1068,21 +1068,21 @@ static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
blockArgVar, needsFree);
if (!privatizedValue)
- return nullptr;
+ return failure();
}
// Add yield operation to init block
acc::YieldOp::create(builder, loc, privatizedValue);
- return initBlock;
+ return success();
}
/// Create and populate a copy region for firstprivate recipes.
-/// Returns the copy block on success, or nullptr on failure.
+/// Returns success if the region is populated, failure otherwise.
/// TODO: Handle MappableType - it does not yet have a copy API.
-static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
- Type varType,
- ValueRange bounds) {
+static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
+ Region &copyRegion, Type varType,
+ ValueRange bounds) {
// Create copy block with arguments: original value + privatized value +
// bounds
SmallVector<Type> copyArgTypes{varType, varType};
@@ -1092,16 +1092,16 @@ static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
copyArgLocs.push_back(loc);
}
- auto copyBlock = std::make_unique<Block>();
+ Block *copyBlock = builder.createBlock(&copyRegion);
copyBlock->addArguments(copyArgTypes, copyArgLocs);
- builder.setInsertionPointToStart(copyBlock.get());
+ builder.setInsertionPointToStart(copyBlock);
bool isMappable = isa<MappableType>(varType);
bool isPointerLike = isa<PointerLikeType>(varType);
// TODO: Handle MappableType - it does not yet have a copy API.
// Otherwise, for now just fallback to pointer-like behavior.
if (isMappable && !isPointerLike)
- return nullptr;
+ return failure();
// Generate copy region body based on variable type
if (isPointerLike) {
@@ -1113,21 +1113,20 @@ static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
if (!pointerLikeTy.genCopy(
builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
cast<TypedValue<PointerLikeType>>(originalArg), varType))
- return nullptr;
+ return failure();
}
// Add terminator to copy block
acc::TerminatorOp::create(builder, loc);
- return copyBlock;
+ return success();
}
/// Create and populate a destroy region for privatization recipes.
-/// Returns the destroy block on success, or nullptr if not needed.
-static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
- Location loc, Type varType,
- Value allocRes,
- ValueRange bounds) {
+/// Returns success if the region is populated, failure otherwise.
+static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
+ Region &destroyRegion, Type varType,
+ Value allocRes, ValueRange bounds) {
// Create destroy block with arguments: original value + privatized value +
// bounds
SmallVector<Type> destroyArgTypes{varType, varType};
@@ -1137,28 +1136,25 @@ static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
destroyArgLocs.push_back(loc);
}
- auto destroyBlock = std::make_unique<Block>();
+ Block *destroyBlock = builder.createBlock(&destroyRegion);
destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
- builder.setInsertionPointToStart(destroyBlock.get());
+ builder.setInsertionPointToStart(destroyBlock);
- bool isMappable = isa<MappableType>(varType);
- bool isPointerLike = isa<PointerLikeType>(varType);
- // TODO: Handle MappableType - it does not yet have a deallocation API.
- // Otherwise, for now just fallback to pointer-like behavior.
- if (isMappable && !isPointerLike)
- return nullptr;
-
- assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
- auto pointerLikeTy = cast<PointerLikeType>(varType);
- auto privatizedArg =
+ auto varToFree =
cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
- // Pass allocRes to help determine the allocation type
- if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
- return nullptr;
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
+ return failure();
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
+ return failure();
+ }
acc::TerminatorOp::create(builder, loc);
-
- return destroyBlock;
+ return success();
}
} // namespace
@@ -1220,40 +1216,33 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
if (!isMappable && !isPointerLike)
return std::nullopt;
- // Create init and destroy blocks using shared helpers
OpBuilder::InsertionGuard guard(builder);
- // Save the original insertion point for creating the recipe operation later
- auto originalInsertionPoint = builder.saveInsertionPoint();
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+ // Populate the init region
bool needsFree = false;
- auto initBlock =
- createInitRegion(builder, loc, varType, varName, bounds, needsFree);
- if (!initBlock)
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
return std::nullopt;
+ }
// Only create destroy region if the allocation needs deallocation
- std::unique_ptr<Block> destroyBlock;
if (needsFree) {
// Extract the allocated value from the init block's yield operation
- auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
Value allocRes = yieldOp.getOperand(0);
- destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
- if (!destroyBlock)
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
return std::nullopt;
+ }
}
- // Now create the recipe operation at the original insertion point and attach
- // the blocks
- builder.restoreInsertionPoint(originalInsertionPoint);
- auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
-
- // Move the blocks into the recipe's regions
- recipe.getInitRegion().push_back(initBlock.release());
- if (destroyBlock)
- recipe.getDestroyRegion().push_back(destroyBlock.release());
-
return recipe;
}
@@ -1299,45 +1288,40 @@ FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
if (!isMappable && !isPointerLike)
return std::nullopt;
- // Create init, copy, and destroy blocks using shared helpers
OpBuilder::InsertionGuard guard(builder);
- // Save the original insertion point for creating the recipe operation later
- auto originalInsertionPoint = builder.saveInsertionPoint();
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+ // Populate the init region
bool needsFree = false;
- auto initBlock =
- createInitRegion(builder, loc, varType, varName, bounds, needsFree);
- if (!initBlock)
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
return std::nullopt;
+ }
- auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
- if (!copyBlock)
+ // Populate the copy region
+ if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
+ bounds))) {
+ recipe.erase();
return std::nullopt;
+ }
// Only create destroy region if the allocation needs deallocation
- std::unique_ptr<Block> destroyBlock;
if (needsFree) {
// Extract the allocated value from the init block's yield operation
- auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
Value allocRes = yieldOp.getOperand(0);
- destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
- if (!destroyBlock)
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
return std::nullopt;
+ }
}
- // Now create the recipe operation at the original insertion point and attach
- // the blocks
- builder.restoreInsertionPoint(originalInsertionPoint);
- auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
-
- // Move the blocks into the recipe's regions
- recipe.getInitRegion().push_back(initBlock.release());
- recipe.getCopyRegion().push_back(copyBlock.release());
- if (destroyBlock)
- recipe.getDestroyRegion().push_back(destroyBlock.release());
-
return recipe;
}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05..abc1316 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
@@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
+ producesHandle(getResults(), effects);
}
DiagnosedSilenceableFailure
@@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
- return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+ return emitSilenceableFailure(getLoc())
+ << "op does not have interpreted semantics yet";
}
LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ auto yieldTerminator =
+ dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << mlir::smt::YieldOp::getOperationName()
+ << "' as terminator";
+
+ auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
+ Type paramType, StringRef paramDesc,
+ auto *atOp) -> InFlightDiagnostic {
+ if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
+ smtType))
+ return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
+ << " is expected to be either a !smt.bool, a "
+ "!smt.int, or a !smt.bv";
+
+ assert(isa<TransformParamTypeInterface>(paramType) &&
+ "ODS specifies params' type should implement param interface");
+ if (isa<transform::AnyParamType>(paramType))
+ return {}; // No further checks can be done.
+
+ // NB: This cast must succeed as long as the only implementors of
+ // TransformParamTypeInterface are AnyParamType and ParamType.
+ Type typeWrappedByParam = cast<ParamType>(paramType).getType();
+
+ if (isa<mlir::smt::IntType>(smtType)) {
+ if (!isa<IntegerType>(typeWrappedByParam))
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.int though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != 1)
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.bool though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping i1";
+ } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx << " is " << smtType
+ << " though the corresponding " << paramDesc << " type ("
+ << paramType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+
+ return {};
+ };
+
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
+ for (auto [idx, operandType, blockArgType] :
+ llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, blockArgType, "block arg", operandType, "operand",
+ /*atOp=*/this);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
+ }
+
for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [idx, termOperandType, resultType] : llvm::enumerate(
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, termOperandType, "terminator operand",
+ cast<transform::ParamType>(resultType), "result",
+ /*atOp=*/&yieldTerminator);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 12e6475..7c019e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2032,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2072,20 +2080,24 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(),
+ /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
+ /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
+ /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 1b656d8..ea93085 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
}
};
+/// Convert broadcasts from scalars or 1-element vectors, such as
+///
+/// ```mlir
+/// vector.broadcast %value : f32 to vector<4x4xf32>
+/// ```
+///
+/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
+/// The above becomes,
+///
+/// ```mlir
+/// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// ```
+struct LinearizeVectorBroadcast final
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using Base::Base;
+
+ LinearizeVectorBroadcast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ int numElements = 1;
+ Type sourceType = broadcastOp.getSourceType();
+ if (auto vecType = dyn_cast<VectorType>(sourceType)) {
+ numElements = vecType.getNumElements();
+ }
+
+ if (numElements != 1) {
+ return rewriter.notifyMatchFailure(
+ broadcastOp, "only broadcasts of single elements can be linearized.");
+ }
+
+ auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
+ adaptor.getSource());
+
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
- LinearizeVectorFromElements, LinearizeVectorToElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorBroadcast, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a..c809c502 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
- // transpose pattens. Otherwise, these patterns are not applicable.
+ // transpose patterns. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
op.getPermutation()))
return failure();
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 89b62a2..a514ea9 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
@@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
opPrinter.printKeywordOrString("else ");
opPrinter.printRegion(elseRegion);
}
-
-ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
- std::string keyword;
- auto initLocation = opParser.getCurrentLocation();
- std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
- if (keyword == "nested" or keyword == "") {
- visibility = StringAttr::get(opParser.getContext(), "nested");
- return ParseResult::success();
- }
-
- if (keyword == "public" || keyword == "private") {
- visibility = StringAttr::get(opParser.getContext(), keyword);
- return ParseResult::success();
- }
- opParser.emitError(initLocation, "expecting symbol visibility");
- return ParseResult::failure();
-}
-
-void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
- Attribute visibility) {
- opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
-}
} // namespace
#define GET_OP_CLASSES
@@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() {
void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, FunctionType funcType) {
- FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ bool exported{false};
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ exported = true;
+ }
+
auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
@@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return builder.getFunctionType(argTypesWithoutLocal, results);
};
-
- return function_interface_impl::parseFunctionOp(
+ auto funcParseRes = function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ if (exported)
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ return funcParseRes;
}
LogicalResult FuncOp::verifyBody() {
@@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() {
}
void FuncOp::print(OpAsmPrinter &p) {
+ /// If exported, print it before and mask it before printing
+ /// using generic interface.
+ auto exported = getExported();
+ if (exported) {
+ p << " exported";
+ removeExportedAttr();
+ }
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
+ if (exported)
+ setExported(true);
}
//===----------------------------------------------------------------------===//
@@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, StringRef moduleName,
StringRef importName, FunctionType type) {
FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, {}, {}, odsBuilder.getStringAttr("nested"));
+ type, {}, {});
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
-
-void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, Type type, bool isMutable) {
- GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
- odsBuilder.getStringAttr("nested"));
-}
-
// Custom formats
ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr symbolName;
Type globalType;
auto *ctx = parser.getContext();
- ParseResult res = parser.parseSymbolName(
- symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ }
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
res = parser.parseType(globalType);
result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
std::string mutableString;
res = parser.parseOptionalKeywordOrString(&mutableString);
if (res.succeeded() && mutableString == "mutable")
result.addAttribute("isMutable", UnitAttr::get(ctx));
- std::string visibilityString;
- res = parser.parseOptionalKeywordOrString(&visibilityString);
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, visibilityString));
+
res = parser.parseColon();
Region *globalInitRegion = result.addRegion();
res = parser.parseRegion(*globalInitRegion);
@@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
}
void GlobalOp::print(OpAsmPrinter &printer) {
+ if (getExported())
+ printer << " exported";
printer << " @" << getSymName().str() << " " << getType();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " :";
Region &body = getRegion();
if (!body.empty()) {
@@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// GlobalImportOp
//===----------------------------------------------------------------------===//
-void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, Type type, bool isMutable) {
- GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, isMutable, odsBuilder.getStringAttr("nested"));
-}
-
ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
auto *ctx = parser.getContext();
ParseResult res = parseImportOp(parser, result);
@@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
if (res.succeeded() && mutableOrSymVisString == "mutable") {
result.addAttribute("isMutable", UnitAttr::get(ctx));
- res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
}
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, mutableOrSymVisString));
res = parser.parseColon();
Type importedType;
@@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) {
<< "\" as @" << getSymName();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " : " << getType();
}
@@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() {
Block *LoopOp::getLabelTarget() { return &getBody().front(); }
//===----------------------------------------------------------------------===//
-// MemOp
-//===----------------------------------------------------------------------===//
-
-void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, LimitType limit) {
- MemOp::build(odsBuilder, odsState, symbol, limit,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// MemImportOp
-//===----------------------------------------------------------------------===//
-
-void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, LimitType limits) {
- MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- limits, odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
// ReinterpretOp
//===----------------------------------------------------------------------===//
@@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() {
//===----------------------------------------------------------------------===//
void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
-
-//===----------------------------------------------------------------------===//
-// TableOp
-//===----------------------------------------------------------------------===//
-
-void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, TableType type) {
- TableOp::build(odsBuilder, odsState, symbol, type,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// TableImportOp
-//===----------------------------------------------------------------------===//
-
-void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, TableType type) {
- TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, odsBuilder.getStringAttr("nested"));
-}