diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp | 83 |
1 files changed, 80 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp index 8e7af05..abc1316 100644 --- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp @@ -8,8 +8,8 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h" #include "mlir/Dialect/SMT/IR/SMTDialect.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" using namespace mlir; @@ -23,6 +23,7 @@ using namespace mlir; void transform::smt::ConstrainParamsOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { onlyReadsHandle(getParamsMutable(), effects); + producesHandle(getResults(), effects); } DiagnosedSilenceableFailure @@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter, // and allow for users to attach their own implementation, which would, // e.g., translate the ops to SMTLIB and hand that over to the user's // favourite solver. This requires changes to the dialect's verifier. - return emitDefiniteFailure() << "op does not have interpreted semantics yet"; + return emitSilenceableFailure(getLoc()) + << "op does not have interpreted semantics yet"; } LogicalResult transform::smt::ConstrainParamsOp::verify() { + auto yieldTerminator = + dyn_cast<mlir::smt::YieldOp>(getRegion().front().back()); + if (!yieldTerminator) + return emitOpError() << "expected '" + << mlir::smt::YieldOp::getOperationName() + << "' as terminator"; + + auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc, + Type paramType, StringRef paramDesc, + auto *atOp) -> InFlightDiagnostic { + if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>( + smtType)) + return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx + << " is expected to be either a !smt.bool, a " + "!smt.int, or a !smt.bv"; + + assert(isa<TransformParamTypeInterface>(paramType) && + "ODS specifies params' type should implement param interface"); + if (isa<transform::AnyParamType>(paramType)) + return {}; // No further checks can be done. + + // NB: This cast must succeed as long as the only implementors of + // TransformParamTypeInterface are AnyParamType and ParamType. + Type typeWrappedByParam = cast<ParamType>(paramType).getType(); + + if (isa<mlir::smt::IntType>(smtType)) { + if (!isa<IntegerType>(typeWrappedByParam)) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx + << " is !smt.int though the corresponding " << paramDesc + << " type (" << paramType << ") is not wrapping an integer type"; + } else if (isa<mlir::smt::BoolType>(smtType)) { + auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam); + if (!wrappedIntType || wrappedIntType.getWidth() != 1) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx + << " is !smt.bool though the corresponding " << paramDesc + << " type (" << paramType << ") is not wrapping i1"; + } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) { + auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam); + if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth()) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx << " is " << smtType + << " though the corresponding " << paramDesc << " type (" + << paramType + << ") is not wrapping an integer type of the same bitwidth"; + } + + return {}; + }; + if (getOperands().size() != getBody().getNumArguments()) return emitOpError( "must have the same number of block arguments as operands"); + for (auto [idx, operandType, blockArgType] : + llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) { + InFlightDiagnostic typeCheckResult = + checkTypes(idx, blockArgType, "block arg", operandType, "operand", + /*atOp=*/this); + if (LogicalResult(typeCheckResult).failed()) + return typeCheckResult; + } + for (auto &op : getBody().getOps()) { if (!isa<mlir::smt::SMTDialect>(op.getDialect())) return emitOpError( "ops contained in region should belong to SMT-dialect"); } + if (yieldTerminator->getNumOperands() != getNumResults()) + return yieldTerminator.emitOpError() + << "expected terminator to have as many operands as the parent op " + "has results"; + + for (auto [idx, termOperandType, resultType] : llvm::enumerate( + yieldTerminator->getOperands().getType(), getResultTypes())) { + InFlightDiagnostic typeCheckResult = + checkTypes(idx, termOperandType, "terminator operand", + cast<transform::ParamType>(resultType), "result", + /*atOp=*/&yieldTerminator); + if (LogicalResult(typeCheckResult).failed()) + return typeCheckResult; + } + return success(); } |