aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp748
1 files changed, 638 insertions, 110 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ef96e25..ea17f99 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -20,6 +20,9 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <numeric>
#define DEBUG_TYPE "pdl-bytecode"
@@ -60,6 +63,14 @@ void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
currentPatternBenefits[patternIndex] = benefit;
}
+/// Cleanup any allocated state after a full match/rewrite has been completed.
+/// This method should be called irregardless of whether the match+rewrite was a
+/// success or not.
+void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
+ allocatedTypeRangeMemory.clear();
+ allocatedValueRangeMemory.clear();
+}
+
//===----------------------------------------------------------------------===//
// Bytecode OpCodes
//===----------------------------------------------------------------------===//
@@ -72,6 +83,8 @@ enum OpCode : ByteCodeField {
ApplyRewrite,
/// Check if two generic values are equal.
AreEqual,
+ /// Check if two ranges are equal.
+ AreRangesEqual,
/// Unconditional branch.
Branch,
/// Compare the operand count of an operation with a constant.
@@ -80,8 +93,12 @@ enum OpCode : ByteCodeField {
CheckOperationName,
/// Compare the result count of an operation with a constant.
CheckResultCount,
+ /// Compare a range of types to a constant range of types.
+ CheckTypes,
/// Create an operation.
CreateOperation,
+ /// Create a range of types.
+ CreateTypes,
/// Erase an operation.
EraseOp,
/// Terminate a matcher or rewrite sequence.
@@ -98,14 +115,20 @@ enum OpCode : ByteCodeField {
GetOperand2,
GetOperand3,
GetOperandN,
+ /// Get a specific operand group of an operation.
+ GetOperands,
/// Get a specific result of an operation.
GetResult0,
GetResult1,
GetResult2,
GetResult3,
GetResultN,
+ /// Get a specific result group of an operation.
+ GetResults,
/// Get the type of a value.
GetValueType,
+ /// Get the types of a value range.
+ GetValueRangeTypes,
/// Check if a generic value is not null.
IsNotNull,
/// Record a successful pattern match.
@@ -122,9 +145,9 @@ enum OpCode : ByteCodeField {
SwitchResultCount,
/// Compare a type with a set of constants.
SwitchType,
+ /// Compare a range of types with a set of constants.
+ SwitchTypes,
};
-
-enum class PDLValueKind { Attribute, Operation, Type, Value };
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -145,11 +168,15 @@ public:
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
+ ByteCodeField &maxTypeRangeMemoryIndex,
+ ByteCodeField &maxValueRangeMemoryIndex,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
- maxValueMemoryIndex(maxValueMemoryIndex) {
+ maxValueMemoryIndex(maxValueMemoryIndex),
+ maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
+ maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
@@ -166,6 +193,13 @@ public:
return valueToMemIndex[value];
}
+ /// Return the range memory index used to store the given range value.
+ ByteCodeField &getRangeStorageIndex(Value value) {
+ assert(valueToRangeIndex.count(value) &&
+ "expected range index to be assigned");
+ return valueToRangeIndex[value];
+ }
+
/// Return an index to use when referring to the given data that is uniqued in
/// the MLIR context.
template <typename T>
@@ -197,16 +231,20 @@ private:
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
@@ -214,6 +252,7 @@ private:
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
@@ -221,6 +260,9 @@ private:
/// Mapping from value to its corresponding memory index.
DenseMap<Value, ByteCodeField> valueToMemIndex;
+ /// Mapping from a range value to its corresponding range storage index.
+ DenseMap<Value, ByteCodeField> valueToRangeIndex;
+
/// Mapping from the name of an externally registered rewrite to its index in
/// the bytecode registry.
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
@@ -246,6 +288,8 @@ private:
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
SmallVectorImpl<PDLByteCodePattern> &patterns;
ByteCodeField &maxValueMemoryIndex;
+ ByteCodeField &maxTypeRangeMemoryIndex;
+ ByteCodeField &maxValueRangeMemoryIndex;
};
/// This class provides utilities for writing a bytecode stream.
@@ -281,19 +325,33 @@ struct ByteCodeWriter {
/// Append a range of values that will be read as generic PDLValues.
void appendPDLValueList(OperandRange values) {
bytecode.push_back(values.size());
- for (Value value : values) {
- // Append the type of the value in addition to the value itself.
- PDLValueKind kind =
- TypeSwitch<Type, PDLValueKind>(value.getType())
- .Case<pdl::AttributeType>(
- [](Type) { return PDLValueKind::Attribute; })
- .Case<pdl::OperationType>(
- [](Type) { return PDLValueKind::Operation; })
- .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
- .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
- bytecode.push_back(static_cast<ByteCodeField>(kind));
- append(value);
- }
+ for (Value value : values)
+ appendPDLValue(value);
+ }
+
+ /// Append a value as a PDLValue.
+ void appendPDLValue(Value value) {
+ appendPDLValueKind(value);
+ append(value);
+ }
+
+ /// Append the PDLValue::Kind of the given value.
+ void appendPDLValueKind(Value value) {
+ // Append the type of the value in addition to the value itself.
+ PDLValue::Kind kind =
+ TypeSwitch<Type, PDLValue::Kind>(value.getType())
+ .Case<pdl::AttributeType>(
+ [](Type) { return PDLValue::Kind::Attribute; })
+ .Case<pdl::OperationType>(
+ [](Type) { return PDLValue::Kind::Operation; })
+ .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
+ if (rangeTy.getElementType().isa<pdl::TypeType>())
+ return PDLValue::Kind::TypeRange;
+ return PDLValue::Kind::ValueRange;
+ })
+ .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
+ .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
+ bytecode.push_back(static_cast<ByteCodeField>(kind));
}
/// Check if the given class `T` has an iterator type.
@@ -334,6 +392,36 @@ struct ByteCodeWriter {
/// The main generator producing PDL.
Generator &generator;
};
+
+/// This class represents a live range of PDL Interpreter values, containing
+/// information about when values are live within a match/rewrite.
+struct ByteCodeLiveRange {
+ using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
+ using Allocator = Set::Allocator;
+
+ ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
+
+ /// Union this live range with the one provided.
+ void unionWith(const ByteCodeLiveRange &rhs) {
+ for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
+ liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ }
+
+ /// Returns true if this range overlaps with the one provided.
+ bool overlaps(const ByteCodeLiveRange &rhs) const {
+ return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
+ }
+
+ /// A map representing the ranges of the match/rewrite that a value is live in
+ /// the interpreter.
+ llvm::IntervalMap<ByteCodeField, char, 16> liveness;
+
+ /// The type range storage index for this range.
+ Optional<unsigned> typeRangeIndex;
+
+ /// The value range storage index for this range.
+ Optional<unsigned> valueRangeIndex;
+};
} // end anonymous namespace
void Generator::generate(ModuleOp module) {
@@ -381,15 +469,30 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Rewriters use simplistic allocation scheme that simply assigns an index to
// each result.
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
- ByteCodeField index = 0;
+ ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
+ auto processRewriterValue = [&](Value val) {
+ valueToMemIndex.try_emplace(val, index++);
+ if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
+ Type elementTy = rangeType.getElementType();
+ if (elementTy.isa<pdl::TypeType>())
+ valueToRangeIndex.try_emplace(val, typeRangeIndex++);
+ else if (elementTy.isa<pdl::ValueType>())
+ valueToRangeIndex.try_emplace(val, valueRangeIndex++);
+ }
+ };
+
for (BlockArgument arg : rewriterFunc.getArguments())
- valueToMemIndex.try_emplace(arg, index++);
+ processRewriterValue(arg);
rewriterFunc.getBody().walk([&](Operation *op) {
for (Value result : op->getResults())
- valueToMemIndex.try_emplace(result, index++);
+ processRewriterValue(result);
});
if (index > maxValueMemoryIndex)
maxValueMemoryIndex = index;
+ if (typeRangeIndex > maxTypeRangeMemoryIndex)
+ maxTypeRangeMemoryIndex = typeRangeIndex;
+ if (valueRangeIndex > maxValueRangeMemoryIndex)
+ maxValueRangeMemoryIndex = valueRangeIndex;
}
// The matcher function uses a more sophisticated numbering that tries to
@@ -404,9 +507,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
});
// Liveness info for each of the defs within the matcher.
- using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
- LivenessSet::Allocator allocator;
- DenseMap<Value, LivenessSet> valueDefRanges;
+ ByteCodeLiveRange::Allocator allocator;
+ DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
// Assign the root operation being matched to slot 0.
BlockArgument rootOpArg = matcherFunc.getArgument(0);
@@ -425,10 +527,19 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
- defRangeIt->second.insert(
+ defRangeIt->second.liveness.insert(
opToIndex[firstUseOrDef],
opToIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
+
+ // Check to see if this value is a range type.
+ if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
+ Type eleType = rangeTy.getElementType();
+ if (eleType.isa<pdl::TypeType>())
+ defRangeIt->second.typeRangeIndex = 0;
+ else if (eleType.isa<pdl::ValueType>())
+ defRangeIt->second.valueRangeIndex = 0;
+ }
};
// Process the live-ins of this block.
@@ -442,37 +553,59 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
}
// Greedily allocate memory slots using the computed def live ranges.
- std::vector<LivenessSet> allocatedIndices;
+ std::vector<ByteCodeLiveRange> allocatedIndices;
+ ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
for (auto &defIt : valueDefRanges) {
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
- LivenessSet &defSet = defIt.second;
+ ByteCodeLiveRange &defRange = defIt.second;
// Try to allocate to an existing index.
for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
- LivenessSet &existingIndex = existingIndexIt.value();
- llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
- defIt.second, existingIndex);
- if (overlaps.valid())
- continue;
- // Union the range of the def within the existing index.
- for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
- existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
- memIndex = existingIndexIt.index() + 1;
+ ByteCodeLiveRange &existingRange = existingIndexIt.value();
+ if (!defRange.overlaps(existingRange)) {
+ existingRange.unionWith(defRange);
+ memIndex = existingIndexIt.index() + 1;
+
+ if (defRange.typeRangeIndex) {
+ if (!existingRange.typeRangeIndex)
+ existingRange.typeRangeIndex = numTypeRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
+ } else if (defRange.valueRangeIndex) {
+ if (!existingRange.valueRangeIndex)
+ existingRange.valueRangeIndex = numValueRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
+ }
+ break;
+ }
}
// If no existing index could be used, add a new one.
if (memIndex == 0) {
allocatedIndices.emplace_back(allocator);
- for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
- allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ ByteCodeLiveRange &newRange = allocatedIndices.back();
+ newRange.unionWith(defRange);
+
+ // Allocate an index for type/value ranges.
+ if (defRange.typeRangeIndex) {
+ newRange.typeRangeIndex = numTypeRanges;
+ valueToRangeIndex[defIt.first] = numTypeRanges++;
+ } else if (defRange.valueRangeIndex) {
+ newRange.valueRangeIndex = numValueRanges;
+ valueToRangeIndex[defIt.first] = numValueRanges++;
+ }
+
memIndex = allocatedIndices.size();
+ ++numIndices;
}
}
// Update the max number of indices.
- ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
- if (numMatcherIndices > maxValueMemoryIndex)
- maxValueMemoryIndex = numMatcherIndices;
+ if (numIndices > maxValueMemoryIndex)
+ maxValueMemoryIndex = numIndices;
+ if (numTypeRanges > maxTypeRangeMemoryIndex)
+ maxTypeRangeMemoryIndex = numTypeRanges;
+ if (numValueRanges > maxValueRangeMemoryIndex)
+ maxValueRangeMemoryIndex = numValueRanges;
}
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
@@ -481,17 +614,19 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::AreEqualOp, pdl_interp::BranchOp,
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
- pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
- pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+ pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
+ pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
+ pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
pdl_interp::EraseOp, pdl_interp::FinalizeOp,
pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
- pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
+ pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
+ pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
- pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
- pdl_interp::SwitchResultCountOp>(
+ pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
+ pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@@ -515,16 +650,31 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
op.constParamsAttr());
writer.appendPDLValueList(op.args());
+ ResultRange results = op.results();
+ writer.append(ByteCodeField(results.size()));
+ for (Value result : results) {
+ // In debug mode we also record the expected kind of the result, so that we
+ // can provide extra verification of the native rewrite function.
#ifndef NDEBUG
- // In debug mode we also append the number of results so that we can assert
- // that the native creation function gave us the correct number of results.
- writer.append(ByteCodeField(op.results().size()));
+ writer.appendPDLValueKind(result);
#endif
- for (Value result : op.results())
+
+ // Range results also need to append the range storage index.
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
writer.append(result);
+ }
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
+ Value lhs = op.lhs();
+ if (lhs.getType().isa<pdl::RangeType>()) {
+ writer.append(OpCode::AreRangesEqual);
+ writer.appendPDLValueKind(lhs);
+ writer.append(op.lhs(), op.rhs(), op.getSuccessors());
+ return;
+ }
+
+ writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
}
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
@@ -537,6 +687,7 @@ void Generator::generate(pdl_interp::CheckAttributeOp op,
void Generator::generate(pdl_interp::CheckOperandCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
+ static_cast<ByteCodeField>(op.compareAtLeast()),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckOperationNameOp op,
@@ -547,11 +698,15 @@ void Generator::generate(pdl_interp::CheckOperationNameOp op,
void Generator::generate(pdl_interp::CheckResultCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
+ static_cast<ByteCodeField>(op.compareAtLeast()),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
}
+void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
+}
void Generator::generate(pdl_interp::CreateAttributeOp op,
ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
@@ -560,7 +715,8 @@ void Generator::generate(pdl_interp::CreateAttributeOp op,
void Generator::generate(pdl_interp::CreateOperationOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CreateOperation, op.operation(),
- OperationName(op.name(), ctx), op.operands());
+ OperationName(op.name(), ctx));
+ writer.appendPDLValueList(op.operands());
// Add the attributes.
OperandRange attributes = op.attributes();
@@ -570,12 +726,16 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
std::get<1>(it));
}
- writer.append(op.types());
+ writer.appendPDLValueList(op.types());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.result()) = getMemIndex(op.value());
}
+void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::CreateTypes, op.result(),
+ getRangeStorageIndex(op.result()), op.value());
+}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
writer.append(OpCode::EraseOp, op.operation());
}
@@ -593,7 +753,8 @@ void Generator::generate(pdl_interp::GetAttributeTypeOp op,
}
void Generator::generate(pdl_interp::GetDefiningOpOp op,
ByteCodeWriter &writer) {
- writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
+ writer.append(OpCode::GetDefiningOp, op.operation());
+ writer.appendPDLValue(op.value());
}
void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
@@ -603,6 +764,18 @@ void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetOperandN, index);
writer.append(op.operation(), op.value());
}
+void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
+ Value result = op.value();
+ Optional<uint32_t> index = op.index();
+ writer.append(OpCode::GetOperands,
+ index.getValueOr(std::numeric_limits<uint32_t>::max()),
+ op.operation());
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
+ else
+ writer.append(std::numeric_limits<ByteCodeField>::max());
+ writer.append(result);
+}
void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
if (index < 4)
@@ -611,10 +784,29 @@ void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetResultN, index);
writer.append(op.operation(), op.value());
}
+void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
+ Value result = op.value();
+ Optional<uint32_t> index = op.index();
+ writer.append(OpCode::GetResults,
+ index.getValueOr(std::numeric_limits<uint32_t>::max()),
+ op.operation());
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
+ else
+ writer.append(std::numeric_limits<ByteCodeField>::max());
+ writer.append(result);
+}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
- writer.append(OpCode::GetValueType, op.result(), op.value());
+ if (op.getType().isa<pdl::RangeType>()) {
+ Value result = op.result();
+ writer.append(OpCode::GetValueRangeTypes, result,
+ getRangeStorageIndex(result), op.value());
+ } else {
+ writer.append(OpCode::GetValueType, op.result(), op.value());
+ }
}
+
void Generator::generate(pdl_interp::InferredTypesOp op,
ByteCodeWriter &writer) {
// InferType maps to a null type as a marker for inferring result types.
@@ -628,11 +820,12 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
patterns.emplace_back(PDLByteCodePattern::create(
op, rewriterToAddr[op.rewriter().getLeafReference()]));
writer.append(OpCode::RecordMatch, patternIndex,
- SuccessorRange(op.getOperation()), op.matchedOps(),
- op.inputs());
+ SuccessorRange(op.getOperation()), op.matchedOps());
+ writer.appendPDLValueList(op.inputs());
}
void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
+ writer.append(OpCode::ReplaceOp, op.operation());
+ writer.appendPDLValueList(op.replValues());
}
void Generator::generate(pdl_interp::SwitchAttributeOp op,
ByteCodeWriter &writer) {
@@ -661,6 +854,10 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
op.getSuccessors());
}
+void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
//===----------------------------------------------------------------------===//
// PDLByteCode
@@ -671,7 +868,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
- constraintFns, rewriteFns);
+ maxTypeRangeCount, maxValueRangeCount, constraintFns,
+ rewriteFns);
generator.generate(module);
// Initialize the external functions.
@@ -685,6 +883,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
/// bytecode.
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
state.memory.resize(maxValueMemoryIndex, nullptr);
+ state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
+ state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
state.currentPatternBenefits.reserve(patterns.size());
for (const PDLByteCodePattern &pattern : patterns)
state.currentPatternBenefits.push_back(pattern.getBenefit());
@@ -697,17 +897,24 @@ namespace {
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
- ByteCodeExecutor(const ByteCodeField *curCodeIt,
- MutableArrayRef<const void *> memory,
- ArrayRef<const void *> uniquedMemory,
- ArrayRef<ByteCodeField> code,
- ArrayRef<PatternBenefit> currentPatternBenefits,
- ArrayRef<PDLByteCodePattern> patterns,
- ArrayRef<PDLConstraintFunction> constraintFunctions,
- ArrayRef<PDLRewriteFunction> rewriteFunctions)
- : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
- code(code), currentPatternBenefits(currentPatternBenefits),
- patterns(patterns), constraintFunctions(constraintFunctions),
+ ByteCodeExecutor(
+ const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
+ MutableArrayRef<TypeRange> typeRangeMemory,
+ std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
+ MutableArrayRef<ValueRange> valueRangeMemory,
+ std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
+ ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
+ ArrayRef<PatternBenefit> currentPatternBenefits,
+ ArrayRef<PDLByteCodePattern> patterns,
+ ArrayRef<PDLConstraintFunction> constraintFunctions,
+ ArrayRef<PDLRewriteFunction> rewriteFunctions)
+ : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
+ allocatedTypeRangeMemory(allocatedTypeRangeMemory),
+ valueRangeMemory(valueRangeMemory),
+ allocatedValueRangeMemory(allocatedValueRangeMemory),
+ uniquedMemory(uniquedMemory), code(code),
+ currentPatternBenefits(currentPatternBenefits), patterns(patterns),
+ constraintFunctions(constraintFunctions),
rewriteFunctions(rewriteFunctions) {}
/// Start executing the code at the current bytecode index. `matches` is an
@@ -722,19 +929,25 @@ private:
void executeApplyConstraint(PatternRewriter &rewriter);
void executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual();
+ void executeAreRangesEqual();
void executeBranch();
void executeCheckOperandCount();
void executeCheckOperationName();
void executeCheckResultCount();
+ void executeCheckTypes();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
+ void executeCreateTypes();
void executeEraseOp(PatternRewriter &rewriter);
void executeGetAttribute();
void executeGetAttributeType();
void executeGetDefiningOp();
void executeGetOperand(unsigned index);
+ void executeGetOperands();
void executeGetResult(unsigned index);
+ void executeGetResults();
void executeGetValueType();
+ void executeGetValueRangeTypes();
void executeIsNotNull();
void executeRecordMatch(PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> &matches);
@@ -744,6 +957,7 @@ private:
void executeSwitchOperationName();
void executeSwitchResultCount();
void executeSwitchType();
+ void executeSwitchTypes();
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
@@ -763,6 +977,19 @@ private:
list.push_back(read<ValueT>());
}
+ /// Read a list of values from the bytecode buffer. The values may be encoded
+ /// as either Value or ValueRange elements.
+ void readValueList(SmallVectorImpl<Value> &list) {
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ list.push_back(read<Value>());
+ } else {
+ ValueRange *values = read<ValueRange *>();
+ list.append(values->begin(), values->end());
+ }
+ }
+ }
+
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@@ -771,8 +998,8 @@ private:
}
/// Handle a switch operation with the provided value and cases.
- template <typename T, typename RangeT>
- void handleSwitch(const T &value, RangeT &&cases) {
+ template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
+ void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
LLVM_DEBUG({
llvm::dbgs() << " * Value: " << value << "\n"
<< " * Cases: ";
@@ -783,7 +1010,7 @@ private:
// Check to see if the attribute value is within the case list. Jump to
// the correct successor index based on the result.
for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
- if (*it == value)
+ if (cmp(*it, value))
return selectJump(size_t((it - cases.begin()) + 1));
selectJump(size_t(0));
}
@@ -795,7 +1022,9 @@ private:
size_t index = *curCodeIt++;
// If this type is an SSA value, it can only be stored in non-const memory.
- if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
+ if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
+ Value>::value ||
+ index < memory.size())
return memory[index];
// Otherwise, if this index is not inbounds it is uniqued.
@@ -813,17 +1042,21 @@ private:
}
template <typename T>
std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
- switch (static_cast<PDLValueKind>(read())) {
- case PDLValueKind::Attribute:
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::Attribute:
return read<Attribute>();
- case PDLValueKind::Operation:
+ case PDLValue::Kind::Operation:
return read<Operation *>();
- case PDLValueKind::Type:
+ case PDLValue::Kind::Type:
return read<Type>();
- case PDLValueKind::Value:
+ case PDLValue::Kind::Value:
return read<Value>();
+ case PDLValue::Kind::TypeRange:
+ return read<TypeRange *>();
+ case PDLValue::Kind::ValueRange:
+ return read<ValueRange *>();
}
- llvm_unreachable("unhandled PDLValueKind");
+ llvm_unreachable("unhandled PDLValue::Kind");
}
template <typename T>
std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
@@ -838,12 +1071,20 @@ private:
std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
return *curCodeIt++;
}
+ template <typename T>
+ std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
+ return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
+ }
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
/// The current execution memory.
MutableArrayRef<const void *> memory;
+ MutableArrayRef<TypeRange> typeRangeMemory;
+ std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
+ MutableArrayRef<ValueRange> valueRangeMemory;
+ std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
/// References to ByteCode data necessary for execution.
ArrayRef<const void *> uniquedMemory;
@@ -859,8 +1100,21 @@ private:
/// overexposing access to information specific solely to the ByteCode.
class ByteCodeRewriteResultList : public PDLResultList {
public:
+ ByteCodeRewriteResultList(unsigned maxNumResults)
+ : PDLResultList(maxNumResults) {}
+
/// Return the list of PDL results.
MutableArrayRef<PDLValue> getResults() { return results; }
+
+ /// Return the type ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ return allocatedTypeRanges;
+ }
+
+ /// Return the value ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ return allocatedValueRanges;
+ }
};
} // end anonymous namespace
@@ -893,21 +1147,46 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});
- ByteCodeRewriteResultList results;
+
+ // Execute the rewrite function.
+ ByteCodeField numResults = read();
+ ByteCodeRewriteResultList results(numResults);
rewriteFn(args, constParams, rewriter, results);
- // Store the results in the bytecode memory.
-#ifndef NDEBUG
- ByteCodeField expectedNumberOfResults = read();
- assert(results.getResults().size() == expectedNumberOfResults &&
+ assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
-#endif
// Store the results in the bytecode memory.
for (PDLValue &result : results.getResults()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- memory[read()] = result.getAsOpaquePointer();
+
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+ assert(result.getKind() == read<PDLValue::Kind>() &&
+ "native PDL rewrite function returned an unexpected type of result");
+#endif
+
+ // If the result is a range, we need to copy it over to the bytecodes
+ // range memory.
+ if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
+ unsigned rangeIndex = read();
+ typeRangeMemory[rangeIndex] = *typeRange;
+ memory[read()] = &typeRangeMemory[rangeIndex];
+ } else if (Optional<ValueRange> valueRange =
+ result.dyn_cast<ValueRange>()) {
+ unsigned rangeIndex = read();
+ valueRangeMemory[rangeIndex] = *valueRange;
+ memory[read()] = &valueRangeMemory[rangeIndex];
+ } else {
+ memory[read()] = result.getAsOpaquePointer();
+ }
}
+
+ // Copy over any underlying storage allocated for result ranges.
+ for (auto &it : results.getAllocatedTypeRanges())
+ allocatedTypeRangeMemory.push_back(std::move(it));
+ for (auto &it : results.getAllocatedValueRanges())
+ allocatedValueRangeMemory.push_back(std::move(it));
}
void ByteCodeExecutor::executeAreEqual() {
@@ -919,6 +1198,32 @@ void ByteCodeExecutor::executeAreEqual() {
selectJump(lhs == rhs);
}
+void ByteCodeExecutor::executeAreRangesEqual() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
+ PDLValue::Kind valueKind = read<PDLValue::Kind>();
+ const void *lhs = read<const void *>();
+ const void *rhs = read<const void *>();
+
+ switch (valueKind) {
+ case PDLValue::Kind::TypeRange: {
+ const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
+ const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ selectJump(*lhsRange == *rhsRange);
+ break;
+ }
+ case PDLValue::Kind::ValueRange: {
+ const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
+ const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ selectJump(*lhsRange == *rhsRange);
+ break;
+ }
+ default:
+ llvm_unreachable("unexpected `AreRangesEqual` value kind");
+ }
+}
+
void ByteCodeExecutor::executeBranch() {
LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
curCodeIt = &code[read<ByteCodeAddr>()];
@@ -928,10 +1233,16 @@ void ByteCodeExecutor::executeCheckOperandCount() {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
+ bool compareAtLeast = read();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
- << " * Expected: " << expectedCount << "\n");
- selectJump(op->getNumOperands() == expectedCount);
+ << " * Expected: " << expectedCount << "\n"
+ << " * Comparator: "
+ << (compareAtLeast ? ">=" : "==") << "\n");
+ if (compareAtLeast)
+ selectJump(op->getNumOperands() >= expectedCount);
+ else
+ selectJump(op->getNumOperands() == expectedCount);
}
void ByteCodeExecutor::executeCheckOperationName() {
@@ -948,10 +1259,44 @@ void ByteCodeExecutor::executeCheckResultCount() {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
+ bool compareAtLeast = read();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
- << " * Expected: " << expectedCount << "\n");
- selectJump(op->getNumResults() == expectedCount);
+ << " * Expected: " << expectedCount << "\n"
+ << " * Comparator: "
+ << (compareAtLeast ? ">=" : "==") << "\n");
+ if (compareAtLeast)
+ selectJump(op->getNumResults() >= expectedCount);
+ else
+ selectJump(op->getNumResults() == expectedCount);
+}
+
+void ByteCodeExecutor::executeCheckTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ TypeRange *lhs = read<TypeRange *>();
+ Attribute rhs = read<Attribute>();
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+
+ selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
+}
+
+void ByteCodeExecutor::executeCreateTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
+
+ // Allocate a buffer for this type range.
+ llvm::OwningArrayRef<Type> storage(typesAttr.size());
+ llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
+ allocatedTypeRangeMemory.emplace_back(std::move(storage));
+
+ // Assign this to the range slot and use the range as the value for the
+ // memory index.
+ typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
+ memory[memIndex] = &typeRangeMemory[rangeIndex];
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -960,22 +1305,26 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
- readList<Value>(state.operands);
+ readValueList(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
Identifier name = read<Identifier>();
if (Attribute attr = read<Attribute>())
state.addAttribute(name, attr);
}
- bool hasInferredTypes = false;
for (unsigned i = 0, e = read(); i != e; ++i) {
- Type resultType = read<Type>();
- hasInferredTypes |= !resultType;
- state.types.push_back(resultType);
- }
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ state.types.push_back(read<Type>());
+ continue;
+ }
+
+ // If we find a null range, this signals that the types are infered.
+ if (TypeRange *resultTypes = read<TypeRange *>()) {
+ state.types.append(resultTypes->begin(), resultTypes->end());
+ continue;
+ }
- // Handle the case where the operation has inferred types.
- if (hasInferredTypes) {
+ // Handle the case where the operation has inferred types.
InferTypeOpInterface::Concept *concept =
state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
@@ -986,7 +1335,9 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
return;
+ break;
}
+
Operation *resultOp = rewriter.createOperation(state);
memory[memIndex] = resultOp;
@@ -1036,11 +1387,21 @@ void ByteCodeExecutor::executeGetAttributeType() {
void ByteCodeExecutor::executeGetDefiningOp() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
unsigned memIndex = read();
- Value value = read<Value>();
- Operation *op = value ? value.getDefiningOp() : nullptr;
+ Operation *op = nullptr;
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ Value value = read<Value>();
+ if (value)
+ op = value.getDefiningOp();
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ } else {
+ ValueRange *values = read<ValueRange *>();
+ if (values && !values->empty()) {
+ op = values->front().getDefiningOp();
+ }
+ LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
+ }
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << *op << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
memory[memIndex] = op;
}
@@ -1056,6 +1417,75 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) {
memory[memIndex] = operand.getAsOpaquePointer();
}
+/// This function is the internal implementation of `GetResults` and
+/// `GetOperands` that provides support for extracting a value range from the
+/// given operation.
+template <template <typename> class AttrSizedSegmentsT, typename RangeT>
+static void *
+executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
+ ByteCodeField rangeIndex, StringRef attrSizedSegments,
+ MutableArrayRef<ValueRange> &valueRangeMemory) {
+ // Check for the sentinel index that signals that all values should be
+ // returned.
+ if (index == std::numeric_limits<uint32_t>::max()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
+ // `values` is already the full value range.
+
+ // Otherwise, check to see if this operation uses AttrSizedSegments.
+ } else if (op->hasTrait<AttrSizedSegmentsT>()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " * Extracting values from `" << attrSizedSegments << "`\n");
+
+ auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
+ if (!segmentAttr || segmentAttr.getNumElements() <= index)
+ return nullptr;
+
+ auto segments = segmentAttr.getValues<int32_t>();
+ unsigned startIndex =
+ std::accumulate(segments.begin(), segments.begin() + index, 0);
+ values = values.slice(startIndex, *std::next(segments.begin(), index));
+
+ LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
+ << *std::next(segments.begin(), index) << "]\n");
+
+ // Otherwise, assume this is the last operand group of the operation.
+ // FIXME: We currently don't support operations with
+ // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
+ // have a way to detect it's presence.
+ } else if (values.size() >= index) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " * Treating values as trailing variadic range\n");
+ values = values.drop_front(index);
+
+ // If we couldn't detect a way to compute the values, bail out.
+ } else {
+ return nullptr;
+ }
+
+ // If the range index is valid, we are returning a range.
+ if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
+ valueRangeMemory[rangeIndex] = values;
+ return &valueRangeMemory[rangeIndex];
+ }
+
+ // If a range index wasn't provided, the range is required to be non-variadic.
+ return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetOperands() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
+ unsigned index = read<uint32_t>();
+ Operation *op = read<Operation *>();
+ ByteCodeField rangeIndex = read();
+
+ void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
+ op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
+ valueRangeMemory);
+ if (!result)
+ LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
+ memory[read()] = result;
+}
+
void ByteCodeExecutor::executeGetResult(unsigned index) {
Operation *op = read<Operation *>();
unsigned memIndex = read();
@@ -1068,6 +1498,20 @@ void ByteCodeExecutor::executeGetResult(unsigned index) {
memory[memIndex] = result.getAsOpaquePointer();
}
+void ByteCodeExecutor::executeGetResults() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
+ unsigned index = read<uint32_t>();
+ Operation *op = read<Operation *>();
+ ByteCodeField rangeIndex = read();
+
+ void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
+ op->getResults(), op, index, rangeIndex, "result_segment_sizes",
+ valueRangeMemory);
+ if (!result)
+ LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
+ memory[read()] = result;
+}
+
void ByteCodeExecutor::executeGetValueType() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
unsigned memIndex = read();
@@ -1079,6 +1523,28 @@ void ByteCodeExecutor::executeGetValueType() {
memory[memIndex] = type.getAsOpaquePointer();
}
+void ByteCodeExecutor::executeGetValueRangeTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ ValueRange *values = read<ValueRange *>();
+ if (!values) {
+ LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
+ memory[memIndex] = nullptr;
+ return;
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Values (" << values->size() << "): ";
+ llvm::interleaveComma(*values, llvm::dbgs());
+ llvm::dbgs() << "\n * Result: ";
+ llvm::interleaveComma(values->getType(), llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ typeRangeMemory[rangeIndex] = values->getType();
+ memory[memIndex] = &typeRangeMemory[rangeIndex];
+}
+
void ByteCodeExecutor::executeIsNotNull() {
LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
const void *value = read<const void *>();
@@ -1117,7 +1583,30 @@ void ByteCodeExecutor::executeRecordMatch(
LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
<< " * Location: " << matchLoc << "\n");
matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
- readList<const void *>(matches.back().values);
+ PDLByteCode::MatchResult &match = matches.back();
+
+ // Record all of the inputs to the match. If any of the inputs are ranges, we
+ // will also need to remap the range pointer to memory stored in the match
+ // state.
+ unsigned numInputs = read();
+ match.values.reserve(numInputs);
+ match.typeRangeValues.reserve(numInputs);
+ match.valueRangeValues.reserve(numInputs);
+ for (unsigned i = 0; i < numInputs; ++i) {
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::TypeRange:
+ match.typeRangeValues.push_back(*read<TypeRange *>());
+ match.values.push_back(&match.typeRangeValues.back());
+ break;
+ case PDLValue::Kind::ValueRange:
+ match.valueRangeValues.push_back(*read<ValueRange *>());
+ match.values.push_back(&match.valueRangeValues.back());
+ break;
+ default:
+ match.values.push_back(read<const void *>());
+ break;
+ }
+ }
curCodeIt = dest;
}
@@ -1125,7 +1614,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
- readList<Value>(args);
+ readValueList(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
@@ -1198,6 +1687,19 @@ void ByteCodeExecutor::executeSwitchType() {
handleSwitch(value, cases);
}
+void ByteCodeExecutor::executeSwitchTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
+ TypeRange *value = read<TypeRange *>();
+ auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
+ if (!value) {
+ LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
+ return selectJump(size_t(0));
+ }
+ handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
+ return value == caseValue.getAsValueRange<TypeAttr>();
+ });
+}
+
void ByteCodeExecutor::execute(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches,
@@ -1214,6 +1716,9 @@ void ByteCodeExecutor::execute(
case AreEqual:
executeAreEqual();
break;
+ case AreRangesEqual:
+ executeAreRangesEqual();
+ break;
case Branch:
executeBranch();
break;
@@ -1226,9 +1731,15 @@ void ByteCodeExecutor::execute(
case CheckResultCount:
executeCheckResultCount();
break;
+ case CheckTypes:
+ executeCheckTypes();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
+ case CreateTypes:
+ executeCreateTypes();
+ break;
case EraseOp:
executeEraseOp(rewriter);
break;
@@ -1257,6 +1768,9 @@ void ByteCodeExecutor::execute(
LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
executeGetOperand(read<uint32_t>());
break;
+ case GetOperands:
+ executeGetOperands();
+ break;
case GetResult0:
case GetResult1:
case GetResult2:
@@ -1270,9 +1784,15 @@ void ByteCodeExecutor::execute(
LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
executeGetResult(read<uint32_t>());
break;
+ case GetResults:
+ executeGetResults();
+ break;
case GetValueType:
executeGetValueType();
break;
+ case GetValueRangeTypes:
+ executeGetValueRangeTypes();
+ break;
case IsNotNull:
executeIsNotNull();
break;
@@ -1299,6 +1819,9 @@ void ByteCodeExecutor::execute(
case SwitchType:
executeSwitchType();
break;
+ case SwitchTypes:
+ executeSwitchTypes();
+ break;
}
LLVM_DEBUG(llvm::dbgs() << "\n");
}
@@ -1313,9 +1836,12 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.memory[0] = op;
// The matcher function always starts at code address 0.
- ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
- matcherByteCode, state.currentPatternBenefits,
- patterns, constraintFunctions, rewriteFunctions);
+ ByteCodeExecutor executor(
+ matcherByteCode.data(), state.memory, state.typeRangeMemory,
+ state.allocatedTypeRangeMemory, state.valueRangeMemory,
+ state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
+ state.currentPatternBenefits, patterns, constraintFunctions,
+ rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
@@ -1332,9 +1858,11 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
// memory buffer.
llvm::copy(match.values, state.memory.begin());
- ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
- state.memory, uniquedData, rewriterByteCode,
- state.currentPatternBenefits, patterns,
- constraintFunctions, rewriteFunctions);
+ ByteCodeExecutor executor(
+ &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
+ state.typeRangeMemory, state.allocatedTypeRangeMemory,
+ state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
+ rewriterByteCode, state.currentPatternBenefits, patterns,
+ constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
}