aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-12-07 08:47:20 +0900
committerGitHub <noreply@github.com>2023-12-07 08:47:20 +0900
commit861600f1751b1a7e84cd99dd79361569542e9c1a (patch)
tree2c9ca8d56f3f146761ed98ef1ce96801d54dbd44 /mlir
parent851f85fffb25143c267dcdbf6acd1916321ad308 (diff)
downloadllvm-861600f1751b1a7e84cd99dd79361569542e9c1a.zip
llvm-861600f1751b1a7e84cd99dd79361569542e9c1a.tar.gz
llvm-861600f1751b1a7e84cd99dd79361569542e9c1a.tar.bz2
[mlir][SparseTensor] Fix invalid IR in `ForallRewriter` pattern (#74547)
The `ForallRewriter` pattern used to generate invalid IR: ``` mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir:0:0: error: 'scf.for' op expects region #0 to have 0 or 1 blocks mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir:0:0: note: see current operation: "scf.for"(%8, %2, %9) ({ ^bb0(%arg5: index): // ... "scf.yield"() : () -> () ^bb1(%10: index): // no predecessors "scf.yield"() : () -> () }) : (index, index, index) -> () ``` This commit fixes tests such as `mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir` when verifying the IR after each pattern application (#74270).
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp4
1 files changed, 4 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 927fc71..5155cab 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -309,6 +309,10 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
// }
Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
+ // The scf.for builder creates an empty block. scf.for does not allow multiple
+ // blocks in its region, so delete the block before `cloneRegionBefore` adds
+ // an additional block.
+ rewriter.eraseBlock(forOp.getBody());
rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
forOp.getRegion().begin(), irMap);