diff options
Diffstat (limited to 'mlir/lib/Conversion')
37 files changed, 490 insertions, 495 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index bc0d9bf..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -232,8 +232,8 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, - kAllocatedPtrPosInMemRefDescriptor); + SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, @@ -481,16 +481,16 @@ struct MemoryCounterWaitOpLowering if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional<int> ds = adaptor.getDs()) - rewriter.create<ROCDL::WaitDscntOp>(loc, *ds); + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional<int> load = adaptor.getLoad()) - rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load); + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional<int> store = adaptor.getStore()) - rewriter.create<ROCDL::WaitStorecntOp>(loc, *store); + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional<int> exp = adaptor.getExp()) - rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp); + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 8c68b57..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -449,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, else if (scaleType.getIntOrFloatBitWidth() > 32) scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Value inCast = vector::BroadcastOp::create(rewriter, loc, @@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = amdgpu::ScaledExtPackedOp::create( - rewriter, loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleExt, 0, sliceWidth, 1); - blockResult = vector::InsertStridedSliceOp::create( - rewriter, loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); @@ -555,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); @@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value zero = arith::ConstantOp::create(rewriter, loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); @@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( - rewriter, loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { scaleTrunc = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleTrunc, 0, sliceWidth, 1); + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } blockResult = vector::InsertStridedSliceOp::create( rewriter, loc, scaleTrunc, blockResult, i, 1); } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 59b3fe2..515fe5c 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -402,8 +402,8 @@ public: Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); // Actual cast (may change bitwidth) - auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), - castDestType, actualOp); + auto cast = + emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp); // Cast to the expected output type auto result = adaptValueType(cast, rewriter, opReturnType); @@ -507,8 +507,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -547,8 +547,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -748,8 +748,8 @@ public: } Value fpCastOperand = adaptor.getIn(); if (actualOperandType != operandType) { - fpCastOperand = rewriter.template create<emitc::CastOp>( - castOp.getLoc(), actualOperandType, fpCastOperand); + fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(), + actualOperandType, fpCastOperand); } rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 1510b0b..e34b368 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 79e1683..29e6552 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 30a7170..3edcbb8 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { scf::YieldOp::create(rewriter, loc, acc); }; - auto size = rewriter - .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), - loopBody) + auto size = scf::ForOp::create(rewriter, loc, zero, rank, one, + ValueRange(one), loopBody) .getResult(0); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index f84375b..785cb82 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) -add_subdirectory(MeshToMPI) +add_subdirectory(ShardToMPI) add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index eeff8a9..5ad514d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include <type_traits> diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index c8311eb..5ac838c 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; - return builder - .create<func::ReturnOp>( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) + return func::ReturnOp::create( + builder, loc, + llvm::map_to_vector( + funcOp.getResultTypes(), + [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp index c9b1dc1..ee6d7d5 100644 --- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp +++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp @@ -9,8 +9,6 @@ #include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp index 252245d..c70b5f0 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -9,7 +9,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseSet.h" using namespace mlir; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 63eb6c58..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); return LLVM::CallOp::create(builder, loc, function, arguments); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a19194e..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -559,8 +561,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index ecd5b63..2568044 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = - toDynamic - ? builder - .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 5b68eb8..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 08a4566..855c582 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -699,7 +698,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { scf::IfOp ifOp = scf::IfOp::create(builder, elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); - ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); + auto thenBuilder = ifOp.getThenBodyBuilder(); + scf::YieldOp::create(thenBuilder, loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 53a1912..6ba5bfe4 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -575,8 +575,8 @@ private: Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); - return rewriter - .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) + return LLVM::LoadOp::create(rewriter, loc, + getTypeConverter()->getIndexType(), sizePtr) .getResult(); } diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 905287e1..5d13353 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -21,7 +21,6 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 3e434ea..5bd1d49 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); predList.emplace_back(pos, builder.getIsNotNull()); - if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { + if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) { // If the attribute has a type or value, add a constraint. if (Value type = attr.getValueType()) getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp index e1a9fa59..2d9c661f 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp @@ -14,9 +14,7 @@ #include "RootOrdering.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" -#include <queue> #include <utility> using namespace mlir; diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 240491a..807be7e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); @@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index f191f35..badd2f6 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -25,9 +25,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" #include <optional> diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index aae3271..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1493,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1505,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 7025c5a..0ff9fb3 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt index 15560aa..564f36f 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRMeshToMPI - MeshToMPI.cpp +add_mlir_conversion_library(MLIRShardToMPI + ShardToMPI.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI DEPENDS MLIRConversionPassIncGen @@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI MLIRLinalgTransforms MLIRMemRefDialect MLIRPass - MLIRMeshDialect + MLIRShardDialect MLIRMPIDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 63b1fda..fd40e7c 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -1,4 +1,4 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communication ops tp MPI ops. +// This file implements a translation of Shard communication ops to MPI ops. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -20,11 +20,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" @@ -35,16 +35,16 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "mesh-to-mpi" +#define DEBUG_TYPE "shard-to-mpi" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; -using namespace mesh; +using namespace shard; namespace { /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -177,9 +177,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto type = RankedTensorType::get({nSplits, 2}, i64); Value resHaloSizes = haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) + ? tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64) .getResult() : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); @@ -188,18 +187,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { // maxSplitSize+1}. Store the offsets in the tensor but set trailing // elements for smaller split-groups to -1. Computing the max size of the // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) + // GridOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array<int64_t, 2>{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); int64_t maxSplitSize = 0; for (auto axes : splitAxes) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic); maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); } @@ -218,7 +217,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { int64_t curr = 0; for (auto [i, axes] : llvm::enumerate(splitAxes)) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef<Value> values(&offsets[curr], splitSize); @@ -264,20 +263,20 @@ struct ConvertProcessMultiIndexOp SymbolTableCollection symbolTableCollection; Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(op, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - // optionally extract subset of mesh axes + // optionally extract subset of grid axes auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector<Value> subIndex; @@ -306,13 +305,11 @@ public: auto ctx = op.getContext(); Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); + auto rank = mpi::CommRankOp::create( + rewriter, loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -338,12 +335,12 @@ struct ConvertNeighborsLinearIndicesOp Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; @@ -394,14 +391,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); if (!sharding) { return op->emitError() - << "Expected SharingOp as defining op for sharding" + << "Expected ShardingOp as defining op for sharding" << " but found " << adaptor.getSharding()[0].getDefiningOp(); } // Compute the sharded shape by applying the sharding to the input shape. // If shardedDimsOffsets is not defined in the sharding, the shard shape is // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in + // dimension (which is given by the size of the grid axes provided in // split-axes). Odd elements get distributed to trailing shards. If a // shardedDimsOffsets is provided, the shard shape is computed by // subtracting the offset of the current shard from the offset of the next @@ -431,11 +428,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { SmallVector<Value> multiIdx = getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + // Get the GridOp, the grid shape is needed to compute the sharded shape. SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(sharding, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); auto splitAxes = sharding.getSplitAxes().getAxes(); @@ -455,7 +452,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { tmp); } - // With static mesh shape the sizes of the split axes are known. + // With static grid shape the sizes of the split axes are known. // Hence the start/pos for each split axes in shardDimsOffsets can be // computed statically. int64_t pos = 0; @@ -475,10 +472,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { // Create a value from the static position in shardDimsOffsets. Value posVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. + // Get the index of the local shard in the grid axis. Value idx = multiIdx[axes[0]]; auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); if (shardedDimsOffs) { // If sharded dims offsets are provided, use them to compute the // sharded shape. @@ -556,13 +553,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto grid = adaptor.getGrid(); + mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "No grid found for AllReduceOp"; + if (ShapedType::isDynamicShape(gridOp.getShape())) return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; + << "Dynamic grid shape not supported in AllReduceOp"; ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); Value input = adaptor.getInput(); @@ -592,27 +589,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh + // The color is the linear index of the process in the grid along the + // non-reduced axes. The key is the linear index of the process in the grid // along the reduced axes. - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), iBuilder.getIndexType()); SmallVector<Value> myMultiIndex = - ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) .getResult(); Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector<Value> multiKey(myMultiIndex.size(), zero); - auto redAxes = adaptor.getMeshAxes(); + auto redAxes = adaptor.getGridAxes(); for (auto axis : redAxes) { multiKey[axis] = myMultiIndex[axis]; myMultiIndex[axis] = zero; } Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); + createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); + Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator @@ -698,15 +695,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } auto rank = cast<ShapedType>(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); + auto grid = adaptor.getGrid(); + auto gridOp = getGrid(op, symbolTableCollection); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); + sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + value) + .getResult(); } // most of the offset/size/stride data is the same for all dims @@ -745,10 +741,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -758,9 +754,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, - splitAxes) + auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid, + myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... Value neighbourIDs[2] = { @@ -791,7 +786,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor + // Processes on the grid borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto hasFrom = arith::CmpIOp::create( @@ -869,8 +864,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } }; -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { +struct ConvertShardToMPIPass + : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> { using Base::Base; /// Run the dialect converter on the module. @@ -879,12 +874,12 @@ struct ConvertMeshToMPIPass RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); - // Define a type converter to convert mesh::ShardingType, + // Define a type converter to convert shard::ShardingType, // mostly for use in return operations. TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - // convert mesh::ShardingType to a tuple of RankedTensorTypes + // convert shard::ShardingType to a tuple of RankedTensorTypes typeConverter.addConversion( [](ShardingType type, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -920,10 +915,10 @@ struct ConvertMeshToMPIPass return results; }); - // No mesh dialect should left after conversion... - target.addIllegalDialect<mesh::MeshDialect>(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); + // No shard dialect should left after conversion... + target.addIllegalDialect<shard::ShardDialect>(); + // ...except the global GridOp. GridShapeOp which will get folded later. + target.addLegalOp<shard::GridOp, shard::GridShapeOp>(); // Allow all the stuff that our patterns will convert to target.addLegalDialect< BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, @@ -951,7 +946,7 @@ struct ConvertMeshToMPIPass // Folding patterns cannot be mixed with conversion patterns -> extra pass. patterns.clear(); SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); + mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ec55091..0e3de06 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -570,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // to UIToFP. if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { auto unrealizedCast = - rewriter - .create<UnrealizedConversionCastOp>( - loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), - args[0]) + UnrealizedConversionCastOp::create( + rewriter, loc, + rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); @@ -869,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op auto resultTensor = - opBuilder - .create<linalg::GenericOp>( - loc, outputTensor.getType(), operand, outputTensor, affineMaps, - getNParallelLoopsAttrs(rank), - [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { - // Emit 'linalg.yield' op - linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); - }) + linalg::GenericOp::create( + opBuilder, loc, outputTensor.getType(), operand, outputTensor, + affineMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + // Emit 'linalg.yield' op + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); + }) .getResult(0); // Cast to original operand type if necessary @@ -1156,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), - dynDims) - .getResult(); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -1168,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, op, "No initial value found for reduction operation"); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; @@ -1187,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto trueAttr = rewriter.getBoolAttr(true); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), - dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{trueValue}, - ValueRange{emptyBoolTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we @@ -1262,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{nanValue}, - ValueRange{emptyNanTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1504,12 +1494,11 @@ public: Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>( - nestedLoc, - nestedBuilder.getIntegerType( - valueTy.getIntOrFloatBitWidth()), - value) + value = UnrealizedConversionCastOp::create( + nestedBuilder, nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { @@ -1558,9 +1547,8 @@ public: } if (outIntType.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>(nestedLoc, - outIntType, value) + value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc, + outIntType, value) .getResult(0); } linalg::YieldOp::create(nestedBuilder, loc, value); @@ -2096,10 +2084,9 @@ public: Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef<Value>({dynDims})) + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputTy.getShape(), + inputTy.getElementType(), ArrayRef<Value>({dynDims})) .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2242,23 +2229,22 @@ public: } // First fill the output buffer for the index. - auto emptyTensorIdx = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - outElementTy, dynDims) - .getResult(); + auto emptyTensorIdx = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); auto fillValueIdx = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, - ValueRange{emptyTensorIdx}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx}, + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto emptyTensorMax = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - inElementTy, dynDims) - .getResult(); + auto emptyTensorMax = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy, + dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2269,9 +2255,8 @@ public: auto fillValueMax = arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, - ValueRange{emptyTensorMax}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax}, + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2372,9 +2357,8 @@ public: auto loc = op.getLoc(); auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, - dynamicDims) + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynamicDims) .getResult(); SmallVector<AffineMap, 2> affineMaps = { @@ -2449,10 +2433,10 @@ public: } } - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - resultElementTy, dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), @@ -2586,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); return filledTensor; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 3a20524..da1fb20 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef<AffineMap> indexingMaps) { ShapedType resultTy = cast<ShapedType>(conv.getType()); - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); - } - Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); - linalg::YieldOp::create(builder, loc, added); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + arith::ExtSIOp::create(builder, loc, resType, biasVal); + } + Value added = + arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); + }) .getResult(0); } @@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); // Build the broadcast-like operation as a linalg.generic. - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({source}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = - resultTy.getElementType().isFloat() - ? arith::ExtFOp::create(builder, loc, resType, biasVal) - .getResult() - : arith::ExtSIOp::create(builder, loc, resType, biasVal) - .getResult(); - } - linalg::YieldOp::create(builder, loc, biasVal); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({source}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + resultTy.getElementType().isFloat() + ? arith::ExtFOp::create(builder, loc, resType, biasVal) + .getResult() + : arith::ExtSIOp::create(builder, loc, resType, + biasVal) + .getResult(); + } + linalg::YieldOp::create(builder, loc, biasVal); + }) .getResult(0); } @@ -397,21 +398,19 @@ public: auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); - Value conv = - rewriter - .create<LinalgConvQOp>( - loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) - ->getResult(0); + Value conv = LinalgConvQOp::create( + rewriter, loc, resultTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) + ->getResult(0); rewriter.replaceOp(op, conv); return success(); } - Value conv = rewriter - .create<LinalgConvOp>( - loc, accTy, ValueRange{input, weight}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) + Value conv = LinalgConvOp::create( + rewriter, loc, accTy, ValueRange{input, weight}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -529,9 +528,8 @@ public: Value emptyTensor = tensor::EmptyOp::create( rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); Value biasEmptyTensor = tensor::EmptyOp::create( @@ -544,10 +542,9 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); if (hasNullZps) { - Value conv = rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmOp>( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) + Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create( + rewriter, loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -565,22 +562,20 @@ public: rewriter, loc, resultTy, conv, reassociationMap); Value result = - rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added; - if (llvm::isa<FloatType>(inputETy)) - added = arith::AddFOp::create(nestedBuilder, loc, args[0], - args[1]); - else - added = arith::AddIOp::create(nestedBuilder, loc, args[0], - args[1]); - linalg::YieldOp::create(nestedBuilder, nestedLoc, added); - }) + linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, convReshape}), + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added; + if (llvm::isa<FloatType>(inputETy)) + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); + else + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); + }) .getResult(0); rewriter.replaceOp(op, result); } else { @@ -588,12 +583,11 @@ public: IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); - Value conv = - rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( - loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) - .getResult(0); + Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create( + rewriter, loc, linalgConvTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = tensor::CollapseShapeOp::create( @@ -639,9 +633,8 @@ public: auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); @@ -910,20 +903,18 @@ public: rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{poolEmptyTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{initialValue}, + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. - Value poolingOp = rewriter - .create<linalg::PoolingNhwcSumOp>( - loc, ArrayRef<Type>{accTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr) + Value poolingOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, ArrayRef<Type>{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each @@ -1050,10 +1041,9 @@ public: Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = - rewriter - .create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), poolVal, multiplier, shift, - rewriter.getStringAttr("SINGLE_ROUND")) + tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, + shift, rewriter.getStringAttr("SINGLE_ROUND")) .getResult(); // If we have quantization information we need to apply output diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index b83f5ec9..f8efb34 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 77aab85..a425eff 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -482,14 +482,12 @@ struct CombineTransferReadOpTranspose final permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); - Value result = - rewriter - .create<vector::TransferReadOp>( - loc, resultType, transferReadOp.getBase(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()) - .getResult(); + Value result = vector::TransferReadOp::create( + rewriter, loc, resultType, transferReadOp.getBase(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); // Fuse through the integer extend op. if (extOp) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9cd491c..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -29,7 +29,9 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APFloat.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" + #include <optional> using namespace mlir; @@ -1068,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1204,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 4c1047a..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -691,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1644,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 2411af0..4dfcb2b 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" |