diff options
Diffstat (limited to 'mlir/lib/Conversion')
5 files changed, 120 insertions, 81 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index ba57155..03ed4d5 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -240,8 +240,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(arith::SelectOp op, Adaptor adaptor, diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 798d8b0..b75968e 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(cf::BranchOp op, Adaptor adaptor, @@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, @@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, + matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get or convert default block. FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index a2dfc12..a922338 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -68,7 +68,7 @@ struct ClampFOpConversion final return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { - typename math::ClampFOp::Adaptor adaptor(operands); + math::ClampFOp::Adaptor adaptor(operands); return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, adaptor.getValue(), adaptor.getMin(), adaptor.getMax()); diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 91c1aa5..1b4d1a4 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return success(); } -static xegpu::CreateNdDescOp -createNdDescriptor(PatternRewriter &rewriter, Location loc, - xegpu::TensorDescType descType, TypedValue<MemRefType> src, - Operation::operand_range offsets) { +static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, + Location loc, + xegpu::TensorDescType descType, + TypedValue<MemRefType> src) { MemRefType srcTy = src.getType(); auto [strides, offset] = srcTy.getStridesAndOffset(); xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { - ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, - getAsOpFoldResult(offsets)); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. - SmallVector<Value> sourceDims; - unsigned srcRank = srcTy.getRank(); - for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); - - SmallVector<int64_t> constOffsets; - SmallVector<Value> dynOffsets; - for (Value offset : offsets) { - std::optional<int64_t> staticVal = getConstantIntValue(offset); - if (!staticVal) - dynOffsets.push_back(offset); - constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); - } - - SmallVector<Value> dynShapes; - for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { - if (shape == ShapedType::kDynamic) - dynShapes.push_back(sourceDims[idx]); - } - - // Compute strides in reverse order. - SmallVector<Value> dynStrides; - Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); - // Last stride is guaranteed to be static and unit. - for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { - accStride = - arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); - if (strides[i] == ShapedType::kDynamic) - dynStrides.push_back(accStride); - } - std::reverse(dynStrides.begin(), dynStrides.end()); - - ndDesc = xegpu::CreateNdDescOp::create( - rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, - DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), - DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), - DenseI64ArrayAttr::get(rewriter.getContext(), strides)); + auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, + meta.getConstifiedMixedSizes(), + meta.getConstifiedMixedStrides()); } return ndDesc; @@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp, .getResult(); } +// Collapses shapes of a nD memref to the target rank while applying offsets for +// the collapsed dimensions. Returns the new memref value and the remaining +// offsets for the last targetRank dimensions. For example: +// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3], +// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3] +static std::pair<Value, SmallVector<OpFoldResult>> +convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc, + Value memref, + SmallVector<OpFoldResult> offsets, + int64_t targetRank) { + auto memrefType = cast<MemRefType>(memref.getType()); + unsigned rank = memrefType.getRank(); + + if (rank <= targetRank) + return {memref, offsets}; + + int64_t numCombinedDims = rank - targetRank; + SmallVector<OpFoldResult> subviewOffsets; + SmallVector<OpFoldResult> subviewSizes; + SmallVector<OpFoldResult> subviewStrides; + + // For the combined dimensions: use the provided offsets, size=1, stride=1 + for (unsigned i = 0; i < numCombinedDims; ++i) { + subviewOffsets.push_back(offsets[i]); + subviewSizes.push_back(rewriter.getI64IntegerAttr(1)); + subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + // For the last targetRank dimensions: offset=0, use full size, stride=1 + SmallVector<int64_t> resultShape; + auto originalShape = memrefType.getShape(); + auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + for (unsigned i = numCombinedDims; i < rank; ++i) { + subviewOffsets.push_back(rewriter.getI64IntegerAttr(0)); + if (ShapedType::isDynamic(originalShape[i])) { + subviewSizes.push_back(meta.getSizes()[i]); + resultShape.push_back(ShapedType::kDynamic); + } else { + subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i])); + resultShape.push_back(originalShape[i]); + } + subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + auto resultType = memref::SubViewOp::inferRankReducedResultType( + resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides); + auto subviewOp = + memref::SubViewOp::create(rewriter, loc, resultType, memref, + subviewOffsets, subviewSizes, subviewStrides); + + // Return the remaining offsets for the last targetRank dimensions + SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims, + offsets.end()); + return {subviewOp.getResult(), newOffsets}; +} + template < typename OpType, typename = std::enable_if_t<llvm::is_one_of< @@ -435,7 +457,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.replaceOp(readOp, gatherOp.getResult()); return success(); @@ -469,7 +492,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.eraseOp(writeOp); return success(); } @@ -523,18 +547,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast<TypedValue<MemRefType>>(readOp.getBase()), - readOp.getIndices()); - DenseI64ArrayAttr transposeAttr = !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef<int64_t>{1, 0}); + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()), + vecTy.getRank()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + + auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); @@ -575,21 +600,23 @@ struct TransferWriteLowering if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, writeOp.getBase(), + getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank()); + auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()), - writeOp.getIndices()); - // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto storeOp = - xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + + auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), + ndDesc, indices, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -621,7 +648,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> { /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); auto selectOp = arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), @@ -655,7 +683,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> { /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.eraseOp(scatterOp); return success(); } @@ -674,19 +703,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + // By default, no specific caching policy is assigned. + xegpu::CachePolicyAttr hint = nullptr; + + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()), + vecTy.getRank()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); - // By default, no specific caching policy is assigned. - xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = xegpu::LoadNdOp::create( - rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + auto loadNdOp = + xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, + /*packed=*/nullptr, /*transpose=*/nullptr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -708,18 +742,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, storeOp.getBase(), + getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank()); + auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + auto storeNdOp = - xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); + rewriter.replaceOp(storeOp, storeNdOp); return success(); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2e..de552ce 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); if (!valOrResVecTy) valOrResVecTy = VectorType::get(1, data.getType()); + if (valOrResVecTy.getShape().size() != 1) + return rewriter.notifyMatchFailure(op, "Expected 1D data vector."); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); |
