diff options
author | George Mitenkov <georgemitenk0v@gmail.com> | 2020-07-15 10:02:01 +0300 |
---|---|---|
committer | George Mitenkov <georgemitenk0v@gmail.com> | 2020-07-15 10:29:46 +0300 |
commit | d431951343cdaa301cbd72743fde8114b93f9d33 (patch) | |
tree | 8c3e43a3ab599c7ec735ea930bfc344d2a18eec5 | |
parent | 1919c8bfe8379402401da52d84d5397233cab8b9 (diff) | |
download | llvm-d431951343cdaa301cbd72743fde8114b93f9d33.zip llvm-d431951343cdaa301cbd72743fde8114b93f9d33.tar.gz llvm-d431951343cdaa301cbd72743fde8114b93f9d33.tar.bz2 |
[MLIR][SPIRVToLLVM] SPIRV function fix and nits
This patch addresses the comments from https://reviews.llvm.org/D83030 and
https://reviews.llvm.org/D82639. `this->` is removed when not inside the
template. Also, type conversion for `spv.func` takes `convertRegionTypes()`
in order to apply type conversion on all blocks within the function.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D83786
-rw-r--r-- | mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp index 5820d90..b070291 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -83,11 +83,12 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { - if (srcType.isa<VectorType>()) + if (srcType.isa<VectorType>()) { return rewriter.create<LLVM::ConstantOp>( loc, dstType, SplatElementsAttr::get(srcType.cast<ShapedType>(), minusOneIntegerAttribute(srcType, rewriter))); + } return rewriter.create<LLVM::ConstantOp>( loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } @@ -239,7 +240,7 @@ public: matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -328,7 +329,7 @@ public: matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -381,7 +382,7 @@ public: matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); - auto dstType = this->typeConverter.convertType(srcType); + auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); @@ -473,7 +474,7 @@ public: } // Function returns a single result. - auto dstType = this->typeConverter.convertType(callOp.getType(0)); + auto dstType = typeConverter.convertType(callOp.getType(0)); rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands, callOp.getAttrs()); return success(); @@ -638,7 +639,7 @@ public: auto funcType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); - auto llvmType = this->typeConverter.convertFunctionSignature( + auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); if (!llvmType) return failure(); @@ -675,7 +676,10 @@ public: rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) { + return failure(); + } rewriter.eraseOp(funcOp); return success(); } |