//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/Passes.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" using namespace mlir; namespace mlir { namespace transform { #define GEN_PASS_DEF_INTERPRETERPASS #include "mlir/Dialect/Transform/Transforms/Passes.h.inc" } // namespace transform } // namespace mlir /// Returns the payload operation to be used as payload root: /// - the operation nested under `passRoot` that has the given tag attribute, /// must be unique; /// - the `passRoot` itself if the tag is empty. static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) { // Fast return. if (tag.empty()) return passRoot; // Walk to do a lookup. Operation *target = nullptr; auto tagAttrName = StringAttr::get( passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName); WalkResult walkResult = passRoot->walk([&](Operation *op) { auto attr = op->getAttrOfType(tagAttrName); if (!attr || attr.getValue() != tag) return WalkResult::advance(); if (!target) { target = op; return WalkResult::advance(); } InFlightDiagnostic diag = op->emitError() << "repeated operation with the target tag '" << tag << "'"; diag.attachNote(target->getLoc()) << "previously seen operation"; return WalkResult::interrupt(); }); if (!target) { passRoot->emitError() << "could not find the operation with transform.target_tag=\"" << tag << "\" attribute"; return nullptr; } return walkResult.wasInterrupted() ? nullptr : target; } namespace { class InterpreterPass : public transform::impl::InterpreterPassBase { // Parses the pass arguments to bind trailing arguments of the entry point. std::optional> parseArguments(Operation *payloadRoot) { MLIRContext *context = payloadRoot->getContext(); SmallVector, 2> trailingBindings; trailingBindings.resize(debugBindTrailingArgs.size()); // Construct lists of op names to match. SmallVector> debugBindNames; debugBindNames.reserve(debugBindTrailingArgs.size()); for (auto &&[position, nameString] : llvm::enumerate(debugBindTrailingArgs)) { StringRef name = nameString; // Parse the integer literals. if (name.starts_with("#")) { debugBindNames.push_back(std::nullopt); StringRef lhs = ""; StringRef rhs = name.drop_front(); do { std::tie(lhs, rhs) = rhs.split(';'); int64_t value; if (lhs.getAsInteger(10, value)) { emitError(UnknownLoc::get(context)) << "couldn't parse integer pass argument " << name; return std::nullopt; } trailingBindings[position].push_back( Builder(context).getI64IntegerAttr(value)); } while (!rhs.empty()); } else if (name.starts_with("^")) { debugBindNames.emplace_back(OperationName(name.drop_front(), context)); } else { debugBindNames.emplace_back(OperationName(name, context)); } } // Collect operations or results for extra bindings. payloadRoot->walk([&](Operation *payload) { for (auto &&[position, name] : llvm::enumerate(debugBindNames)) { if (!name || payload->getName() != *name) continue; if (StringRef(*std::next(debugBindTrailingArgs.begin(), position)) .starts_with("^")) { llvm::append_range(trailingBindings[position], payload->getResults()); } else { trailingBindings[position].push_back(payload); } } }); RaggedArray bindings; bindings.push_back(ArrayRef{payloadRoot}); for (SmallVector &trailing : trailingBindings) bindings.push_back(std::move(trailing)); return bindings; } public: using Base::Base; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp transformModule = transform::detail::getPreloadedTransformModule(context); Operation *payloadRoot = findPayloadRoot(getOperation(), debugPayloadRootTag); if (!payloadRoot) return signalPassFailure(); Operation *transformEntryPoint = transform::detail::findTransformEntryPoint( getOperation(), transformModule, entryPoint); if (!transformEntryPoint) return signalPassFailure(); std::optional> bindings = parseArguments(payloadRoot); if (!bindings) return signalPassFailure(); if (failed(transform::applyTransformNamedSequence( *bindings, cast(transformEntryPoint), transformModule, options.enableExpensiveChecks(!disableExpensiveChecks)))) { return signalPassFailure(); } } private: /// Transform interpreter options. transform::TransformOptions options; }; } // namespace