diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h | 21 | ||||
| -rw-r--r-- | mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp | 5 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp | 21 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 13 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 96 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 94 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 29 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 218 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 19 | ||||
| -rw-r--r-- | mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt | 2 | ||||
| -rw-r--r-- | mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 6 |
14 files changed, 467 insertions, 85 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 2285d26..eb662a1 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, /*isVarArg=*/true); LLVM::LLVMFuncOp printfDecl = - getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); + getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType); + printfDecl.setCConv(callingConvention); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant( @@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + call.setCConv(callingConvention); rewriter.eraseOp(gpuPrintfOp); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 66d3bb4..ec74787 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" namespace mlir { @@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { /// This pass will add a declaration of printf() to the GPUModule if needed /// and separate out the format strings into global constants. For some /// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler -/// will lower printf calls to appropriate device-side code +/// will lower printf calls to appropriate device-side code. +/// However not all backends use the same calling convention and function +/// naming. +/// For example, the LLVM SPIRV backend requires calling convention +/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be +/// mangled as "_Z6printfPU3AS2Kcz". +/// Default callingConvention is LLVM::cconv::CConv::C and +/// funcName is "printf" but they can be customized as needed. struct GPUPrintfOpToLLVMCallLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { - GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter, - int addressSpace = 0) + GPUPrintfOpToLLVMCallLowering( + const LLVMTypeConverter &converter, int addressSpace = 0, + LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C, + StringRef funcName = "printf") : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter), - addressSpace(addressSpace) {} + addressSpace(addressSpace), callingConvention(callingConvention), + funcName(funcName) {} LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, @@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering private: int addressSpace; + LLVM::cconv::CConv callingConvention; + StringRef funcName; }; /// Lowering of gpu.printf to a vprintf standard library. diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index c2363a1..25f1e1b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, - gpu::ThreadIdOp>(); + gpu::ThreadIdOp, gpu::PrintfOp>(); populateGpuToLLVMSPVConversionPatterns(converter, patterns); populateGpuMemorySpaceAttributeConversions(converter); + patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2, + LLVM::cconv::CConv::SPIR_FUNC, + "_Z6printfPU3AS2Kcz"); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 316721b..60ae78b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { const std::function<unsigned(AffineForOp)> getUnrollFactor; LoopUnroll() : getUnrollFactor(nullptr) {} - LoopUnroll(const LoopUnroll &other) - - = default; + LoopUnroll(const LoopUnroll &other) = default; explicit LoopUnroll( std::optional<unsigned> unrollFactor = std::nullopt, - bool unrollUpToFactor = false, bool unrollFull = false, + bool unrollUpToFactor = false, const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) : getUnrollFactor(getUnrollFactor) { if (unrollFactor) this->unrollFactor = *unrollFactor; this->unrollUpToFactor = unrollUpToFactor; - this->unrollFull = unrollFull; } void runOnOperation() override; @@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f, } void LoopUnroll::runOnOperation() { + if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid option: 'unroll-factor' should be greater than 0 or " + "equal to -1"); + return signalPassFailure(); + } FunctionOpInterface func = getOperation(); if (func.isExternal()) return; - if (unrollFull && unrollFullThreshold.hasValue()) { + if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) { // Store short loops as we walk. SmallVector<AffineForOp, 4> loops; @@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, getUnrollFactor(forOp), /*annotateFn=*/nullptr, cleanUpUnroll); // Unroll completely if full loop unroll was specified. - if (unrollFull) + if (unrollFactor.getValue() == -1) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. if (unrollUpToFactor) @@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::affine::createLoopUnrollPass( - int unrollFactor, bool unrollUpToFactor, bool unrollFull, + int unrollFactor, bool unrollUpToFactor, const std::function<unsigned(AffineForOp)> &getUnrollFactor) { return std::make_unique<LoopUnroll>( unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), - unrollUpToFactor, unrollFull, getUnrollFactor); + unrollUpToFactor, getUnrollFactor); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index a6159ee..f0ddb50 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -14,13 +14,6 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" -namespace mlir { -namespace bufferization { -#define GEN_PASS_DEF_TENSORCOPYINSERTION -#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" -} // namespace bufferization -} // namespace mlir - using namespace mlir; using namespace mlir::bufferization; diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 11400de..a15bf89 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -59,6 +59,17 @@ struct DimOpInterface } }; +struct ExpandShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface, + memref::ExpandShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto expandOp = cast<memref::ExpandShapeOp>(op); + assert(value == expandOp.getResult() && "invalid value"); + cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + } +}; + struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, GetGlobalOp> { @@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( + *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 9bd13f3..744a595 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector<Operation *> yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector<Value> externalValues; + SmallVector<Value> internalValues; + SmallVector<Value> opResultsToReplaceWithExternalValues; + SmallVector<Value> opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector<Type> resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, + ExecuteRegionForwardingEliminator>(context); } void ExecuteRegionOp::getSuccessorRegions( diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 1cba1bb..32eb286 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -12,6 +12,96 @@ namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr<TargetEnv> +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } -llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { - return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); -} - } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 99b7cda..a85ff10a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -76,28 +76,6 @@ template <typename OpTy> struct PoolPadFoldAdaptor; template <> -struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> { - using OpTy = tosa::AvgPool2dOp; - static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { - const llvm::ArrayRef<int64_t> kernel = op.getKernel(); - if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || - newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) - return false; - return true; - } - static bool checkPadConstCompliance(OpTy op, Value padConst) { - return checkMatchingPadConstAndZp(padConst, op.getInputZp()); - } - static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, - Value padInput, ArrayRef<int64_t> newPad) { - rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>( - op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), - op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), - op.getAccType()); - } -}; - -template <> struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { using OpTy = tosa::MaxPool2dOp; static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> { }; } // namespace -void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<FoldPadToTensorOp<tosa::AvgPool2dOp, - PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>( - context); -} - void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc..6cd0eae 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast<tosa::NanPropagationModeAttr>(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template <typename T> -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast<TensorType>(inType); - auto outputType = llvm::dyn_cast<TensorType>(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast<TensorType>(aType); + auto bTType = llvm::dyn_cast<TensorType>(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType); - auto outputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType); + auto bQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1873,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W}; + const auto outputType = cast<ShapedType>(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, PadOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index f072e3e..e965ae0 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) POPULATE_PROFILE_INFO_COMMON(Cast) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) @@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7e..3f874d9 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, - Float8E5M2Type>(type); - } - if (auto intTy = dyn_cast<IntegerType>(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt index 6ef1529..c712c64b 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt @@ -21,6 +21,6 @@ set_property(TARGET MLIRSparseTensorRuntime PROPERTY CXX_STANDARD 17) check_cxx_compiler_flag(-Wweak-vtables COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) if(COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) - target_compile_options(MLIRSparseTensorRuntime PUBLIC + target_compile_options(MLIRSparseTensorRuntime PRIVATE "-Wweak-vtables") endif() diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index b56e778..b88fbaa 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -260,9 +260,9 @@ static std::string getDecorationName(StringRef attrName) { } template <typename AttrTy, typename EmitF> -LogicalResult processDecorationList(Location loc, Decoration decoration, - Attribute attrList, StringRef attrName, - EmitF emitter) { +static LogicalResult processDecorationList(Location loc, Decoration decoration, + Attribute attrList, + StringRef attrName, EmitF emitter) { auto arrayAttr = dyn_cast<ArrayAttr>(attrList); if (!arrayAttr) { return emitError(loc, "expecting array attribute of ") |
