diff options
author | Matthias Springer <me@m-sp.org> | 2023-12-07 08:47:20 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-07 08:47:20 +0900 |
commit | 861600f1751b1a7e84cd99dd79361569542e9c1a (patch) | |
tree | 2c9ca8d56f3f146761ed98ef1ce96801d54dbd44 /mlir | |
parent | 851f85fffb25143c267dcdbf6acd1916321ad308 (diff) | |
download | llvm-861600f1751b1a7e84cd99dd79361569542e9c1a.zip llvm-861600f1751b1a7e84cd99dd79361569542e9c1a.tar.gz llvm-861600f1751b1a7e84cd99dd79361569542e9c1a.tar.bz2 |
[mlir][SparseTensor] Fix invalid IR in `ForallRewriter` pattern (#74547)
The `ForallRewriter` pattern used to generate invalid IR:
```
mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir:0:0: error: 'scf.for' op expects region #0 to have 0 or 1 blocks
mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir:0:0: note: see current operation:
"scf.for"(%8, %2, %9) ({
^bb0(%arg5: index):
// ...
"scf.yield"() : () -> ()
^bb1(%10: index): // no predecessors
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
```
This commit fixes tests such as
`mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir` when verifying the
IR after each pattern application (#74270).
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index 927fc71..5155cab 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -309,6 +309,10 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, // } Value upper = irMap.lookup(forallOp.getUpperBound()[0]); scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); + // The scf.for builder creates an empty block. scf.for does not allow multiple + // blocks in its region, so delete the block before `cloneRegionBefore` adds + // an additional block. + rewriter.eraseBlock(forOp.getBody()); rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), forOp.getRegion().begin(), irMap); |