diff options
Diffstat (limited to 'mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp')
-rw-r--r-- | mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index f405d0c..61166db 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -339,6 +339,25 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// ScaledExtPacked816Op +//===----------------------------------------------------------------------===// +LogicalResult ScaledExtPacked816Op::verify() { + int blockSize = getBlockSize(); + assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); + if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError( + "blockSize of 16 can only have firstScaleByte be 0 or 1."); + } + if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError( + "blockSize of 32 can only have firstScaleByte be 0 or 2."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// LogicalResult WMMAOp::verify() { @@ -757,13 +776,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> { offset = numElements - 4l; } Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), - scaleSrcElemType); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); Value newScaleSrc = vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, - ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); |