aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp177
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp2
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp54
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp2
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp9
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp64
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp2
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp4
10 files changed, 256 insertions, 64 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aa..3a307a0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
.Case([](Float6E2M3FNType) { return 2u; })
.Case([](Float6E3M2FNType) { return 3u; })
.Case([](Float4E2M1FNType) { return 4u; })
- .Default([](Type) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// If there is a scaled MFMA instruction for the input element types `aType`
@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
smfma.getN(), smfma.getK(), 1u, chipset);
}
-/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
-/// if one exists. This includes checking to ensure the intrinsic is supported
-/// on the architecture you are compiling for.
-static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
- Chipset chipset) {
- auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
- Type elemSourceType = sourceVectorType.getElementType();
- Type elemBSourceType = sourceBVectorType.getElementType();
- Type elemDestType = destVectorType.getElementType();
-
- const uint32_t k = wmma.getK();
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for RDNA3/4 architectures.
+static std::optional<StringRef>
+wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
+ Type elemDestType, uint32_t k, bool isRDNA3) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ // Handle k == 16 for RDNA3/4.
if (k == 16) {
+ // Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1014,40 +1010,161 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
+
+ // RDNA3 specific patterns.
+ if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ return std::nullopt;
}
- }
- if (chipset.majorVersion < 12)
- return std::nullopt;
- // gfx12+
- if (k == 16) {
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ // RDNA4 specific patterns (fp8/bf8).
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
- if (k == 32) {
+
+ // Handle k == 32 for RDNA4.
+ if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ }
+
+ return std::nullopt;
+}
+
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for the gfx1250 architecture.
+static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
+ Type elemBSourceType,
+ Type elemDestType,
+ uint32_t k) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ if (k == 4) {
+ if (elemSourceType.isF32() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 32) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
return std::nullopt;
}
- llvm_unreachable("unhandled WMMA case");
+ if (k == 64) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+ }
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 128) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+ }
+
+ return std::nullopt;
+ }
+
+ return std::nullopt;
+}
+
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+ const bool isRDNA3 = chipset.majorVersion == 11;
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+
+ // Handle RDNA3 and RDNA4.
+ if (isRDNA3 || isRDNA4)
+ return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
+ k, isRDNA3);
+
+ // Handle gfx1250.
+ if (chipset == Chipset{12, 5, 0})
+ return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
+ elemDestType, k);
+
+ return std::nullopt;
}
namespace {
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 247dba1..cfdcd9c 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!skipOp) {
break;
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0fe7239..9e46b7d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,25 +313,53 @@ private:
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
+ // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
+ // Handle special cases as StableHLO implementation does:
+ // 1. When b == 0, set imag(exp(z)) = 0
+ // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
LogicalResult
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
- Value real =
- complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value imag =
- complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
- Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
+ auto ET = cast<FloatType>(type.getElementType());
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+ const auto &floatSemantics = ET.getFloatSemantics();
+ ImplicitLocOpBuilder b(loc, rewriter);
+
+ Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
+ Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
+ Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
+ Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
+ Value inf = arith::ConstantOp::create(
+ b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
+
+ Value exp = math::ExpOp::create(b, x, fmf);
+ Value xHalf = arith::MulFOp::create(b, x, half, fmf);
+ Value expHalf = math::ExpOp::create(b, xHalf, fmf);
+ Value cos = math::CosOp::create(b, y, fmf);
+ Value sin = math::SinOp::create(b, y, fmf);
+
+ Value expIsInf =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
+ Value yIsZero =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
+
+ // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
+ Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
+ Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
+ Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
Value resultReal =
- arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
- Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
- Value resultImag =
- arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
+ arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
+
+ // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
+ // exp(x/2)*sin(y)*exp(x/2)
+ Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
+ Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
+ Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
+ Value imagNonZero =
+ arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
+ Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 25f1e1b..425594b 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
}
return std::nullopt;
})
- .Default([](auto) { return std::nullopt; });
+ .Default(std::nullopt);
}
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a9efada..ec182f1 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value count = truncToI32(b, adaptor.getCount());
- if (isMbarrierShared(mbarrierType)) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
- op, barrier, count, adaptor.getPredicate());
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
- adaptor.getPredicate());
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+ adaptor.getPredicate());
return success();
}
};
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index b711e33..a4c66e1 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
- /*resultTypes=*/TypeRange(), rewriteName,
+ /*results=*/TypeRange(), rewriteName,
args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 7d0a236..76a822b 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPU.h"
+#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -27,6 +28,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/DebugLog.h"
#include <optional>
@@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
bool seenSideeffects = false;
// Whether we have left a nesting scope (and hence are no longer innermost).
bool leftNestingScope = false;
+ LocalAliasAnalysis aliasAnalysis;
+ llvm::DenseSet<Value> writtenBuffer;
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
// Now walk over the body and clone it.
// TODO: This is only correct if there either is no further scf.parallel
- // nested or this code is side-effect free. Otherwise we might need
- // predication. We are overly conservative for now and only allow
- // side-effects in the innermost scope.
+ // nested or this code has side-effect but the memory buffer is not
+ // alias to inner loop access buffer. Otherwise we might need
+ // predication.
if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
// Before entering a nested scope, make sure there have been no
- // sideeffects until now.
- if (seenSideeffects)
- return failure();
+ // sideeffects until now or the nested operations do not access the
+ // buffer written by outer scope.
+ if (seenSideeffects) {
+ WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) {
+ if (isMemoryEffectFree(nestedOp))
+ return WalkResult::advance();
+
+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
+ if (!memEffectInterface)
+ return WalkResult::advance();
+
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Read>(effect.getEffect()) ||
+ isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value baseBuffer = effect.getValue();
+ if (!baseBuffer)
+ return WalkResult::interrupt();
+ for (Value val : writtenBuffer) {
+ if (aliasAnalysis.alias(baseBuffer, val) !=
+ AliasResult::NoAlias) {
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (walkRes.wasInterrupted())
+ return failure();
+ }
// A nested scf.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
@@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.setInsertionPointAfter(parent);
leftNestingScope = true;
seenSideeffects = false;
+ writtenBuffer.clear();
} else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
// Convert scf.reduction op
auto parentLoop = op->getParentOfType<ParallelOp>();
@@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
Operation *clone = rewriter.clone(*op, cloningMap);
cloningMap.map(op->getResults(), clone->getResults());
// Check for side effects.
+ if (!isMemoryEffectFree(clone)) {
+ // Record the buffer accessed by the operations with write effects.
+ if (auto memEffectInterface =
+ dyn_cast<MemoryEffectOpInterface>(clone)) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value writtenBase = effect.getValue();
+ // Conservatively return failure if we cannot find the written
+ // address.
+ if (!writtenBase)
+ return failure();
+ writtenBuffer.insert(writtenBase);
+ }
+ }
+ }
+ }
// TODO: Handle region side effects properly.
seenSideeffects |=
!isMemoryEffectFree(clone) || clone->getNumRegions() != 0;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d53..69a317ec 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -716,7 +716,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
- /*startValue=*/accumulator, vectorOperand,
+ /*start_value=*/accumulator, vectorOperand,
fmf);
}
@@ -743,7 +743,7 @@ static Value lowerPredicatedReductionWithStartValue(
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
- /*startValue=*/accumulator, vectorOperand,
+ /*satrt_value=*/accumulator, vectorOperand,
mask, vectorLength);
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d80..91c1aa5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) {
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
.Case<IntegerAttr>(
[](auto intAttr) { return intAttr.getValue().isZero(); })
- .Default([](auto) { return false; });
+ .Default(false);
}
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66d..33e8f2e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern
// If source is a memref, we need to extract the aligned pointer as index.
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
- if (!sourceMemrefTy.hasStaticShape()) {
- return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+ if (!sourceMemrefTy.hasRank()) {
+ return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);