diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Shard/IR/ShardOps.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp | 74 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/TableGen/CodeGenHelpers.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 |
10 files changed, 71 insertions, 46 deletions
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index f0d8b78..610ce1f 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { if (auto vectorType = dyn_cast<VectorType>(operandType)) nanAttr = DenseElementsAttr::get(vectorType, nan); - Value NanValue = + Value nanValue = spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); Value lhs = spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, - NanValue, adaptor.getLhs()); + nanValue, adaptor.getLhs()); Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d8f983f..6192d79 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3024,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } if (dynamicPointParseResult.has_value()) { - Type ChunkSizesType; + Type chunkSizesType; if (failed(*dynamicPointParseResult) || parser.parseComma() || - parser.parseType(ChunkSizesType) || - parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, + parser.parseType(chunkSizesType) || + parser.resolveOperand(dynamicChunkSizes, chunkSizesType, result.operands)) { return failure(); } @@ -3399,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects( } static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, - Type targetType, Type tile_sizes, + Type targetType, Type tileSizes, Type) { - printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); + printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes}); } static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 135c033..645cbff 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op, } template <typename It> -bool isUnique(It begin, It end) { +static bool isUnique(It begin, It end) { if (begin == end) { return true; } diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad671..1cba1bb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a..a0661e4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ public: ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333..f072e3e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template <typename T> -FailureOr<SmallVector<T>> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr<OpComplianceInfo<T>> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap<T>(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile<T>(op, it->second, condition); + return findMatchedEntry<T>(op, it->second); } template <typename T> @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition<T>(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile<T>(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector<T> opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); - const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + const auto maybeProfDef = getOperatorDefinition<Profile>(op); + const auto maybeExtDef = getOperatorDefinition<Extension>(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector<TypeInfo> bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template <typename T> -SmallVector<T> TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition) { +OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector<TypeInfo> expected : sets) { + SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector<TypeInfo> expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector<VersionedTypeInfo> typeInfoSet{set}; + OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..e0a8ac4 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -1133,9 +1133,6 @@ LogicalResult MemDescSubviewOp::verify() { return success(); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index cb90ef8..d52d5e7 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} -void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<const Record *> opDefs) { - NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitOpConstraints() { emitTypeConstraints(); emitAttrConstraints(); emitPropConstraints(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 9603813..857e31b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{ StringLiteral("denormal-fp-math-f32"), StringLiteral("fp-contract"), StringLiteral("frame-pointer"), + StringLiteral("inlinehint"), StringLiteral("instrument-function-entry"), StringLiteral("instrument-function-exit"), StringLiteral("memory"), @@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, funcOp.setNoInline(true); if (func->hasFnAttribute(llvm::Attribute::AlwaysInline)) funcOp.setAlwaysInline(true); + if (func->hasFnAttribute(llvm::Attribute::InlineHint)) + funcOp.setInlineHint(true); if (func->hasFnAttribute(llvm::Attribute::OptimizeNone)) funcOp.setOptimizeNone(true); if (func->hasFnAttribute(llvm::Attribute::Convergent)) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 845a14f..147613f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1652,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func, llvmFunc->addFnAttr(llvm::Attribute::NoInline); if (func.getAlwaysInlineAttr()) llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline); + if (func.getInlineHintAttr()) + llvmFunc->addFnAttr(llvm::Attribute::InlineHint); if (func.getOptimizeNoneAttr()) llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone); if (func.getConvergentAttr()) |