diff options
Diffstat (limited to 'mlir/lib/Dialect')
35 files changed, 1586 insertions, 361 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db..585b6da 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() { //===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { - Type sourceAType = getSourceA().getType(); - Type sourceBType = getSourceB().getType(); - Type destType = getDestC().getType(); - VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); - VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType); - VectorType destVectorType = dyn_cast<VectorType>(destType); +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + SmallVector<int64_t, 3> dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; - Type sourceAElemType = sourceVectorAType.getElementType(); - Type sourceBElemType = sourceVectorBType.getElementType(); - Type destElemType = destVectorType.getElementType(); + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} - if (sourceVectorAType.getNumElements() != - sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { + auto sourceAType = cast<VectorType>(getSourceA().getType()); + auto sourceBType = cast<VectorType>(getSourceB().getType()); + auto destType = cast<VectorType>(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { return emitOpError("source vectors have different lengths: ") - << sourceVectorAType << " vs. " << sourceVectorBType; + << sourceAType << " vs. " << sourceBType; } - bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); - bool isSrcFloat = - isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( - sourceAElemType); - - if (isDestFloat && !isSrcFloat) { - return emitOpError("Expected float sources with float destination"); - } + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); - if (!isDestFloat && isSrcFloat) { - return emitOpError("Expected int sources with int destination"); - } + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); - if (sourceAElemType != sourceBElemType && - !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) && - isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) { + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( "source element types much match (except for fp8) but have ") << sourceAType << " and " << sourceBType; } + + if (!sourceAElemType.isInteger(4) && getK() != 16) { + return emitOpError("K dimension must be 16 for source element type ") + << sourceAElemType; + } return success(); } @@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() { Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { + if (auto sourceVector = dyn_cast<VectorType>(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { + if (auto destVector = dyn_cast<VectorType>(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() { return emitOpError("expected both non-small-float source operand types " "to match exactly"); } - // Normalize the wider integer types the compiler expects to i8 + // Normalize the wider integer types the compiler expects to i8. if (sourceElem.isInteger(32)) { sourceLen *= 4; sourceElem = b.getI8Type(); diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 316721b..60ae78b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { const std::function<unsigned(AffineForOp)> getUnrollFactor; LoopUnroll() : getUnrollFactor(nullptr) {} - LoopUnroll(const LoopUnroll &other) - - = default; + LoopUnroll(const LoopUnroll &other) = default; explicit LoopUnroll( std::optional<unsigned> unrollFactor = std::nullopt, - bool unrollUpToFactor = false, bool unrollFull = false, + bool unrollUpToFactor = false, const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) : getUnrollFactor(getUnrollFactor) { if (unrollFactor) this->unrollFactor = *unrollFactor; this->unrollUpToFactor = unrollUpToFactor; - this->unrollFull = unrollFull; } void runOnOperation() override; @@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f, } void LoopUnroll::runOnOperation() { + if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid option: 'unroll-factor' should be greater than 0 or " + "equal to -1"); + return signalPassFailure(); + } FunctionOpInterface func = getOperation(); if (func.isExternal()) return; - if (unrollFull && unrollFullThreshold.hasValue()) { + if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) { // Store short loops as we walk. SmallVector<AffineForOp, 4> loops; @@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, getUnrollFactor(forOp), /*annotateFn=*/nullptr, cleanUpUnroll); // Unroll completely if full loop unroll was specified. - if (unrollFull) + if (unrollFactor.getValue() == -1) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. if (unrollUpToFactor) @@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::affine::createLoopUnrollPass( - int unrollFactor, bool unrollUpToFactor, bool unrollFull, + int unrollFactor, bool unrollUpToFactor, const std::function<unsigned(AffineForOp)> &getUnrollFactor) { return std::make_unique<LoopUnroll>( unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), - unrollUpToFactor, unrollFull, getUnrollFactor); + unrollUpToFactor, getUnrollFactor); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index a6159ee..f0ddb50 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -14,13 +14,6 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" -namespace mlir { -namespace bufferization { -#define GEN_PASS_DEF_TENSORCOPYINSERTION -#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" -} // namespace bufferization -} // namespace mlir - using namespace mlir; using namespace mlir::bufferization; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 2a8c330..f0de4db 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() { return success(); } +LogicalResult ConvertF8x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f8x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF8x2ToBF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<Float8E8M0FNUType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E8M0FNUType::get(ctx) + << " type is supported for conversions from f8x2 to bf16x2."; + + return success(); +} + +LogicalResult ConvertF6x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f6x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF4x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float4E2M1FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f4x2 to f16x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + }) + .Case<Float8E5M2Type>([&](Float8E5M2Type type) { + return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; + }) + .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *extendedI16 = + builder.CreateZExt(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {extendedI16}}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index d4ff095..37a45d4 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms MLIRPass MLIRTransforms MLIRNVVMDialect + MLIROpenMPDialect ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9a8a63e..794dda9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, for (auto [pos, dim] : llvm::enumerate(type.getShape())) { if (!ShapedType::isDynamic(dim)) continue; - Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); - auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + Value cst = + arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos); + auto dimOp = + tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst); preservedOps.insert(dimOp); dynamicDims.push_back(dimOp); } - auto allocation = rewriter.create<bufferization::AllocTensorOp>( - tensor.getLoc(), type, dynamicDims); + auto allocation = bufferization::AllocTensorOp::create( + rewriter, tensor.getLoc(), type, dynamicDims); // Set memory space if provided. if (getMemorySpaceAttr()) allocation.setMemorySpaceAttr(getMemorySpaceAttr()); @@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, // Only insert a materialization (typically bufferizes to a copy) when the // value may be read from. if (needsMaterialization) { - auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( - tensor.getLoc(), tensor, allocated); + auto copy = bufferization::MaterializeInDestinationOp::create( + rewriter, tensor.getLoc(), tensor, allocated); preservedOps.insert(copy); promoted.push_back(copy.getResult()); } else { diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index 15eb51a..5e10ba3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -43,6 +44,33 @@ struct StructuredOpInterface auto zero = arith::ConstantIndexOp::create(builder, loc, 0); auto one = arith::ConstantIndexOp::create(builder, loc, 1); + Value iterationDomainIsNonDegenerate; + for (auto [start, end] : llvm::zip(starts, ends)) { + auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start); + auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); + + // Loop Trip count > 0 iff start < end + Value dimensionHasNonZeroTripCount = index::CmpOp::create( + builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue); + + if (!iterationDomainIsNonDegenerate) { + iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount; + } else { + // Iteration domain is non-degenerate iff all dimensions have loop trip + // count > 0 + iterationDomainIsNonDegenerate = + arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate, + dimensionHasNonZeroTripCount); + } + } + + if (!iterationDomainIsNonDegenerate) + return; + + auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); @@ -110,6 +138,7 @@ struct StructuredOpInterface builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg); } } + builder.setInsertionPointAfter(ifOp); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94947b7..c551fba 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, atLeastOneReplacement |= replaceConstantUsesOf( builder, getLoc(), getStrides(), getConstifiedMixedStrides()); + // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x). + if (auto prev = getSource().getDefiningOp<CastOp>()) + if (isa<MemRefType>(prev.getSource().getType())) { + getSourceMutable().assign(prev.getSource()); + atLeastOneReplacement = true; + } + return success(atLeastOneReplacement); } @@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getSource()); + return getSource(); } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getDest()); + return getDest(); } bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 11400de..a15bf89 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -59,6 +59,17 @@ struct DimOpInterface } }; +struct ExpandShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface, + memref::ExpandShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto expandOp = cast<memref::ExpandShapeOp>(op); + assert(value == expandOp.getResult() && "invalid value"); + cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + } +}; + struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, GetGlobalOp> { @@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( + *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index d35566a..bd02516 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder } }; -/// Replace `base, offset, sizes, strides = -/// extract_strided_metadata( -/// cast(src) to dstTy)` -/// With -/// ``` -/// base, ... = extract_strided_metadata(src) -/// offset = !dstTy.srcOffset.isDynamic() -/// ? dstTy.srcOffset -/// : extract_strided_metadata(src).offset -/// sizes = for each srcSize in dstTy.srcSizes: -/// !srcSize.isDynamic() -/// ? srcSize -// : extract_strided_metadata(src).sizes[i] -/// strides = for each srcStride in dstTy.srcStrides: -/// !srcStrides.isDynamic() -/// ? srcStrides -/// : extract_strided_metadata(src).strides[i] -/// ``` -/// -/// In other words, consume the `cast` and apply its effects -/// on the offset, sizes, and strides or compute them directly from `src`. -class ExtractStridedMetadataOpCastFolder - : public OpRewritePattern<memref::ExtractStridedMetadataOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult - matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, - PatternRewriter &rewriter) const override { - Value source = extractStridedMetadataOp.getSource(); - auto castOp = source.getDefiningOp<memref::CastOp>(); - if (!castOp) - return failure(); - - Location loc = extractStridedMetadataOp.getLoc(); - // Check if the source is suitable for extract_strided_metadata. - SmallVector<Type> inferredReturnTypes; - if (failed(extractStridedMetadataOp.inferReturnTypes( - rewriter.getContext(), loc, {castOp.getSource()}, - /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, - inferredReturnTypes))) - return rewriter.notifyMatchFailure(castOp, - "cast source's type is incompatible"); - - auto memrefType = cast<MemRefType>(source.getType()); - unsigned rank = memrefType.getRank(); - SmallVector<OpFoldResult> results; - results.resize_for_overwrite(rank * 2 + 2); - - auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( - rewriter, loc, castOp.getSource()); - - // Register the base_buffer. - results[0] = newExtractStridedMetadata.getBaseBuffer(); - - auto getConstantOrValue = [&rewriter](int64_t constant, - OpFoldResult ofr) -> OpFoldResult { - return ShapedType::isStatic(constant) - ? OpFoldResult(rewriter.getIndexAttr(constant)) - : ofr; - }; - - auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); - assert(sourceStrides.size() == rank && "unexpected number of strides"); - - // Register the new offset. - results[1] = - getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); - - const unsigned sizeStartIdx = 2; - const unsigned strideStartIdx = sizeStartIdx + rank; - ArrayRef<int64_t> sourceSizes = memrefType.getShape(); - - SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); - SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); - for (unsigned i = 0; i < rank; ++i) { - results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); - results[strideStartIdx + i] = - getConstantOrValue(sourceStrides[i], strides[i]); - } - rewriter.replaceOp(extractStridedMetadataOp, - getValueOrCreateConstantIndexOp(rewriter, loc, results)); - return success(); - } -}; - /// Replace `base, offset, sizes, strides = extract_strided_metadata( /// memory_space_cast(src) to dstTy)` /// with @@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns( RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpSubviewFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( @@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt new file mode 100644 index 0000000..f305068 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACCAnalysis + OpenACCSupport.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACCDialect + MLIROpenACCUtils + MLIRSupport +) + diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp new file mode 100644 index 0000000..f6b4534 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -0,0 +1,26 @@ +//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the OpenACCSupport analysis interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +namespace mlir { +namespace acc { + +std::string OpenACCSupport::getVariableName(Value v) { + if (impl) + return impl->getVariableName(v); + return acc::getVariableName(v); +} + +} // namespace acc +} // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 7117520..e8a916e 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 5ca0100..ca46629 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() { } //===----------------------------------------------------------------------===// +// FirstprivateMapInitialOp +//===----------------------------------------------------------------------===// +LogicalResult acc::FirstprivateMapInitialOp::verify() { + if (getDataClause() != acc::DataClause::acc_firstprivate) + return emitError("data clause associated with firstprivate operation must " + "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkNoModifier(*this))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// LogicalResult acc::ReductionOp::verify() { diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 1223325..89adda82 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) { pointerLikeTy.getElementType()); return typeCategory; } + +std::string mlir::acc::getVariableName(mlir::Value v) { + Value current = v; + + // Walk through view operations until a name is found or can't go further + while (Operation *definingOp = current.getDefiningOp()) { + // Check for `acc.var_name` attribute + if (auto varNameAttr = + definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName())) + return varNameAttr.getName().str(); + + // If it is a data entry operation, get name via getVarName + if (isa<ACC_DATA_ENTRY_OPS>(definingOp)) + if (auto name = acc::getVarName(definingOp)) + return name->str(); + + // If it's a view operation, continue to the source + if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) { + current = viewOp.getViewSource(); + continue; + } + + break; + } + + return ""; +} diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt index 57a6d34..f3c02da 100644 --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Transforms) + add_mlir_dialect_library(MLIROpenMPDialect IR/OpenMPDialect.cpp diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fd4cabbad..1b069c6 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -32,7 +32,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> @@ -1737,10 +1736,10 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { // Parser, printer and verifier for Target //===----------------------------------------------------------------------===// -// Helper function to get bitwise AND of `value` and 'flag' -static uint64_t mapTypeToBitFlag(uint64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) { - return value & llvm::to_underlying(flag); +// Helper function to get bitwise AND of `value` and 'flag' then return it as a +// boolean +static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) { + return (value & flag) == flag; } /// Parses a map_entries map type from a string format back into its numeric @@ -1748,10 +1747,9 @@ static uint64_t mapTypeToBitFlag(uint64_t value, /// /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `? /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` ) -static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - +static ParseResult parseMapClause(OpAsmParser &parser, + ClauseMapFlagsAttr &mapType) { + ClauseMapFlags mapTypeBits = ClauseMapFlags::none; // This simply verifies the correct keyword is read in, the // keyword itself is stored inside of the operation auto parseTypeAndMod = [&]() -> ParseResult { @@ -1760,35 +1758,64 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { return failure(); if (mapTypeMod == "always") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + mapTypeBits |= ClauseMapFlags::always; if (mapTypeMod == "implicit") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mapTypeBits |= ClauseMapFlags::implicit; if (mapTypeMod == "ompx_hold") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + mapTypeBits |= ClauseMapFlags::ompx_hold; if (mapTypeMod == "close") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + mapTypeBits |= ClauseMapFlags::close; if (mapTypeMod == "present") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= ClauseMapFlags::present; if (mapTypeMod == "to") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapTypeBits |= ClauseMapFlags::to; if (mapTypeMod == "from") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::from; if (mapTypeMod == "tofrom") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from; if (mapTypeMod == "delete") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + mapTypeBits |= ClauseMapFlags::del; + + if (mapTypeMod == "storage") + mapTypeBits |= ClauseMapFlags::storage; if (mapTypeMod == "return_param") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mapTypeBits |= ClauseMapFlags::return_param; + + if (mapTypeMod == "private") + mapTypeBits |= ClauseMapFlags::priv; + + if (mapTypeMod == "literal") + mapTypeBits |= ClauseMapFlags::literal; + + if (mapTypeMod == "attach") + mapTypeBits |= ClauseMapFlags::attach; + + if (mapTypeMod == "attach_always") + mapTypeBits |= ClauseMapFlags::attach_always; + + if (mapTypeMod == "attach_none") + mapTypeBits |= ClauseMapFlags::attach_none; + + if (mapTypeMod == "attach_auto") + mapTypeBits |= ClauseMapFlags::attach_auto; + + if (mapTypeMod == "ref_ptr") + mapTypeBits |= ClauseMapFlags::ref_ptr; + + if (mapTypeMod == "ref_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptee; + + if (mapTypeMod == "ref_ptr_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; return success(); }; @@ -1796,9 +1823,8 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { if (parser.parseCommaSeparatedList(parseTypeAndMod)) return failure(); - mapType = parser.getBuilder().getIntegerAttr( - parser.getBuilder().getIntegerType(64, /*isSigned=*/false), - llvm::to_underlying(mapTypeBits)); + mapType = + parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits); return success(); } @@ -1806,60 +1832,62 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { /// Prints a map_entries map type from its numeric value out into its string /// format. static void printMapClause(OpAsmPrinter &p, Operation *op, - IntegerAttr mapType) { - uint64_t mapTypeBits = mapType.getUInt(); - - bool emitAllocRelease = true; + ClauseMapFlagsAttr mapType) { llvm::SmallVector<std::string, 4> mapTypeStrs; + ClauseMapFlags mapFlags = mapType.getValue(); // handling of always, close, present placed at the beginning of the string // to aid readability - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::always)) mapTypeStrs.push_back("always"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit)) mapTypeStrs.push_back("implicit"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold)) mapTypeStrs.push_back("ompx_hold"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::close)) mapTypeStrs.push_back("close"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::present)) mapTypeStrs.push_back("present"); // special handling of to/from/tofrom/delete and release/alloc, release + // alloc are the abscense of one of the other flags, whereas tofrom requires // both the to and from flag to be set. - bool to = mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - if (to && from) { - emitAllocRelease = false; + bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to); + bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from); + + if (to && from) mapTypeStrs.push_back("tofrom"); - } else if (from) { - emitAllocRelease = false; + else if (from) mapTypeStrs.push_back("from"); - } else if (to) { - emitAllocRelease = false; + else if (to) mapTypeStrs.push_back("to"); - } - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { - emitAllocRelease = false; + + if (mapTypeToBool(mapFlags, ClauseMapFlags::del)) mapTypeStrs.push_back("delete"); - } - if (mapTypeToBitFlag( - mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { - emitAllocRelease = false; + if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param)) mapTypeStrs.push_back("return_param"); - } - if (emitAllocRelease) - mapTypeStrs.push_back("exit_release_or_enter_alloc"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::storage)) + mapTypeStrs.push_back("storage"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::priv)) + mapTypeStrs.push_back("private"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::literal)) + mapTypeStrs.push_back("literal"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach)) + mapTypeStrs.push_back("attach"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always)) + mapTypeStrs.push_back("attach_always"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none)) + mapTypeStrs.push_back("attach_none"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto)) + mapTypeStrs.push_back("attach_auto"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr)) + mapTypeStrs.push_back("ref_ptr"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee)) + mapTypeStrs.push_back("ref_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) + mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapFlags == ClauseMapFlags::none) + mapTypeStrs.push_back("none"); for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { p << mapTypeStrs[i]; @@ -1963,21 +1991,15 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { return emitError(op->getLoc(), "missing map operation"); if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) { - uint64_t mapTypeBits = mapInfoOp.getMapType(); - - bool to = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - bool always = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); - bool close = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - bool implicit = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); + mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType(); + + bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to); + bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from); + bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del); + + bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always); + bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close); + bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit); if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del) return emitError(op->getLoc(), diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt new file mode 100644 index 0000000..b9b8eda --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIROpenMPTransforms + OpenMPOffloadPrivatizationPrepare.cpp + + DEPENDS + MLIROpenMPPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRFuncDialect + MLIRLLVMDialect + MLIROpenMPDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp new file mode 100644 index 0000000..c117d9b --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp @@ -0,0 +1,447 @@ +//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include <cstdint> +#include <iterator> +#include <utility> + +//===----------------------------------------------------------------------===// +// A pass that prepares OpenMP code for translation of delayed privatization +// in the context of deferred target tasks. Deferred target tasks are created +// when the nowait clause is used on the target directive. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "omp-prepare-for-offload-privatization" + +namespace mlir { +namespace omp { + +#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS +#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc" + +} // namespace omp +} // namespace mlir + +using namespace mlir; +namespace { + +//===----------------------------------------------------------------------===// +// PrepareForOMPOffloadPrivatizationPass +//===----------------------------------------------------------------------===// + +class PrepareForOMPOffloadPrivatizationPass + : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase< + PrepareForOMPOffloadPrivatizationPass> { + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // In this pass, we make host-allocated privatized variables persist for + // deferred target tasks by copying them to the heap. Once the target task + // is done, this heap memory is freed. Since all of this happens on the host + // we can skip device modules. + auto offloadModuleInterface = + dyn_cast<omp::OffloadModuleInterface>(mod.getOperation()); + if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice()) + return; + + getOperation()->walk([&](omp::TargetOp targetOp) { + if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp)) + return; + IRRewriter rewriter(&getContext()); + OperandRange privateVars = targetOp.getPrivateVars(); + SmallVector<mlir::Value> newPrivVars; + Value fakeDependVar; + omp::TaskOp cleanupTaskOp; + + newPrivVars.reserve(privateVars.size()); + std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms(); + for (auto [privVarIdx, privVarSymPair] : + llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) { + Value privVar = std::get<0>(privVarSymPair); + Attribute privSym = std::get<1>(privVarSymPair); + + omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym); + if (!privatizer.needsMap()) { + newPrivVars.push_back(privVar); + continue; + } + bool isFirstPrivate = privatizer.getDataSharingType() == + omp::DataSharingClauseType::FirstPrivate; + + Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx); + auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp()); + + if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) { + newPrivVars.push_back(privVar); + continue; + } + + // For deferred target tasks (!$omp target nowait), we need to keep + // a copy of the original, i.e. host variable being privatized so + // that it is available when the target task is eventually executed. + // We do this by first allocating as much heap memory as is needed by + // the original variable. Then, we use the init and copy regions of the + // privatizer, an instance of omp::PrivateClauseOp to set up the heap- + // allocated copy. + // After the target task is done, we need to use the dealloc region + // of the privatizer to clean up everything. We also need to free + // the heap memory we allocated. But due to the deferred nature + // of the target task, we cannot simply deallocate right after the + // omp.target operation else we may end up freeing memory before + // its eventual use by the target task. So, we create a dummy + // dependence between the target task and new omp.task. In the omp.task, + // we do all the cleanup. So, we end up with the following structure + // + // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) { + // ... + // omp.terminator + // } + // omp.task depend(in: fakeDependVar) { + // /*cleanup_code*/ + // omp.terminator + // } + // fakeDependVar is the address of the first heap-allocated copy of the + // host variable being privatized. + + bool needsCleanupTask = !privatizer.getDeallocRegion().empty(); + + // Allocate heap memory that corresponds to the type of memory + // pointed to by varPtr + // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols + // should have mapped the pointer to the boxchar so use that as varPtr. + Value varPtr = mapInfoOp.getVarPtr(); + Type varType = mapInfoOp.getVarType(); + bool isPrivatizedByValue = + !isa<LLVM::LLVMPointerType>(privVar.getType()); + + assert(isa<LLVM::LLVMPointerType>(varPtr.getType())); + Value heapMem = + allocateHeapMem(targetOp, varPtr, varType, mod, rewriter); + if (!heapMem) + targetOp.emitError( + "Unable to allocate heap memory when trying to move " + "a private variable out of the stack and into the " + "heap for use by a deferred target task"); + + if (needsCleanupTask && !fakeDependVar) + fakeDependVar = heapMem; + + // The types of private vars should match before and after the + // transformation. In particular, if the type is a pointer, + // simply record the newly allocated malloc location as the + // new private variable. If, however, the type is not a pointer + // then, we need to load the value from the newly allocated + // location. We'll insert that load later after we have updated + // the malloc'd location with the contents of the original + // variable. + if (!isPrivatizedByValue) + newPrivVars.push_back(heapMem); + + // We now need to copy the original private variable into the newly + // allocated location in the heap. + // Find the earliest insertion point for the copy. This will be before + // the first in the list of omp::MapInfoOp instances that use varPtr. + // After the copy these omp::MapInfoOp instances will refer to heapMem + // instead. + Operation *varPtrDefiningOp = varPtr.getDefiningOp(); + DenseSet<Operation *> users; + if (varPtrDefiningOp) { + users.insert(varPtrDefiningOp->user_begin(), + varPtrDefiningOp->user_end()); + } else { + auto blockArg = cast<BlockArgument>(varPtr); + users.insert(blockArg.user_begin(), blockArg.user_end()); + } + auto usesVarPtr = [&users](Operation *op) -> bool { + return users.count(op); + }; + + SmallVector<Operation *> chainOfOps; + chainOfOps.push_back(mapInfoOp); + for (auto member : mapInfoOp.getMembers()) { + omp::MapInfoOp memberMap = + cast<omp::MapInfoOp>(member.getDefiningOp()); + if (usesVarPtr(memberMap)) + chainOfOps.push_back(memberMap); + if (memberMap.getVarPtrPtr()) { + Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp(); + if (defOp && usesVarPtr(defOp)) + chainOfOps.push_back(defOp); + } + } + + DominanceInfo dom; + llvm::sort(chainOfOps, [&](Operation *l, Operation *r) { + if (l == r) + return false; + return dom.properlyDominates(l, r); + }); + + rewriter.setInsertionPoint(chainOfOps.front()); + + Operation *firstOp = chainOfOps.front(); + Location loc = firstOp->getLoc(); + + // Create a llvm.func for 'region' that is marked always_inline and call + // it. + auto createAlwaysInlineFuncAndCallIt = + [&](Region ®ion, llvm::StringRef funcName, + llvm::ArrayRef<Value> args, bool returnsValue) -> Value { + assert(!region.empty() && "region cannot be empty"); + LLVM::LLVMFuncOp func = createFuncOpForRegion( + loc, mod, region, funcName, rewriter, returnsValue); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); + return call.getResult(); + }; + + Value moldArg, newArg; + if (isPrivatizedByValue) { + moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr); + newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem); + } else { + moldArg = varPtr; + newArg = heapMem; + } + + Value initializedVal; + if (!privatizer.getInitRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getInitRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(), + {moldArg, newArg}, /*returnsValue=*/true); + else + initializedVal = newArg; + + if (isFirstPrivate && !privatizer.getCopyRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getCopyRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(), + {moldArg, initializedVal}, /*returnsValue=*/true); + + if (isPrivatizedByValue) + (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem); + + // clone origOp, replace all uses of varPtr with heapMem and + // erase origOp. + auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * { + Operation *clonedOp = rewriter.clone(*origOp); + rewriter.replaceAllOpUsesWith(origOp, clonedOp); + rewriter.modifyOpInPlace(clonedOp, [&]() { + clonedOp->replaceUsesOfWith(varPtr, heapMem); + }); + rewriter.eraseOp(origOp); + return clonedOp; + }; + + // Now that we have set up the heap-allocated copy of the private + // variable, rewrite all the uses of the original variable with + // the heap-allocated variable. + rewriter.setInsertionPoint(targetOp); + mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp)); + rewriter.setInsertionPoint(mapInfoOp); + + // Fix any members that may use varPtr to now use heapMem + for (auto member : mapInfoOp.getMembers()) { + auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp()); + if (!usesVarPtr(memberMapInfoOp)) + continue; + memberMapInfoOp = + cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp)); + rewriter.setInsertionPoint(memberMapInfoOp); + + if (memberMapInfoOp.getVarPtrPtr()) { + Operation *varPtrPtrdefOp = + memberMapInfoOp.getVarPtrPtr().getDefiningOp(); + rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp)); + } + } + + // If the type of the private variable is not a pointer, + // which is typically the case with !fir.boxchar types, then + // we need to ensure that the new private variable is also + // not a pointer. Insert a load from heapMem right before + // targetOp. + if (isPrivatizedByValue) { + rewriter.setInsertionPoint(targetOp); + auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(), + varType, heapMem); + newPrivVars.push_back(newPrivVar); + } + + // Deallocate + if (needsCleanupTask) { + if (!cleanupTaskOp) { + assert(fakeDependVar && + "Need a valid value to set up a dependency"); + rewriter.setInsertionPointAfter(targetOp); + omp::TaskOperands taskOperands; + auto inDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependin); + taskOperands.dependKinds.push_back(inDepend); + taskOperands.dependVars.push_back(fakeDependVar); + cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands); + Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion()); + rewriter.setInsertionPointToEnd(taskBlock); + omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc()); + } + rewriter.setInsertionPointToStart( + &*cleanupTaskOp.getRegion().getBlocks().begin()); + (void)createAlwaysInlineFuncAndCallIt( + privatizer.getDeallocRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc") + .str(), + {initializedVal}, /*returnsValue=*/false); + llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc = + LLVM::lookupOrCreateFreeFn(rewriter, mod); + assert(llvm::succeeded(freeFunc) && + "Could not find free in the module"); + (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(), + ValueRange{heapMem}); + } + } + assert(newPrivVars.size() == privateVars.size() && + "The number of private variables must match before and after " + "transformation"); + if (fakeDependVar) { + omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependout); + SmallVector<Attribute> newDependKinds; + if (!targetOp.getDependVars().empty()) { + std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds(); + assert(dependKinds && "bad depend clause in omp::TargetOp"); + llvm::copy(*dependKinds, std::back_inserter(newDependKinds)); + } + newDependKinds.push_back(outDepend); + ArrayAttr newDependKindsAttr = + ArrayAttr::get(rewriter.getContext(), newDependKinds); + targetOp.getDependVarsMutable().append(fakeDependVar); + targetOp.setDependKindsAttr(newDependKindsAttr); + } + rewriter.setInsertionPoint(targetOp); + targetOp.getPrivateVarsMutable().clear(); + targetOp.getPrivateVarsMutable().assign(newPrivVars); + }); + } + +private: + bool hasPrivateVars(omp::TargetOp targetOp) const { + return !targetOp.getPrivateVars().empty(); + } + + bool isTargetTaskDeferred(omp::TargetOp targetOp) const { + return targetOp.getNowait(); + } + + template <typename OpTy> + omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const { + SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym); + omp::PrivateClauseOp privatizer = + SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>( + op, privatizerName); + return privatizer; + } + + // Get the (compile-time constant) size of varType as per the + // given DataLayout dl. + std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const { + llvm::TypeSize size = dl.getTypeSize(varType); + unsigned short alignment = dl.getTypeABIAlignment(varType); + return llvm::alignTo(size, alignment); + } + + LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const { + llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall = + LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type()); + assert(llvm::succeeded(mallocCall) && + "Could not find malloc in the module"); + return mallocCall.value(); + } + + Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType, + ModuleOp mod, IRRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + Value varPtr = privVar; + Operation *definingOp = varPtr.getDefiningOp(); + BlockArgument blockArg; + if (!definingOp) { + blockArg = mlir::dyn_cast<BlockArgument>(varPtr); + rewriter.setInsertionPointToStart(blockArg.getParentBlock()); + } else { + rewriter.setInsertionPoint(definingOp); + } + Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc(); + LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter); + + assert(mod.getDataLayoutSpec() && + "MLIR module with no datalayout spec not handled yet"); + + const DataLayout &dl = DataLayout(mod); + std::int64_t distance = getSizeInBytes(dl, varType); + + Value sizeBytes = LLVM::ConstantOp::create( + rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance); + + auto mallocCallOp = + LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes}); + return mallocCallOp.getResult(); + } + + // Create a function for srcRegion and attribute it to be always_inline. + // The big assumption here is that srcRegion is one of init, copy or dealloc + // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed + // to either be the same as the types of the two arguments of the region (for + // init and copy regions) or void as would be the case for dealloc regions. + LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod, + Region &srcRegion, + llvm::StringRef funcName, + IRRewriter &rewriter, + bool returnsValue = false) { + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); + Region clonedRegion; + IRMapping mapper; + srcRegion.cloneInto(&clonedRegion, mapper); + + SmallVector<Type> paramTypes; + llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes)); + Type resultType = returnsValue + ? srcRegion.getArgument(0).getType() + : LLVM::LLVMVoidType::get(rewriter.getContext()); + LLVM::LLVMFunctionType funcType = + LLVM::LLVMFunctionType::get(resultType, paramTypes); + + LLVM::LLVMFuncOp func = + LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType); + func.setAlwaysInline(true); + rewriter.inlineRegionBefore(clonedRegion, func.getRegion(), + func.getRegion().end()); + for (auto &block : func.getRegion().getBlocks()) { + if (isa<omp::YieldOp>(block.getTerminator())) { + omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(), + yieldOp.getOperands()); + } + } + return func; + } +}; +} // namespace diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a9da6c2..744a595 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector<Operation *> yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector<Value> externalValues; + SmallVector<Value> internalValues; + SmallVector<Value> opResultsToReplaceWithExternalValues; + SmallVector<Value> opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector<Type> resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, + ExecuteRegionForwardingEliminator>(context); } void ExecuteRegionOp::getSuccessorRegions( @@ -2490,8 +2584,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); @@ -2500,8 +2594,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantFalse) - constantFalse = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + constantFalse = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dc61a2..335ca1a 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue<ShapedType> sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( + TypedValue<ShapedType> targetShard = AllSliceOp::create(builder, sourceShard, grid, ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) - .getResult()); + .getResult(); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; @@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding( APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allGatherResult) - .getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); return {targetShard, targetSharding}; } @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult(); return {targetShard, targetSharding}; } @@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source, auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, - cast<TypedValue<ShapedType>>(source.getSrc()), - sourceShardValue); + source.getSrc(), sourceShardValue); } TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index 73e0f3d..f53d272 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter( loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) { // One map per tensor. - assert(loop2InsLvl.size() == ins.size()); + assert(this->loop2InsLvl.size() == this->ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { auto [m, v] = mvPair; - return m.getNumResults() == cast<ShapedType>(v.getType()).getRank(); + + // For ranked types the rank must match. + // Simply return true for UnrankedTensorType + if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) { + return !shapedType.hasRank() || + (m.getNumResults() == shapedType.getRank()); + } + // Non-shaped (scalar) types behave like rank-0. + return m.getNumResults() == 0; })); itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false)); diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 1cba1bb..eb47e85 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -12,6 +12,97 @@ namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + case Extension::int64: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr<TargetEnv> +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -39,9 +130,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } -llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { - return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); -} - } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index caf8016..a85ff10a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -76,28 +76,6 @@ template <typename OpTy> struct PoolPadFoldAdaptor; template <> -struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> { - using OpTy = tosa::AvgPool2dOp; - static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { - const llvm::ArrayRef<int64_t> kernel = op.getKernel(); - if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || - newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) - return false; - return true; - } - static bool checkPadConstCompliance(OpTy op, Value padConst) { - return checkMatchingPadConstAndZp(padConst, op.getInputZp()); - } - static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, - Value padInput, ArrayRef<int64_t> newPad) { - rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>( - op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), - op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), - op.getAccType()); - } -}; - -template <> struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { using OpTy = tosa::MaxPool2dOp; static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> { }; } // namespace -void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<FoldPadToTensorOp<tosa::AvgPool2dOp, - PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>( - context); -} - void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< @@ -1001,8 +972,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { !outputTy.hasStaticShape()) return {}; - if (inputTy.getDimSize(getAxis()) == 1) - return DenseElementsAttr::get(outputTy, 0); + const Type outputElementTy = getElementTypeOrSelf(outputTy); + if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) { + const auto outputElemIntTy = cast<IntegerType>(outputElementTy); + const APInt zero = APInt::getZero(outputElemIntTy.getWidth()); + return DenseElementsAttr::get(outputTy, zero); + } return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc..0aff67f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast<tosa::NanPropagationModeAttr>(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template <typename T> -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast<TensorType>(inType); - auto outputType = llvm::dyn_cast<TensorType>(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast<TensorType>(aType); + auto bTType = llvm::dyn_cast<TensorType>(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType); - auto outputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType); + auto bQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W}; + const auto outputType = cast<ShapedType>(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, PadOp::Adaptor adaptor, @@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector<int64_t> inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(inputDataDims).drop_back(1), + ArrayRef<int64_t>(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector<int64_t> outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector<int64_t> outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(outputDataDims).drop_back(1), + ArrayRef<int64_t>(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index f072e3e..ab363ee 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -19,12 +19,19 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16}; const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32}; const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48}; + const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64}; const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16}; const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16}; const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32}; const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -44,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() { // Base populating function LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, - Value output) { - for (auto operand : operands) + ValueRange results) { + for (const auto &operand : operands) addValue(operand); - addValue(output); + for (const auto &result : results) + addValue(result); return success(); } @@ -170,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { } template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getInputImag()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { addValue(op.getOnTrue()); addValue(op.getOnFalse()); @@ -239,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function populates the info for all operands. #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ if (isa<tosa::tosaOp##Op>(op)) { \ - return populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return populateProfileInfo(op->getOperands(), op->getResults()); \ } // Skip irrelevant operands when they are independent and not tied to any @@ -250,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) - POPULATE_PROFILE_INFO_CUSTOM(FFT2d) - POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) POPULATE_PROFILE_INFO_CUSTOM(Concat) POPULATE_PROFILE_INFO_CUSTOM(Pad) POPULATE_PROFILE_INFO_CUSTOM(Reshape) @@ -269,7 +258,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) + POPULATE_PROFILE_INFO_COMMON(FFT2d) + POPULATE_PROFILE_INFO_COMMON(RFFT2d) POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled) + POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) POPULATE_PROFILE_INFO_COMMON(Sub) @@ -623,6 +617,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7e..4d0b61a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(Transpose); // Type Conversion CHECK_RANKS_AND_SIZES(Cast); + CHECK_RANKS_AND_SIZES(CastFromBlockScaled); + CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); @@ -657,6 +659,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1195,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, - Float8E5M2Type>(type); - } - if (auto intTy = dyn_cast<IntegerType>(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1203,6 +1206,7 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { case 16: case 32: case 48: + case 64: return true; } } else if (allowUnsigned && intTy.isUnsigned()) { @@ -1220,13 +1224,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 8f46ad6..ef49c86 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final for (int64_t i = 0; i < origNumElems; ++i) promoteMask[i] = i; - Value promotedInput = rewriter.create<vector::ShuffleOp>( - shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote, - promoteMask); + Value promotedInput = + vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType, + inputToPromote, inputToPromote, promoteMask); // Create the final shuffle with the promoted inputs. Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7c019e7..8b5e950 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -341,13 +341,18 @@ private: /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the -/// distributed dimensions. The number of results should be equal to the number +/// distributed dimensions. If the number of results is zero there is no +/// distribution (i.e. original type is returned). +/// Otherwise, The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { + // If the map has zero results, return the original type. + if (map.getNumResults() == 0) + return originalType; SmallVector<int64_t> targetShape(originalType.getShape()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 1599ae9..24e9095 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder) { auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); - return builder.create<ArithOp>(loc, aVal, bVal).getResult(); + return ArithOp::create(builder, loc, aVal, bVal).getResult(); } // a helper utility to perform division operation on OpFoldResult and int64_t. diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index f77784a..2c37140 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { - if (!layout.getEffectiveInstDataAsInt().empty()) - return layout.getEffectiveInstDataAsInt(); + if (!layout.getEffectiveInstDataAsInt().empty()) { + SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt(); + // Remove leading unit dimensions from inst_data + // For example, if the inst_data is [1, 1, 32] + // it will pass [32] as the unroll/blocking size. + // Skip it for xegpu nd ops since it will be 2D + // TODO: For vectors ops, experiment with the + // upstream vector remove leading unit dims patterns, + // populateCastAwayVectorLeadingOneDimPatterns. + Operation *definingOp = value.getDefiningOp(); + bool skipLeadingUnitDimRemoval = + definingOp && + (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp, + xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp)); + if (!skipLeadingUnitDimRemoval) { + auto it = llvm::find_if(instData, [](auto val) { return val != 1; }); + instData.erase(instData.begin(), it); + } + return instData; + } if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); @@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() { // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, tdescTy.getMemorySpace(), blockedChunkSize); - encoding = newEncoding; } } @@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() { xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, tdescTy.getLayoutAttr().dropInstData()); } else { - newTy = type.clone(tileShape, elemTy); + newTy = VectorType::get(tileShape, elemTy); } if (returnSingleType) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 26770b3..d09dc19 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() { return AffineMap::get(val.getContext()); // Get the layout of the vector type. xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val); - // If no layout is specified, assume the inner most dimension is distributed - // for now. + // If no layout is specified, that means no distribution. if (!layout) - return AffineMap::getMultiDimMapWithTargets( - vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); + return AffineMap::getMultiDimMapWithTargets(vecRank, {}, + val.getContext()); + // Expecting vector and layout rank to match. + assert(layout.getRank() == vecRank && + "Expecting vector and layout rank to match"); + // A dimension is distributed only if layout suggests there are + // multiple lanes assigned for this dimension and the shape can be evenly + // distributed to those lanes. SmallVector<unsigned int> distributedDims; for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) { - if (v > 1) + if (v > 1 && vecType.getShape()[i] % v == 0) distributedDims.push_back(i); } return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, @@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { auto warpReduction = [](Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size) { // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { - Value shuffled = - builder - .create<gpu::ShuffleOp>(loc, laneVal, i, - /*width=*/size, - /*mode=*/gpu::ShuffleMode::XOR) - .getShuffleResult(); + Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index aafa1b7..e6e71cc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -66,8 +66,6 @@ protected: Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(destTy)) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of destTy."); auto shape = vecTy.getShape(); return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape); } @@ -93,8 +91,6 @@ protected: ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(src.getType())) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of src."); return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, blockSize); } @@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType()); VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType()); Type elemTy = valueTy.getElementType(); - VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + VectorType newValueTy = VectorType::get(*targetShape, elemTy); SmallVector<Type> convertedMaskTypes; SmallVector<Value> convertedMasks; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 31a967d..9fc5ad9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), baseTileValues); - auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); + auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr); // Get subgroup id Value sgId = @@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { SmallVector<Value, 2> strideConsts; strideConsts.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, colStride)); + arith::ConstantIndexOp::create(rewriter, loc, colStride)); if (rows > 1) strideConsts.insert( strideConsts.begin(), - rewriter.create<arith::ConstantIndexOp>(loc, rowStride)); + arith::ConstantIndexOp::create(rewriter, loc, rowStride)); SmallVector<Value> newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec - Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); for (size_t i = 0; i < strideConsts.size(); ++i) { - Value mul = rewriter.create<arith::MulIOp>( - loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); - mulOffset = rewriter.create<arith::AddIOp>( - loc, rewriter.getIndexType(), mulOffset, mul); + Value mul = + arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(), + offsets[i], strideConsts[i]); + mulOffset = arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), mulOffset, mul); } // Broadcast to baseConstVec size - auto bcastOffset = rewriter.create<vector::BroadcastOp>( - loc, baseConstVec.getType(), mulOffset); + auto bcastOffset = vector::BroadcastOp::create( + rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); setLayoutIfNeeded(baseConstVec); @@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp SmallVector<Value> newShapeCastOps; for (auto src : adaptor.getSource()) { - auto newShapeCast = - rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src); + auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), + newResultType, src); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), @@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp SmallVector<Value> newReductions; for (auto sgSrc : adaptor.getSource()) { - auto newOp = rewriter.create<vector::MultiDimReductionOp>( - op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], - op.getReductionDims()); + auto newOp = vector::MultiDimReductionOp::create( + rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, + adaptor.getAcc()[0], op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newOp->getResult(0), diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2c56a43..b4605cd 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, if (!computeShapeRatio(srcShape, shape)) return {value}; + int64_t srcShapeRank = srcShape.size(); + int64_t targetShapeRank = shape.size(); + + SmallVector<int64_t> adjustedTargetShape(srcShape.size()); + int64_t rankDiff = srcShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, + 1); + std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + SmallVector<Value> result; - for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(srcShape, adjustedTargetShape)) { SmallVector<int64_t> staticStrides(offsets.size(), 1); - result.push_back(vector::ExtractStridedSliceOp::create( - builder, loc, value, offsets, shape, staticStrides)); + Value slice = vector::ExtractStridedSliceOp::create( + builder, loc, value, offsets, adjustedTargetShape, staticStrides); + + // Reshape to remove leading unit dims if needed + if (srcShapeRank > targetShapeRank) { + auto targetTy = VectorType::get(shape, vecTy.getElementType()); + slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice); + } + result.push_back(slice); } return result; @@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, for (auto [src, offsets] : llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { - SmallVector<int64_t> staticStrides(offsets.size(), 1); + SmallVector<int64_t> staticStrides(tileShape.size(), 1); result = vector::InsertStridedSliceOp::create(builder, loc, src, result, offsets, staticStrides); } |
