diff options
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r-- | mlir/lib/Rewrite/ByteCode.cpp | 748 |
1 files changed, 638 insertions, 110 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index ef96e25..ea17f99 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -20,6 +20,9 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include <numeric> #define DEBUG_TYPE "pdl-bytecode" @@ -60,6 +63,14 @@ void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, currentPatternBenefits[patternIndex] = benefit; } +/// Cleanup any allocated state after a full match/rewrite has been completed. +/// This method should be called irregardless of whether the match+rewrite was a +/// success or not. +void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { + allocatedTypeRangeMemory.clear(); + allocatedValueRangeMemory.clear(); +} + //===----------------------------------------------------------------------===// // Bytecode OpCodes //===----------------------------------------------------------------------===// @@ -72,6 +83,8 @@ enum OpCode : ByteCodeField { ApplyRewrite, /// Check if two generic values are equal. AreEqual, + /// Check if two ranges are equal. + AreRangesEqual, /// Unconditional branch. Branch, /// Compare the operand count of an operation with a constant. @@ -80,8 +93,12 @@ enum OpCode : ByteCodeField { CheckOperationName, /// Compare the result count of an operation with a constant. CheckResultCount, + /// Compare a range of types to a constant range of types. + CheckTypes, /// Create an operation. CreateOperation, + /// Create a range of types. + CreateTypes, /// Erase an operation. EraseOp, /// Terminate a matcher or rewrite sequence. @@ -98,14 +115,20 @@ enum OpCode : ByteCodeField { GetOperand2, GetOperand3, GetOperandN, + /// Get a specific operand group of an operation. + GetOperands, /// Get a specific result of an operation. GetResult0, GetResult1, GetResult2, GetResult3, GetResultN, + /// Get a specific result group of an operation. + GetResults, /// Get the type of a value. GetValueType, + /// Get the types of a value range. + GetValueRangeTypes, /// Check if a generic value is not null. IsNotNull, /// Record a successful pattern match. @@ -122,9 +145,9 @@ enum OpCode : ByteCodeField { SwitchResultCount, /// Compare a type with a set of constants. SwitchType, + /// Compare a range of types with a set of constants. + SwitchTypes, }; - -enum class PDLValueKind { Attribute, Operation, Type, Value }; } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -145,11 +168,15 @@ public: SmallVectorImpl<ByteCodeField> &rewriterByteCode, SmallVectorImpl<PDLByteCodePattern> &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxTypeRangeMemoryIndex, + ByteCodeField &maxValueRangeMemoryIndex, llvm::StringMap<PDLConstraintFunction> &constraintFns, llvm::StringMap<PDLRewriteFunction> &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), - maxValueMemoryIndex(maxValueMemoryIndex) { + maxValueMemoryIndex(maxValueMemoryIndex), + maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), + maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) @@ -166,6 +193,13 @@ public: return valueToMemIndex[value]; } + /// Return the range memory index used to store the given range value. + ByteCodeField &getRangeStorageIndex(Value value) { + assert(valueToRangeIndex.count(value) && + "expected range index to be assigned"); + return valueToRangeIndex[value]; + } + /// Return an index to use when referring to the given data that is uniqued in /// the MLIR context. template <typename T> @@ -197,16 +231,20 @@ private: void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); @@ -214,6 +252,7 @@ private: void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); @@ -221,6 +260,9 @@ private: /// Mapping from value to its corresponding memory index. DenseMap<Value, ByteCodeField> valueToMemIndex; + /// Mapping from a range value to its corresponding range storage index. + DenseMap<Value, ByteCodeField> valueToRangeIndex; + /// Mapping from the name of an externally registered rewrite to its index in /// the bytecode registry. llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; @@ -246,6 +288,8 @@ private: SmallVectorImpl<ByteCodeField> &rewriterByteCode; SmallVectorImpl<PDLByteCodePattern> &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxTypeRangeMemoryIndex; + ByteCodeField &maxValueRangeMemoryIndex; }; /// This class provides utilities for writing a bytecode stream. @@ -281,19 +325,33 @@ struct ByteCodeWriter { /// Append a range of values that will be read as generic PDLValues. void appendPDLValueList(OperandRange values) { bytecode.push_back(values.size()); - for (Value value : values) { - // Append the type of the value in addition to the value itself. - PDLValueKind kind = - TypeSwitch<Type, PDLValueKind>(value.getType()) - .Case<pdl::AttributeType>( - [](Type) { return PDLValueKind::Attribute; }) - .Case<pdl::OperationType>( - [](Type) { return PDLValueKind::Operation; }) - .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) - .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); - bytecode.push_back(static_cast<ByteCodeField>(kind)); - append(value); - } + for (Value value : values) + appendPDLValue(value); + } + + /// Append a value as a PDLValue. + void appendPDLValue(Value value) { + appendPDLValueKind(value); + append(value); + } + + /// Append the PDLValue::Kind of the given value. + void appendPDLValueKind(Value value) { + // Append the type of the value in addition to the value itself. + PDLValue::Kind kind = + TypeSwitch<Type, PDLValue::Kind>(value.getType()) + .Case<pdl::AttributeType>( + [](Type) { return PDLValue::Kind::Attribute; }) + .Case<pdl::OperationType>( + [](Type) { return PDLValue::Kind::Operation; }) + .Case<pdl::RangeType>([](pdl::RangeType rangeTy) { + if (rangeTy.getElementType().isa<pdl::TypeType>()) + return PDLValue::Kind::TypeRange; + return PDLValue::Kind::ValueRange; + }) + .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; }) + .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; }); + bytecode.push_back(static_cast<ByteCodeField>(kind)); } /// Check if the given class `T` has an iterator type. @@ -334,6 +392,36 @@ struct ByteCodeWriter { /// The main generator producing PDL. Generator &generator; }; + +/// This class represents a live range of PDL Interpreter values, containing +/// information about when values are live within a match/rewrite. +struct ByteCodeLiveRange { + using Set = llvm::IntervalMap<ByteCodeField, char, 16>; + using Allocator = Set::Allocator; + + ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + + /// Union this live range with the one provided. + void unionWith(const ByteCodeLiveRange &rhs) { + for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) + liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + } + + /// Returns true if this range overlaps with the one provided. + bool overlaps(const ByteCodeLiveRange &rhs) const { + return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid(); + } + + /// A map representing the ranges of the match/rewrite that a value is live in + /// the interpreter. + llvm::IntervalMap<ByteCodeField, char, 16> liveness; + + /// The type range storage index for this range. + Optional<unsigned> typeRangeIndex; + + /// The value range storage index for this range. + Optional<unsigned> valueRangeIndex; +}; } // end anonymous namespace void Generator::generate(ModuleOp module) { @@ -381,15 +469,30 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // Rewriters use simplistic allocation scheme that simply assigns an index to // each result. for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { - ByteCodeField index = 0; + ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; + auto processRewriterValue = [&](Value val) { + valueToMemIndex.try_emplace(val, index++); + if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) { + Type elementTy = rangeType.getElementType(); + if (elementTy.isa<pdl::TypeType>()) + valueToRangeIndex.try_emplace(val, typeRangeIndex++); + else if (elementTy.isa<pdl::ValueType>()) + valueToRangeIndex.try_emplace(val, valueRangeIndex++); + } + }; + for (BlockArgument arg : rewriterFunc.getArguments()) - valueToMemIndex.try_emplace(arg, index++); + processRewriterValue(arg); rewriterFunc.getBody().walk([&](Operation *op) { for (Value result : op->getResults()) - valueToMemIndex.try_emplace(result, index++); + processRewriterValue(result); }); if (index > maxValueMemoryIndex) maxValueMemoryIndex = index; + if (typeRangeIndex > maxTypeRangeMemoryIndex) + maxTypeRangeMemoryIndex = typeRangeIndex; + if (valueRangeIndex > maxValueRangeMemoryIndex) + maxValueRangeMemoryIndex = valueRangeIndex; } // The matcher function uses a more sophisticated numbering that tries to @@ -404,9 +507,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, }); // Liveness info for each of the defs within the matcher. - using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; - LivenessSet::Allocator allocator; - DenseMap<Value, LivenessSet> valueDefRanges; + ByteCodeLiveRange::Allocator allocator; + DenseMap<Value, ByteCodeLiveRange> valueDefRanges; // Assign the root operation being matched to slot 0. BlockArgument rootOpArg = matcherFunc.getArgument(0); @@ -425,10 +527,19 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.insert( + defRangeIt->second.liveness.insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); + + // Check to see if this value is a range type. + if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) { + Type eleType = rangeTy.getElementType(); + if (eleType.isa<pdl::TypeType>()) + defRangeIt->second.typeRangeIndex = 0; + else if (eleType.isa<pdl::ValueType>()) + defRangeIt->second.valueRangeIndex = 0; + } }; // Process the live-ins of this block. @@ -442,37 +553,59 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc, } // Greedily allocate memory slots using the computed def live ranges. - std::vector<LivenessSet> allocatedIndices; + std::vector<ByteCodeLiveRange> allocatedIndices; + ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; - LivenessSet &defSet = defIt.second; + ByteCodeLiveRange &defRange = defIt.second; // Try to allocate to an existing index. for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { - LivenessSet &existingIndex = existingIndexIt.value(); - llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( - defIt.second, existingIndex); - if (overlaps.valid()) - continue; - // Union the range of the def within the existing index. - for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) - existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); - memIndex = existingIndexIt.index() + 1; + ByteCodeLiveRange &existingRange = existingIndexIt.value(); + if (!defRange.overlaps(existingRange)) { + existingRange.unionWith(defRange); + memIndex = existingIndexIt.index() + 1; + + if (defRange.typeRangeIndex) { + if (!existingRange.typeRangeIndex) + existingRange.typeRangeIndex = numTypeRanges++; + valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; + } else if (defRange.valueRangeIndex) { + if (!existingRange.valueRangeIndex) + existingRange.valueRangeIndex = numValueRanges++; + valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; + } + break; + } } // If no existing index could be used, add a new one. if (memIndex == 0) { allocatedIndices.emplace_back(allocator); - for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) - allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); + ByteCodeLiveRange &newRange = allocatedIndices.back(); + newRange.unionWith(defRange); + + // Allocate an index for type/value ranges. + if (defRange.typeRangeIndex) { + newRange.typeRangeIndex = numTypeRanges; + valueToRangeIndex[defIt.first] = numTypeRanges++; + } else if (defRange.valueRangeIndex) { + newRange.valueRangeIndex = numValueRanges; + valueToRangeIndex[defIt.first] = numValueRanges++; + } + memIndex = allocatedIndices.size(); + ++numIndices; } } // Update the max number of indices. - ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; - if (numMatcherIndices > maxValueMemoryIndex) - maxValueMemoryIndex = numMatcherIndices; + if (numIndices > maxValueMemoryIndex) + maxValueMemoryIndex = numIndices; + if (numTypeRanges > maxTypeRangeMemoryIndex) + maxTypeRangeMemoryIndex = numTypeRanges; + if (numValueRanges > maxValueRangeMemoryIndex) + maxValueRangeMemoryIndex = numValueRanges; } void Generator::generate(Operation *op, ByteCodeWriter &writer) { @@ -481,17 +614,19 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { pdl_interp::AreEqualOp, pdl_interp::BranchOp, pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, - pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, + pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, + pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, pdl_interp::EraseOp, pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, - pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, + pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, + pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, - pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, - pdl_interp::SwitchResultCountOp>( + pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, + pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -515,16 +650,31 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op, op.constParamsAttr()); writer.appendPDLValueList(op.args()); + ResultRange results = op.results(); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // In debug mode we also record the expected kind of the result, so that we + // can provide extra verification of the native rewrite function. #ifndef NDEBUG - // In debug mode we also append the number of results so that we can assert - // that the native creation function gave us the correct number of results. - writer.append(ByteCodeField(op.results().size())); + writer.appendPDLValueKind(result); #endif - for (Value result : op.results()) + + // Range results also need to append the range storage index. + if (result.getType().isa<pdl::RangeType>()) + writer.append(getRangeStorageIndex(result)); writer.append(result); + } } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { - writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); + Value lhs = op.lhs(); + if (lhs.getType().isa<pdl::RangeType>()) { + writer.append(OpCode::AreRangesEqual); + writer.appendPDLValueKind(lhs); + writer.append(op.lhs(), op.rhs(), op.getSuccessors()); + return; + } + + writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); } void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); @@ -537,6 +687,7 @@ void Generator::generate(pdl_interp::CheckAttributeOp op, void Generator::generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), + static_cast<ByteCodeField>(op.compareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperationNameOp op, @@ -547,11 +698,15 @@ void Generator::generate(pdl_interp::CheckOperationNameOp op, void Generator::generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckResultCount, op.operation(), op.count(), + static_cast<ByteCodeField>(op.compareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); } +void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); +} void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. @@ -560,7 +715,8 @@ void Generator::generate(pdl_interp::CreateAttributeOp op, void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateOperation, op.operation(), - OperationName(op.name(), ctx), op.operands()); + OperationName(op.name(), ctx)); + writer.appendPDLValueList(op.operands()); // Add the attributes. OperandRange attributes = op.attributes(); @@ -570,12 +726,16 @@ void Generator::generate(pdl_interp::CreateOperationOp op, Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), std::get<1>(it)); } - writer.append(op.types()); + writer.appendPDLValueList(op.types()); } void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.result()) = getMemIndex(op.value()); } +void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::CreateTypes, op.result(), + getRangeStorageIndex(op.result()), op.value()); +} void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } @@ -593,7 +753,8 @@ void Generator::generate(pdl_interp::GetAttributeTypeOp op, } void Generator::generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer) { - writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); + writer.append(OpCode::GetDefiningOp, op.operation()); + writer.appendPDLValue(op.value()); } void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); @@ -603,6 +764,18 @@ void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetOperandN, index); writer.append(op.operation(), op.value()); } +void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { + Value result = op.value(); + Optional<uint32_t> index = op.index(); + writer.append(OpCode::GetOperands, + index.getValueOr(std::numeric_limits<uint32_t>::max()), + op.operation()); + if (result.getType().isa<pdl::RangeType>()) + writer.append(getRangeStorageIndex(result)); + else + writer.append(std::numeric_limits<ByteCodeField>::max()); + writer.append(result); +} void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) @@ -611,10 +784,29 @@ void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetResultN, index); writer.append(op.operation(), op.value()); } +void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { + Value result = op.value(); + Optional<uint32_t> index = op.index(); + writer.append(OpCode::GetResults, + index.getValueOr(std::numeric_limits<uint32_t>::max()), + op.operation()); + if (result.getType().isa<pdl::RangeType>()) + writer.append(getRangeStorageIndex(result)); + else + writer.append(std::numeric_limits<ByteCodeField>::max()); + writer.append(result); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { - writer.append(OpCode::GetValueType, op.result(), op.value()); + if (op.getType().isa<pdl::RangeType>()) { + Value result = op.result(); + writer.append(OpCode::GetValueRangeTypes, result, + getRangeStorageIndex(result), op.value()); + } else { + writer.append(OpCode::GetValueType, op.result(), op.value()); + } } + void Generator::generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer) { // InferType maps to a null type as a marker for inferring result types. @@ -628,11 +820,12 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { patterns.emplace_back(PDLByteCodePattern::create( op, rewriterToAddr[op.rewriter().getLeafReference()])); writer.append(OpCode::RecordMatch, patternIndex, - SuccessorRange(op.getOperation()), op.matchedOps(), - op.inputs()); + SuccessorRange(op.getOperation()), op.matchedOps()); + writer.appendPDLValueList(op.inputs()); } void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { - writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); + writer.append(OpCode::ReplaceOp, op.operation()); + writer.appendPDLValueList(op.replValues()); } void Generator::generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer) { @@ -661,6 +854,10 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), op.getSuccessors()); } +void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), + op.getSuccessors()); +} //===----------------------------------------------------------------------===// // PDLByteCode @@ -671,7 +868,8 @@ PDLByteCode::PDLByteCode(ModuleOp module, llvm::StringMap<PDLRewriteFunction> rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - constraintFns, rewriteFns); + maxTypeRangeCount, maxValueRangeCount, constraintFns, + rewriteFns); generator.generate(module); // Initialize the external functions. @@ -685,6 +883,8 @@ PDLByteCode::PDLByteCode(ModuleOp module, /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); + state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); @@ -697,17 +897,24 @@ namespace { /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: - ByteCodeExecutor(const ByteCodeField *curCodeIt, - MutableArrayRef<const void *> memory, - ArrayRef<const void *> uniquedMemory, - ArrayRef<ByteCodeField> code, - ArrayRef<PatternBenefit> currentPatternBenefits, - ArrayRef<PDLByteCodePattern> patterns, - ArrayRef<PDLConstraintFunction> constraintFunctions, - ArrayRef<PDLRewriteFunction> rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), - code(code), currentPatternBenefits(currentPatternBenefits), - patterns(patterns), constraintFunctions(constraintFunctions), + ByteCodeExecutor( + const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, + MutableArrayRef<TypeRange> typeRangeMemory, + std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, + MutableArrayRef<ValueRange> valueRangeMemory, + std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, + ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code, + ArrayRef<PatternBenefit> currentPatternBenefits, + ArrayRef<PDLByteCodePattern> patterns, + ArrayRef<PDLConstraintFunction> constraintFunctions, + ArrayRef<PDLRewriteFunction> rewriteFunctions) + : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + allocatedTypeRangeMemory(allocatedTypeRangeMemory), + valueRangeMemory(valueRangeMemory), + allocatedValueRangeMemory(allocatedValueRangeMemory), + uniquedMemory(uniquedMemory), code(code), + currentPatternBenefits(currentPatternBenefits), patterns(patterns), + constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} /// Start executing the code at the current bytecode index. `matches` is an @@ -722,19 +929,25 @@ private: void executeApplyConstraint(PatternRewriter &rewriter); void executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); + void executeAreRangesEqual(); void executeBranch(); void executeCheckOperandCount(); void executeCheckOperationName(); void executeCheckResultCount(); + void executeCheckTypes(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); + void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); void executeGetOperand(unsigned index); + void executeGetOperands(); void executeGetResult(unsigned index); + void executeGetResults(); void executeGetValueType(); + void executeGetValueRangeTypes(); void executeIsNotNull(); void executeRecordMatch(PatternRewriter &rewriter, SmallVectorImpl<PDLByteCode::MatchResult> &matches); @@ -744,6 +957,7 @@ private: void executeSwitchOperationName(); void executeSwitchResultCount(); void executeSwitchType(); + void executeSwitchTypes(); /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point @@ -763,6 +977,19 @@ private: list.push_back(read<ValueT>()); } + /// Read a list of values from the bytecode buffer. The values may be encoded + /// as either Value or ValueRange elements. + void readValueList(SmallVectorImpl<Value> &list) { + for (unsigned i = 0, e = read(); i != e; ++i) { + if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { + list.push_back(read<Value>()); + } else { + ValueRange *values = read<ValueRange *>(); + list.append(values->begin(), values->end()); + } + } + } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -771,8 +998,8 @@ private: } /// Handle a switch operation with the provided value and cases. - template <typename T, typename RangeT> - void handleSwitch(const T &value, RangeT &&cases) { + template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> + void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { LLVM_DEBUG({ llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; @@ -783,7 +1010,7 @@ private: // Check to see if the attribute value is within the case list. Jump to // the correct successor index based on the result. for (auto it = cases.begin(), e = cases.end(); it != e; ++it) - if (*it == value) + if (cmp(*it, value)) return selectJump(size_t((it - cases.begin()) + 1)); selectJump(size_t(0)); } @@ -795,7 +1022,9 @@ private: size_t index = *curCodeIt++; // If this type is an SSA value, it can only be stored in non-const memory. - if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) + if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *, + Value>::value || + index < memory.size()) return memory[index]; // Otherwise, if this index is not inbounds it is uniqued. @@ -813,17 +1042,21 @@ private: } template <typename T> std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { - switch (static_cast<PDLValueKind>(read())) { - case PDLValueKind::Attribute: + switch (read<PDLValue::Kind>()) { + case PDLValue::Kind::Attribute: return read<Attribute>(); - case PDLValueKind::Operation: + case PDLValue::Kind::Operation: return read<Operation *>(); - case PDLValueKind::Type: + case PDLValue::Kind::Type: return read<Type>(); - case PDLValueKind::Value: + case PDLValue::Kind::Value: return read<Value>(); + case PDLValue::Kind::TypeRange: + return read<TypeRange *>(); + case PDLValue::Kind::ValueRange: + return read<ValueRange *>(); } - llvm_unreachable("unhandled PDLValueKind"); + llvm_unreachable("unhandled PDLValue::Kind"); } template <typename T> std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { @@ -838,12 +1071,20 @@ private: std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { return *curCodeIt++; } + template <typename T> + std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() { + return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); + } /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; /// The current execution memory. MutableArrayRef<const void *> memory; + MutableArrayRef<TypeRange> typeRangeMemory; + std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; + MutableArrayRef<ValueRange> valueRangeMemory; + std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; /// References to ByteCode data necessary for execution. ArrayRef<const void *> uniquedMemory; @@ -859,8 +1100,21 @@ private: /// overexposing access to information specific solely to the ByteCode. class ByteCodeRewriteResultList : public PDLResultList { public: + ByteCodeRewriteResultList(unsigned maxNumResults) + : PDLResultList(maxNumResults) {} + /// Return the list of PDL results. MutableArrayRef<PDLValue> getResults() { return results; } + + /// Return the type ranges allocated by this list. + MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { + return allocatedTypeRanges; + } + + /// Return the value ranges allocated by this list. + MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { + return allocatedValueRanges; + } }; } // end anonymous namespace @@ -893,21 +1147,46 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); - ByteCodeRewriteResultList results; + + // Execute the rewrite function. + ByteCodeField numResults = read(); + ByteCodeRewriteResultList results(numResults); rewriteFn(args, constParams, rewriter, results); - // Store the results in the bytecode memory. -#ifndef NDEBUG - ByteCodeField expectedNumberOfResults = read(); - assert(results.getResults().size() == expectedNumberOfResults && + assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); -#endif // Store the results in the bytecode memory. for (PDLValue &result : results.getResults()) { LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); - memory[read()] = result.getAsOpaquePointer(); + +// In debug mode we also verify the expected kind of the result. +#ifndef NDEBUG + assert(result.getKind() == read<PDLValue::Kind>() && + "native PDL rewrite function returned an unexpected type of result"); +#endif + + // If the result is a range, we need to copy it over to the bytecodes + // range memory. + if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { + unsigned rangeIndex = read(); + typeRangeMemory[rangeIndex] = *typeRange; + memory[read()] = &typeRangeMemory[rangeIndex]; + } else if (Optional<ValueRange> valueRange = + result.dyn_cast<ValueRange>()) { + unsigned rangeIndex = read(); + valueRangeMemory[rangeIndex] = *valueRange; + memory[read()] = &valueRangeMemory[rangeIndex]; + } else { + memory[read()] = result.getAsOpaquePointer(); + } } + + // Copy over any underlying storage allocated for result ranges. + for (auto &it : results.getAllocatedTypeRanges()) + allocatedTypeRangeMemory.push_back(std::move(it)); + for (auto &it : results.getAllocatedValueRanges()) + allocatedValueRangeMemory.push_back(std::move(it)); } void ByteCodeExecutor::executeAreEqual() { @@ -919,6 +1198,32 @@ void ByteCodeExecutor::executeAreEqual() { selectJump(lhs == rhs); } +void ByteCodeExecutor::executeAreRangesEqual() { + LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); + PDLValue::Kind valueKind = read<PDLValue::Kind>(); + const void *lhs = read<const void *>(); + const void *rhs = read<const void *>(); + + switch (valueKind) { + case PDLValue::Kind::TypeRange: { + const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); + const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(*lhsRange == *rhsRange); + break; + } + case PDLValue::Kind::ValueRange: { + const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); + const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(*lhsRange == *rhsRange); + break; + } + default: + llvm_unreachable("unexpected `AreRangesEqual` value kind"); + } +} + void ByteCodeExecutor::executeBranch() { LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); curCodeIt = &code[read<ByteCodeAddr>()]; @@ -928,10 +1233,16 @@ void ByteCodeExecutor::executeCheckOperandCount() { LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); Operation *op = read<Operation *>(); uint32_t expectedCount = read<uint32_t>(); + bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" - << " * Expected: " << expectedCount << "\n"); - selectJump(op->getNumOperands() == expectedCount); + << " * Expected: " << expectedCount << "\n" + << " * Comparator: " + << (compareAtLeast ? ">=" : "==") << "\n"); + if (compareAtLeast) + selectJump(op->getNumOperands() >= expectedCount); + else + selectJump(op->getNumOperands() == expectedCount); } void ByteCodeExecutor::executeCheckOperationName() { @@ -948,10 +1259,44 @@ void ByteCodeExecutor::executeCheckResultCount() { LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); Operation *op = read<Operation *>(); uint32_t expectedCount = read<uint32_t>(); + bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" - << " * Expected: " << expectedCount << "\n"); - selectJump(op->getNumResults() == expectedCount); + << " * Expected: " << expectedCount << "\n" + << " * Comparator: " + << (compareAtLeast ? ">=" : "==") << "\n"); + if (compareAtLeast) + selectJump(op->getNumResults() >= expectedCount); + else + selectJump(op->getNumResults() == expectedCount); +} + +void ByteCodeExecutor::executeCheckTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + TypeRange *lhs = read<TypeRange *>(); + Attribute rhs = read<Attribute>(); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + + selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>()); +} + +void ByteCodeExecutor::executeCreateTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); + + LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); + + // Allocate a buffer for this type range. + llvm::OwningArrayRef<Type> storage(typesAttr.size()); + llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); + allocatedTypeRangeMemory.emplace_back(std::move(storage)); + + // Assign this to the range slot and use the range as the value for the + // memory index. + typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); + memory[memIndex] = &typeRangeMemory[rangeIndex]; } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, @@ -960,22 +1305,26 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, unsigned memIndex = read(); OperationState state(mainRewriteLoc, read<OperationName>()); - readList<Value>(state.operands); + readValueList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { Identifier name = read<Identifier>(); if (Attribute attr = read<Attribute>()) state.addAttribute(name, attr); } - bool hasInferredTypes = false; for (unsigned i = 0, e = read(); i != e; ++i) { - Type resultType = read<Type>(); - hasInferredTypes |= !resultType; - state.types.push_back(resultType); - } + if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { + state.types.push_back(read<Type>()); + continue; + } + + // If we find a null range, this signals that the types are infered. + if (TypeRange *resultTypes = read<TypeRange *>()) { + state.types.append(resultTypes->begin(), resultTypes->end()); + continue; + } - // Handle the case where the operation has inferred types. - if (hasInferredTypes) { + // Handle the case where the operation has inferred types. InferTypeOpInterface::Concept *concept = state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); @@ -986,7 +1335,9 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, state.attributes.getDictionary(state.getContext()), state.regions, state.types))) return; + break; } + Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; @@ -1036,11 +1387,21 @@ void ByteCodeExecutor::executeGetAttributeType() { void ByteCodeExecutor::executeGetDefiningOp() { LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); unsigned memIndex = read(); - Value value = read<Value>(); - Operation *op = value ? value.getDefiningOp() : nullptr; + Operation *op = nullptr; + if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { + Value value = read<Value>(); + if (value) + op = value.getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + } else { + ValueRange *values = read<ValueRange *>(); + if (values && !values->empty()) { + op = values->front().getDefiningOp(); + } + LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); + } - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" - << " * Result: " << *op << "\n"); + LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); memory[memIndex] = op; } @@ -1056,6 +1417,75 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) { memory[memIndex] = operand.getAsOpaquePointer(); } +/// This function is the internal implementation of `GetResults` and +/// `GetOperands` that provides support for extracting a value range from the +/// given operation. +template <template <typename> class AttrSizedSegmentsT, typename RangeT> +static void * +executeGetOperandsResults(RangeT values, Operation *op, unsigned index, + ByteCodeField rangeIndex, StringRef attrSizedSegments, + MutableArrayRef<ValueRange> &valueRangeMemory) { + // Check for the sentinel index that signals that all values should be + // returned. + if (index == std::numeric_limits<uint32_t>::max()) { + LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); + // `values` is already the full value range. + + // Otherwise, check to see if this operation uses AttrSizedSegments. + } else if (op->hasTrait<AttrSizedSegmentsT>()) { + LLVM_DEBUG(llvm::dbgs() + << " * Extracting values from `" << attrSizedSegments << "`\n"); + + auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments); + if (!segmentAttr || segmentAttr.getNumElements() <= index) + return nullptr; + + auto segments = segmentAttr.getValues<int32_t>(); + unsigned startIndex = + std::accumulate(segments.begin(), segments.begin() + index, 0); + values = values.slice(startIndex, *std::next(segments.begin(), index)); + + LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " + << *std::next(segments.begin(), index) << "]\n"); + + // Otherwise, assume this is the last operand group of the operation. + // FIXME: We currently don't support operations with + // SameVariadicOperandSize/SameVariadicResultSize here given that we don't + // have a way to detect it's presence. + } else if (values.size() >= index) { + LLVM_DEBUG(llvm::dbgs() + << " * Treating values as trailing variadic range\n"); + values = values.drop_front(index); + + // If we couldn't detect a way to compute the values, bail out. + } else { + return nullptr; + } + + // If the range index is valid, we are returning a range. + if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) { + valueRangeMemory[rangeIndex] = values; + return &valueRangeMemory[rangeIndex]; + } + + // If a range index wasn't provided, the range is required to be non-variadic. + return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer(); +} + +void ByteCodeExecutor::executeGetOperands() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); + unsigned index = read<uint32_t>(); + Operation *op = read<Operation *>(); + ByteCodeField rangeIndex = read(); + + void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>( + op->getOperands(), op, index, rangeIndex, "operand_segment_sizes", + valueRangeMemory); + if (!result) + LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); + memory[read()] = result; +} + void ByteCodeExecutor::executeGetResult(unsigned index) { Operation *op = read<Operation *>(); unsigned memIndex = read(); @@ -1068,6 +1498,20 @@ void ByteCodeExecutor::executeGetResult(unsigned index) { memory[memIndex] = result.getAsOpaquePointer(); } +void ByteCodeExecutor::executeGetResults() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); + unsigned index = read<uint32_t>(); + Operation *op = read<Operation *>(); + ByteCodeField rangeIndex = read(); + + void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>( + op->getResults(), op, index, rangeIndex, "result_segment_sizes", + valueRangeMemory); + if (!result) + LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); + memory[read()] = result; +} + void ByteCodeExecutor::executeGetValueType() { LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); unsigned memIndex = read(); @@ -1079,6 +1523,28 @@ void ByteCodeExecutor::executeGetValueType() { memory[memIndex] = type.getAsOpaquePointer(); } +void ByteCodeExecutor::executeGetValueRangeTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + ValueRange *values = read<ValueRange *>(); + if (!values) { + LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); + memory[memIndex] = nullptr; + return; + } + + LLVM_DEBUG({ + llvm::dbgs() << " * Values (" << values->size() << "): "; + llvm::interleaveComma(*values, llvm::dbgs()); + llvm::dbgs() << "\n * Result: "; + llvm::interleaveComma(values->getType(), llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + typeRangeMemory[rangeIndex] = values->getType(); + memory[memIndex] = &typeRangeMemory[rangeIndex]; +} + void ByteCodeExecutor::executeIsNotNull() { LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); const void *value = read<const void *>(); @@ -1117,7 +1583,30 @@ void ByteCodeExecutor::executeRecordMatch( LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" << " * Location: " << matchLoc << "\n"); matches.emplace_back(matchLoc, patterns[patternIndex], benefit); - readList<const void *>(matches.back().values); + PDLByteCode::MatchResult &match = matches.back(); + + // Record all of the inputs to the match. If any of the inputs are ranges, we + // will also need to remap the range pointer to memory stored in the match + // state. + unsigned numInputs = read(); + match.values.reserve(numInputs); + match.typeRangeValues.reserve(numInputs); + match.valueRangeValues.reserve(numInputs); + for (unsigned i = 0; i < numInputs; ++i) { + switch (read<PDLValue::Kind>()) { + case PDLValue::Kind::TypeRange: + match.typeRangeValues.push_back(*read<TypeRange *>()); + match.values.push_back(&match.typeRangeValues.back()); + break; + case PDLValue::Kind::ValueRange: + match.valueRangeValues.push_back(*read<ValueRange *>()); + match.values.push_back(&match.valueRangeValues.back()); + break; + default: + match.values.push_back(read<const void *>()); + break; + } + } curCodeIt = dest; } @@ -1125,7 +1614,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); Operation *op = read<Operation *>(); SmallVector<Value, 16> args; - readList<Value>(args); + readValueList(args); LLVM_DEBUG({ llvm::dbgs() << " * Operation: " << *op << "\n" @@ -1198,6 +1687,19 @@ void ByteCodeExecutor::executeSwitchType() { handleSwitch(value, cases); } +void ByteCodeExecutor::executeSwitchTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); + TypeRange *value = read<TypeRange *>(); + auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); + if (!value) { + LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); + return selectJump(size_t(0)); + } + handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { + return value == caseValue.getAsValueRange<TypeAttr>(); + }); +} + void ByteCodeExecutor::execute( PatternRewriter &rewriter, SmallVectorImpl<PDLByteCode::MatchResult> *matches, @@ -1214,6 +1716,9 @@ void ByteCodeExecutor::execute( case AreEqual: executeAreEqual(); break; + case AreRangesEqual: + executeAreRangesEqual(); + break; case Branch: executeBranch(); break; @@ -1226,9 +1731,15 @@ void ByteCodeExecutor::execute( case CheckResultCount: executeCheckResultCount(); break; + case CheckTypes: + executeCheckTypes(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; + case CreateTypes: + executeCreateTypes(); + break; case EraseOp: executeEraseOp(rewriter); break; @@ -1257,6 +1768,9 @@ void ByteCodeExecutor::execute( LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); executeGetOperand(read<uint32_t>()); break; + case GetOperands: + executeGetOperands(); + break; case GetResult0: case GetResult1: case GetResult2: @@ -1270,9 +1784,15 @@ void ByteCodeExecutor::execute( LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); executeGetResult(read<uint32_t>()); break; + case GetResults: + executeGetResults(); + break; case GetValueType: executeGetValueType(); break; + case GetValueRangeTypes: + executeGetValueRangeTypes(); + break; case IsNotNull: executeIsNotNull(); break; @@ -1299,6 +1819,9 @@ void ByteCodeExecutor::execute( case SwitchType: executeSwitchType(); break; + case SwitchTypes: + executeSwitchTypes(); + break; } LLVM_DEBUG(llvm::dbgs() << "\n"); } @@ -1313,9 +1836,12 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, state.memory[0] = op; // The matcher function always starts at code address 0. - ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, - matcherByteCode, state.currentPatternBenefits, - patterns, constraintFunctions, rewriteFunctions); + ByteCodeExecutor executor( + matcherByteCode.data(), state.memory, state.typeRangeMemory, + state.allocatedTypeRangeMemory, state.valueRangeMemory, + state.allocatedValueRangeMemory, uniquedData, matcherByteCode, + state.currentPatternBenefits, patterns, constraintFunctions, + rewriteFunctions); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1332,9 +1858,11 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, // memory buffer. llvm::copy(match.values, state.memory.begin()); - ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()], - state.memory, uniquedData, rewriterByteCode, - state.currentPatternBenefits, patterns, - constraintFunctions, rewriteFunctions); + ByteCodeExecutor executor( + &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, + state.typeRangeMemory, state.allocatedTypeRangeMemory, + state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, + rewriterByteCode, state.currentPatternBenefits, patterns, + constraintFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); } |