diff options
Diffstat (limited to 'mlir/lib/Conversion')
10 files changed, 256 insertions, 64 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 478b6aa..3a307a0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) { .Case([](Float6E2M3FNType) { return 2u; }) .Case([](Float6E3M2FNType) { return 3u; }) .Case([](Float4E2M1FNType) { return 4u; }) - .Default([](Type) { return std::nullopt; }); + .Default(std::nullopt); } /// If there is a scaled MFMA instruction for the input element types `aType` @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { smfma.getN(), smfma.getK(), 1u, chipset); } -/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` -/// if one exists. This includes checking to ensure the intrinsic is supported -/// on the architecture you are compiling for. -static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, - Chipset chipset) { - auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); - auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); - auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); - Type elemSourceType = sourceVectorType.getElementType(); - Type elemBSourceType = sourceBVectorType.getElementType(); - Type elemDestType = destVectorType.getElementType(); - - const uint32_t k = wmma.getK(); +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for RDNA3/4 architectures. +static std::optional<StringRef> +wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, + Type elemDestType, uint32_t k, bool isRDNA3) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + // Handle k == 16 for RDNA3/4. if (k == 16) { + // Common patterns for RDNA3 and RDNA4. if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) @@ -1014,40 +1010,161 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { + + // RDNA3 specific patterns. + if (isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + return std::nullopt; } - } - if (chipset.majorVersion < 12) - return std::nullopt; - // gfx12+ - if (k == 16) { - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + // RDNA4 specific patterns (fp8/bf8). + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } - if (k == 32) { + + // Handle k == 32 for RDNA4. + if (k == 32 && !isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + } + + return std::nullopt; +} + +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for the gfx1250 architecture. +static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType, + Type elemBSourceType, + Type elemDestType, + uint32_t k) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + if (k == 4) { + if (elemSourceType.isF32() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); + + return std::nullopt; + } + + if (k == 32) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); + return std::nullopt; } - llvm_unreachable("unhandled WMMA case"); + if (k == 64) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); + } + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); + + return std::nullopt; + } + + if (k == 128) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); + } + + return std::nullopt; + } + + return std::nullopt; +} + +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, + Chipset chipset) { + auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); + auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); + auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + const bool isRDNA3 = chipset.majorVersion == 11; + const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; + + // Handle RDNA3 and RDNA4. + if (isRDNA3 || isRDNA4) + return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, + k, isRDNA3); + + // Handle gfx1250. + if (chipset == Chipset{12, 5, 0}) + return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, + elemDestType, k); + + return std::nullopt; } namespace { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 247dba1..cfdcd9c 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!skipOp) { break; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0fe7239..9e46b7d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,25 +313,53 @@ private: struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { using OpConversionPattern<complex::ExpOp>::OpConversionPattern; + // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y)) + // Handle special cases as StableHLO implementation does: + // 1. When b == 0, set imag(exp(z)) = 0 + // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2) LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getComplex().getType()); - auto elementType = cast<FloatType>(type.getElementType()); - arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - - Value real = - complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value imag = - complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); - Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); + auto ET = cast<FloatType>(type.getElementType()); + arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); + const auto &floatSemantics = ET.getFloatSemantics(); + ImplicitLocOpBuilder b(loc, rewriter); + + Value x = complex::ReOp::create(b, ET, adaptor.getComplex()); + Value y = complex::ImOp::create(b, ET, adaptor.getComplex()); + Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET)); + Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5)); + Value inf = arith::ConstantOp::create( + b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics))); + + Value exp = math::ExpOp::create(b, x, fmf); + Value xHalf = arith::MulFOp::create(b, x, half, fmf); + Value expHalf = math::ExpOp::create(b, xHalf, fmf); + Value cos = math::CosOp::create(b, y, fmf); + Value sin = math::SinOp::create(b, y, fmf); + + Value expIsInf = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf); + Value yIsZero = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero); + + // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2) + Value realNormal = arith::MulFOp::create(b, exp, cos, fmf); + Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf); + Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf); Value resultReal = - arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); - Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); - Value resultImag = - arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); + arith::SelectOp::create(b, expIsInf, realOverflow, realNormal); + + // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and + // exp(x/2)*sin(y)*exp(x/2) + Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf); + Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf); + Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf); + Value imagNonZero = + arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal); + Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 25f1e1b..425594b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { } return std::nullopt; }) - .Default([](auto) { return std::nullopt; }); + .Default(std::nullopt); } static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index a9efada..ec182f1 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(mbarrierType)) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( - op, barrier, count, adaptor.getPredicate()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, - adaptor.getPredicate()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, + adaptor.getPredicate()); return success(); } }; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index b711e33..a4c66e1 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), - /*resultTypes=*/TypeRange(), rewriteName, + /*results=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 7d0a236..76a822b 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -27,6 +28,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/Support/DebugLog.h" #include <optional> @@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, bool seenSideeffects = false; // Whether we have left a nesting scope (and hence are no longer innermost). bool leftNestingScope = false; + LocalAliasAnalysis aliasAnalysis; + llvm::DenseSet<Value> writtenBuffer; while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); // Now walk over the body and clone it. // TODO: This is only correct if there either is no further scf.parallel - // nested or this code is side-effect free. Otherwise we might need - // predication. We are overly conservative for now and only allow - // side-effects in the innermost scope. + // nested or this code has side-effect but the memory buffer is not + // alias to inner loop access buffer. Otherwise we might need + // predication. if (auto nestedParallel = dyn_cast<ParallelOp>(op)) { // Before entering a nested scope, make sure there have been no - // sideeffects until now. - if (seenSideeffects) - return failure(); + // sideeffects until now or the nested operations do not access the + // buffer written by outer scope. + if (seenSideeffects) { + WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) { + if (isMemoryEffectFree(nestedOp)) + return WalkResult::advance(); + + auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp); + if (!memEffectInterface) + return WalkResult::advance(); + + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Read>(effect.getEffect()) || + isa<MemoryEffects::Write>(effect.getEffect())) { + Value baseBuffer = effect.getValue(); + if (!baseBuffer) + return WalkResult::interrupt(); + for (Value val : writtenBuffer) { + if (aliasAnalysis.alias(baseBuffer, val) != + AliasResult::NoAlias) { + return WalkResult::interrupt(); + } + } + } + } + return WalkResult::advance(); + }); + if (walkRes.wasInterrupted()) + return failure(); + } // A nested scf.parallel needs insertion of code to compute indices. // Insert that now. This will also update the worklist with the loops // body. @@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, rewriter.setInsertionPointAfter(parent); leftNestingScope = true; seenSideeffects = false; + writtenBuffer.clear(); } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) { // Convert scf.reduction op auto parentLoop = op->getParentOfType<ParallelOp>(); @@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, Operation *clone = rewriter.clone(*op, cloningMap); cloningMap.map(op->getResults(), clone->getResults()); // Check for side effects. + if (!isMemoryEffectFree(clone)) { + // Record the buffer accessed by the operations with write effects. + if (auto memEffectInterface = + dyn_cast<MemoryEffectOpInterface>(clone)) { + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Write>(effect.getEffect())) { + Value writtenBase = effect.getValue(); + // Conservatively return failure if we cannot find the written + // address. + if (!writtenBase) + return failure(); + writtenBuffer.insert(writtenBase); + } + } + } + } // TODO: Handle region side effects properly. seenSideeffects |= !isMemoryEffectFree(clone) || clone->getNumRegions() != 0; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 41d8d53..69a317ec 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -716,7 +716,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, llvmType, accumulator); return LLVMRedIntrinOp::create(rewriter, loc, llvmType, - /*startValue=*/accumulator, vectorOperand, + /*start_value=*/accumulator, vectorOperand, fmf); } @@ -743,7 +743,7 @@ static Value lowerPredicatedReductionWithStartValue( Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, - /*startValue=*/accumulator, vectorOperand, + /*satrt_value=*/accumulator, vectorOperand, mask, vectorLength); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d80..91c1aa5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) { [](auto floatAttr) { return floatAttr.getValue().isZero(); }) .Case<IntegerAttr>( [](auto intAttr) { return intAttr.getValue().isZero(); }) - .Default([](auto) { return false; }); + .Default(false); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66d..33e8f2e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern // If source is a memref, we need to extract the aligned pointer as index. // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { - if (!sourceMemrefTy.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, "Expected static memref shape."); + if (!sourceMemrefTy.hasRank()) { + return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); |
