diff options
Diffstat (limited to 'mlir/lib/Dialect')
22 files changed, 785 insertions, 142 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 2a8c330..f0de4db 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() { return success(); } +LogicalResult ConvertF8x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f8x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF8x2ToBF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<Float8E8M0FNUType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E8M0FNUType::get(ctx) + << " type is supported for conversions from f8x2 to bf16x2."; + + return success(); +} + +LogicalResult ConvertF6x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f6x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF4x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float4E2M1FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f4x2 to f16x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + }) + .Case<Float8E5M2Type>([&](Float8E5M2Type type) { + return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; + }) + .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *extendedI16 = + builder.CreateZExt(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {extendedI16}}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index d4ff095..37a45d4 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms MLIRPass MLIRTransforms MLIRNVVMDialect + MLIROpenMPDialect ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9a8a63e..794dda9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, for (auto [pos, dim] : llvm::enumerate(type.getShape())) { if (!ShapedType::isDynamic(dim)) continue; - Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); - auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + Value cst = + arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos); + auto dimOp = + tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst); preservedOps.insert(dimOp); dynamicDims.push_back(dimOp); } - auto allocation = rewriter.create<bufferization::AllocTensorOp>( - tensor.getLoc(), type, dynamicDims); + auto allocation = bufferization::AllocTensorOp::create( + rewriter, tensor.getLoc(), type, dynamicDims); // Set memory space if provided. if (getMemorySpaceAttr()) allocation.setMemorySpaceAttr(getMemorySpaceAttr()); @@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, // Only insert a materialization (typically bufferizes to a copy) when the // value may be read from. if (needsMaterialization) { - auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( - tensor.getLoc(), tensor, allocated); + auto copy = bufferization::MaterializeInDestinationOp::create( + rewriter, tensor.getLoc(), tensor, allocated); preservedOps.insert(copy); promoted.push_back(copy.getResult()); } else { diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index 15eb51a..5e10ba3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -43,6 +44,33 @@ struct StructuredOpInterface auto zero = arith::ConstantIndexOp::create(builder, loc, 0); auto one = arith::ConstantIndexOp::create(builder, loc, 1); + Value iterationDomainIsNonDegenerate; + for (auto [start, end] : llvm::zip(starts, ends)) { + auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start); + auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); + + // Loop Trip count > 0 iff start < end + Value dimensionHasNonZeroTripCount = index::CmpOp::create( + builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue); + + if (!iterationDomainIsNonDegenerate) { + iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount; + } else { + // Iteration domain is non-degenerate iff all dimensions have loop trip + // count > 0 + iterationDomainIsNonDegenerate = + arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate, + dimensionHasNonZeroTripCount); + } + } + + if (!iterationDomainIsNonDegenerate) + return; + + auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); @@ -110,6 +138,7 @@ struct StructuredOpInterface builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg); } } + builder.setInsertionPointAfter(ifOp); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94947b7..c551fba 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, atLeastOneReplacement |= replaceConstantUsesOf( builder, getLoc(), getStrides(), getConstifiedMixedStrides()); + // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x). + if (auto prev = getSource().getDefiningOp<CastOp>()) + if (isa<MemRefType>(prev.getSource().getType())) { + getSourceMutable().assign(prev.getSource()); + atLeastOneReplacement = true; + } + return success(atLeastOneReplacement); } @@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getSource()); + return getSource(); } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getDest()); + return getDest(); } bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index d35566a..bd02516 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder } }; -/// Replace `base, offset, sizes, strides = -/// extract_strided_metadata( -/// cast(src) to dstTy)` -/// With -/// ``` -/// base, ... = extract_strided_metadata(src) -/// offset = !dstTy.srcOffset.isDynamic() -/// ? dstTy.srcOffset -/// : extract_strided_metadata(src).offset -/// sizes = for each srcSize in dstTy.srcSizes: -/// !srcSize.isDynamic() -/// ? srcSize -// : extract_strided_metadata(src).sizes[i] -/// strides = for each srcStride in dstTy.srcStrides: -/// !srcStrides.isDynamic() -/// ? srcStrides -/// : extract_strided_metadata(src).strides[i] -/// ``` -/// -/// In other words, consume the `cast` and apply its effects -/// on the offset, sizes, and strides or compute them directly from `src`. -class ExtractStridedMetadataOpCastFolder - : public OpRewritePattern<memref::ExtractStridedMetadataOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult - matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, - PatternRewriter &rewriter) const override { - Value source = extractStridedMetadataOp.getSource(); - auto castOp = source.getDefiningOp<memref::CastOp>(); - if (!castOp) - return failure(); - - Location loc = extractStridedMetadataOp.getLoc(); - // Check if the source is suitable for extract_strided_metadata. - SmallVector<Type> inferredReturnTypes; - if (failed(extractStridedMetadataOp.inferReturnTypes( - rewriter.getContext(), loc, {castOp.getSource()}, - /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, - inferredReturnTypes))) - return rewriter.notifyMatchFailure(castOp, - "cast source's type is incompatible"); - - auto memrefType = cast<MemRefType>(source.getType()); - unsigned rank = memrefType.getRank(); - SmallVector<OpFoldResult> results; - results.resize_for_overwrite(rank * 2 + 2); - - auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( - rewriter, loc, castOp.getSource()); - - // Register the base_buffer. - results[0] = newExtractStridedMetadata.getBaseBuffer(); - - auto getConstantOrValue = [&rewriter](int64_t constant, - OpFoldResult ofr) -> OpFoldResult { - return ShapedType::isStatic(constant) - ? OpFoldResult(rewriter.getIndexAttr(constant)) - : ofr; - }; - - auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); - assert(sourceStrides.size() == rank && "unexpected number of strides"); - - // Register the new offset. - results[1] = - getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); - - const unsigned sizeStartIdx = 2; - const unsigned strideStartIdx = sizeStartIdx + rank; - ArrayRef<int64_t> sourceSizes = memrefType.getShape(); - - SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); - SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); - for (unsigned i = 0; i < rank; ++i) { - results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); - results[strideStartIdx + i] = - getConstantOrValue(sourceStrides[i], strides[i]); - } - rewriter.replaceOp(extractStridedMetadataOp, - getValueOrCreateConstantIndexOp(rewriter, loc, results)); - return success(); - } -}; - /// Replace `base, offset, sizes, strides = extract_strided_metadata( /// memory_space_cast(src) to dstTy)` /// with @@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns( RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpSubviewFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( @@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt new file mode 100644 index 0000000..f305068 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACCAnalysis + OpenACCSupport.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACCDialect + MLIROpenACCUtils + MLIRSupport +) + diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp new file mode 100644 index 0000000..f6b4534 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -0,0 +1,26 @@ +//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the OpenACCSupport analysis interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +namespace mlir { +namespace acc { + +std::string OpenACCSupport::getVariableName(Value v) { + if (impl) + return impl->getVariableName(v); + return acc::getVariableName(v); +} + +} // namespace acc +} // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 7117520..e8a916e 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 5ca0100..ca46629 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() { } //===----------------------------------------------------------------------===// +// FirstprivateMapInitialOp +//===----------------------------------------------------------------------===// +LogicalResult acc::FirstprivateMapInitialOp::verify() { + if (getDataClause() != acc::DataClause::acc_firstprivate) + return emitError("data clause associated with firstprivate operation must " + "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkNoModifier(*this))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// LogicalResult acc::ReductionOp::verify() { diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 1223325..89adda82 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) { pointerLikeTy.getElementType()); return typeCategory; } + +std::string mlir::acc::getVariableName(mlir::Value v) { + Value current = v; + + // Walk through view operations until a name is found or can't go further + while (Operation *definingOp = current.getDefiningOp()) { + // Check for `acc.var_name` attribute + if (auto varNameAttr = + definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName())) + return varNameAttr.getName().str(); + + // If it is a data entry operation, get name via getVarName + if (isa<ACC_DATA_ENTRY_OPS>(definingOp)) + if (auto name = acc::getVarName(definingOp)) + return name->str(); + + // If it's a view operation, continue to the source + if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) { + current = viewOp.getViewSource(); + continue; + } + + break; + } + + return ""; +} diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt index 57a6d34..f3c02da 100644 --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Transforms) + add_mlir_dialect_library(MLIROpenMPDialect IR/OpenMPDialect.cpp diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt new file mode 100644 index 0000000..b9b8eda --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIROpenMPTransforms + OpenMPOffloadPrivatizationPrepare.cpp + + DEPENDS + MLIROpenMPPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRFuncDialect + MLIRLLVMDialect + MLIROpenMPDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp new file mode 100644 index 0000000..a9125ec --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp @@ -0,0 +1,445 @@ +//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include <cstdint> +#include <iterator> +#include <utility> + +//===----------------------------------------------------------------------===// +// A pass that prepares OpenMP code for translation of delayed privatization +// in the context of deferred target tasks. Deferred target tasks are created +// when the nowait clause is used on the target directive. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "omp-prepare-for-offload-privatization" + +namespace mlir { +namespace omp { + +#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS +#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc" + +} // namespace omp +} // namespace mlir + +using namespace mlir; +namespace { + +//===----------------------------------------------------------------------===// +// PrepareForOMPOffloadPrivatizationPass +//===----------------------------------------------------------------------===// + +class PrepareForOMPOffloadPrivatizationPass + : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase< + PrepareForOMPOffloadPrivatizationPass> { + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // In this pass, we make host-allocated privatized variables persist for + // deferred target tasks by copying them to the heap. Once the target task + // is done, this heap memory is freed. Since all of this happens on the host + // we can skip device modules. + auto offloadModuleInterface = + dyn_cast<omp::OffloadModuleInterface>(mod.getOperation()); + if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice()) + return; + + getOperation()->walk([&](omp::TargetOp targetOp) { + if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp)) + return; + IRRewriter rewriter(&getContext()); + OperandRange privateVars = targetOp.getPrivateVars(); + SmallVector<mlir::Value> newPrivVars; + Value fakeDependVar; + omp::TaskOp cleanupTaskOp; + + newPrivVars.reserve(privateVars.size()); + std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms(); + for (auto [privVarIdx, privVarSymPair] : + llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) { + Value privVar = std::get<0>(privVarSymPair); + Attribute privSym = std::get<1>(privVarSymPair); + + omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym); + if (!privatizer.needsMap()) { + newPrivVars.push_back(privVar); + continue; + } + bool isFirstPrivate = privatizer.getDataSharingType() == + omp::DataSharingClauseType::FirstPrivate; + + Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx); + auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp()); + + if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) { + newPrivVars.push_back(privVar); + continue; + } + + // For deferred target tasks (!$omp target nowait), we need to keep + // a copy of the original, i.e. host variable being privatized so + // that it is available when the target task is eventually executed. + // We do this by first allocating as much heap memory as is needed by + // the original variable. Then, we use the init and copy regions of the + // privatizer, an instance of omp::PrivateClauseOp to set up the heap- + // allocated copy. + // After the target task is done, we need to use the dealloc region + // of the privatizer to clean up everything. We also need to free + // the heap memory we allocated. But due to the deferred nature + // of the target task, we cannot simply deallocate right after the + // omp.target operation else we may end up freeing memory before + // its eventual use by the target task. So, we create a dummy + // dependence between the target task and new omp.task. In the omp.task, + // we do all the cleanup. So, we end up with the following structure + // + // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) { + // ... + // omp.terminator + // } + // omp.task depend(in: fakeDependVar) { + // /*cleanup_code*/ + // omp.terminator + // } + // fakeDependVar is the address of the first heap-allocated copy of the + // host variable being privatized. + + bool needsCleanupTask = !privatizer.getDeallocRegion().empty(); + + // Allocate heap memory that corresponds to the type of memory + // pointed to by varPtr + // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols + // should have mapped the pointer to the boxchar so use that as varPtr. + Value varPtr = mapInfoOp.getVarPtr(); + Type varType = mapInfoOp.getVarType(); + bool isPrivatizedByValue = + !isa<LLVM::LLVMPointerType>(privVar.getType()); + + assert(isa<LLVM::LLVMPointerType>(varPtr.getType())); + Value heapMem = + allocateHeapMem(targetOp, varPtr, varType, mod, rewriter); + if (!heapMem) + targetOp.emitError( + "Unable to allocate heap memory when trying to move " + "a private variable out of the stack and into the " + "heap for use by a deferred target task"); + + if (needsCleanupTask && !fakeDependVar) + fakeDependVar = heapMem; + + // The types of private vars should match before and after the + // transformation. In particular, if the type is a pointer, + // simply record the newly allocated malloc location as the + // new private variable. If, however, the type is not a pointer + // then, we need to load the value from the newly allocated + // location. We'll insert that load later after we have updated + // the malloc'd location with the contents of the original + // variable. + if (!isPrivatizedByValue) + newPrivVars.push_back(heapMem); + + // We now need to copy the original private variable into the newly + // allocated location in the heap. + // Find the earliest insertion point for the copy. This will be before + // the first in the list of omp::MapInfoOp instances that use varPtr. + // After the copy these omp::MapInfoOp instances will refer to heapMem + // instead. + Operation *varPtrDefiningOp = varPtr.getDefiningOp(); + DenseSet<Operation *> users; + if (varPtrDefiningOp) { + users.insert(varPtrDefiningOp->user_begin(), + varPtrDefiningOp->user_end()); + } else { + auto blockArg = cast<BlockArgument>(varPtr); + users.insert(blockArg.user_begin(), blockArg.user_end()); + } + auto usesVarPtr = [&users](Operation *op) -> bool { + return users.count(op); + }; + + SmallVector<Operation *> chainOfOps; + chainOfOps.push_back(mapInfoOp); + for (auto member : mapInfoOp.getMembers()) { + omp::MapInfoOp memberMap = + cast<omp::MapInfoOp>(member.getDefiningOp()); + if (usesVarPtr(memberMap)) + chainOfOps.push_back(memberMap); + if (memberMap.getVarPtrPtr()) { + Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp(); + if (defOp && usesVarPtr(defOp)) + chainOfOps.push_back(defOp); + } + } + + DominanceInfo dom; + llvm::sort(chainOfOps, [&](Operation *l, Operation *r) { + return dom.dominates(l, r); + }); + + rewriter.setInsertionPoint(chainOfOps.front()); + + Operation *firstOp = chainOfOps.front(); + Location loc = firstOp->getLoc(); + + // Create a llvm.func for 'region' that is marked always_inline and call + // it. + auto createAlwaysInlineFuncAndCallIt = + [&](Region ®ion, llvm::StringRef funcName, + llvm::ArrayRef<Value> args, bool returnsValue) -> Value { + assert(!region.empty() && "region cannot be empty"); + LLVM::LLVMFuncOp func = createFuncOpForRegion( + loc, mod, region, funcName, rewriter, returnsValue); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); + return call.getResult(); + }; + + Value moldArg, newArg; + if (isPrivatizedByValue) { + moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr); + newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem); + } else { + moldArg = varPtr; + newArg = heapMem; + } + + Value initializedVal; + if (!privatizer.getInitRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getInitRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(), + {moldArg, newArg}, /*returnsValue=*/true); + else + initializedVal = newArg; + + if (isFirstPrivate && !privatizer.getCopyRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getCopyRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(), + {moldArg, initializedVal}, /*returnsValue=*/true); + + if (isPrivatizedByValue) + (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem); + + // clone origOp, replace all uses of varPtr with heapMem and + // erase origOp. + auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * { + Operation *clonedOp = rewriter.clone(*origOp); + rewriter.replaceAllOpUsesWith(origOp, clonedOp); + rewriter.modifyOpInPlace(clonedOp, [&]() { + clonedOp->replaceUsesOfWith(varPtr, heapMem); + }); + rewriter.eraseOp(origOp); + return clonedOp; + }; + + // Now that we have set up the heap-allocated copy of the private + // variable, rewrite all the uses of the original variable with + // the heap-allocated variable. + rewriter.setInsertionPoint(targetOp); + mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp)); + rewriter.setInsertionPoint(mapInfoOp); + + // Fix any members that may use varPtr to now use heapMem + for (auto member : mapInfoOp.getMembers()) { + auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp()); + if (!usesVarPtr(memberMapInfoOp)) + continue; + memberMapInfoOp = + cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp)); + rewriter.setInsertionPoint(memberMapInfoOp); + + if (memberMapInfoOp.getVarPtrPtr()) { + Operation *varPtrPtrdefOp = + memberMapInfoOp.getVarPtrPtr().getDefiningOp(); + rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp)); + } + } + + // If the type of the private variable is not a pointer, + // which is typically the case with !fir.boxchar types, then + // we need to ensure that the new private variable is also + // not a pointer. Insert a load from heapMem right before + // targetOp. + if (isPrivatizedByValue) { + rewriter.setInsertionPoint(targetOp); + auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(), + varType, heapMem); + newPrivVars.push_back(newPrivVar); + } + + // Deallocate + if (needsCleanupTask) { + if (!cleanupTaskOp) { + assert(fakeDependVar && + "Need a valid value to set up a dependency"); + rewriter.setInsertionPointAfter(targetOp); + omp::TaskOperands taskOperands; + auto inDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependin); + taskOperands.dependKinds.push_back(inDepend); + taskOperands.dependVars.push_back(fakeDependVar); + cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands); + Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion()); + rewriter.setInsertionPointToEnd(taskBlock); + omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc()); + } + rewriter.setInsertionPointToStart( + &*cleanupTaskOp.getRegion().getBlocks().begin()); + (void)createAlwaysInlineFuncAndCallIt( + privatizer.getDeallocRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc") + .str(), + {initializedVal}, /*returnsValue=*/false); + llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc = + LLVM::lookupOrCreateFreeFn(rewriter, mod); + assert(llvm::succeeded(freeFunc) && + "Could not find free in the module"); + (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(), + ValueRange{heapMem}); + } + } + assert(newPrivVars.size() == privateVars.size() && + "The number of private variables must match before and after " + "transformation"); + if (fakeDependVar) { + omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependout); + SmallVector<Attribute> newDependKinds; + if (!targetOp.getDependVars().empty()) { + std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds(); + assert(dependKinds && "bad depend clause in omp::TargetOp"); + llvm::copy(*dependKinds, std::back_inserter(newDependKinds)); + } + newDependKinds.push_back(outDepend); + ArrayAttr newDependKindsAttr = + ArrayAttr::get(rewriter.getContext(), newDependKinds); + targetOp.getDependVarsMutable().append(fakeDependVar); + targetOp.setDependKindsAttr(newDependKindsAttr); + } + rewriter.setInsertionPoint(targetOp); + targetOp.getPrivateVarsMutable().clear(); + targetOp.getPrivateVarsMutable().assign(newPrivVars); + }); + } + +private: + bool hasPrivateVars(omp::TargetOp targetOp) const { + return !targetOp.getPrivateVars().empty(); + } + + bool isTargetTaskDeferred(omp::TargetOp targetOp) const { + return targetOp.getNowait(); + } + + template <typename OpTy> + omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const { + SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym); + omp::PrivateClauseOp privatizer = + SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>( + op, privatizerName); + return privatizer; + } + + // Get the (compile-time constant) size of varType as per the + // given DataLayout dl. + std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const { + llvm::TypeSize size = dl.getTypeSize(varType); + unsigned short alignment = dl.getTypeABIAlignment(varType); + return llvm::alignTo(size, alignment); + } + + LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const { + llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall = + LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type()); + assert(llvm::succeeded(mallocCall) && + "Could not find malloc in the module"); + return mallocCall.value(); + } + + Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType, + ModuleOp mod, IRRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + Value varPtr = privVar; + Operation *definingOp = varPtr.getDefiningOp(); + BlockArgument blockArg; + if (!definingOp) { + blockArg = mlir::dyn_cast<BlockArgument>(varPtr); + rewriter.setInsertionPointToStart(blockArg.getParentBlock()); + } else { + rewriter.setInsertionPoint(definingOp); + } + Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc(); + LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter); + + assert(mod.getDataLayoutSpec() && + "MLIR module with no datalayout spec not handled yet"); + + const DataLayout &dl = DataLayout(mod); + std::int64_t distance = getSizeInBytes(dl, varType); + + Value sizeBytes = LLVM::ConstantOp::create( + rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance); + + auto mallocCallOp = + LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes}); + return mallocCallOp.getResult(); + } + + // Create a function for srcRegion and attribute it to be always_inline. + // The big assumption here is that srcRegion is one of init, copy or dealloc + // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed + // to either be the same as the types of the two arguments of the region (for + // init and copy regions) or void as would be the case for dealloc regions. + LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod, + Region &srcRegion, + llvm::StringRef funcName, + IRRewriter &rewriter, + bool returnsValue = false) { + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); + Region clonedRegion; + IRMapping mapper; + srcRegion.cloneInto(&clonedRegion, mapper); + + SmallVector<Type> paramTypes; + llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes)); + Type resultType = returnsValue + ? srcRegion.getArgument(0).getType() + : LLVM::LLVMVoidType::get(rewriter.getContext()); + LLVM::LLVMFunctionType funcType = + LLVM::LLVMFunctionType::get(resultType, paramTypes); + + LLVM::LLVMFuncOp func = + LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType); + func.setAlwaysInline(true); + rewriter.inlineRegionBefore(clonedRegion, func.getRegion(), + func.getRegion().end()); + for (auto &block : func.getRegion().getBlocks()) { + if (isa<omp::YieldOp>(block.getTerminator())) { + omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(), + yieldOp.getOperands()); + } + } + return func; + } +}; +} // namespace diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a9da6c2..9bd13f3 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2490,8 +2490,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); @@ -2500,8 +2500,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantFalse) - constantFalse = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + constantFalse = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dc61a2..335ca1a 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue<ShapedType> sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( + TypedValue<ShapedType> targetShard = AllSliceOp::create(builder, sourceShard, grid, ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) - .getResult()); + .getResult(); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; @@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding( APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allGatherResult) - .getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); return {targetShard, targetSharding}; } @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult(); return {targetShard, targetSharding}; } @@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source, auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, - cast<TypedValue<ShapedType>>(source.getSrc()), - sourceShardValue); + source.getSrc(), sourceShardValue); } TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index caf8016..99b7cda 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { !outputTy.hasStaticShape()) return {}; - if (inputTy.getDimSize(getAxis()) == 1) - return DenseElementsAttr::get(outputTy, 0); + const Type outputElementTy = getElementTypeOrSelf(outputTy); + if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) { + const auto outputElemIntTy = cast<IntegerType>(outputElementTy); + const APInt zero = APInt::getZero(outputElemIntTy.getWidth()); + return DenseElementsAttr::get(outputTy, zero); + } return {}; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 8f46ad6..ef49c86 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final for (int64_t i = 0; i < origNumElems; ++i) promoteMask[i] = i; - Value promotedInput = rewriter.create<vector::ShuffleOp>( - shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote, - promoteMask); + Value promotedInput = + vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType, + inputToPromote, inputToPromote, promoteMask); // Create the final shuffle with the promoted inputs. Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7c019e7..8b5e950 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -341,13 +341,18 @@ private: /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the -/// distributed dimensions. The number of results should be equal to the number +/// distributed dimensions. If the number of results is zero there is no +/// distribution (i.e. original type is returned). +/// Otherwise, The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { + // If the map has zero results, return the original type. + if (map.getNumResults() == 0) + return originalType; SmallVector<int64_t> targetShape(originalType.getShape()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 1599ae9..24e9095 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder) { auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); - return builder.create<ArithOp>(loc, aVal, bVal).getResult(); + return ArithOp::create(builder, loc, aVal, bVal).getResult(); } // a helper utility to perform division operation on OpFoldResult and int64_t. diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 26770b3..d09dc19 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() { return AffineMap::get(val.getContext()); // Get the layout of the vector type. xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val); - // If no layout is specified, assume the inner most dimension is distributed - // for now. + // If no layout is specified, that means no distribution. if (!layout) - return AffineMap::getMultiDimMapWithTargets( - vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); + return AffineMap::getMultiDimMapWithTargets(vecRank, {}, + val.getContext()); + // Expecting vector and layout rank to match. + assert(layout.getRank() == vecRank && + "Expecting vector and layout rank to match"); + // A dimension is distributed only if layout suggests there are + // multiple lanes assigned for this dimension and the shape can be evenly + // distributed to those lanes. SmallVector<unsigned int> distributedDims; for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) { - if (v > 1) + if (v > 1 && vecType.getShape()[i] % v == 0) distributedDims.push_back(i); } return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, @@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { auto warpReduction = [](Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size) { // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { - Value shuffled = - builder - .create<gpu::ShuffleOp>(loc, laneVal, i, - /*width=*/size, - /*mode=*/gpu::ShuffleMode::XOR) - .getShuffleResult(); + Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 31a967d..9fc5ad9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), baseTileValues); - auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); + auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr); // Get subgroup id Value sgId = @@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { SmallVector<Value, 2> strideConsts; strideConsts.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, colStride)); + arith::ConstantIndexOp::create(rewriter, loc, colStride)); if (rows > 1) strideConsts.insert( strideConsts.begin(), - rewriter.create<arith::ConstantIndexOp>(loc, rowStride)); + arith::ConstantIndexOp::create(rewriter, loc, rowStride)); SmallVector<Value> newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec - Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); for (size_t i = 0; i < strideConsts.size(); ++i) { - Value mul = rewriter.create<arith::MulIOp>( - loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); - mulOffset = rewriter.create<arith::AddIOp>( - loc, rewriter.getIndexType(), mulOffset, mul); + Value mul = + arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(), + offsets[i], strideConsts[i]); + mulOffset = arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), mulOffset, mul); } // Broadcast to baseConstVec size - auto bcastOffset = rewriter.create<vector::BroadcastOp>( - loc, baseConstVec.getType(), mulOffset); + auto bcastOffset = vector::BroadcastOp::create( + rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); setLayoutIfNeeded(baseConstVec); @@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp SmallVector<Value> newShapeCastOps; for (auto src : adaptor.getSource()) { - auto newShapeCast = - rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src); + auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), + newResultType, src); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), @@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp SmallVector<Value> newReductions; for (auto sgSrc : adaptor.getSource()) { - auto newOp = rewriter.create<vector::MultiDimReductionOp>( - op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], - op.getReductionDims()); + auto newOp = vector::MultiDimReductionOp::create( + rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, + adaptor.getAcc()[0], op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newOp->getResult(0), |