diff options
author | Jakub Kuderski <jakub@nod-labs.com> | 2024-04-01 11:40:09 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-01 11:40:09 -0400 |
commit | 971b852546a7d96bc8887ced913724b884cf40df (patch) | |
tree | 32c89d978a378b32e2647a3d9a0e05d2166ca42f | |
parent | a7206a6fa32ada15578e3afddcc1480364c25f4c (diff) | |
download | llvm-971b852546a7d96bc8887ced913724b884cf40df.zip llvm-971b852546a7d96bc8887ced913724b884cf40df.tar.gz llvm-971b852546a7d96bc8887ced913724b884cf40df.tar.bz2 |
[mlir][NFC] Simplify type checks with isa predicates (#87183)
For more context on isa predicates, see:
https://github.com/llvm/llvm-project/pull/83753.
28 files changed, 83 insertions, 118 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 73d418c..993c09b 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); - if (llvm::none_of(operandTypes, - [](Type type) { return isa<VectorType>(type); })) { + if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) { return rewriter.notifyMatchFailure(op, "expected vector operand"); } if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 85fb8a5..399c045 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -202,9 +202,7 @@ template <typename ExtOpTy> static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) return false; - return llvm::all_of(extOp->getUsers(), [](Operation *user) { - return isa<vector::ContractionOp>(user); - }); + return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>); } static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } @@ -345,15 +343,13 @@ getSliceContract(Operation *op, static SetVector<Operation *> getOpToConvert(mlir::Operation *op, bool useNvGpu) { auto hasVectorDest = [](Operation *op) { - return llvm::any_of(op->getResultTypes(), - [](Type t) { return isa<VectorType>(t); }); + return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>); }; BackwardSliceOptions backwardSliceOptions; backwardSliceOptions.filter = hasVectorDest; auto hasVectorSrc = [](Operation *op) { - return llvm::any_of(op->getOperandTypes(), - [](Type t) { return isa<VectorType>(t); }); + return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>); }; ForwardSliceOptions forwardSliceOptions; forwardSliceOptions.filter = hasVectorSrc; diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 6124492..69b3d41 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) { bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) { // Any memref-typed iteration arguments are treated as serializing. - if (llvm::any_of(forOp.getResultTypes(), - [](Type type) { return isa<BaseMemRefType>(type); })) + if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>)) return false; // Collect all load and store ops in loop nest rooted at 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 46c7871..71e9648 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> ¶llelLoops, int vectorRank, } static NestedPattern &vectorTransferPattern() { - static auto pattern = affine::matcher::Op([](Operation &op) { - return isa<vector::TransferReadOp, vector::TransferWriteOp>(op); - }); + static auto pattern = affine::matcher::Op( + llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>); return pattern; } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp index fb45528..84ae4b5 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); // Return common loop depth for loads if there are no store ops. - if (all_of(targetDstOps, - [&](Operation *op) { return isa<AffineReadOpInterface>(op); })) + if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>)) return loopDepth; // Check dependences on all pairs of ops in 'targetDstOps' and store the diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 4cdbbf3..053ea79 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -326,7 +326,7 @@ struct FuncOpInterface static bool supportsUnstructuredControlFlow() { return true; } bool hasTensorSemantics(Operation *op) const { - auto isaTensor = [](Type type) { return isa<TensorType>(type); }; + auto isaTensor = llvm::IsaPred<TensorType>; // A function has tensor semantics if it has tensor arguments/results. auto funcOp = cast<FuncOp>(op); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 33feea0..0a40726 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -67,6 +67,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" using namespace mlir; @@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp, /// Return "true" if the given function signature has tensor semantics. static bool hasTensorSignature(func::FuncOp funcOp) { - auto isaTensor = [](Type t) { return isa<TensorType>(t); }; - return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) || - llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor); + return llvm::any_of(funcOp.getFunctionType().getInputs(), + llvm::IsaPred<TensorType>) || + llvm::any_of(funcOp.getFunctionType().getResults(), + llvm::IsaPred<TensorType>); } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index ab5c418..f4a9dc3 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() { } } - if (llvm::any_of(getResultTypes(), - [](Type type) { return isa<ArrayType>(type); })) { + if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) { return emitOpError() << "cannot return array type"; } diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index fc3a437..b584f63 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, "scf.forall op requires a mapping attribute"); } - bool hasBlockMapping = - llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return isa<GPUBlockMappingAttr>(attr); - }); - bool hasWarpgroupMapping = - llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return isa<GPUWarpgroupMappingAttr>(attr); - }); - bool hasWarpMapping = - llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return isa<GPUWarpMappingAttr>(attr); - }); - bool hasThreadMapping = - llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return isa<GPUThreadMappingAttr>(attr); - }); + bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), + llvm::IsaPred<GPUBlockMappingAttr>); + bool hasWarpgroupMapping = llvm::any_of( + forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); + bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), + llvm::IsaPred<GPUWarpMappingAttr>); + bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), + llvm::IsaPred<GPUThreadMappingAttr>); int64_t countMappingTypes = 0; countMappingTypes += hasBlockMapping ? 1 : 0; countMappingTypes += hasWarpgroupMapping ? 1 : 0; diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index 40903f1..b2fa3a9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -232,9 +232,8 @@ private: // control flow code. static bool areAllUsersExecuteOrAwait(Value token) { return !token.use_empty() && - llvm::all_of(token.getUsers(), [](Operation *user) { - return isa<async::ExecuteOp, async::AwaitOp>(user); - }); + llvm::all_of(token.getUsers(), + llvm::IsaPred<async::ExecuteOp, async::AwaitOp>); } // Add the `asyncToken` as dependency as needed after `op`. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3ba6ac6..e5c19a9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() { if (!resultType) return success(); - auto isVector = [](Type type) { - return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>( - type); - }; + auto isVector = + llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>; // Due to bitcast requiring both operands to be of the same size, it is not // possible for only one of the two to be a pointer of vectors. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6954eee..2d7219f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" @@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs, RegionBuilderFn regionBuilder) { - assert(llvm::all_of(outputTypes, - [](Type t) { return llvm::isa<ShapedType>(t); })); + assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>)); SmallVector<Type, 8> argTypes; SmallVector<Location, 8> argLocs; @@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state, resultTensorTypes.value_or(TypeRange()); if (!resultTensorTypes) copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), - [](Type type) { return llvm::isa<RankedTensorType>(type); }); + llvm::IsaPred<RankedTensorType>); state.addOperands(inputs); state.addOperands(outputs); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 5508aaf..28d6752 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { // TODO: The conversion pattern can be made to work for `any_of` here, but // it's more complex as it requires tracking which operands are scalars. - return llvm::all_of(op->getOperandTypes(), - [](Type type) { return isa<RankedTensorType>(type); }); + return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index c74ab1e..25785653 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -3537,15 +3537,14 @@ private: // Otherwise, check for one or zero `ext` predecessor. The `ext` operands // must be block arguments or extension of block arguments. bool setOperKind(Operation *reduceOp) { - int numBlockArguments = llvm::count_if( - reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); }); + int numBlockArguments = + llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>); switch (numBlockArguments) { case 1: { // Will be convolution if feeder is a MulOp. // Otherwise, if it can be pooling. - auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) { - return !isa<BlockArgument>(v); - }); + auto feedValIt = llvm::find_if_not(reduceOp->getOperands(), + llvm::IsaPred<BlockArgument>); Operation *feedOp = (*feedValIt).getDefiningOp(); if (isCastOfBlockArgument(feedOp)) { oper = Pool; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index c09a340..9ba96e4 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, } static bool isComputeOperation(Operation *op) { - return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); + return isa<acc::ParallelOp, acc::LoopOp>(op); } namespace { diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index d532d46..2ff3efd 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() { if (getMatrixOperands()) { Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(), typeC.getElementType()}; - if (!llvm::all_of(elementTypes, - [](Type ty) { return isa<IntegerType>(ty); })) { + if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) { return emitOpError("Matrix Operands require all matrix element types to " "be Integer Types"); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index f5a3717..58c3f4c 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input, } static bool isErrorPropagationPossible(TypeRange operandTypes) { - return llvm::any_of(operandTypes, [](Type ty) { - return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty); - }); + return llvm::any_of(operandTypes, + llvm::IsaPred<SizeType, ShapeType, ValueShapeType>); } static LogicalResult verifySizeOrIndexOp(Operation *op) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index d4e0f8a..2efc157 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2, /// Returns a tuple corresponding to whether range has tensor or vector type. template <typename iterator_range> static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { - return std::make_tuple( - llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }), - llvm::any_of(types, [](Type t) { return isa<VectorType>(t); })); + return {llvm::any_of(types, llvm::IsaPred<TensorType>), + llvm::any_of(types, llvm::IsaPred<VectorType>)}; } static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, @@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, }; if (inferred.size() != existing.size()) return false; - for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing)) + for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing)) if (!isCompatible(inferredDim, existingDim)) return false; return true; @@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { std::get<1>(resultsHasTensorVectorType))) return op->emitError("cannot broadcast vector with tensor"); - auto rankedOperands = make_filter_range( - op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); }); + auto rankedOperands = + make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); // If all operands are unranked, then all result shapes are possible. if (rankedOperands.empty()) @@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { return op->emitOpError("operands don't have broadcast-compatible shapes"); } - auto rankedResults = make_filter_range( - op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); }); + auto rankedResults = + make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>); // If all of the results are unranked then no further verification. if (rankedResults.empty()) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 578b249..c8d06ba15 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { assert(outputs.size() == 1 && "expected one output"); return llvm::all_of( std::initializer_list<Type>{inputs.front(), outputs.front()}, - [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); }); + llvm::IsaPred<transform::TransformHandleTypeInterface>); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e566bfa..3e64258 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape( AffineMap resMap = op.getIndexingMapsArray()[2]; auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), - /*symCount=*/0, extents, ctx); + /*symbolCount=*/0, extents, ctx); // Compose the resMap with the extentsMap, which is a constant map. AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); - assert( - llvm::all_of(expectedMap.getResults(), - [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) && - "expected constant extent along all dimensions."); + assert(llvm::all_of(expectedMap.getResults(), + llvm::IsaPred<AffineConstantExpr>) && + "expected constant extent along all dimensions."); // Extract the expected shape and build the type. auto expectedShape = llvm::to_vector<4>( llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index b3ab4a9..a67e03e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -598,9 +598,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> { } // Do not process warp ops that contain only TransferWriteOps. - if (llvm::all_of(warpOp.getOps(), [](Operation &op) { - return isa<vector::TransferWriteOp, vector::YieldOp>(&op); - })) + if (llvm::all_of(warpOp.getOps(), + llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>)) return failure(); SmallVector<Value> yieldValues = {writeOp.getVector()}; @@ -746,8 +745,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> { using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = getWarpResult( - warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); }); + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); if (!yieldOperand) return failure(); auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); @@ -1060,8 +1059,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); @@ -1097,8 +1096,8 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> { using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); }); + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); if (!operand) return failure(); @@ -1156,8 +1155,8 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::CreateMaskOp>(op); }); + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); if (!yieldOperand) return failure(); @@ -1222,8 +1221,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> { using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); }); + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); @@ -1325,9 +1324,8 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> { warpShuffleFromIdxFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { - return isa<vector::ExtractElementOp>(op); - }); + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); @@ -1422,8 +1420,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> { LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); }); + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); @@ -1503,8 +1501,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> { LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); }); + OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); @@ -1808,8 +1805,8 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *yieldOperand = getWarpResult( - warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); if (!yieldOperand) return failure(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 00a0f05..6cdc268 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -359,9 +359,7 @@ bool AffineMap::isSingleConstant() const { } bool AffineMap::isConstant() const { - return llvm::all_of(getResults(), [](AffineExpr expr) { - return isa<AffineConstantExpr>(expr); - }); + return llvm::all_of(getResults(), llvm::IsaPred<AffineConstantExpr>); } int64_t AffineMap::getSingleConstantResult() const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index d6d5983..ca5ff9f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1288,9 +1288,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { } LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { - auto isMappableType = [](Type type) { - return llvm::isa<VectorType, TensorType>(type); - }; + auto isMappableType = llvm::IsaPred<VectorType, TensorType>; auto resultMappableTypes = llvm::to_vector<1>( llvm::make_filter_range(op->getResultTypes(), isMappableType)); auto operandMappableTypes = llvm::to_vector<2>( diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp index 9092adcc..fedf64f 100644 --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -369,9 +369,7 @@ void Class::finalize() { Visibility Class::getLastVisibilityDecl() const { auto reverseDecls = llvm::reverse(declarations); - auto it = llvm::find_if(reverseDecls, [](auto &decl) { - return isa<VisibilityDeclaration>(decl); - }); + auto it = llvm::find_if(reverseDecls, llvm::IsaPred<VisibilityDeclaration>); return it == reverseDecls.end() ? (isStruct ? Visibility::Public : Visibility::Private) : cast<VisibilityDeclaration>(**it).getVisibility(); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 95c7af2..0b07b4b 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1000,8 +1000,7 @@ static LogicalResult printOperation(CppEmitter &emitter, "with multiple blocks needs variables declared at top"); } - if (llvm::any_of(functionOp.getResultTypes(), - [](Type type) { return isa<ArrayType>(type); })) { + if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) { return functionOp.emitOpError() << "cannot emit array type as result type"; } @@ -1576,7 +1575,7 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { } LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { - if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) { + if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) { return emitError(loc, "cannot emit tuple of array type"); } os << "std::tuple<"; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 4a4e878..9a74ac1 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1031,9 +1031,9 @@ Serializer::processBlock(Block *block, bool omitLabel, // into multiple basic blocks. If that's the case, we need to emit the merge // right now and then create new blocks for further serialization of the ops // in this block. - if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) { - return isa<spirv::LoopOp, spirv::SelectionOp>(op); - })) { + if (emitMerge && + llvm::any_of(block->getOperations(), + llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) { if (failed(emitMerge())) return failure(); emitMerge = nullptr; @@ -1045,7 +1045,7 @@ Serializer::processBlock(Block *block, bool omitLabel, } // Process each op in this block except the terminator. - for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { + for (Operation &op : llvm::drop_end(*block)) { if (failed(processOperation(&op))) return failure(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2ec0b96..8671c10 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2825,9 +2825,9 @@ static void computeNecessaryMaterializations( } // Check to see if this is an argument materialization. - auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); }; - if (llvm::any_of(op->getOperands(), isBlockArg) || - llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { + if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>) || + llvm::any_of(inverseMapping[op->getResult(0)], + llvm::IsaPred<BlockArgument>)) { mat->setMaterializationKind(MaterializationKind::Argument); } diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index 7cb957d..fef9d8e 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -392,8 +392,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, // Argument materialization. assert(castKind == getCastKindName(CastKind::Argument) && "unexpected value of cast kind attribute"); - assert(llvm::all_of(operands, - [&](Value v) { return isa<BlockArgument>(v); })); + assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>)); maybeResult = typeConverter.materializeArgumentConversion( rewriter, castOp->getLoc(), resultTypes.front(), castOp.getOperands()); |