diff options
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 62 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp | 21 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 13 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp | 4 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 96 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 95 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 29 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 377 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp | 48 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 22 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 25 |
14 files changed, 691 insertions, 139 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db..4c4965e 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() { //===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { - Type sourceAType = getSourceA().getType(); - Type sourceBType = getSourceB().getType(); - Type destType = getDestC().getType(); - VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); - VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType); - VectorType destVectorType = dyn_cast<VectorType>(destType); +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + SmallVector<int64_t, 3> dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; - Type sourceAElemType = sourceVectorAType.getElementType(); - Type sourceBElemType = sourceVectorBType.getElementType(); - Type destElemType = destVectorType.getElementType(); + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} - if (sourceVectorAType.getNumElements() != - sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { + auto sourceAType = cast<VectorType>(getSourceA().getType()); + auto sourceBType = cast<VectorType>(getSourceB().getType()); + auto destType = cast<VectorType>(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { return emitOpError("source vectors have different lengths: ") - << sourceVectorAType << " vs. " << sourceVectorBType; + << sourceAType << " vs. " << sourceBType; } - bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); - bool isSrcFloat = - isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( - sourceAElemType); - - if (isDestFloat && !isSrcFloat) { - return emitOpError("Expected float sources with float destination"); - } + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); - if (!isDestFloat && isSrcFloat) { - return emitOpError("Expected int sources with int destination"); - } + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); - if (sourceAElemType != sourceBElemType && - !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) && - isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) { + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( "source element types much match (except for fp8) but have ") << sourceAType << " and " << sourceBType; } + + if (!sourceAElemType.isInteger(4) && getK() != 16) { + return emitOpError("K dimension must be 16 for source element type ") + << sourceAElemType; + } return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 316721b..60ae78b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { const std::function<unsigned(AffineForOp)> getUnrollFactor; LoopUnroll() : getUnrollFactor(nullptr) {} - LoopUnroll(const LoopUnroll &other) - - = default; + LoopUnroll(const LoopUnroll &other) = default; explicit LoopUnroll( std::optional<unsigned> unrollFactor = std::nullopt, - bool unrollUpToFactor = false, bool unrollFull = false, + bool unrollUpToFactor = false, const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) : getUnrollFactor(getUnrollFactor) { if (unrollFactor) this->unrollFactor = *unrollFactor; this->unrollUpToFactor = unrollUpToFactor; - this->unrollFull = unrollFull; } void runOnOperation() override; @@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f, } void LoopUnroll::runOnOperation() { + if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid option: 'unroll-factor' should be greater than 0 or " + "equal to -1"); + return signalPassFailure(); + } FunctionOpInterface func = getOperation(); if (func.isExternal()) return; - if (unrollFull && unrollFullThreshold.hasValue()) { + if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) { // Store short loops as we walk. SmallVector<AffineForOp, 4> loops; @@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, getUnrollFactor(forOp), /*annotateFn=*/nullptr, cleanUpUnroll); // Unroll completely if full loop unroll was specified. - if (unrollFull) + if (unrollFactor.getValue() == -1) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. if (unrollUpToFactor) @@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::affine::createLoopUnrollPass( - int unrollFactor, bool unrollUpToFactor, bool unrollFull, + int unrollFactor, bool unrollUpToFactor, const std::function<unsigned(AffineForOp)> &getUnrollFactor) { return std::make_unique<LoopUnroll>( unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), - unrollUpToFactor, unrollFull, getUnrollFactor); + unrollUpToFactor, getUnrollFactor); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index a6159ee..f0ddb50 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -14,13 +14,6 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" -namespace mlir { -namespace bufferization { -#define GEN_PASS_DEF_TENSORCOPYINSERTION -#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" -} // namespace bufferization -} // namespace mlir - using namespace mlir; using namespace mlir::bufferization; diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 11400de..a15bf89 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -59,6 +59,17 @@ struct DimOpInterface } }; +struct ExpandShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface, + memref::ExpandShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto expandOp = cast<memref::ExpandShapeOp>(op); + assert(value == expandOp.getResult() && "invalid value"); + cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + } +}; + struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, GetGlobalOp> { @@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( + *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp index a9125ec..c117d9b 100644 --- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp +++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp @@ -189,7 +189,9 @@ class PrepareForOMPOffloadPrivatizationPass DominanceInfo dom; llvm::sort(chainOfOps, [&](Operation *l, Operation *r) { - return dom.dominates(l, r); + if (l == r) + return false; + return dom.properlyDominates(l, r); }); rewriter.setInsertionPoint(chainOfOps.front()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 9bd13f3..744a595 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector<Operation *> yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector<Value> externalValues; + SmallVector<Value> internalValues; + SmallVector<Value> opResultsToReplaceWithExternalValues; + SmallVector<Value> opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector<Type> resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, + ExecuteRegionForwardingEliminator>(context); } void ExecuteRegionOp::getSuccessorRegions( diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 1cba1bb..eb47e85 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -12,6 +12,97 @@ namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + case Extension::int64: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr<TargetEnv> +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -39,9 +130,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } -llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { - return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); -} - } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 99b7cda..a85ff10a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -76,28 +76,6 @@ template <typename OpTy> struct PoolPadFoldAdaptor; template <> -struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> { - using OpTy = tosa::AvgPool2dOp; - static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { - const llvm::ArrayRef<int64_t> kernel = op.getKernel(); - if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || - newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) - return false; - return true; - } - static bool checkPadConstCompliance(OpTy op, Value padConst) { - return checkMatchingPadConstAndZp(padConst, op.getInputZp()); - } - static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, - Value padInput, ArrayRef<int64_t> newPad) { - rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>( - op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), - op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), - op.getAccType()); - } -}; - -template <> struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { using OpTy = tosa::MaxPool2dOp; static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> { }; } // namespace -void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<FoldPadToTensorOp<tosa::AvgPool2dOp, - PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>( - context); -} - void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc..0aff67f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast<tosa::NanPropagationModeAttr>(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template <typename T> -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast<TensorType>(inType); - auto outputType = llvm::dyn_cast<TensorType>(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast<TensorType>(aType); + auto bTType = llvm::dyn_cast<TensorType>(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType); - auto outputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType); + auto bQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W}; + const auto outputType = cast<ShapedType>(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, PadOp::Adaptor adaptor, @@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector<int64_t> inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(inputDataDims).drop_back(1), + ArrayRef<int64_t>(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector<int64_t> outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector<int64_t> outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(outputDataDims).drop_back(1), + ArrayRef<int64_t>(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index f072e3e..ab363ee 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -19,12 +19,19 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16}; const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32}; const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48}; + const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64}; const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16}; const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16}; const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32}; const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -44,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() { // Base populating function LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, - Value output) { - for (auto operand : operands) + ValueRange results) { + for (const auto &operand : operands) addValue(operand); - addValue(output); + for (const auto &result : results) + addValue(result); return success(); } @@ -170,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { } template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getInputImag()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { addValue(op.getOnTrue()); addValue(op.getOnFalse()); @@ -239,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function populates the info for all operands. #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ if (isa<tosa::tosaOp##Op>(op)) { \ - return populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return populateProfileInfo(op->getOperands(), op->getResults()); \ } // Skip irrelevant operands when they are independent and not tied to any @@ -250,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) - POPULATE_PROFILE_INFO_CUSTOM(FFT2d) - POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) POPULATE_PROFILE_INFO_CUSTOM(Concat) POPULATE_PROFILE_INFO_CUSTOM(Pad) POPULATE_PROFILE_INFO_CUSTOM(Reshape) @@ -269,7 +258,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) + POPULATE_PROFILE_INFO_COMMON(FFT2d) + POPULATE_PROFILE_INFO_COMMON(RFFT2d) POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled) + POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) POPULATE_PROFILE_INFO_COMMON(Sub) @@ -623,6 +617,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7e..4d0b61a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(Transpose); // Type Conversion CHECK_RANKS_AND_SIZES(Cast); + CHECK_RANKS_AND_SIZES(CastFromBlockScaled); + CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); @@ -657,6 +659,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1195,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, - Float8E5M2Type>(type); - } - if (auto intTy = dyn_cast<IntegerType>(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1203,6 +1206,7 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { case 16: case 32: case 48: + case 64: return true; } } else if (allowUnsigned && intTy.isUnsigned()) { @@ -1220,13 +1224,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index f77784a..2c37140 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { - if (!layout.getEffectiveInstDataAsInt().empty()) - return layout.getEffectiveInstDataAsInt(); + if (!layout.getEffectiveInstDataAsInt().empty()) { + SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt(); + // Remove leading unit dimensions from inst_data + // For example, if the inst_data is [1, 1, 32] + // it will pass [32] as the unroll/blocking size. + // Skip it for xegpu nd ops since it will be 2D + // TODO: For vectors ops, experiment with the + // upstream vector remove leading unit dims patterns, + // populateCastAwayVectorLeadingOneDimPatterns. + Operation *definingOp = value.getDefiningOp(); + bool skipLeadingUnitDimRemoval = + definingOp && + (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp, + xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp)); + if (!skipLeadingUnitDimRemoval) { + auto it = llvm::find_if(instData, [](auto val) { return val != 1; }); + instData.erase(instData.begin(), it); + } + return instData; + } if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); @@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() { // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, tdescTy.getMemorySpace(), blockedChunkSize); - encoding = newEncoding; } } @@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() { xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, tdescTy.getLayoutAttr().dropInstData()); } else { - newTy = type.clone(tileShape, elemTy); + newTy = VectorType::get(tileShape, elemTy); } if (returnSingleType) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index aafa1b7..e6e71cc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -66,8 +66,6 @@ protected: Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(destTy)) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of destTy."); auto shape = vecTy.getShape(); return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape); } @@ -93,8 +91,6 @@ protected: ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(src.getType())) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of src."); return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, blockSize); } @@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType()); VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType()); Type elemTy = valueTy.getElementType(); - VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + VectorType newValueTy = VectorType::get(*targetShape, elemTy); SmallVector<Type> convertedMaskTypes; SmallVector<Value> convertedMasks; diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2c56a43..b4605cd 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, if (!computeShapeRatio(srcShape, shape)) return {value}; + int64_t srcShapeRank = srcShape.size(); + int64_t targetShapeRank = shape.size(); + + SmallVector<int64_t> adjustedTargetShape(srcShape.size()); + int64_t rankDiff = srcShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, + 1); + std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + SmallVector<Value> result; - for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(srcShape, adjustedTargetShape)) { SmallVector<int64_t> staticStrides(offsets.size(), 1); - result.push_back(vector::ExtractStridedSliceOp::create( - builder, loc, value, offsets, shape, staticStrides)); + Value slice = vector::ExtractStridedSliceOp::create( + builder, loc, value, offsets, adjustedTargetShape, staticStrides); + + // Reshape to remove leading unit dims if needed + if (srcShapeRank > targetShapeRank) { + auto targetTy = VectorType::get(shape, vecTy.getElementType()); + slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice); + } + result.push_back(slice); } return result; @@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, for (auto [src, offsets] : llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { - SmallVector<int64_t> staticStrides(offsets.size(), 1); + SmallVector<int64_t> staticStrides(tileShape.size(), 1); result = vector::InsertStridedSliceOp::create(builder, loc, src, result, offsets, staticStrides); } |
