aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp11
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp3
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp13
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp33
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp52
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp32
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp47
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp10
9 files changed, 84 insertions, 119 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index d5c7190..f405d0c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -40,6 +41,15 @@ using namespace mlir::amdgpu;
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
+namespace {
+struct AMDGPUInlinerInterface final : DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
+
void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -49,6 +59,7 @@ void AMDGPUDialect::initialize() {
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
>();
+ addInterfaces<AMDGPUInlinerInterface>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index c64e10f5..d018cdd 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, arith::SelectOp, vector::SplatOp,
- vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index a50ddbe..624519f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
return returnOp;
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
-}
-
LogicalResult
mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
IRRewriter rewriter(module.getContext());
@@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
// Collect the mapping of functions to their call sites.
module.walk([&](func::CallOp callOp) {
- if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
+ if (func::FuncOp calledFunc =
+ dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
callerMap[calledFunc].insert(callOp);
}
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a0..5edcc40b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args: dst, mbar, src, size.
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getSize()));
+
+ // Multicast mask, if available.
+ mlir::Value multicastMask = thisOp.getMulticastMask();
+ const bool hasMulticastMask = static_cast<bool>(multicastMask);
+ llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+
+ // Cache hint, if available.
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+ // Flag arguments for multicast and cachehint.
+ args.push_back(builder.getInt1(hasMulticastMask));
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ llvm::Intrinsic::ID id =
+ llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 14e235f..a7e3ba8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
-/// 1s, are considered to be 'broadcastlike'.
+/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
+/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
- if (isa<BroadcastOp, SplatOp>(op))
+ if (isa<BroadcastOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
-/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
-/// value that is splatted. Otherwise return null.
+/// vector.broadcast with a scalar operand, return the scalar value that is
+/// splatted. Otherwise return null.
///
-/// Examples:
+/// Example:
///
-/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
@@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) {
if (!defOp)
return {};
- // Splat:
- if (auto splat = dyn_cast<vector::SplatOp>(defOp))
- return splat.getInput();
-
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
// Not broadcast (and not splat):
@@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}
-//===----------------------------------------------------------------------===//
-// SplatOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getInput();
- if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
- return {};
-
- // SplatElementsAttr::get treats single value for second arg as being a splat.
- return SplatElementsAttr::get(getType(), {constOperand});
-}
-
-// Canonicalizer for vector.splat. It always gets canonicalized to a
-// vector.broadcast.
-class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
-public:
- using Base::Base;
- LogicalResult matchAndRewrite(SplatOp splatOp,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
- splatOp.getOperand());
- return success();
- }
-};
-void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<SplatToBroadcastPattern>(context);
-}
-
-void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 255f2bf..3a3231d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
- // TODO: add support to `vector.splat`.
+ // TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 71fba71c..1b656d8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};
-/// This pattern converts the SplatOp to work on a linearized vector.
-/// Following,
-/// vector.splat %value : vector<4x4xf32>
-/// is converted to:
-/// %out_1d = vector.splat %value : vector<16xf32>
-/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
-struct LinearizeVectorSplat final
- : public OpConversionPattern<vector::SplatOp> {
- using Base::Base;
-
- LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit) {}
-
- LogicalResult
- matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto dstTy = getTypeConverter()->convertType(splatOp.getType());
- if (!dstTy)
- return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
- rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
- dstTy);
- return success();
- }
-};
-
/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
@@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements,
- LinearizeVectorToElements>(typeConverter, patterns.getContext());
+ LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
+ LinearizeVectorFromElements, LinearizeVectorToElements>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d6a6d7cd..726da1e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// This transforms IR like:
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
-// %cst = vector.splat %c0_f32 : vector<4xf32>
+// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
// %1 = vector.extract_strided_slice %0 {
// offsets = [0], sizes = [4], strides = [1]
// } : vector<8xf16> to vector<4xf16>
@@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) {
return newElementType;
}
-/// If `value` is the result of a splat or broadcast operation, return the input
-/// of the splat/broadcast operation.
+/// If `value` is the result of a broadcast operation, return the input
+/// of the broadcast operation.
static Value getBroadcastLikeSource(Value value) {
Operation *op = value.getDefiningOp();
@@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) {
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcast.getSource();
- if (auto splat = dyn_cast<vector::SplatOp>(op))
- return splat.getInput();
-
return {};
}
-/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
///
/// Example:
/// ```
@@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
-///
-/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
-/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final
Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
- Value splatSource;
+ Value broadcastSource;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
- splatSource = getBroadcastLikeSource(operand);
+ broadcastSource = getBroadcastLikeSource(operand);
break;
}
- if (!splatSource)
+ if (!broadcastSource)
return failure();
Type unbroadcastResultType =
- cloneOrReplace(splatSource.getType(), resultElemType);
+ cloneOrReplace(broadcastSource.getType(), resultElemType);
// Make sure that all operands are broadcast from identically-shaped types:
- // * scalar (`vector.broadcast` + `vector.splat`), or
+ // * scalar (`vector.broadcast`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
if (auto source = getBroadcastLikeSource(val))
return haveSameShapeAndScaling(source.getType(),
- splatSource.getType());
+ broadcastSource.getType());
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
@@ -1271,19 +1265,18 @@ public:
}
};
-/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
///
/// Example:
/// ```
-/// %0 = vector.splat %arg2 : vector<1xf32>
+/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
/// ```
/// Gets converted to:
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
-class StoreOpFromSplatOrBroadcast final
- : public OpRewritePattern<vector::StoreOp> {
+class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
public:
using Base::Base;
@@ -1308,9 +1301,9 @@ public:
return rewriter.notifyMatchFailure(
op, "value to store is not from a broadcast");
- // Checking for single use so we can remove splat.
- Operation *splat = toStore.getDefiningOp();
- if (!splat->hasOneUse())
+ // Checking for single use so we can remove broadcast.
+ Operation *broadcast = toStore.getDefiningOp();
+ if (!broadcast->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
Value base = op.getBase();
@@ -1321,7 +1314,7 @@ public:
} else {
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
}
- rewriter.eraseOp(splat);
+ rewriter.eraseOp(broadcast);
return success();
}
};
@@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
- patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
- patterns.getContext(), benefit);
+ patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index f1dbc5d..26770b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
/// }
/// return %0
/// }
-struct MoveFuncBodyToWarpExecuteOnLane0
- : public OpRewritePattern<gpu::GPUFuncOp> {
+struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
@@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
/*pattern benefit=*/highPatternBenefit);
}
+void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
+}
+
void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 1: Attach layouts to op operands.
// TODO: Following assumptions are made:
@@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// gpu.warp_execute_on_lane_0 operation.
{
RewritePatternSet patterns(&getContext());
- patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
+ xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();