aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp96
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Loops.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp12
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp7
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp11
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp24
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp22
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h17
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp67
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformDialect.cpp15
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp18
24 files changed, 207 insertions, 147 deletions
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9196d2e..39e398b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -170,7 +170,7 @@ public:
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
- .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ .DefaultUnreachable("unexpected extend op!");
} else if (kind == arm_sme::CombiningKind::Sub) {
TypeSwitch<Operation *>(extOp)
.Case<arith::ExtFOp>([&](auto) {
@@ -188,7 +188,7 @@ public:
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
- .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ .DefaultUnreachable("unexpected extend op!");
} else {
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index e30e094..25f941d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -23,6 +23,8 @@ namespace bufferization {
using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
+using AllocDynamicSizesMap =
+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
return type.getLayout().isIdentity();
}
+/// Return the dynamic shapes of the `memref` based on the defining op. If the
+/// complete dynamic shape fails to be captured, return an empty value.
+/// Currently, only function block arguments are supported for capturing.
+static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
+ Operation *defOp = memref.getDefiningOp();
+ if (!defOp)
+ return {};
+ auto operands = defOp->getOperands();
+ SmallVector<Value> dynamicSizes;
+ for (Value size : operands) {
+ if (!isa<IndexType>(size.getType()))
+ continue;
+
+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
+ if (!sizeSrc)
+ return {};
+ auto arguments = funcOp.getArguments();
+ auto iter = llvm::find(arguments, sizeSrc);
+ if (iter == arguments.end())
+ return {};
+ dynamicSizes.push_back(*iter);
+ }
+ return dynamicSizes;
+}
+
+/// Returns the dynamic sizes at the callee, through the call relationship
+/// between the caller and callee.
+static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
+ func::FuncOp callee,
+ ValueRange dynamicSizes) {
+ SmallVector<Value> mappedDynamicSizes;
+ for (Value size : dynamicSizes) {
+ for (auto [src, dst] :
+ llvm::zip_first(call.getOperands(), callee.getArguments())) {
+ if (size != dst)
+ continue;
+ mappedDynamicSizes.push_back(src);
+ }
+ }
+ assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
+ "could not find all dynamic sizes");
+ return mappedDynamicSizes;
+}
+
// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
+ AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
@@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
+ SmallVector<SmallVector<Value>> dynamicSizes;
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
- if (options.hoistStaticAllocs &&
+ bool hoistStaticAllocs =
+ options.hoistStaticAllocs &&
+ cast<MemRefType>(orig.getType()).hasStaticShape();
+ bool hoistDynamicAllocs =
+ options.hoistDynamicAllocs &&
+ !cast<MemRefType>(orig.getType()).hasStaticShape();
+ if ((hoistStaticAllocs || hoistDynamicAllocs) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
- orig.getDefiningOp()) &&
- mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
+ orig.getDefiningOp())) {
orig.replaceAllUsesWith(arg);
+ if (hoistDynamicAllocs) {
+ SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
+ dynamicSizes.push_back(dynamicSize);
+ }
orig.getDefiningOp()->erase();
} else {
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
@@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
}
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
op.erase();
+ auto dynamicSizePair =
+ std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
+ dynamicSizes);
+ map.insert(dynamicSizePair);
return WalkResult::advance();
});
return failure(res.wasInterrupted());
@@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
-updateCalls(ModuleOp module,
+updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
@@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
}
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
+ size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
- if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
+ SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
+ ? dynamicSizes[dynamicSizesIndex]
+ : SmallVector<Value>();
+ bool memrefStaticShape =
+ cast<MemRefType>(memref.getType()).hasStaticShape();
+ if (!memrefStaticShape && dynamicSize.empty()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
+
+ if (memrefStaticShape) {
+ dynamicSize = {};
+ } else {
+ ++dynamicSizesIndex;
+ dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
+ }
auto maybeOutParam =
- options.allocationFn(builder, op.getLoc(), allocType);
+ options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
@@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOpts &options) {
+ // It maps the shape source of the dynamic shape memref returned by each
+ // function.
+ AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
@@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
+ if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
return failure();
}
}
- if (failed(updateCalls(module, options)))
+ if (failed(updateCalls(module, map, options)))
return failure();
return success();
}
@@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;
+ if (hoistDynamicAllocs)
+ options.hoistDynamicAllocs = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c0f9132..19eba6b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -375,7 +375,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
os << shape.back() << 'x' << fragTy.getElementType();
os << ", \"" << fragTy.getOperand() << "\"" << '>';
})
- .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
+ .DefaultUnreachable("unexpected 'gpu' type kind");
}
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 2561f66..0a3ef7d 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -847,9 +847,7 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
*maybeMaskingAttr);
})
- .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
- llvm_unreachable("unknown mapping attribute");
- });
+ .DefaultUnreachable("unknown mapping attribute");
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index ef38027..cee943d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1096,10 +1096,8 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
Value intVal = buildMemsetValue(type.getWidth());
return LLVM::BitcastOp::create(builder, op.getLoc(), type, intVal);
})
- .Default([](Type) -> Value {
- llvm_unreachable(
- "getStored should not be called on memset to unsupported type");
- });
+ .DefaultUnreachable(
+ "getStored should not be called on memset to unsupported type");
}
template <class MemsetIntr>
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 297640c..705d07d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,9 +45,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
.Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
- .Default([](Type) -> StringRef {
- llvm_unreachable("unexpected 'llvm' type kind");
- });
+ .DefaultUnreachable("unexpected 'llvm' type kind");
}
/// Prints a structure type. Keeps track of known struct names to handle self-
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 682bf8c..e8f8824 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -648,6 +648,9 @@ LogicalResult MmaOp::verify() {
expectedB.emplace_back(unitB, multiplicandFragType);
allowedShapes.push_back({16, 8, kFactor});
allowedShapes.push_back({16, 8, kFactor * 2});
+
+ if (resultPtxType() != accumPtxType())
+ return emitOpError("ctype does not match dtype");
}
// In the M=8 case, there is only 1 possible case per data type.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 38f1a8b..42160a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -192,7 +192,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
.Case([&](affine::AffineForOp affineForOp) {
allIvs.push_back(affineForOp.getInductionVar());
})
- .Default([&](Operation *op) { assert(false && "unexpected op"); });
+ .DefaultUnreachable("unexpected op");
}
assert(linalgOp.getNumLoops() == allIvs.size() &&
"expected the number of loops and induction variables to match");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
index 00a076b..c904556 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -48,10 +48,7 @@ ElementwiseKind getKind(Operation *op) {
.Case([](SquareOp) { return ElementwiseKind::square; })
.Case([](TanhOp) { return ElementwiseKind::tanh; })
.Case([](ErfOp) { return ElementwiseKind::erf; })
- .Default([&](Operation *op) {
- llvm_unreachable("unhandled case in named to elementwise");
- return ElementwiseKind::sub;
- });
+ .DefaultUnreachable("unhandled case in named to elementwise");
}
template <typename NamedOpTy>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b25..7863c21 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1427,10 +1427,7 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
.Case([&](linalg::PoolingNchwMaxOp op) {
return std::make_tuple(0, 1, 2, 3);
})
- .Default([&](Operation *op) {
- llvm_unreachable("unexpected conv2d/pool2d operation.");
- return std::make_tuple(0, 0, 0, 0);
- });
+ .DefaultUnreachable("unexpected conv2d/pool2d operation.");
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3593b53..24d3722 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -604,9 +604,7 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
builder, loc, valueToTile, sliceParams.offsets,
sliceParams.sizes, sliceParams.strides);
})
- .Default([](ShapedType) -> Operation * {
- llvm_unreachable("Unexpected shaped type");
- });
+ .DefaultUnreachable("Unexpected shaped type");
return sliceOp;
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 24da447..214410f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -315,7 +315,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), subViewOp.getSource(), sourceIndices,
op.getTranspose(), op.getNumTiles());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -367,7 +367,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -415,7 +415,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -482,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
op, op.getSrc(), subViewOp.getSource(), sourceIndices,
op.getLeadDimension(), op.getTransposeAttr());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -535,7 +535,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -584,7 +584,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5672942..fd4cabbad 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3425,10 +3425,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
llvm_unreachable("Unexpected generatee argument");
})
- .Default([&](Operation *op) {
- assert(false && "TODO: Custom name for this operation");
- return "transformed";
- });
+ .DefaultUnreachable("TODO: Custom name for this operation");
}
setNameFn(result, cliName);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 36685d3..29b770f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2177,10 +2177,9 @@ cloneAsInsertSlices(RewriterBase &rewriter,
auto clonedOp = cloneAsInsertSlice(rewriter, op);
clonedSlices.push_back(clonedOp);
})
- .Default([&](Operation *op) {
- // Assert here assuming this has already been checked.
- assert(0 && "unexpected slice type while cloning as insert slice");
- });
+ // Assert here assuming this has already been checked.
+ .DefaultUnreachable(
+ "unexpected slice type while cloning as insert slice");
}
return clonedSlices;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c8efdf0..24c33f9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -987,7 +987,7 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
- .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
+ .DefaultUnreachable("Unhandled SPIR-V type");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7e9a80e..f895807 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -57,7 +57,7 @@ public:
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
- .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ .DefaultUnreachable("Unhandled type");
}
void add(Type type) { add(cast<SPIRVType>(type)); }
@@ -107,7 +107,7 @@ public:
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
- .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ .DefaultUnreachable("Unhandled type");
}
void add(Type type) { add(cast<SPIRVType>(type)); }
@@ -198,8 +198,7 @@ Type CompositeType::getElementType(unsigned index) const {
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
- .Default(
- [](Type) -> Type { llvm_unreachable("invalid composite type"); });
+ .DefaultUnreachable("Invalid composite type");
}
unsigned CompositeType::getNumElements() const {
@@ -207,9 +206,7 @@ unsigned CompositeType::getNumElements() const {
.Case<ArrayType, StructType, TensorArmType, VectorType>(
[](auto type) { return type.getNumElements(); })
.Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
- .Default([](SPIRVType) -> unsigned {
- llvm_unreachable("Invalid type for number of elements query");
- });
+ .DefaultUnreachable("Invalid type for number of elements query");
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 122f61e0..88e1ab6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -622,7 +622,7 @@ static spirv::Dim convertRank(int64_t rank) {
}
static spirv::ImageFormat getImageFormat(Type elementType) {
- return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
+ return TypeSwitch<Type, spirv::ImageFormat>(elementType)
.Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
.Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
.Case<IntegerType>([](IntegerType intType) {
@@ -639,11 +639,7 @@ static spirv::ImageFormat getImageFormat(Type elementType) {
llvm_unreachable("Unhandled integer type!");
}
})
- .Default([](Type) {
- llvm_unreachable("Unhandled element type!");
- // We need to return something here to satisfy the type switch.
- return spirv::ImageFormat::R32f;
- });
+ .DefaultUnreachable("Unhandled element type!");
#undef BIT_WIDTH_CASE
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a1e35b8..0fc5cc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
// Flattens an affine expression into a list of AffineDimExprs.
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
- explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+ explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
BitVector dims;
};
@@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
// Flattens an affine expression into a list of AffineDimExprs.
struct AffineExprAdmissibleVisitor
: public AffineExprVisitor<AffineExprAdmissibleVisitor> {
- explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
+ explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
// We only allow AffineDimExpr on output.
void visitAddExpr(AffineBinaryOpExpr expr) {
@@ -407,7 +407,10 @@ public:
};
struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
- using OpRewritePattern::OpRewritePattern;
+ GenericOpScheduler(MLIRContext *context,
+ sparse_tensor::LoopOrderingStrategy strategy)
+ : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
+
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
if (linalgOp->hasAttr(sorted))
return failure();
- auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
+ // Pass strategy to IterationGraphSorter.
+ auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
bool isAdmissible = false;
AffineMap order;
// A const list of all masks that we used for iteration graph
@@ -582,6 +586,9 @@ private:
// TODO: convert more than one?
return failure();
}
+
+private:
+ sparse_tensor::LoopOrderingStrategy strategy;
};
//===----------------------------------------------------------------------===//
@@ -786,12 +793,13 @@ struct ForeachOpDemapper
} // namespace
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
- ReinterpretMapScope scope) {
+void mlir::populateSparseReinterpretMap(
+ RewritePatternSet &patterns, ReinterpretMapScope scope,
+ sparse_tensor::LoopOrderingStrategy strategy) {
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kGenericOnly) {
- patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
- patterns.getContext());
+ patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+ patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 153b9b1..b660e22 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -67,12 +67,13 @@ struct SparseReinterpretMap
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
scope = options.scope;
+ loopOrderingStrategy = options.loopOrderingStrategy;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseReinterpretMap(patterns, scope);
+ populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
return std::make_unique<SparseReinterpretMap>(options);
}
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass(
+ ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) {
+ SparseReinterpretMapOptions options;
+ options.scope = scope;
+ options.loopOrderingStrategy = strategy;
+ return std::make_unique<SparseReinterpretMap>(options);
+}
+
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index c7e463a..73e0f3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() {
// We always prefer a parallel loop over a reduction loop because putting
// a reduction loop early might make the loop sequence inadmissible.
auto &it = !parIt.empty() ? parIt : redIt;
- auto src = it.back();
+
+ // Select loop based on strategy.
+ unsigned src;
+ switch (strategy) {
+ case sparse_tensor::LoopOrderingStrategy::kDefault:
+ src = it.back();
+ break;
+ }
+
loopOrder.push_back(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
@@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() {
return AffineMap();
}
-IterationGraphSorter
-IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
+IterationGraphSorter IterationGraphSorter::fromGenericOp(
+ linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) {
// Must be a demapped sparse kernel.
assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
genericOp.getIteratorTypesArray();
return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
- std::move(iterTypes));
+ std::move(iterTypes), strategy);
}
IterationGraphSorter::IterationGraphSorter(
SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
- AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
+ AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+ sparse_tensor::LoopOrderingStrategy strategy)
: ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
- loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+ strategy(strategy) {
// One map per tensor.
assert(loop2InsLvl.size() == ins.size());
// All the affine maps have the same number of dimensions (loops).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index a6abe9e..b2a16e9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/IR/AffineMap.h"
namespace mlir {
@@ -41,9 +42,12 @@ enum class SortMask : unsigned {
class IterationGraphSorter {
public:
- /// Factory method that construct an iteration graph sorter
- /// for the given linalg.generic operation.
- static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
+ /// Factory method that constructs an iteration graph sorter
+ /// for the given linalg.generic operation with a specific loop ordering
+ /// strategy.
+ static IterationGraphSorter
+ fromGenericOp(linalg::GenericOp genericOp,
+ sparse_tensor::LoopOrderingStrategy strategy);
/// Returns a permutation that represents the scheduled loop order.
/// Note that the returned AffineMap could be null if the kernel
@@ -58,7 +62,9 @@ private:
IterationGraphSorter(SmallVector<Value> &&ins,
SmallVector<AffineMap> &&loop2InsLvl, Value out,
AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes);
+ SmallVector<utils::IteratorType> &&iterTypes,
+ sparse_tensor::LoopOrderingStrategy strategy =
+ sparse_tensor::LoopOrderingStrategy::kDefault);
// Adds all the constraints in the given loop to level map.
void addConstraints(Value t, AffineMap loop2LvlMap);
@@ -84,6 +90,9 @@ private:
// InDegree used for topo sort.
std::vector<unsigned> inDegree;
+
+ // Loop ordering strategy.
+ sparse_tensor::LoopOrderingStrategy strategy;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 332f1a0..c51b5e9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
- StringRef symName) {
- ModuleOp module = op->getParentOfType<ModuleOp>();
- tosa::VariableOp varOp = nullptr;
-
- // TODO: Adopt SymbolTable trait to Varible ops.
- // Currently, the variable's definition point is searched via walk(),
- // starting from the top-level ModuleOp and stopping at the point of use. Once
- // TOSA control flow and variable extensions reach the complete state, may
- // leverage MLIR's Symbol Table functionality to look up symbol and enhance
- // the search to a TOSA specific graph traversal over the IR structure.
- module.walk([&](Operation *tempOp) {
- // Reach this op itself.
- if (tempOp == op) {
- return WalkResult::interrupt();
- }
-
- if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
- if (symName == tosaOp.getName()) {
- varOp = tosaOp;
- return WalkResult::interrupt();
- }
- }
-
- return WalkResult::advance();
- });
-
- if (varOp)
- return varOp;
-
- return failure();
-}
-
template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
- StringRef symName = op.getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
- if (failed(varOp))
+ Operation *symTableOp =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symTableOp)
+ // If the operation is not the scope of a symbol table, we cannot
+ // verify it against it's declaration.
+ return success();
+
+ SymbolTable symTable(symTableOp);
+ const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+ // Verify prior declaration
+ if (!varOp)
return op->emitOpError("'")
- << symName << "' has not been declared by 'tosa.variable'";
+ << op.getName() << "' has not been declared by 'tosa.variable'";
// Verify type and shape
- auto variableType = getVariableType(varOp.value());
+ auto variableType = getVariableType(varOp);
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();
-
return success();
}
@@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
- result.addAttribute("name", nameAttr);
+ result.addAttribute("sym_name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
@@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
-LogicalResult tosa::VariableOp::verify() {
- StringRef symName = getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
- if (succeeded(varOp))
- return emitOpError("illegal to have multiple declaration of '")
- << symName << "'";
-
- return success();
-}
-
LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index a500228..45cef9c1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Verifier.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
"operations with symbol tables";
}
+ // Pre-verify calls and callables because call graph construction below
+ // assumes they are valid, but this verifier runs before verifying the
+ // nested operations.
+ WalkResult walkResult = op->walk([](Operation *nested) {
+ if (!isa<CallableOpInterface, CallOpInterface>(nested))
+ return WalkResult::advance();
+
+ if (failed(verify(nested, /*verifyRecursively=*/false)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
const mlir::CallGraph callgraph(op);
for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
if (!scc.hasCycle())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3385b2a..365afab 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2097,17 +2097,11 @@ void transform::IncludeOp::getEffects(
getOperation(), getTarget());
if (!callee)
return defaultEffects();
- DiagnosedSilenceableFailure earlyVerifierResult =
- verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
- if (!earlyVerifierResult.succeeded()) {
- (void)earlyVerifierResult.silence();
- return defaultEffects();
- }
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
consumesHandle(getOperation()->getOpOperand(i), effects);
- else
+ else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
onlyReadsHandle(getOperation()->getOpOperand(i), effects);
}
}
@@ -2597,10 +2591,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
.Case([&](TransformParamTypeInterface param) {
return llvm::range_size(state.getParams(getHandle()));
})
- .Default([](Type) {
- llvm_unreachable("unknown kind of transform dialect type");
- return 0;
- });
+ .DefaultUnreachable("unknown kind of transform dialect type");
results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
@@ -2657,10 +2648,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
.Case<TransformParamTypeInterface>([&](auto x) {
return llvm::range_size(state.getParams(getHandle()));
})
- .Default([](auto x) {
- llvm_unreachable("unknown transform dialect type interface");
- return -1;
- });
+ .DefaultUnreachable("unknown transform dialect type interface");
auto produceNumOpsError = [&]() {
return emitSilenceableError()