diff options
Diffstat (limited to 'mlir/lib/Dialect')
10 files changed, 247 insertions, 154 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 25f941d..b9ee0a4 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -217,6 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, } if (!options.filterFn(&callee)) return; + if (callee.isExternal() || callee.isPublic()) + return; + SmallVector<Value, 6> replaceWithNewCallResults; SmallVector<Value, 6> replaceWithOutParams; for (OpResult result : op.getResults()) { @@ -292,14 +295,14 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( // function. AllocDynamicSizesMap map; for (auto func : module.getOps<func::FuncOp>()) { + if (func.isExternal() || func.isPublic()) + continue; if (!options.filterFn(&func)) continue; SmallVector<BlockArgument, 6> appendedEntryArgs; if (failed( updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) return failure(); - if (func.isExternal()) - continue; if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) { return failure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 8295492..04e8836 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() { template <typename OpType, typename = std::enable_if_t<llvm::is_one_of< OpType, BlockLoadOp, BlockStoreOp>::value>> LogicalResult verify1DBlockArg(OpType op) { - VectorType vTy; + Type srcOrDstTy; if constexpr (std::is_same_v<OpType, BlockLoadOp>) - vTy = op.getResult().getType(); + srcOrDstTy = op.getResult().getType(); else - vTy = op.getVal().getType(); + srcOrDstTy = op.getVal().getType(); + VectorType vTy = dyn_cast<VectorType>(srcOrDstTy); + // scalar case is always valid + if (!vTy) + return success(); int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8; if (elemTySize == 1) { - llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16}; + llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type"); + "vector size must be 2, 4, 8 or 16 for 8-bit element type"); } else { - llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8}; + llvm::SmallSet<int, 3> validSizes{2, 4, 8}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4 or 8 for element type > 8 bits"); + "vector size must be 2, 4 or 8 for element type > 8 bits"); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index eac0e47..15eb51a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -31,8 +31,10 @@ template <typename T> struct StructuredOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< StructuredOpInterface<T>, T> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto linalgOp = llvm::cast<LinalgOp>(op); SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc); @@ -70,7 +72,7 @@ struct StructuredOpInterface builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex); auto cmpOp = builder.createOrFold<index::CmpOp>( loc, index::IndexCmpPredicate::SGE, min, zero); - auto msg = RuntimeVerifiableOpInterface::generateErrorMessage( + auto msg = generateErrorMessage( linalgOp, "unexpected negative result on dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber())); @@ -100,7 +102,7 @@ struct StructuredOpInterface cmpOp = builder.createOrFold<index::CmpOp>( loc, predicate, inferredDimSize, actualDimSize); - msg = RuntimeVerifiableOpInterface::generateErrorMessage( + msg = generateErrorMessage( linalgOp, "dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber()) + diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index d3a77c0..291da1f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, struct AssumeAlignmentOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< AssumeAlignmentOpInterface, AssumeAlignmentOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto assumeOp = cast<AssumeAlignmentOp>(op); Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc, assumeOp.getMemref()); @@ -48,9 +50,9 @@ struct AssumeAlignmentOpInterface Value isAligned = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest, arith::ConstantIndexOp::create(builder, loc, 0)); - cf::AssertOp::create(builder, loc, isAligned, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "memref is not aligned to " + + cf::AssertOp::create( + builder, loc, isAligned, + generateErrorMessage(op, "memref is not aligned to " + std::to_string(assumeOp.getAlignment()))); } }; @@ -58,8 +60,10 @@ struct AssumeAlignmentOpInterface struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, CastOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto castOp = cast<CastOp>(op); auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); @@ -76,8 +80,7 @@ struct CastOpInterface Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "rank mismatch")); + generateErrorMessage(op, "rank mismatch")); } // Get source offset and strides. We do not have an op to get offsets and @@ -116,8 +119,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "size mismatch of dim " + + std::to_string(it.index()))); } // Get result offset and strides. @@ -135,8 +138,7 @@ struct CastOpInterface Value isSameOffset = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); cf::AssertOp::create(builder, loc, isSameOffset, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset mismatch")); + generateErrorMessage(op, "offset mismatch")); } // Check strides. @@ -153,8 +155,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); cf::AssertOp::create( builder, loc, isSameStride, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "stride mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "stride mismatch of dim " + + std::to_string(it.index()))); } } }; @@ -162,8 +164,10 @@ struct CastOpInterface struct CopyOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface, CopyOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto copyOp = cast<CopyOp>(op); BaseMemRefType sourceType = copyOp.getSource().getType(); BaseMemRefType targetType = copyOp.getTarget().getType(); @@ -193,9 +197,9 @@ struct CopyOpInterface Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i); Value sameDimSize = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim); - cf::AssertOp::create(builder, loc, sameDimSize, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size of " + std::to_string(i) + + cf::AssertOp::create( + builder, loc, sameDimSize, + generateErrorMessage(op, "size of " + std::to_string(i) + "-th source/target dim does not match")); } } @@ -204,16 +208,17 @@ struct CopyOpInterface struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, DimOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto dimOp = cast<DimOp>(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "index is out of bounds")); + generateErrorMessage(op, "index is out of bounds")); } }; @@ -223,8 +228,10 @@ template <typename LoadStoreOp> struct LoadStoreOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto loadStoreOp = cast<LoadStoreOp>(op); auto memref = loadStoreOp.getMemref(); @@ -245,16 +252,17 @@ struct LoadStoreOpInterface : inBounds; } cf::AssertOp::create(builder, loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + generateErrorMessage(op, "out-of-bounds access")); } }; struct SubViewOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto subView = cast<SubViewOp>(op); MemRefType sourceType = subView.getSource().getType(); @@ -277,10 +285,10 @@ struct SubViewOpInterface Value dimSize = metadataOp.getSizes()[i]; Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create( - builder, loc, offsetInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset " + std::to_string(i) + " is out-of-bounds")); + cf::AssertOp::create(builder, loc, offsetInBounds, + generateErrorMessage(op, "offset " + + std::to_string(i) + + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); @@ -292,9 +300,9 @@ struct SubViewOpInterface generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); cf::AssertOp::create( builder, loc, lastPosInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "subview runs out-of-bounds along dimension " + - std::to_string(i))); + generateErrorMessage(op, + "subview runs out-of-bounds along dimension " + + std::to_string(i))); } } }; @@ -302,8 +310,10 @@ struct SubViewOpInterface struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, ExpandShapeOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto expandShapeOp = cast<ExpandShapeOp>(op); // Verify that the expanded dim sizes are a product of the collapsed dim @@ -333,9 +343,9 @@ struct ExpandShapeOpInterface Value isModZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, mod, arith::ConstantIndexOp::create(builder, loc, 0)); - cf::AssertOp::create(builder, loc, isModZero, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "static result dims in reassoc group do not " + cf::AssertOp::create( + builder, loc, isModZero, + generateErrorMessage(op, "static result dims in reassoc group do not " "divide src dim evenly")); } } diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 838ff1f..c031118 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -35,8 +35,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, CastOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto castOp = cast<CastOp>(op); auto srcType = cast<TensorType>(castOp.getSource().getType()); @@ -53,8 +55,7 @@ struct CastOpInterface Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "rank mismatch")); + generateErrorMessage(op, "rank mismatch")); } // Check dimension sizes. @@ -76,8 +77,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "size mismatch of dim " + + std::to_string(it.index()))); } } }; @@ -85,16 +86,17 @@ struct CastOpInterface struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, DimOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto dimOp = cast<DimOp>(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "index is out of bounds")); + generateErrorMessage(op, "index is out of bounds")); } }; @@ -104,8 +106,10 @@ template <typename OpTy> struct ExtractInsertOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractInsertOpInterface<OpTy>, OpTy> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto extractInsertOp = cast<OpTy>(op); Value tensor; @@ -135,16 +139,17 @@ struct ExtractInsertOpInterface : inBounds; } cf::AssertOp::create(builder, loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + generateErrorMessage(op, "out-of-bounds access")); } }; struct ExtractSliceOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractSliceOpInterface, ExtractSliceOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto extractSliceOp = cast<ExtractSliceOp>(op); RankedTensorType sourceType = extractSliceOp.getSource().getType(); @@ -166,10 +171,10 @@ struct ExtractSliceOpInterface loc, extractSliceOp.getSource(), i); Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create( - builder, loc, offsetInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset " + std::to_string(i) + " is out-of-bounds")); + cf::AssertOp::create(builder, loc, offsetInBounds, + generateErrorMessage(op, "offset " + + std::to_string(i) + + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); @@ -181,7 +186,7 @@ struct ExtractSliceOpInterface generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); cf::AssertOp::create( builder, loc, lastPosInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( + generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); } diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d..a95906a 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000..5aad671 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993b..41b338d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000..bcb880a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector<Profile, 2> selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional<Profile> profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector<Extension, 10> selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional<Extension> extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector<Extension> allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template <typename T> + std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4fc7ce8..82f2f7e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include <string> @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ public: explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -207,28 +181,28 @@ private: LogicalResult levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) + if (v > targetEnv.getLevel().MAX_KERNEL) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) + if (v > targetEnv.getLevel().MAX_STRIDE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) + if (v > targetEnv.getLevel().MAX_SCALE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) return op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return success(); @@ -285,6 +259,7 @@ private: template <typename T> LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) return failure(); @@ -466,7 +441,7 @@ private: int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return failure(); @@ -523,43 +498,6 @@ private: return success(); } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - LogicalResult CheckVariable(Operation *op); LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -567,7 +505,6 @@ private: SmallVector< std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> constCheckers; - TosaLevel tosaLevel; DenseMap<StringAttr, mlir::Type> variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -576,13 +513,13 @@ private: template <> LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (failed( - levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + if (failed(levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK))) return failure(); // rank(output) = rank(input) - 1 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", - tosaLevel.MAX_RANK - 1))) + targetEnv.getLevel().MAX_RANK - 1))) return failure(); return success(); @@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { // Only the condition input has rank limitation. if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); if (failed(levelCheckRank(op, variableType, "variable type", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) return op->emitOpError() << "failed level check: " << operandOrResult @@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; |