diff options
Diffstat (limited to 'mlir/lib')
6 files changed, 145 insertions, 96 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 25f941d..b9ee0a4 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -217,6 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, } if (!options.filterFn(&callee)) return; + if (callee.isExternal() || callee.isPublic()) + return; + SmallVector<Value, 6> replaceWithNewCallResults; SmallVector<Value, 6> replaceWithOutParams; for (OpResult result : op.getResults()) { @@ -292,14 +295,14 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( // function. AllocDynamicSizesMap map; for (auto func : module.getOps<func::FuncOp>()) { + if (func.isExternal() || func.isPublic()) + continue; if (!options.filterFn(&func)) continue; SmallVector<BlockArgument, 6> appendedEntryArgs; if (failed( updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) return failure(); - if (func.isExternal()) - continue; if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) { return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index eac0e47..15eb51a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -31,8 +31,10 @@ template <typename T> struct StructuredOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< StructuredOpInterface<T>, T> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto linalgOp = llvm::cast<LinalgOp>(op); SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc); @@ -70,7 +72,7 @@ struct StructuredOpInterface builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex); auto cmpOp = builder.createOrFold<index::CmpOp>( loc, index::IndexCmpPredicate::SGE, min, zero); - auto msg = RuntimeVerifiableOpInterface::generateErrorMessage( + auto msg = generateErrorMessage( linalgOp, "unexpected negative result on dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber())); @@ -100,7 +102,7 @@ struct StructuredOpInterface cmpOp = builder.createOrFold<index::CmpOp>( loc, predicate, inferredDimSize, actualDimSize); - msg = RuntimeVerifiableOpInterface::generateErrorMessage( + msg = generateErrorMessage( linalgOp, "dimension #" + std::to_string(dim) + " of input/output operand #" + std::to_string(opOperand.getOperandNumber()) + diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index d3a77c0..291da1f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, struct AssumeAlignmentOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< AssumeAlignmentOpInterface, AssumeAlignmentOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto assumeOp = cast<AssumeAlignmentOp>(op); Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc, assumeOp.getMemref()); @@ -48,9 +50,9 @@ struct AssumeAlignmentOpInterface Value isAligned = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest, arith::ConstantIndexOp::create(builder, loc, 0)); - cf::AssertOp::create(builder, loc, isAligned, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "memref is not aligned to " + + cf::AssertOp::create( + builder, loc, isAligned, + generateErrorMessage(op, "memref is not aligned to " + std::to_string(assumeOp.getAlignment()))); } }; @@ -58,8 +60,10 @@ struct AssumeAlignmentOpInterface struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, CastOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto castOp = cast<CastOp>(op); auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); @@ -76,8 +80,7 @@ struct CastOpInterface Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "rank mismatch")); + generateErrorMessage(op, "rank mismatch")); } // Get source offset and strides. We do not have an op to get offsets and @@ -116,8 +119,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "size mismatch of dim " + + std::to_string(it.index()))); } // Get result offset and strides. @@ -135,8 +138,7 @@ struct CastOpInterface Value isSameOffset = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); cf::AssertOp::create(builder, loc, isSameOffset, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset mismatch")); + generateErrorMessage(op, "offset mismatch")); } // Check strides. @@ -153,8 +155,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); cf::AssertOp::create( builder, loc, isSameStride, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "stride mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "stride mismatch of dim " + + std::to_string(it.index()))); } } }; @@ -162,8 +164,10 @@ struct CastOpInterface struct CopyOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface, CopyOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto copyOp = cast<CopyOp>(op); BaseMemRefType sourceType = copyOp.getSource().getType(); BaseMemRefType targetType = copyOp.getTarget().getType(); @@ -193,9 +197,9 @@ struct CopyOpInterface Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i); Value sameDimSize = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim); - cf::AssertOp::create(builder, loc, sameDimSize, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size of " + std::to_string(i) + + cf::AssertOp::create( + builder, loc, sameDimSize, + generateErrorMessage(op, "size of " + std::to_string(i) + "-th source/target dim does not match")); } } @@ -204,16 +208,17 @@ struct CopyOpInterface struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, DimOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto dimOp = cast<DimOp>(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "index is out of bounds")); + generateErrorMessage(op, "index is out of bounds")); } }; @@ -223,8 +228,10 @@ template <typename LoadStoreOp> struct LoadStoreOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto loadStoreOp = cast<LoadStoreOp>(op); auto memref = loadStoreOp.getMemref(); @@ -245,16 +252,17 @@ struct LoadStoreOpInterface : inBounds; } cf::AssertOp::create(builder, loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + generateErrorMessage(op, "out-of-bounds access")); } }; struct SubViewOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto subView = cast<SubViewOp>(op); MemRefType sourceType = subView.getSource().getType(); @@ -277,10 +285,10 @@ struct SubViewOpInterface Value dimSize = metadataOp.getSizes()[i]; Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create( - builder, loc, offsetInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset " + std::to_string(i) + " is out-of-bounds")); + cf::AssertOp::create(builder, loc, offsetInBounds, + generateErrorMessage(op, "offset " + + std::to_string(i) + + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); @@ -292,9 +300,9 @@ struct SubViewOpInterface generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); cf::AssertOp::create( builder, loc, lastPosInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "subview runs out-of-bounds along dimension " + - std::to_string(i))); + generateErrorMessage(op, + "subview runs out-of-bounds along dimension " + + std::to_string(i))); } } }; @@ -302,8 +310,10 @@ struct SubViewOpInterface struct ExpandShapeOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, ExpandShapeOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto expandShapeOp = cast<ExpandShapeOp>(op); // Verify that the expanded dim sizes are a product of the collapsed dim @@ -333,9 +343,9 @@ struct ExpandShapeOpInterface Value isModZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, mod, arith::ConstantIndexOp::create(builder, loc, 0)); - cf::AssertOp::create(builder, loc, isModZero, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "static result dims in reassoc group do not " + cf::AssertOp::create( + builder, loc, isModZero, + generateErrorMessage(op, "static result dims in reassoc group do not " "divide src dim evenly")); } } diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 838ff1f..c031118 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -35,8 +35,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, CastOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto castOp = cast<CastOp>(op); auto srcType = cast<TensorType>(castOp.getSource().getType()); @@ -53,8 +55,7 @@ struct CastOpInterface Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "rank mismatch")); + generateErrorMessage(op, "rank mismatch")); } // Check dimension sizes. @@ -76,8 +77,8 @@ struct CastOpInterface builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size mismatch of dim " + std::to_string(it.index()))); + generateErrorMessage(op, "size mismatch of dim " + + std::to_string(it.index()))); } } }; @@ -85,16 +86,17 @@ struct CastOpInterface struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, DimOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto dimOp = cast<DimOp>(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "index is out of bounds")); + generateErrorMessage(op, "index is out of bounds")); } }; @@ -104,8 +106,10 @@ template <typename OpTy> struct ExtractInsertOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractInsertOpInterface<OpTy>, OpTy> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto extractInsertOp = cast<OpTy>(op); Value tensor; @@ -135,16 +139,17 @@ struct ExtractInsertOpInterface : inBounds; } cf::AssertOp::create(builder, loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + generateErrorMessage(op, "out-of-bounds access")); } }; struct ExtractSliceOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractSliceOpInterface, ExtractSliceOp> { - void generateRuntimeVerification(Operation *op, OpBuilder &builder, - Location loc) const { + void + generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc, + function_ref<std::string(Operation *, StringRef)> + generateErrorMessage) const { auto extractSliceOp = cast<ExtractSliceOp>(op); RankedTensorType sourceType = extractSliceOp.getSource().getType(); @@ -166,10 +171,10 @@ struct ExtractSliceOpInterface loc, extractSliceOp.getSource(), i); Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create( - builder, loc, offsetInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset " + std::to_string(i) + " is out-of-bounds")); + cf::AssertOp::create(builder, loc, offsetInBounds, + generateErrorMessage(op, "offset " + + std::to_string(i) + + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); @@ -181,7 +186,7 @@ struct ExtractSliceOpInterface generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); cf::AssertOp::create( builder, loc, lastPosInBounds, - RuntimeVerifiableOpInterface::generateErrorMessage( + generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); } diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp index 8aa194b..f9a54f9 100644 --- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp +++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp @@ -8,31 +8,5 @@ #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" -namespace mlir { -class Location; -class OpBuilder; - -/// Generate an error message string for the given op and the specified error. -std::string -RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op, - const std::string &msg) { - std::string buffer; - llvm::raw_string_ostream stream(buffer); - OpPrintingFlags flags; - // We may generate a lot of error messages and so we need to ensure the - // printing is fast. - flags.elideLargeElementsAttrs(); - flags.printGenericOpForm(); - flags.skipRegions(); - flags.useLocalScope(); - stream << "ERROR: Runtime op verification failed\n"; - op->print(stream, flags); - stream << "\n^ " << msg; - stream << "\nLocation: "; - op->getLoc().print(stream); - return buffer; -} -} // namespace mlir - /// Include the definitions of the interface. #include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc" diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp index a40bc2b..63c71cd 100644 --- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/AsmState.h" #include "mlir/Transforms/Passes.h" #include "mlir/IR/Builders.h" @@ -25,9 +26,51 @@ struct GenerateRuntimeVerificationPass GenerateRuntimeVerificationPass> { void runOnOperation() override; }; + +/// Default error message generator for runtime verification failures. +/// +/// This class generates error messages with different levels of verbosity: +/// - Level 0: Shows only the error message and operation location +/// - Level 1: Shows the full operation string, error message, and location +/// +/// Clients can call getVerboseLevel() to retrieve the current verbose level +/// and use it to customize their own error message generators with similar +/// behavior patterns. +class DefaultErrMsgGenerator { +private: + unsigned vLevel; + AsmState &state; + +public: + DefaultErrMsgGenerator(unsigned verboseLevel, AsmState &asmState) + : vLevel(verboseLevel), state(asmState) {} + + std::string operator()(Operation *op, StringRef msg) { + std::string buffer; + llvm::raw_string_ostream stream(buffer); + stream << "ERROR: Runtime op verification failed\n"; + if (vLevel == 1) { + op->print(stream, state); + stream << "\n"; + } + stream << "^\nLocation: "; + op->getLoc().print(stream); + return buffer; + } + + unsigned getVerboseLevel() const { return vLevel; } +}; } // namespace void GenerateRuntimeVerificationPass::runOnOperation() { + // Check verboseLevel is in range [0, 1]. + if (verboseLevel > 1) { + getOperation()->emitError( + "generate-runtime-verification pass: set verboseLevel to 0 or 1"); + signalPassFailure(); + return; + } + // The implementation of the RuntimeVerifiableOpInterface may create ops that // can be verified. We don't want to generate verification for IR that // performs verification, so gather all runtime-verifiable ops first. @@ -36,10 +79,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() { ops.push_back(verifiableOp); }); + // We may generate a lot of error messages and so we need to ensure the + // printing is fast. + OpPrintingFlags flags; + flags.elideLargeElementsAttrs(); + flags.skipRegions(); + flags.useLocalScope(); + AsmState state(getOperation(), flags); + + // Client can call getVerboseLevel() to fetch verbose level. + DefaultErrMsgGenerator defaultErrMsgGenerator(verboseLevel.getValue(), state); + OpBuilder builder(getOperation()->getContext()); for (RuntimeVerifiableOpInterface verifiableOp : ops) { builder.setInsertionPoint(verifiableOp); - verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); + verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(), + defaultErrMsgGenerator); }; } |