//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; namespace mlir { namespace tensor { namespace { /// Generate a runtime check for lb <= value < ub. Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, Value lb, Value ub) { Value inBounds1 = builder.createOrFold( loc, arith::CmpIPredicate::sge, value, lb); Value inBounds2 = builder.createOrFold( loc, arith::CmpIPredicate::slt, value, ub); Value inBounds = builder.createOrFold(loc, inBounds1, inBounds2); return inBounds; } struct CastOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast(op); auto srcType = cast(castOp.getSource().getType()); // Nothing to check if the result is an unranked tensor. auto resultType = dyn_cast(castOp.getType()); if (!resultType) return; if (isa(srcType)) { // Check rank. Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); Value isSameRank = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); cf::AssertOp::create(builder, loc, isSameRank, RuntimeVerifiableOpInterface::generateErrorMessage( op, "rank mismatch")); } // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. if (auto rankedSrcType = dyn_cast(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; // Static/dynamic dim size -> dynamic dim size does not need verification. if (resultType.isDynamicDim(it.index())) continue; Value srcDimSz = DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = arith::ConstantIndexOp::create(builder, loc, it.value()); Value isSameSz = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); cf::AssertOp::create( builder, loc, isSameSz, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size mismatch of dim " + std::to_string(it.index()))); } } }; struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto dimOp = cast(op); Value rank = RankOp::create(builder, loc, dimOp.getSource()); Value zero = arith::ConstantIndexOp::create(builder, loc, 0); cf::AssertOp::create( builder, loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), RuntimeVerifiableOpInterface::generateErrorMessage( op, "index is out of bounds")); } }; /// Verifies that the indices on extract/insert ops are in-bounds of the /// tensor's index space: 0 <= index#i < dim#i template struct ExtractInsertOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractInsertOpInterface, OpTy> { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto extractInsertOp = cast(op); Value tensor; if constexpr (std::is_same_v) { tensor = extractInsertOp.getTensor(); } else if constexpr (std::is_same_v) { tensor = extractInsertOp.getDest(); } else { llvm_unreachable("invalid op"); } auto tensorType = cast(tensor.getType()); auto rank = tensorType.getRank(); if (rank == 0) { // Nothing to check for 0-d tensors. return; } auto indices = extractInsertOp.getIndices(); auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq(0, rank)) { Value dimOp = builder.createOrFold(loc, tensor, i); Value inBounds = generateInBoundsCheck(builder, loc, indices[i], zero, dimOp); assertCond = i > 0 ? builder.createOrFold(loc, assertCond, inBounds) : inBounds; } cf::AssertOp::create(builder, loc, assertCond, RuntimeVerifiableOpInterface::generateErrorMessage( op, "out-of-bounds access")); } }; struct ExtractSliceOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< ExtractSliceOpInterface, ExtractSliceOp> { void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto extractSliceOp = cast(op); RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: // 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { Value offset = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); // Verify that offset is in-bounds. Value dimSize = builder.createOrFold( loc, extractSliceOp.getSource(), i); Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); cf::AssertOp::create( builder, loc, offsetInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = arith::MulIOp::create(builder, loc, sizeMinusOne, stride); Value lastPos = arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); cf::AssertOp::create( builder, loc, lastPosInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "extract_slice runs out-of-bounds along dimension " + std::to_string(i))); } } }; } // namespace } // namespace tensor } // namespace mlir void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { CastOp::attachInterface(*ctx); DimOp::attachInterface(*ctx); ExtractOp::attachInterface>(*ctx); ExtractSliceOp::attachInterface(*ctx); InsertOp::attachInterface>(*ctx); // Load additional dialects of which ops may get created. ctx->loadDialect(); }); }