aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp83
1 files changed, 80 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05..abc1316 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
@@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
+ producesHandle(getResults(), effects);
}
DiagnosedSilenceableFailure
@@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
- return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+ return emitSilenceableFailure(getLoc())
+ << "op does not have interpreted semantics yet";
}
LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ auto yieldTerminator =
+ dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << mlir::smt::YieldOp::getOperationName()
+ << "' as terminator";
+
+ auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
+ Type paramType, StringRef paramDesc,
+ auto *atOp) -> InFlightDiagnostic {
+ if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
+ smtType))
+ return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
+ << " is expected to be either a !smt.bool, a "
+ "!smt.int, or a !smt.bv";
+
+ assert(isa<TransformParamTypeInterface>(paramType) &&
+ "ODS specifies params' type should implement param interface");
+ if (isa<transform::AnyParamType>(paramType))
+ return {}; // No further checks can be done.
+
+ // NB: This cast must succeed as long as the only implementors of
+ // TransformParamTypeInterface are AnyParamType and ParamType.
+ Type typeWrappedByParam = cast<ParamType>(paramType).getType();
+
+ if (isa<mlir::smt::IntType>(smtType)) {
+ if (!isa<IntegerType>(typeWrappedByParam))
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.int though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != 1)
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.bool though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping i1";
+ } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx << " is " << smtType
+ << " though the corresponding " << paramDesc << " type ("
+ << paramType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+
+ return {};
+ };
+
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
+ for (auto [idx, operandType, blockArgType] :
+ llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, blockArgType, "block arg", operandType, "operand",
+ /*atOp=*/this);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
+ }
+
for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [idx, termOperandType, resultType] : llvm::enumerate(
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, termOperandType, "terminator operand",
+ cast<transform::ParamType>(resultType), "result",
+ /*atOp=*/&yieldTerminator);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
+ }
+
return success();
}