diff options
Diffstat (limited to 'mlir')
73 files changed, 1685 insertions, 1692 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 2d9f78e..16c898b 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -123,7 +123,6 @@ else() endif() add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ENABLE_ROCM_CONVERSIONS}) -set(MLIR_ENABLE_DEPRECATED_GPU_SERIALIZATION 0 CACHE BOOL "Enable deprecated GPU serialization passes") set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the mlir CUDA runner") set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the mlir ROCm runner") set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the mlir Sycl runner") diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h index 5885fac..8f7466a 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -147,25 +147,11 @@ protected: // Registration //===----------------------------------------------------------------------===// -/// Register pass to serialize GPU kernel functions to a CUBIN binary -/// annotation. -LLVM_DEPRECATED("use Target attributes instead", "") -void registerGpuSerializeToCubinPass(); - /// Register pass to serialize GPU kernel functions to a HSAco binary /// annotation. LLVM_DEPRECATED("use Target attributes instead", "") void registerGpuSerializeToHsacoPass(); -/// Create an instance of the GPU kernel function to CUBIN binary serialization -/// pass with optLevel (default level 2). -LLVM_DEPRECATED("use Target attributes instead", "") -std::unique_ptr<Pass> createGpuSerializeToCubinPass(StringRef triple, - StringRef chip, - StringRef features, - int optLevel = 2, - bool dumpPtx = false); - /// Create an instance of the GPU kernel function to HSAco binary serialization /// pass. LLVM_DEPRECATED("use Target attributes instead", "") diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index feb3578..b88f118 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -513,7 +513,11 @@ def LLVM_DbgLabelOp : LLVM_IntrOp<"dbg.label", [], [], [], 0> { }); }]; let mlirBuilder = [{ - $_op = $_builder.create<$_qualCppClassName>($_location, $_label_attr($label)); + DILabelAttr labelAttr = $_label_attr($label); + // Drop the intrinsic if the label translation fails due to cylic metadata. + if (!labelAttr) + return success(); + $_op = $_builder.create<$_qualCppClassName>($_location, labelAttr); }]; let assemblyFormat = "$label attr-dict"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index d9b130b..3da5dee 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -49,7 +49,7 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic, } class LLVM_IntArithmeticOp<string mnemonic, string instName, list<Trait> traits = []> : - LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, traits> { + LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, traits> { let arguments = commonArgs; string mlirBuilder = [{ $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); @@ -57,7 +57,7 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName, } class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, list<Trait> traits = []> : - LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, + LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> { dag iofArg = ( ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags); @@ -143,9 +143,9 @@ class LLVM_ArithmeticCmpOp<string mnemonic, list<Trait> traits = []> : // Other integer operations. def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> { let arguments = (ins ICmpPredicate:$predicate, - AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, + AnyTypeOf<[LLVM_ScalarOrVectorOf<AnySignlessInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs, - AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, + AnyTypeOf<[LLVM_ScalarOrVectorOf<AnySignlessInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs); let hasCustomAssemblyFormat = 1; string llvmInstName = "ICmp"; @@ -204,7 +204,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca", DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>, DeclareOpInterfaceMethods<GetResultPtrElementType>]>, LLVM_MemOpPatterns { - let arguments = (ins AnyInteger:$arraySize, + let arguments = (ins AnySignlessInteger:$arraySize, OptionalAttr<I64Attr>:$alignment, TypeAttr:$elem_type, UnitAttr:$inalloca); @@ -250,7 +250,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<GetResultPtrElementType>]> { let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base, - Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices, + Variadic<LLVM_ScalarOrVectorOf<AnySignlessInteger>>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, TypeAttr:$elem_type, UnitAttr:$inbounds); @@ -499,37 +499,37 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast", let hasFolder = 1; } def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr", - LLVM_ScalarOrVectorOf<AnyInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>; def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt", LLVM_ScalarOrVectorOf<LLVM_AnyPointer>, - LLVM_ScalarOrVectorOf<AnyInteger>>; + LLVM_ScalarOrVectorOf<AnySignlessInteger>>; def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt", - LLVM_ScalarOrVectorOf<AnyInteger>, - LLVM_ScalarOrVectorOf<AnyInteger>> { + LLVM_ScalarOrVectorOf<AnySignlessInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>> { let hasVerifier = 1; } def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt", - LLVM_ScalarOrVectorOf<AnyInteger>, - LLVM_ScalarOrVectorOf<AnyInteger>> { + LLVM_ScalarOrVectorOf<AnySignlessInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>> { let hasFolder = 1; let hasVerifier = 1; } def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc", - LLVM_ScalarOrVectorOf<AnyInteger>, - LLVM_ScalarOrVectorOf<AnyInteger>>; + LLVM_ScalarOrVectorOf<AnySignlessInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>>; def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP", - LLVM_ScalarOrVectorOf<AnyInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>; def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP", - LLVM_ScalarOrVectorOf<AnyInteger>, + LLVM_ScalarOrVectorOf<AnySignlessInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>; def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI", LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, - LLVM_ScalarOrVectorOf<AnyInteger>>; + LLVM_ScalarOrVectorOf<AnySignlessInteger>>; def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "FPToUI", LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, - LLVM_ScalarOrVectorOf<AnyInteger>>; + LLVM_ScalarOrVectorOf<AnySignlessInteger>>; def LLVM_FPExtOp : LLVM_CastOp<"fpext", "FPExt", LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>; @@ -671,7 +671,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure, "LLVM::getVectorElementType($_self)">]> { let summary = "Extract an element from an LLVM vector."; - let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); + let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position); let results = (outs LLVM_Type:$res); let assemblyFormat = [{ @@ -733,7 +733,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure, let summary = "Insert an element into an LLVM vector."; let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, - AnyInteger:$position); + AnySignlessInteger:$position); let results = (outs LLVM_AnyVector:$res); let builders = [LLVM_OneResultOpBuilder]; @@ -971,7 +971,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", DeclareOpInterfaceMethods<BranchWeightOpInterface>, Pure]> { let arguments = (ins - AnyInteger:$value, + AnySignlessInteger:$value, Variadic<AnyType>:$defaultOperands, VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, OptionalAttr<AnyIntElementsAttr>:$case_values, @@ -1647,7 +1647,7 @@ def LLVM_ConstantOp // Atomic operations. // -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnyInteger]>; +def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>; def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ TypesMatchWith<"result #0 and operand #1 have the same type", @@ -1696,7 +1696,7 @@ def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ let hasVerifier = 1; } -def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>; +def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnySignlessInteger, LLVM_AnyPointer]>; def LLVM_AtomicCmpXchgOp : LLVM_MemAccessOpBase<"cmpxchg", [ TypesMatchWith<"operand #1 and operand #2 have the same type", diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 010dde5..11b2c7a 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -45,6 +45,9 @@ struct MathPolynomialApproximationOptions { bool enableAvx2 = false; }; +void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns); +void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns); + void populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options = {}); diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index a00c9c3..1c81d80 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -145,12 +145,9 @@ constexpr bool isComplexPrimaryType(PrimaryType valTy) { /// The actions performed by @newSparseTensor. enum class Action : uint32_t { kEmpty = 0, - kEmptyForward = 1, - kFromCOO = 2, - kFromReader = 4, - kToCOO = 5, - kPack = 7, - kSortCOOInPlace = 8, + kFromReader = 1, + kPack = 2, + kSortCOOInPlace = 3, }; /// This enum defines all supported storage format without the level properties. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0ee9e71..0ecded7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -990,6 +990,26 @@ def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> { } //===----------------------------------------------------------------------===// +// Operator: cos +//===----------------------------------------------------------------------===// +def Tosa_CosOp : Tosa_ElementwiseOp<"cos", + [SameOperandsAndResultElementType]> { + let summary = "Elementwise cos op"; + + let description = [{ + Elementwise cosine operation for values given in radians. + }]; + + let arguments = (ins + Tosa_FloatTensor:$input + ); + + let results = (outs + Tosa_FloatTensor:$output + ); +} + +//===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> { @@ -1149,6 +1169,26 @@ def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt", } //===----------------------------------------------------------------------===// +// Operator: sin +//===----------------------------------------------------------------------===// +def Tosa_SinOp : Tosa_ElementwiseOp<"sin", + [SameOperandsAndResultElementType]> { + let summary = "Elementwise sin op"; + + let description = [{ + Elementwise sine operation for values given in radians. + }]; + + let arguments = (ins + Tosa_FloatTensor:$input + ); + + let results = (outs + Tosa_FloatTensor:$output + ); +} + +//===----------------------------------------------------------------------===// // TOSA Spec Section 2.6 // Operator Class: Elementwise unary/binary/ternary operators. // Operator Subclass: Elementwise ternary ops. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index c55ddaa..5a4d6ff 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -113,6 +113,8 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>; def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; +def Tosa_FloatTensor : TensorOf<[Tosa_Float]>; + // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>; diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 2453d84..9892253 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -257,6 +257,9 @@ SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, std::pair<AffineExpr, SmallVector<OpFoldResult>> computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices); +std::pair<AffineExpr, SmallVector<OpFoldResult>> +computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, + ArrayRef<Value> indices); //===----------------------------------------------------------------------===// // Utilities for decomposing larger shapes diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index eff1aca..fe0e08b 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -149,13 +149,6 @@ public: MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES) #undef DECL_GETVALUES - /// Element-wise forwarding insertions. The first argument is the - /// dimension-coordinates for the value being inserted. -#define DECL_FORWARDINGINSERT(VNAME, V) \ - virtual void forwardingInsert(const uint64_t *, V); - MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT) -#undef DECL_FORWARDINGINSERT - /// Element-wise insertion in lexicographic coordinate order. The first /// argument is the level-coordinates for the value being inserted. #define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V); @@ -171,9 +164,6 @@ public: MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT) #undef DECL_EXPINSERT - /// Finalizes forwarding insertions. - virtual void endForwardingInsert() = 0; - /// Finalizes lexicographic insertions. virtual void endLexInsert() = 0; @@ -248,7 +238,7 @@ public: static SparseTensorStorage<P, C, V> * newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, - const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding); + const uint64_t *dim2lvl, const uint64_t *lvl2dim); /// Allocates a new sparse tensor and initializes it from the given COO. static SparseTensorStorage<P, C, V> * @@ -284,13 +274,6 @@ public: *out = &values; } - /// Partially specialize forwarding insertions based on template types. - void forwardingInsert(const uint64_t *dimCoords, V val) final { - assert(dimCoords && coo); - map.pushforward(dimCoords, lvlCursor.data()); - coo->add(lvlCursor, val); - } - /// Partially specialize lexicographical insertions based on template types. void lexInsert(const uint64_t *lvlCoords, V val) final { assert(lvlCoords); @@ -345,21 +328,6 @@ public: } } - /// Finalizes forwarding insertions. - void endForwardingInsert() final { - // Ensure COO is sorted. - assert(coo); - coo->sort(); - // Now actually insert the `elements`. - const auto &elements = coo->getElements(); - const uint64_t nse = elements.size(); - assert(values.size() == 0); - values.reserve(nse); - fromCOO(elements, 0, nse, 0); - delete coo; - coo = nullptr; - } - /// Finalizes lexicographic insertions. void endLexInsert() final { if (!allDense) { @@ -653,13 +621,10 @@ template <typename P, typename C, typename V> SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty( uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, - const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding) { - SparseTensorCOO<V> *lvlCOO = nullptr; - if (forwarding) - lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes); + const uint64_t *dim2lvl, const uint64_t *lvl2dim) { return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes, - lvlTypes, dim2lvl, lvl2dim, lvlCOO, - !forwarding); + lvlTypes, dim2lvl, lvl2dim, nullptr, + true); } template <typename P, typename C, typename V> diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h index 8b0829a..d916186 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h @@ -38,15 +38,12 @@ extern "C" { /// This is the "swiss army knife" method for materializing sparse /// tensors into the computation. The types of the `ptr` argument and /// the result depend on the action, as explained in the following table, -/// where "STS" means a sparse-tensor-storage object and "COO" means -/// a coordinate-scheme object. +/// where "STS" means a sparse-tensor-storage object. /// /// Action: `ptr`: Returns: +/// --------------------------------------------------------------------------- /// kEmpty - STS, empty -/// kEmptyForward - STS, empty, with forwarding COO -/// kFromCOO COO STS, copied from the COO source /// kFromReader reader STS, input from reader -/// kToCOO STS COO, copied from the STS source /// kPack buffers STS, from level buffers /// kSortCOOInPlace STS STS, sorted in place MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT @@ -80,14 +77,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS) MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES) #undef DECL_SPARSECOORDINATES -/// Tensor-storage method for a dim to lvl forwarding insertion. -#define DECL_FORWARDINGINSERT(VNAME, V) \ - MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_forwardingInsert##VNAME( \ - void *tensor, StridedMemRefType<V, 0> *vref, \ - StridedMemRefType<index_type, 1> *dimCoordsRef); \ - MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT) -#undef DECL_FORWARDINGINSERT - /// Tensor-storage method to insert elements in lexicographical /// level-coordinate order. #define DECL_LEXINSERT(VNAME, V) \ @@ -160,21 +149,12 @@ MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l); /// Tensor-storage method to get the size of the given dimension. MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d); -/// Tensor-storage method to finalize forwarding insertions. -MLIR_CRUNNERUTILS_EXPORT void endForwardingInsert(void *tensor); - /// Tensor-storage method to finalize lexicographic insertions. MLIR_CRUNNERUTILS_EXPORT void endLexInsert(void *tensor); /// Releases the memory for the tensor-storage object. MLIR_CRUNNERUTILS_EXPORT void delSparseTensor(void *tensor); -/// Releases the memory for the coordinate-scheme object. -#define DECL_DELCOO(VNAME, V) \ - MLIR_CRUNNERUTILS_EXPORT void delSparseTensorCOO##VNAME(void *coo); -MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO) -#undef DECL_DELCOO - /// Helper function to read a sparse tensor filename from the environment, /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc. MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id); diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td index 98e0025..970a781 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td @@ -147,12 +147,12 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ }]; let extraSharedClassDeclaration = [{ /// Block list iterator types. - using BlockListType = Region::BlockListType; + using BlockListType = ::mlir::Region::BlockListType; using iterator = BlockListType::iterator; using reverse_iterator = BlockListType::reverse_iterator; /// Block argument iterator types. - using BlockArgListType = Region::BlockArgListType; + using BlockArgListType = ::mlir::Region::BlockArgListType; using args_iterator = BlockArgListType::iterator; //===------------------------------------------------------------------===// @@ -163,7 +163,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ bool isExternal() { return empty(); } /// Return the region containing the body of this function. - Region &getFunctionBody() { return $_op->getRegion(0); } + ::mlir::Region &getFunctionBody() { return $_op->getRegion(0); } /// Delete all blocks from this function. void eraseBody() { @@ -183,39 +183,39 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ bool empty() { return getFunctionBody().empty(); } /// Push a new block to the back of the body region. - void push_back(Block *block) { getFunctionBody().push_back(block); } + void push_back(::mlir::Block *block) { getFunctionBody().push_back(block); } /// Push a new block to the front of the body region. - void push_front(Block *block) { getFunctionBody().push_front(block); } + void push_front(::mlir::Block *block) { getFunctionBody().push_front(block); } /// Return the last block in the body region. - Block &back() { return getFunctionBody().back(); } + ::mlir::Block &back() { return getFunctionBody().back(); } /// Return the first block in the body region. - Block &front() { return getFunctionBody().front(); } + ::mlir::Block &front() { return getFunctionBody().front(); } /// Add an entry block to an empty function, and set up the block arguments /// to match the signature of the function. The newly inserted entry block /// is returned. - Block *addEntryBlock() { + ::mlir::Block *addEntryBlock() { assert(empty() && "function already has an entry block"); - Block *entry = new Block(); + ::mlir::Block *entry = new ::mlir::Block(); push_back(entry); // FIXME: Allow for passing in locations for these arguments instead of using // the operations location. - ArrayRef<Type> inputTypes = $_op.getArgumentTypes(); - SmallVector<Location> locations(inputTypes.size(), - $_op.getOperation()->getLoc()); + ::llvm::ArrayRef<::mlir::Type> inputTypes = $_op.getArgumentTypes(); + ::llvm::SmallVector<::mlir::Location> locations(inputTypes.size(), + $_op.getOperation()->getLoc()); entry->addArguments(inputTypes, locations); return entry; } /// Add a normal block to the end of the function's block list. The function /// should at least already have an entry block. - Block *addBlock() { + ::mlir::Block *addBlock() { assert(!empty() && "function should at least have an entry block"); - push_back(new Block()); + push_back(new ::mlir::Block()); return &back(); } @@ -230,8 +230,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// - the argument/result attributes may need an update: if the new type /// has less parameters we drop the extra attributes, if there are more /// parameters they won't have any attributes. - void setType(Type newType) { - function_interface_impl::setFunctionType($_op, newType); + void setType(::mlir::Type newType) { + ::mlir::function_interface_impl::setFunctionType($_op, newType); } //===------------------------------------------------------------------===// @@ -245,7 +245,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ unsigned getNumResults() { return $_op.getResultTypes().size(); } /// Returns the entry block function argument at the given index. - BlockArgument getArgument(unsigned idx) { + ::mlir::BlockArgument getArgument(unsigned idx) { return getFunctionBody().getArgument(idx); } @@ -256,8 +256,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// Insert a single argument of type `argType` with attributes `argAttrs` and /// location `argLoc` at `argIndex`. - void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs, - Location argLoc) { + void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs, + ::mlir::Location argLoc) { insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc}); } @@ -265,20 +265,20 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// listed indices. `argIndices` must be sorted. Arguments are inserted in the /// order they are listed, such that arguments with identical index will /// appear in the same order that they were listed here. - void insertArguments(ArrayRef<unsigned> argIndices, TypeRange argTypes, - ArrayRef<DictionaryAttr> argAttrs, - ArrayRef<Location> argLocs) { + void insertArguments(::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs, + ::llvm::ArrayRef<::mlir::Location> argLocs) { unsigned originalNumArgs = $_op.getNumArguments(); - Type newType = $_op.getTypeWithArgsAndResults( + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); - function_interface_impl::insertFunctionArguments( + ::mlir::function_interface_impl::insertFunctionArguments( $_op, argIndices, argTypes, argAttrs, argLocs, originalNumArgs, newType); } /// Insert a single result of type `resultType` at `resultIndex`. - void insertResult(unsigned resultIndex, Type resultType, - DictionaryAttr resultAttrs) { + void insertResult(unsigned resultIndex, ::mlir::Type resultType, + ::mlir::DictionaryAttr resultAttrs) { insertResults({resultIndex}, {resultType}, {resultAttrs}); } @@ -286,41 +286,41 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// `resultIndices` must be sorted. Results are inserted in the order they are /// listed, such that results with identical index will appear in the same /// order that they were listed here. - void insertResults(ArrayRef<unsigned> resultIndices, TypeRange resultTypes, - ArrayRef<DictionaryAttr> resultAttrs) { + void insertResults(::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) { unsigned originalNumResults = $_op.getNumResults(); - Type newType = $_op.getTypeWithArgsAndResults( + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes); - function_interface_impl::insertFunctionResults( + ::mlir::function_interface_impl::insertFunctionResults( $_op, resultIndices, resultTypes, resultAttrs, originalNumResults, newType); } /// Erase a single argument at `argIndex`. void eraseArgument(unsigned argIndex) { - BitVector argsToErase($_op.getNumArguments()); + ::llvm::BitVector argsToErase($_op.getNumArguments()); argsToErase.set(argIndex); eraseArguments(argsToErase); } /// Erases the arguments listed in `argIndices`. - void eraseArguments(const BitVector &argIndices) { - Type newType = $_op.getTypeWithoutArgs(argIndices); - function_interface_impl::eraseFunctionArguments( + void eraseArguments(const ::llvm::BitVector &argIndices) { + ::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices); + ::mlir::function_interface_impl::eraseFunctionArguments( $_op, argIndices, newType); } /// Erase a single result at `resultIndex`. void eraseResult(unsigned resultIndex) { - BitVector resultsToErase($_op.getNumResults()); + ::llvm::BitVector resultsToErase($_op.getNumResults()); resultsToErase.set(resultIndex); eraseResults(resultsToErase); } /// Erases the results listed in `resultIndices`. - void eraseResults(const BitVector &resultIndices) { - Type newType = $_op.getTypeWithoutResults(resultIndices); - function_interface_impl::eraseFunctionResults( + void eraseResults(const ::llvm::BitVector &resultIndices) { + ::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices); + ::mlir::function_interface_impl::eraseFunctionResults( $_op, resultIndices, newType); } @@ -328,13 +328,13 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// results inserted. This is used to update the function's signature in /// the `insertArguments` and `insertResults` methods. The arrays must be /// sorted by increasing index. - Type getTypeWithArgsAndResults( - ArrayRef<unsigned> argIndices, TypeRange argTypes, - ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { - SmallVector<Type> argStorage, resultStorage; - TypeRange newArgTypes = insertTypesInto( + ::mlir::Type getTypeWithArgsAndResults( + ::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = insertTypesInto( $_op.getArgumentTypes(), argIndices, argTypes, argStorage); - TypeRange newResultTypes = insertTypesInto( + ::mlir::TypeRange newResultTypes = insertTypesInto( $_op.getResultTypes(), resultIndices, resultTypes, resultStorage); return $_op.cloneTypeWith(newArgTypes, newResultTypes); } @@ -342,24 +342,24 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ /// Return the type of this function without the specified arguments and /// results. This is used to update the function's signature in the /// `eraseArguments` and `eraseResults` methods. - Type getTypeWithoutArgsAndResults( - const BitVector &argIndices, const BitVector &resultIndices) { - SmallVector<Type> argStorage, resultStorage; - TypeRange newArgTypes = filterTypesOut( + ::mlir::Type getTypeWithoutArgsAndResults( + const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( $_op.getArgumentTypes(), argIndices, argStorage); - TypeRange newResultTypes = filterTypesOut( + ::mlir::TypeRange newResultTypes = filterTypesOut( $_op.getResultTypes(), resultIndices, resultStorage); return $_op.cloneTypeWith(newArgTypes, newResultTypes); } - Type getTypeWithoutArgs(const BitVector &argIndices) { - SmallVector<Type> argStorage; - TypeRange newArgTypes = filterTypesOut( + ::mlir::Type getTypeWithoutArgs(const ::llvm::BitVector &argIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( $_op.getArgumentTypes(), argIndices, argStorage); return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes()); } - Type getTypeWithoutResults(const BitVector &resultIndices) { - SmallVector<Type> resultStorage; - TypeRange newResultTypes = filterTypesOut( + ::mlir::Type getTypeWithoutResults(const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> resultStorage; + ::mlir::TypeRange newResultTypes = filterTypesOut( $_op.getResultTypes(), resultIndices, resultStorage); return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes); } @@ -369,88 +369,88 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ //===------------------------------------------------------------------===// /// Return all of the attributes for the argument at 'index'. - ArrayRef<NamedAttribute> getArgAttrs(unsigned index) { - return function_interface_impl::getArgAttrs($_op, index); + ::llvm::ArrayRef<::mlir::NamedAttribute> getArgAttrs(unsigned index) { + return ::mlir::function_interface_impl::getArgAttrs($_op, index); } /// Return an ArrayAttr containing all argument attribute dictionaries of /// this function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + ::mlir::ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } /// Return all argument attributes of this function. - void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) { - if (ArrayAttr argAttrs = getAllArgAttrs()) { - auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>(); + void getAllArgAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); result.append(argAttrRange.begin(), argAttrRange.end()); } else { result.append($_op.getNumArguments(), - DictionaryAttr::get(this->getOperation()->getContext())); + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); } } /// Return the specified attribute, if present, for the argument at 'index', /// null otherwise. - Attribute getArgAttr(unsigned index, StringAttr name) { + ::mlir::Attribute getArgAttr(unsigned index, ::mlir::StringAttr name) { auto argDict = getArgAttrDict(index); return argDict ? argDict.get(name) : nullptr; } - Attribute getArgAttr(unsigned index, StringRef name) { + ::mlir::Attribute getArgAttr(unsigned index, ::llvm::StringRef name) { auto argDict = getArgAttrDict(index); return argDict ? argDict.get(name) : nullptr; } template <typename AttrClass> - AttrClass getArgAttrOfType(unsigned index, StringAttr name) { + AttrClass getArgAttrOfType(unsigned index, ::mlir::StringAttr name) { return ::llvm::dyn_cast_or_null<AttrClass>(getArgAttr(index, name)); } template <typename AttrClass> - AttrClass getArgAttrOfType(unsigned index, StringRef name) { + AttrClass getArgAttrOfType(unsigned index, ::llvm::StringRef name) { return ::llvm::dyn_cast_or_null<AttrClass>(getArgAttr(index, name)); } /// Set the attributes held by the argument at 'index'. - void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) { - function_interface_impl::setArgAttrs($_op, index, attributes); + void setArgAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::mlir::function_interface_impl::setArgAttrs($_op, index, attributes); } /// Set the attributes held by the argument at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. - void setArgAttrs(unsigned index, DictionaryAttr attributes) { - function_interface_impl::setArgAttrs($_op, index, attributes); + void setArgAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::mlir::function_interface_impl::setArgAttrs($_op, index, attributes); } - void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) { + void setAllArgAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { assert(attributes.size() == $_op.getNumArguments()); - function_interface_impl::setAllArgAttrDicts($_op, attributes); + ::mlir::function_interface_impl::setAllArgAttrDicts($_op, attributes); } - void setAllArgAttrs(ArrayRef<Attribute> attributes) { + void setAllArgAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { assert(attributes.size() == $_op.getNumArguments()); - function_interface_impl::setAllArgAttrDicts($_op, attributes); + ::mlir::function_interface_impl::setAllArgAttrDicts($_op, attributes); } - void setAllArgAttrs(ArrayAttr attributes) { + void setAllArgAttrs(::mlir::ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); $_op.setArgAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setArgAttr(unsigned index, StringAttr name, Attribute value) { - function_interface_impl::setArgAttr($_op, index, name, value); + void setArgAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::mlir::function_interface_impl::setArgAttr($_op, index, name, value); } - void setArgAttr(unsigned index, StringRef name, Attribute value) { + void setArgAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { setArgAttr(index, - StringAttr::get(this->getOperation()->getContext(), name), + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), value); } /// Remove the attribute 'name' from the argument at 'index'. Return the /// attribute that was erased, or nullptr if there was no attribute with /// such name. - Attribute removeArgAttr(unsigned index, StringAttr name) { - return function_interface_impl::removeArgAttr($_op, index, name); + ::mlir::Attribute removeArgAttr(unsigned index, ::mlir::StringAttr name) { + return ::mlir::function_interface_impl::removeArgAttr($_op, index, name); } - Attribute removeArgAttr(unsigned index, StringRef name) { + ::mlir::Attribute removeArgAttr(unsigned index, ::llvm::StringRef name) { return removeArgAttr( - index, StringAttr::get(this->getOperation()->getContext(), name)); + index, ::mlir::StringAttr::get(this->getOperation()->getContext(), name)); } //===------------------------------------------------------------------===// @@ -458,102 +458,102 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ //===------------------------------------------------------------------===// /// Return all of the attributes for the result at 'index'. - ArrayRef<NamedAttribute> getResultAttrs(unsigned index) { - return function_interface_impl::getResultAttrs($_op, index); + ::llvm::ArrayRef<::mlir::NamedAttribute> getResultAttrs(unsigned index) { + return ::mlir::function_interface_impl::getResultAttrs($_op, index); } /// Return an ArrayAttr containing all result attribute dictionaries of this /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + ::mlir::ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } /// Return all result attributes of this function. - void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) { - if (ArrayAttr argAttrs = getAllResultAttrs()) { - auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>(); + void getAllResultAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); result.append(argAttrRange.begin(), argAttrRange.end()); } else { result.append($_op.getNumResults(), - DictionaryAttr::get(this->getOperation()->getContext())); + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); } } /// Return the specified attribute, if present, for the result at 'index', /// null otherwise. - Attribute getResultAttr(unsigned index, StringAttr name) { + ::mlir::Attribute getResultAttr(unsigned index, ::mlir::StringAttr name) { auto argDict = getResultAttrDict(index); return argDict ? argDict.get(name) : nullptr; } - Attribute getResultAttr(unsigned index, StringRef name) { + ::mlir::Attribute getResultAttr(unsigned index, ::llvm::StringRef name) { auto argDict = getResultAttrDict(index); return argDict ? argDict.get(name) : nullptr; } template <typename AttrClass> - AttrClass getResultAttrOfType(unsigned index, StringAttr name) { + AttrClass getResultAttrOfType(unsigned index, ::mlir::StringAttr name) { return ::llvm::dyn_cast_or_null<AttrClass>(getResultAttr(index, name)); } template <typename AttrClass> - AttrClass getResultAttrOfType(unsigned index, StringRef name) { + AttrClass getResultAttrOfType(unsigned index, ::llvm::StringRef name) { return ::llvm::dyn_cast_or_null<AttrClass>(getResultAttr(index, name)); } /// Set the attributes held by the result at 'index'. - void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) { - function_interface_impl::setResultAttrs($_op, index, attributes); + void setResultAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::mlir::function_interface_impl::setResultAttrs($_op, index, attributes); } /// Set the attributes held by the result at 'index'. `attributes` may be /// null, in which case any existing argument attributes are removed. - void setResultAttrs(unsigned index, DictionaryAttr attributes) { - function_interface_impl::setResultAttrs($_op, index, attributes); + void setResultAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::mlir::function_interface_impl::setResultAttrs($_op, index, attributes); } - void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) { + void setAllResultAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { assert(attributes.size() == $_op.getNumResults()); - function_interface_impl::setAllResultAttrDicts( + ::mlir::function_interface_impl::setAllResultAttrDicts( $_op, attributes); } - void setAllResultAttrs(ArrayRef<Attribute> attributes) { + void setAllResultAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { assert(attributes.size() == $_op.getNumResults()); - function_interface_impl::setAllResultAttrDicts( + ::mlir::function_interface_impl::setAllResultAttrDicts( $_op, attributes); } - void setAllResultAttrs(ArrayAttr attributes) { + void setAllResultAttrs(::mlir::ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); $_op.setResAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setResultAttr(unsigned index, StringAttr name, Attribute value) { - function_interface_impl::setResultAttr($_op, index, name, value); + void setResultAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::mlir::function_interface_impl::setResultAttr($_op, index, name, value); } - void setResultAttr(unsigned index, StringRef name, Attribute value) { + void setResultAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { setResultAttr(index, - StringAttr::get(this->getOperation()->getContext(), name), + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), value); } /// Remove the attribute 'name' from the result at 'index'. Return the /// attribute that was erased, or nullptr if there was no attribute with /// such name. - Attribute removeResultAttr(unsigned index, StringAttr name) { - return function_interface_impl::removeResultAttr($_op, index, name); + ::mlir::Attribute removeResultAttr(unsigned index, ::mlir::StringAttr name) { + return ::mlir::function_interface_impl::removeResultAttr($_op, index, name); } /// Returns the dictionary attribute corresponding to the argument at /// 'index'. If there are no argument attributes at 'index', a null /// attribute is returned. - DictionaryAttr getArgAttrDict(unsigned index) { + ::mlir::DictionaryAttr getArgAttrDict(unsigned index) { assert(index < $_op.getNumArguments() && "invalid argument number"); - return function_interface_impl::getArgAttrDict($_op, index); + return ::mlir::function_interface_impl::getArgAttrDict($_op, index); } /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. - DictionaryAttr getResultAttrDict(unsigned index) { + ::mlir::DictionaryAttr getResultAttrDict(unsigned index) { assert(index < $_op.getNumResults() && "invalid result number"); - return function_interface_impl::getResultAttrDict($_op, index); + return ::mlir::function_interface_impl::getResultAttrDict($_op, index); } }]; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 0d7722a..7e8e67a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -24,9 +24,11 @@ namespace mlir { // Forward declarations. class Attribute; class Block; +struct ConversionConfig; class ConversionPatternRewriter; class MLIRContext; class Operation; +struct OperationConverter; class Type; class Value; @@ -657,12 +659,13 @@ struct ConversionPatternRewriterImpl; /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This /// replaces the entry block with a new block containing the updated /// signature. The new entry block to the region is returned for convenience. + /// If no block argument types are changing, the entry original block will be + /// left in place and returned. /// /// If provided, `converter` will be used for any materializations. Block * @@ -671,8 +674,11 @@ public: const TypeConverter *converter = nullptr); /// Convert the types of block arguments within the given region. This - /// replaces each block with a new block containing the updated signature. The - /// entry block may have a special conversion if `entryConversion` is + /// replaces each block with a new block containing the updated signature. If + /// an updated signature would match the current signature, the respective + /// block is left in place as is. + /// + /// The entry block may have a special conversion if `entryConversion` is /// provided. On success, the new entry block to the region is returned for /// convenience. Otherwise, failure is returned. FailureOr<Block *> convertRegionTypes( @@ -681,7 +687,8 @@ public: /// Convert the types of block arguments within the given region except for /// the entry region. This replaces each non-entry block with a new block - /// containing the updated signature. + /// containing the updated signature. If an updated signature would match the + /// current signature, the respective block is left in place as is. /// /// If special conversion behavior is needed for the non-entry blocks (for /// example, we need to convert only a subset of a BB arguments), such @@ -758,6 +765,15 @@ public: detail::ConversionPatternRewriterImpl &getImpl(); private: + // Allow OperationConverter to construct new rewriters. + friend struct OperationConverter; + + /// Conversion pattern rewriters must not be used outside of dialect + /// conversions. They apply some IR rewrites in a delayed fashion and could + /// bring the IR into an inconsistent state when used standalone. + explicit ConversionPatternRewriter(MLIRContext *ctx, + const ConversionConfig &config); + // Hide unsupported pattern rewriter API. using OpBuilder::setListener; @@ -1057,6 +1073,30 @@ public: #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// +// ConversionConfig +//===----------------------------------------------------------------------===// + +/// Dialect conversion configuration. +struct ConversionConfig { + /// An optional callback used to notify about match failure diagnostics during + /// the conversion. Diagnostics reported to this callback may only be + /// available in debug mode. + function_ref<void(Diagnostic &)> notifyCallback = nullptr; + + /// Partial conversion only. All operations that are found not to be + /// legalizable are placed in this set. (Note that if there is an op + /// explicitly marked as illegal, the conversion terminates and the set will + /// not necessarily be complete.) + DenseSet<Operation *> *unlegalizedOps = nullptr; + + /// Analysis conversion only. All operations that are found to be legalizable + /// are placed in this set. Note that no actual rewrites are applied to the + /// IR during an analysis conversion and only pre-existing operations are + /// added to the set. + DenseSet<Operation *> *legalizableOps = nullptr; +}; + +//===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// @@ -1069,19 +1109,16 @@ public: /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If an -/// `unconvertedOps` set is provided, all operations that are found not to be -/// legalizable to the given `target` are placed within that set. (Note that if -/// there is an op explicitly marked as illegal, the conversion terminates and -/// the `unconvertedOps` set will not necessarily be complete.) +/// returns failure if there ops explicitly marked as illegal. LogicalResult -applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, +applyPartialConversion(ArrayRef<Operation *> ops, + const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps = nullptr); + ConversionConfig config = ConversionConfig()); LogicalResult applyPartialConversion(Operation *op, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps = nullptr); + ConversionConfig config = ConversionConfig()); /// Apply a complete conversion on the given operations, and all nested /// operations. This method returns failure if the conversion of any operation @@ -1089,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target, /// within 'ops'. LogicalResult applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully /// converted to the target if a conversion was applied. All operations that /// were found to be legalizable to the given 'target' are placed within the -/// provided 'convertedOps' set; note that no actual rewrites are applied to the -/// operations on success and only pre-existing operations are added to the set. -/// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops'. There's an additional argument -/// `notifyCallback` which is used for collecting match failure diagnostics -/// generated during the conversion. Diagnostics are only reported to this -/// callback may only be available in debug mode. -LogicalResult applyAnalysisConversion( - ArrayRef<Operation *> ops, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); -LogicalResult applyAnalysisConversion( - Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); +/// provided 'config.legalizableOps' set; note that no actual rewrites are +/// applied to the operations on success. This method only returns failure if +/// there are unreachable blocks in any of the regions nested within 'ops'. +LogicalResult +applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); +LogicalResult +applyAnalysisConversion(Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); } // namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 1b72961..23e9572 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { + auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), + mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, + {numElements.isScalable()}), floatOne); auto one = rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); @@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { + auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), + mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, + {numElements.isScalable()}), floatOne); auto one = rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); @@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { + auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy); auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, - floatType), + mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, + {numElements.isScalable()}), floatOne); auto one = rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 7eb32eb..7c477f2 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -384,23 +384,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) { auto intTy = cast<IntegerType>(elementTy); - int32_t min = static_cast<int32_t>( - cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue()); - int32_t max = static_cast<int32_t>( - cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue()); + int64_t min = + cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue(); + int64_t max = + cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue(); if (intTy.isUnsignedInteger()) { - min = std::max<int32_t>(min, 0); - max = std::min<int32_t>( + min = std::max(min, (int64_t)0); + max = std::min( max, APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); } else { - min = std::max<int32_t>( - min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) - .getSExtValue()); - max = std::min<int32_t>( - max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) - .getSExtValue()); + min = + std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + max = + std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); } auto minVal = rewriter.create<arith::ConstantIntOp>( diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index e88f82c..26dfb38 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -40,12 +40,12 @@ namespace { //===----------------------------------------------------------------------===// // Common match failure reasons. -static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE( +static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( "op vector size is not multiple of SME tiles"); -static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP( +static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( "op mask is unsupported for legalization/decomposition"); static constexpr StringLiteral - MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation"); + kMatchFailureNonPermutationMap("op affine map is not a permutation"); /// An SMESubTile represents a single SME-sized sub-tile from decomposing a /// larger vector type. The (`row`, `col`) are the position of the tile in the @@ -174,8 +174,8 @@ struct LegalizeVectorOuterProductOpsByDecomposition OneToNPatternRewriter &rewriter) const override { auto vectorType = outerProductOp.getResultVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) - return rewriter.notifyMatchFailure( - outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE); + return rewriter.notifyMatchFailure(outerProductOp, + kMatchFailureNotSMETileTypeMultiple); Value mask; Operation *rootOp = outerProductOp; @@ -188,7 +188,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(outerProductOp, - MATCH_FAILURE_UNSUPPORTED_MASK_OP); + kMatchFailureUnsupportedMaskOp); ValueRange accSMETiles = adaptor.getAcc(); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); @@ -252,18 +252,18 @@ struct LegalizeTransferReadOpsByDecomposition OneToNPatternRewriter &rewriter) const override { auto vectorType = readOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) - return rewriter.notifyMatchFailure( - readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE); + return rewriter.notifyMatchFailure(readOp, + kMatchFailureNotSMETileTypeMultiple); auto mask = readOp.getMask(); if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(readOp, - MATCH_FAILURE_UNSUPPORTED_MASK_OP); + kMatchFailureUnsupportedMaskOp); auto permutationMap = readOp.getPermutationMap(); if (!permutationMap.isPermutation()) return rewriter.notifyMatchFailure(readOp, - MATCH_FAILURE_NON_PERMUTATION_MAP); + kMatchFailureNonPermutationMap); // Note: For 2D vector types the only non-identity permutation is a simple // tranpose [1, 0]. @@ -300,18 +300,18 @@ struct LegalizeTransferWriteOpsByDecomposition OneToNPatternRewriter &rewriter) const override { auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) - return rewriter.notifyMatchFailure( - writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE); + return rewriter.notifyMatchFailure(writeOp, + kMatchFailureNotSMETileTypeMultiple); auto mask = writeOp.getMask(); if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(writeOp, - MATCH_FAILURE_UNSUPPORTED_MASK_OP); + kMatchFailureUnsupportedMaskOp); auto permutationMap = writeOp.getPermutationMap(); if (!permutationMap.isPermutation()) return rewriter.notifyMatchFailure(writeOp, - MATCH_FAILURE_NON_PERMUTATION_MAP); + kMatchFailureNonPermutationMap); // Note: For 2D vector types the only non-identity permutation is a simple // tranpose [1, 0]. diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index e5776e1..51cfa22 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -1,11 +1,3 @@ -if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD) - set(NVPTX_LIBS - NVPTXCodeGen - NVPTXDesc - NVPTXInfo - ) -endif() - if (MLIR_ENABLE_ROCM_CONVERSIONS) set(AMDGPU_LIBS IRReader @@ -60,7 +52,6 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/ParallelLoopMapper.cpp Transforms/ROCDLAttachTarget.cpp Transforms/SerializeToBlob.cpp - Transforms/SerializeToCubin.cpp Transforms/SerializeToHsaco.cpp Transforms/ShuffleRewriter.cpp Transforms/SPIRVAttachTarget.cpp @@ -74,7 +65,6 @@ add_mlir_dialect_library(MLIRGPUTransforms Core MC Target - ${NVPTX_LIBS} ${AMDGPU_LIBS} DEPENDS @@ -110,48 +100,6 @@ add_mlir_dialect_library(MLIRGPUTransforms add_subdirectory(TransformOps) add_subdirectory(Pipelines) -if(MLIR_ENABLE_CUDA_RUNNER) - if(NOT MLIR_ENABLE_CUDA_CONVERSIONS) - message(SEND_ERROR - "Building mlir with cuda support requires the NVPTX backend") - endif() - - # Configure CUDA language support. Using check_language first allows us to - # give a custom error message. - include(CheckLanguage) - check_language(CUDA) - if (CMAKE_CUDA_COMPILER) - enable_language(CUDA) - else() - message(SEND_ERROR - "Building mlir with cuda support requires a working CUDA install") - endif() - - # Enable gpu-to-cubin pass. - target_compile_definitions(obj.MLIRGPUTransforms - PRIVATE - MLIR_GPU_TO_CUBIN_PASS_ENABLE=1 - ) - - # Add CUDA headers includes and the libcuda.so library. - target_include_directories(obj.MLIRGPUTransforms - PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} - ) - - # Add link path for the cuda driver library. - find_library(CUDA_DRIVER_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED) - get_filename_component(CUDA_DRIVER_LIBRARY_PATH "${CUDA_DRIVER_LIBRARY}" DIRECTORY) - target_link_directories(MLIRGPUTransforms PRIVATE ${CUDA_DRIVER_LIBRARY_PATH}) - - target_link_libraries(MLIRGPUTransforms - PRIVATE - MLIRNVVMToLLVMIRTranslation - cuda - ) - -endif() - if(MLIR_ENABLE_ROCM_CONVERSIONS) if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD)) message(SEND_ERROR diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 30b6cd7..33ce5c1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -648,6 +648,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result, TypeRange workgroupAttributions, TypeRange privateAttributions, Value clusterSizeX, Value clusterSizeY, Value clusterSizeZ) { + OpBuilder::InsertionGuard g(builder); + // Add a WorkGroup attribution attribute. This attribute is required to // identify private attributions in the list of block argguments. result.addAttribute(getNumWorkgroupAttributionsAttrName(), @@ -674,7 +676,7 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result, // attributions, where the first kNumConfigRegionAttributes arguments have // `index` type and the rest have the same types as the data operands. Region *kernelRegion = result.addRegion(); - Block *body = new Block(); + Block *body = builder.createBlock(kernelRegion); // TODO: Allow passing in proper locations here. for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) body->addArgument(builder.getIndexType(), result.location); @@ -683,7 +685,6 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result, body->addArgument(argTy, result.location); for (Type argTy : privateAttributions) body->addArgument(argTy, result.location); - kernelRegion->push_back(body); // Fill OperandSegmentSize Attribute. SmallVector<int32_t, 11> segmentSizes(11, 1); segmentSizes.front() = asyncDependencies.size(); @@ -1325,6 +1326,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, TypeRange workgroupAttributions, TypeRange privateAttributions, ArrayRef<NamedAttribute> attrs) { + OpBuilder::InsertionGuard g(builder); + result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(getFunctionTypeAttrName(result.name), @@ -1333,7 +1336,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, builder.getI64IntegerAttr(workgroupAttributions.size())); result.addAttributes(attrs); Region *body = result.addRegion(); - Block *entryBlock = new Block; + Block *entryBlock = builder.createBlock(body); // TODO: Allow passing in proper locations here. for (Type argTy : type.getInputs()) @@ -1342,8 +1345,6 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result, entryBlock->addArgument(argTy, result.location); for (Type argTy : privateAttributions) entryBlock->addArgument(argTy, result.location); - - body->getBlocks().push_back(entryBlock); } /// Parses a GPU function memory attribution. diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp deleted file mode 100644 index 34ad4e6..0000000 --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp +++ /dev/null @@ -1,180 +0,0 @@ -//===- LowerGPUToCUBIN.cpp - Convert GPU kernel to CUBIN blob -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass that serializes a gpu module into CUBIN blob and -// adds that blob as a string attribute of the module. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "llvm/Support/Debug.h" - -#if MLIR_GPU_TO_CUBIN_PASS_ENABLE -#include "mlir/Pass/Pass.h" -#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/Threading.h" - -#include <cuda.h> - -using namespace mlir; - -static void emitCudaError(const llvm::Twine &expr, const char *buffer, - CUresult result, Location loc) { - const char *error = nullptr; - cuGetErrorString(result, &error); - emitError(loc, - expr.concat(error ? " failed with error code " + llvm::Twine{error} - : llvm::Twine(" failed with unknown error ")) - .concat("[") - .concat(buffer) - .concat("]")); -} - -#define RETURN_ON_CUDA_ERROR(expr) \ - do { \ - if (auto status = (expr)) { \ - emitCudaError(#expr, jitErrorBuffer, status, loc); \ - return {}; \ - } \ - } while (false) - -namespace { -class SerializeToCubinPass - : public PassWrapper<SerializeToCubinPass, gpu::SerializeToBlobPass> { - static llvm::once_flag initializeBackendOnce; - -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToCubinPass) - - SerializeToCubinPass(StringRef triple = "nvptx64-nvidia-cuda", - StringRef chip = "sm_35", StringRef features = "+ptx60", - int optLevel = 2, bool dumpPtx = false); - - StringRef getArgument() const override { return "gpu-to-cubin"; } - StringRef getDescription() const override { - return "Lower GPU kernel function to CUBIN binary annotations"; - } - -private: - // Serializes PTX to CUBIN. - std::unique_ptr<std::vector<char>> - serializeISA(const std::string &isa) override; -}; -} // namespace - -// Sets the 'option' to 'value' unless it already has a value. -static void maybeSetOption(Pass::Option<std::string> &option, StringRef value) { - if (!option.hasValue()) - option = value.str(); -} - -llvm::once_flag SerializeToCubinPass::initializeBackendOnce; - -SerializeToCubinPass::SerializeToCubinPass(StringRef triple, StringRef chip, - StringRef features, int optLevel, - bool dumpPtx) { - // No matter how this pass is constructed, ensure that the NVPTX backend - // is initialized exactly once. - llvm::call_once(initializeBackendOnce, []() { - // Initialize LLVM NVPTX backend. -#if LLVM_HAS_NVPTX_TARGET - LLVMInitializeNVPTXTarget(); - LLVMInitializeNVPTXTargetInfo(); - LLVMInitializeNVPTXTargetMC(); - LLVMInitializeNVPTXAsmPrinter(); -#endif - }); - - maybeSetOption(this->triple, triple); - maybeSetOption(this->chip, chip); - maybeSetOption(this->features, features); - this->dumpPtx = dumpPtx; - if (this->optLevel.getNumOccurrences() == 0) - this->optLevel.setValue(optLevel); -} - -std::unique_ptr<std::vector<char>> -SerializeToCubinPass::serializeISA(const std::string &isa) { - Location loc = getOperation().getLoc(); - char jitErrorBuffer[4096] = {0}; - - RETURN_ON_CUDA_ERROR(cuInit(0)); - - // Linking requires a device context. - CUdevice device; - RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0)); - CUcontext context; - // Use the primary context. - RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRetain(&context, device)); - // Push the primary context so that the next CUDA operations - // actually use it. - RETURN_ON_CUDA_ERROR(cuCtxPushCurrent(context)); - CUlinkState linkState; - - CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; - void *jitOptionsVals[] = {jitErrorBuffer, - reinterpret_cast<void *>(sizeof(jitErrorBuffer))}; - - RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */ - jitOptions, /* jit options */ - jitOptionsVals, /* jit option values */ - &linkState)); - - auto kernelName = getOperation().getName().str(); - if (dumpPtx) { - llvm::dbgs() << " Kernel Name : [" << kernelName << "]\n"; - llvm::dbgs() << isa << "\n"; - } - RETURN_ON_CUDA_ERROR(cuLinkAddData( - linkState, CUjitInputType::CU_JIT_INPUT_PTX, - const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(), - kernelName.c_str(), 0, /* number of jit options */ - nullptr, /* jit options */ - nullptr /* jit option values */ - )); - - void *cubinData; - size_t cubinSize; - RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize)); - - char *cubinAsChar = static_cast<char *>(cubinData); - auto result = - std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize); - - // This will also destroy the cubin data. - RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState)); - // Pop and release the primary context. - CUcontext poppedContext; - RETURN_ON_CUDA_ERROR(cuCtxPopCurrent(&poppedContext)); - RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRelease(device)); - - return result; -} - -// Register pass to serialize GPU kernel functions to a CUBIN binary annotation. -void mlir::registerGpuSerializeToCubinPass() { - PassRegistration<SerializeToCubinPass> registerSerializeToCubin( - [] { return std::make_unique<SerializeToCubinPass>(); }); -} - -std::unique_ptr<Pass> mlir::createGpuSerializeToCubinPass(StringRef triple, - StringRef arch, - StringRef features, - int optLevel, - bool dumpPtx) { - return std::make_unique<SerializeToCubinPass>(triple, arch, features, - optLevel, dumpPtx); -} - -#else // MLIR_GPU_TO_CUBIN_PASS_ENABLE -void mlir::registerGpuSerializeToCubinPass() {} -#endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 7eed792..3627ff6 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast<LinalgOp>(op); + // Mixed tensor/buffer operands are not allowed. + if (!linalgOp.hasPureTensorSemantics() && + !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) + return op->emitOpError("expected to have pure tensor or buffer semantics"); + // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 71e4e13..962cb28 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1471,6 +1471,16 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, //----------------------------------------------------------------------------// +void mlir::populatePolynomialApproximateTanhPattern( + RewritePatternSet &patterns) { + patterns.add<TanhApproximation>(patterns.getContext()); +} + +void mlir::populatePolynomialApproximateErfPattern( + RewritePatternSet &patterns) { + patterns.add<ErfPolynomialApproximation>(patterns.getContext()); +} + void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index 7cbe0de..c4d8b0b 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) { Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast<ShardOp>(definingOp); - assert(shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; @@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) { assert(result.hasOneUse()); Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::cast<ShardOp>(userOp); - assert(!shardOp.getAnnotateForUsers()); return shardOp.getShard(); }); return res; } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, +spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { - ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); - if (shardOp) { - if (!shardOp.getAnnotateForUsers()) { - return success(); - } - + Value targetSpmdValue; + + // Check if 2 shard ops are chained. If not there is no need for resharding + // as the source and target shared the same sharding. + ShardOp srcShardOp = + dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp()); + if (!srcShardOp) { + targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand()); + } else { // Insert resharding. - ShardOp srcShardOp = - llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp()); - assert(!srcShardOp.getAnnotateForUsers()); + assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); TypedValue<ShapedType> srcSpmdValue = spmdizationMap.lookup(srcShardOp.getOperand()) .cast<TypedValue<ShapedType>>(); - Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); - return success(); + targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, + symbolTableCollection); + } + + assert(!spmdizationMap.contains(shardOp.getResult())); + spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + return success(); +} + +static LogicalResult +spmdizeOperation(Operation &op, IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); + if (shardOp) { + return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, + builder); } SmallVector<Value> spmdizedOperands; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 119df9a..233e702 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -306,17 +306,18 @@ void ConditionOp::getSuccessorRegions( void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, ValueRange iterArgs, BodyBuilderFn bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + result.addOperands({lb, ub, step}); result.addOperands(iterArgs); for (Value v : iterArgs) result.addTypes(v.getType()); Type t = lb.getType(); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); - Block &bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(t, result.location); + Block *bodyBlock = builder.createBlock(bodyRegion); + bodyBlock->addArgument(t, result.location); for (Value v : iterArgs) - bodyBlock.addArgument(v.getType(), v.getLoc()); + bodyBlock->addArgument(v.getType(), v.getLoc()); // Create the default terminator if the builder is not provided and if the // iteration arguments are not provided. Otherwise, leave this to the caller @@ -325,9 +326,9 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, ForOp::ensureTerminator(*bodyRegion, builder, result.location); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArgument(0), - bodyBlock.getArguments().drop_front()); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, result.location, bodyBlock->getArgument(0), + bodyBlock->getArguments().drop_front()); } } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 950ee59..62d0785 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1330,6 +1330,7 @@ NARY_SHAPE_INFER(tosa::CastOp) NARY_SHAPE_INFER(tosa::CeilOp) NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) +NARY_SHAPE_INFER(tosa::CosOp) NARY_SHAPE_INFER(tosa::DivOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) @@ -1352,6 +1353,7 @@ NARY_SHAPE_INFER(tosa::ReciprocalOp) NARY_SHAPE_INFER(tosa::RescaleOp) NARY_SHAPE_INFER(tosa::ReverseOp) NARY_SHAPE_INFER(tosa::RsqrtOp) +NARY_SHAPE_INFER(tosa::SinOp) NARY_SHAPE_INFER(tosa::SelectOp) NARY_SHAPE_INFER(tosa::SubOp) NARY_SHAPE_INFER(tosa::TanhOp) diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index baaa581..4c96065 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -7,13 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/IndexingUtils.h" - +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/STLExtras.h" - #include <numeric> #include <optional> @@ -306,6 +305,14 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset, return {expr, values}; } +std::pair<AffineExpr, SmallVector<OpFoldResult>> +mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, + ArrayRef<Value> indices) { + return computeLinearIndex( + sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides), + getAsOpFoldResult(ValueRange(indices))); +} + //===----------------------------------------------------------------------===// // TileOffsetRange //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 36fb667..fc11ae6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, VectorType preconditionType, Operation *op) { - if (!preconditionType || preconditionType.getRank() != 1 || - preconditionType.isScalable()) - return rewriter.notifyMatchFailure(op, "scalable or >1-D vector"); + if (!preconditionType || preconditionType.isScalable()) + return rewriter.notifyMatchFailure(op, "scalable vector"); // TODO: consider relaxing this restriction in the future if we find ways // to really work with subbyte elements across the MLIR/LLVM boundary. @@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, if (!enumerator.sourceVectorType || !enumerator.targetVectorType) return rewriter.notifyMatchFailure(op, "types are not vector"); + if (!preconditionType || preconditionType.getRank() != 1) + return rewriter.notifyMatchFailure(op, "unsupported >1-D vector"); + return commonConversionPrecondition(rewriter, preconditionType, op); } @@ -855,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, "Expected i4 type"); // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>. - int64_t vecDimSize = srcVecType.getShape().back(); SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape()); constexpr int64_t i4Toi8BitwidthFactor = 2; i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor; @@ -871,16 +872,8 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues); Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues); - // 3. Interleave low and high i8 elements using a shuffle. - SmallVector<int64_t> interleaveMaskValues; - interleaveMaskValues.reserve(vecDimSize); - for (int i = 0, end = vecDimSize / 2; i < end; ++i) { - interleaveMaskValues.push_back(i); - interleaveMaskValues.push_back(i + (vecDimSize / 2)); - } - - return rewriter.create<vector::ShuffleOp>( - loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues)); + // 3. Interleave low and high i8 elements. + return rewriter.create<vector::InterleaveOp>(loc, low, high); } namespace { @@ -1008,8 +1001,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7] -/// : vector<4xi8>, vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> /// /// arith.sitofp %in : vector<8xi4> to vector<8xf32> @@ -1018,8 +1010,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7] -/// : vector<4xi8>, vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// template <typename ConversionOpType> diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 04e5a81..0ffef6a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" @@ -577,7 +578,6 @@ public: if (transferReadOp.getMask()) return failure(); - SmallVector<Value> collapsedIndices; int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); // 1. Collapse the source memref @@ -599,12 +599,14 @@ public: // 2.2 New indices // If all the collapsed indices are zero then no extra logic is needed. // Otherwise, a new offset/index has to be computed. + SmallVector<Value> collapsedIndices; if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), firstDimToCollapse, collapsedIndices))) { - // Copy all the leading indices - collapsedIndices = transferReadOp.getIndices(); - collapsedIndices.resize(firstDimToCollapse); + // Copy all the leading indices. + SmallVector<Value> indices = transferReadOp.getIndices(); + collapsedIndices.append(indices.begin(), + indices.begin() + firstDimToCollapse); // Compute the remaining trailing index/offset required for reading from // the collapsed memref: @@ -621,24 +623,26 @@ public: // memref<1x86xi32>, vector<2xi32> // one would get the following offset: // %offset = %arg0 * 43 - AffineExpr offsetExpr, idxExpr; - bindSymbols(rewriter.getContext(), offsetExpr, idxExpr); - - int64_t outputRank = transferReadOp.getIndices().size(); - OpFoldResult offset = + OpFoldResult collapsedOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); - for (int64_t i = firstDimToCollapse; i < outputRank; ++i) { - int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i); - offset = affine::makeComposedFoldedAffineApply( - rewriter, loc, offsetExpr + dim * idxExpr, - {offset, transferReadOp.getIndices()[i]}); - } - if (offset.is<Value>()) { - collapsedIndices.push_back(offset.get<Value>()); + auto sourceShape = sourceType.getShape(); + auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>( + sourceShape.begin() + firstDimToCollapse, sourceShape.end())); + + // Compute the collapsed offset. + ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, + indices.end()); + auto &&[collapsedExpr, collapsedVals] = computeLinearIndex( + collapsedOffset, collapsedStrides, indicesToCollapse); + collapsedOffset = affine::makeComposedFoldedAffineApply( + rewriter, loc, collapsedExpr, collapsedVals); + + if (collapsedOffset.is<Value>()) { + collapsedIndices.push_back(collapsedOffset.get<Value>()); } else { collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>( - loc, *getConstantIntValue(offset))); + loc, *getConstantIntValue(collapsedOffset))); } } @@ -710,6 +714,7 @@ public: firstContiguousInnerDim, collapsedIndices))) return failure(); + Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp index 9e8b240..bbe10b0 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp @@ -74,13 +74,6 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES) MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES) #undef IMPL_GETVALUES -#define IMPL_FORWARDINGINSERT(VNAME, V) \ - void SparseTensorStorageBase::forwardingInsert(const uint64_t *, V) { \ - FATAL_PIV("forwardingInsert" #VNAME); \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT) -#undef IMPL_FORWARDINGINSERT - #define IMPL_LEXINSERT(VNAME, V) \ void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \ FATAL_PIV("lexInsert" #VNAME); \ diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp index a5e75a7..0bc90b2 100644 --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -117,20 +117,7 @@ extern "C" { switch (action) { \ case Action::kEmpty: { \ return SparseTensorStorage<P, C, V>::newEmpty( \ - dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ - false); \ - } \ - case Action::kEmptyForward: { \ - return SparseTensorStorage<P, C, V>::newEmpty( \ - dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ - true); \ - } \ - case Action::kFromCOO: { \ - assert(ptr && "Received nullptr for SparseTensorCOO object"); \ - auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr); \ - return SparseTensorStorage<P, C, V>::newFromCOO( \ - dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ - coo); \ + dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \ } \ case Action::kFromReader: { \ assert(ptr && "Received nullptr for SparseTensorReader object"); \ @@ -138,11 +125,6 @@ extern "C" { return static_cast<void *>(reader.readSparseTensor<P, C, V>( \ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \ } \ - case Action::kToCOO: { \ - assert(ptr && "Received nullptr for SparseTensorStorage object"); \ - auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \ - return tensor.toCOO(); \ - } \ case Action::kPack: { \ assert(ptr && "Received nullptr for SparseTensorStorage object"); \ intptr_t *buffers = static_cast<intptr_t *>(ptr); \ @@ -341,21 +323,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES) #undef IMPL_SPARSECOORDINATES #undef IMPL_GETOVERHEAD -#define IMPL_FORWARDINGINSERT(VNAME, V) \ - void _mlir_ciface_forwardingInsert##VNAME( \ - void *t, StridedMemRefType<V, 0> *vref, \ - StridedMemRefType<index_type, 1> *dimCoordsRef) { \ - assert(t &&vref); \ - ASSERT_NO_STRIDE(dimCoordsRef); \ - const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \ - assert(dimCoords); \ - const V *value = MEMREF_GET_PAYLOAD(vref); \ - static_cast<SparseTensorStorageBase *>(t)->forwardingInsert(dimCoords, \ - *value); \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT) -#undef IMPL_FORWARDINGINSERT - #define IMPL_LEXINSERT(VNAME, V) \ void _mlir_ciface_lexInsert##VNAME( \ void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \ @@ -427,8 +394,8 @@ void _mlir_ciface_getSparseTensorReaderDimSizes( const uint64_t cSize = MEMREF_GET_USIZE(cref); \ const uint64_t vSize = MEMREF_GET_USIZE(vref); \ ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \ - assert(cSize >= lvlRank * vSize); \ - assert(vSize >= reader.getNSE() && "Not enough space in buffers"); \ + assert(cSize >= lvlRank * reader.getNSE()); \ + assert(vSize >= reader.getNSE()); \ (void)dimRank; \ (void)cSize; \ (void)vSize; \ @@ -488,10 +455,6 @@ index_type sparseDimSize(void *tensor, index_type d) { return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d); } -void endForwardingInsert(void *tensor) { - return static_cast<SparseTensorStorageBase *>(tensor)->endForwardingInsert(); -} - void endLexInsert(void *tensor) { return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert(); } @@ -500,13 +463,6 @@ void delSparseTensor(void *tensor) { delete static_cast<SparseTensorStorageBase *>(tensor); } -#define IMPL_DELCOO(VNAME, V) \ - void delSparseTensorCOO##VNAME(void *coo) { \ - delete static_cast<SparseTensorCOO<V> *>(coo); \ - } -MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO) -#undef IMPL_DELCOO - char *getTensorFilename(index_type id) { constexpr size_t bufSize = 80; char var[bufSize]; diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 6521295..c631617 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -99,21 +99,31 @@ DIFileAttr DebugImporter::translateImpl(llvm::DIFile *node) { } DILabelAttr DebugImporter::translateImpl(llvm::DILabel *node) { - return DILabelAttr::get(context, translate(node->getScope()), + // Return nullptr if the scope or type is a cyclic dependency. + DIScopeAttr scope = translate(node->getScope()); + if (node->getScope() && !scope) + return nullptr; + return DILabelAttr::get(context, scope, getStringAttrOrNull(node->getRawName()), translate(node->getFile()), node->getLine()); } DILexicalBlockAttr DebugImporter::translateImpl(llvm::DILexicalBlock *node) { - return DILexicalBlockAttr::get(context, translate(node->getScope()), - translate(node->getFile()), node->getLine(), - node->getColumn()); + // Return nullptr if the scope or type is a cyclic dependency. + DIScopeAttr scope = translate(node->getScope()); + if (node->getScope() && !scope) + return nullptr; + return DILexicalBlockAttr::get(context, scope, translate(node->getFile()), + node->getLine(), node->getColumn()); } DILexicalBlockFileAttr DebugImporter::translateImpl(llvm::DILexicalBlockFile *node) { - return DILexicalBlockFileAttr::get(context, translate(node->getScope()), - translate(node->getFile()), + // Return nullptr if the scope or type is a cyclic dependency. + DIScopeAttr scope = translate(node->getScope()); + if (node->getScope() && !scope) + return nullptr; + return DILexicalBlockFileAttr::get(context, scope, translate(node->getFile()), node->getDiscriminator()); } @@ -135,11 +145,14 @@ DebugImporter::translateImpl(llvm::DIGlobalVariable *node) { } DILocalVariableAttr DebugImporter::translateImpl(llvm::DILocalVariable *node) { - return DILocalVariableAttr::get(context, translate(node->getScope()), - getStringAttrOrNull(node->getRawName()), - translate(node->getFile()), node->getLine(), - node->getArg(), node->getAlignInBits(), - translate(node->getType())); + // Return nullptr if the scope or type is a cyclic dependency. + DIScopeAttr scope = translate(node->getScope()); + if (node->getScope() && !scope) + return nullptr; + return DILocalVariableAttr::get( + context, scope, getStringAttrOrNull(node->getRawName()), + translate(node->getFile()), node->getLine(), node->getArg(), + node->getAlignInBits(), translate(node->getType())); } DIScopeAttr DebugImporter::translateImpl(llvm::DIScope *node) { diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 97ccb2b..d63ea12 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1966,6 +1966,13 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, // TODO: find a way to support this case. if (isMetadataKillLocation(dbgIntr)) return emitUnsupportedWarning(); + // Drop debug intrinsics if the associated variable information cannot be + // translated due to cyclic debug metadata. + // TODO: Support cyclic debug metadata. + DILocalVariableAttr localVariableAttr = + matchLocalVariableAttr(dbgIntr->getArgOperand(1)); + if (!localVariableAttr) + return emitUnsupportedWarning(); FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0)); if (failed(argOperand)) return emitError(loc) << "failed to convert a debug intrinsic operand: " @@ -1991,8 +1998,6 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, } else { builder.setInsertionPointAfterValue(*argOperand); } - DILocalVariableAttr localVariableAttr = - matchLocalVariableAttr(dbgIntr->getArgOperand(1)); auto locationExprAttr = debugImporter->translateExpression(dbgIntr->getExpression()); Operation *op = diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4989ddc..857b601 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -152,519 +152,25 @@ namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, - unsigned numReplacements, unsigned numArgReplacements, - unsigned numRewrites, unsigned numIgnoredOperations) - : numCreatedOps(numCreatedOps), - numUnresolvedMaterializations(numUnresolvedMaterializations), - numReplacements(numReplacements), - numArgReplacements(numArgReplacements), numRewrites(numRewrites), - numIgnoredOperations(numIgnoredOperations) {} - - /// The current number of created operations. - unsigned numCreatedOps; - - /// The current number of unresolved materializations. - unsigned numUnresolvedMaterializations; - - /// The current number of replacements queued. - unsigned numReplacements; - - /// The current number of argument replacements queued. - unsigned numArgReplacements; + RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, + unsigned numErased) + : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), + numErased(numErased) {} /// The current number of rewrites performed. unsigned numRewrites; /// The current number of ignored operations. unsigned numIgnoredOperations; -}; - -//===----------------------------------------------------------------------===// -// OpReplacement - -/// This class represents one requested operation replacement via 'replaceOp' or -/// 'eraseOp`. -struct OpReplacement { - OpReplacement(const TypeConverter *converter = nullptr) - : converter(converter) {} - - /// An optional type converter that can be used to materialize conversions - /// between the new and old values if necessary. - const TypeConverter *converter; -}; - -//===----------------------------------------------------------------------===// -// UnresolvedMaterialization - -/// This class represents an unresolved materialization, i.e. a materialization -/// that was inserted during conversion that needs to be legalized at the end of -/// the conversion process. -class UnresolvedMaterialization { -public: - /// The type of materialization. - enum Kind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to a legal one. - Argument, - - /// This materialization materializes a conversion from an illegal type to a - /// legal one. - Target - }; - - UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, - const TypeConverter *converter = nullptr, - Kind kind = Target, Type origOutputType = nullptr) - : op(op), converterAndKind(converter, kind), - origOutputType(origOutputType) {} - - /// Return the temporary conversion operation inserted for this - /// materialization. - UnrealizedConversionCastOp getOp() const { return op; } - /// Return the type converter of this materialization (which may be null). - const TypeConverter *getConverter() const { - return converterAndKind.getPointer(); - } - - /// Return the kind of this materialization. - Kind getKind() const { return converterAndKind.getInt(); } - - /// Set the kind of this materialization. - void setKind(Kind kind) { converterAndKind.setInt(kind); } - - /// Return the original illegal output type of the input values. - Type getOrigOutputType() const { return origOutputType; } - -private: - /// The unresolved materialization operation created during conversion. - UnrealizedConversionCastOp op; - - /// The corresponding type converter to use when resolving this - /// materialization, and the kind of this materialization. - llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind; - - /// The original output type. This is only used for argument conversions. - Type origOutputType; + /// The current number of erased operations/blocks. + unsigned numErased; }; -} // namespace - -/// Build an unresolved materialization operation given an output type and set -/// of input operands. -static Value buildUnresolvedMaterialization( - UnresolvedMaterialization::Kind kind, Block *insertBlock, - Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, - Type origOutputType, const TypeConverter *converter, - SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { - // Avoid materializing an unnecessary cast. - if (inputs.size() == 1 && inputs.front().getType() == outputType) - return inputs.front(); - - // Create an unresolved materialization. We use a new OpBuilder to avoid - // tracking the materialization like we do for other operations. - OpBuilder builder(insertBlock, insertPt); - auto convertOp = - builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); - unresolvedMaterializations.emplace_back(convertOp, converter, kind, - origOutputType); - return convertOp.getResult(0); -} -static Value buildUnresolvedArgumentMaterialization( - PatternRewriter &rewriter, Location loc, ValueRange inputs, - Type origOutputType, Type outputType, const TypeConverter *converter, - SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { - return buildUnresolvedMaterialization( - UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), - rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, - converter, unresolvedMaterializations); -} -static Value buildUnresolvedTargetMaterialization( - Location loc, Value input, Type outputType, const TypeConverter *converter, - SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { - Block *insertBlock = input.getParentBlock(); - Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = dyn_cast<OpResult>(input)) - insertPt = ++inputRes.getOwner()->getIterator(); - - return buildUnresolvedMaterialization( - UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input, - outputType, outputType, converter, unresolvedMaterializations); -} - -//===----------------------------------------------------------------------===// -// ArgConverter -//===----------------------------------------------------------------------===// -namespace { -/// This class provides a simple interface for converting the types of block -/// arguments. This is done by creating a new block that contains the new legal -/// types and extracting the block that contains the old illegal types to allow -/// for undoing pending rewrites in the case of failure. -struct ArgConverter { - ArgConverter( - PatternRewriter &rewriter, - SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) - : rewriter(rewriter), - unresolvedMaterializations(unresolvedMaterializations) {} - - /// This structure contains the information pertaining to an argument that has - /// been converted. - struct ConvertedArgInfo { - ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - Value castValue = nullptr) - : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} - - /// The start index of in the new argument list that contains arguments that - /// replace the original. - unsigned newArgIdx; - - /// The number of arguments that replaced the original argument. - unsigned newArgSize; - - /// The cast value that was created to cast from the new arguments to the - /// old. This only used if 'newArgSize' > 1. - Value castValue; - }; - - /// This structure contains information pertaining to a block that has had its - /// signature converted. - struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter) - : origBlock(origBlock), converter(converter) {} - - /// The original block that was requested to have its signature converted. - Block *origBlock; - - /// The conversion information for each of the arguments. The information is - /// std::nullopt if the argument was dropped during conversion. - SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; - - /// The type converter used to convert the arguments. - const TypeConverter *converter; - }; - - //===--------------------------------------------------------------------===// - // Rewrite Application - //===--------------------------------------------------------------------===// - - /// Erase any rewrites registered for the blocks within the given operation - /// which is about to be removed. This merely drops the rewrites without - /// undoing them. - void notifyOpRemoved(Operation *op); - - /// Cleanup and undo any generated conversions for the arguments of block. - /// This method replaces the new block with the original, reverting the IR to - /// its original state. - void discardRewrites(Block *block); - - /// Fully replace uses of the old arguments with the new. - void applyRewrites(ConversionValueMapping &mapping); - - /// Materialize any necessary conversions for converted arguments that have - /// live users, using the provided `findLiveUser` to search for a user that - /// survives the conversion process. - LogicalResult - materializeLiveConversions(ConversionValueMapping &mapping, - OpBuilder &builder, - function_ref<Operation *(Value)> findLiveUser); - - //===--------------------------------------------------------------------===// - // Conversion - //===--------------------------------------------------------------------===// - - /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. Returns `block` if it did - /// not require conversion. - FailureOr<Block *> - convertSignature(Block *block, const TypeConverter *converter, - ConversionValueMapping &mapping, - SmallVectorImpl<BlockArgument> &argReplacements); - - /// Apply the given signature conversion on the given block. The new block - /// containing the updated signature is returned. If no conversions were - /// necessary, e.g. if the block has no arguments, `block` is returned. - /// `converter` is used to generate any necessary cast operations that - /// translate between the origin argument types and those specified in the - /// signature conversion. - Block *applySignatureConversion( - Block *block, const TypeConverter *converter, - TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping, - SmallVectorImpl<BlockArgument> &argReplacements); - - /// A collection of blocks that have had their arguments converted. This is a - /// map from the new replacement block, back to the original block. - llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; - - /// The pattern rewriter to use when materializing conversions. - PatternRewriter &rewriter; - - /// An ordered set of unresolved materializations during conversion. - SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations; -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Rewrite Application - -void ArgConverter::notifyOpRemoved(Operation *op) { - if (conversionInfo.empty()) - return; - - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // Drop any rewrites from within. - for (Operation &nestedOp : block) - if (nestedOp.getNumRegions()) - notifyOpRemoved(&nestedOp); - - // Check if this block was converted. - auto *it = conversionInfo.find(&block); - if (it == conversionInfo.end()) - continue; - - // Drop all uses of the original arguments and delete the original block. - Block *origBlock = it->second.origBlock; - for (BlockArgument arg : origBlock->getArguments()) - arg.dropAllUses(); - conversionInfo.erase(it); - } - } -} - -void ArgConverter::discardRewrites(Block *block) { - auto *it = conversionInfo.find(block); - if (it == conversionInfo.end()) - return; - Block *origBlock = it->second.origBlock; - - // Drop all uses of the new block arguments and replace uses of the new block. - for (int i = block->getNumArguments() - 1; i >= 0; --i) - block->getArgument(i).dropAllUses(); - block->replaceAllUsesWith(origBlock); - - // Move the operations back the original block, move the original block back - // into its original location and the delete the new block. - origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); - block->erase(); - - conversionInfo.erase(it); -} - -void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { - for (auto &info : conversionInfo) { - ConvertedBlockInfo &blockInfo = info.second; - Block *origBlock = blockInfo.origBlock; - - // Process the remapping for each of the original arguments. - for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i]; - BlockArgument origArg = origBlock->getArgument(i); - - // Handle the case of a 1->0 value mapping. - if (!argInfo) { - if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType())) - origArg.replaceAllUsesWith(newArg); - continue; - } - - // Otherwise this is a 1->1+ value mapping. - Value castValue = argInfo->castValue; - assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); - - // If the argument is still used, replace it with the generated cast. - if (!origArg.use_empty()) { - origArg.replaceAllUsesWith( - mapping.lookupOrDefault(castValue, origArg.getType())); - } - } - - delete origBlock; - blockInfo.origBlock = nullptr; - } -} - -LogicalResult ArgConverter::materializeLiveConversions( - ConversionValueMapping &mapping, OpBuilder &builder, - function_ref<Operation *(Value)> findLiveUser) { - for (auto &info : conversionInfo) { - Block *newBlock = info.first; - ConvertedBlockInfo &blockInfo = info.second; - Block *origBlock = blockInfo.origBlock; - - // Process the remapping for each of the original arguments. - for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - // If the type of this argument changed and the argument is still live, we - // need to materialize a conversion. - BlockArgument origArg = origBlock->getArgument(i); - if (mapping.lookupOrNull(origArg, origArg.getType())) - continue; - Operation *liveUser = findLiveUser(origArg); - if (!liveUser) - continue; - - Value replacementValue = mapping.lookupOrDefault(origArg); - bool isDroppedArg = replacementValue == origArg; - if (isDroppedArg) - rewriter.setInsertionPointToStart(newBlock); - else - rewriter.setInsertionPointAfterValue(replacementValue); - Value newArg; - if (blockInfo.converter) { - newArg = blockInfo.converter->materializeSourceConversion( - rewriter, origArg.getLoc(), origArg.getType(), - isDroppedArg ? ValueRange() : ValueRange(replacementValue)); - assert((!newArg || newArg.getType() == origArg.getType()) && - "materialization hook did not provide a value of the expected " - "type"); - } - if (!newArg) { - InFlightDiagnostic diag = - emitError(origArg.getLoc()) - << "failed to materialize conversion for block argument #" << i - << " that remained live after conversion, type was " - << origArg.getType(); - if (!isDroppedArg) - diag << ", with target type " << replacementValue.getType(); - diag.attachNote(liveUser->getLoc()) - << "see existing live user here: " << *liveUser; - return failure(); - } - mapping.map(origArg, newArg); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Conversion - -FailureOr<Block *> ArgConverter::convertSignature( - Block *block, const TypeConverter *converter, - ConversionValueMapping &mapping, - SmallVectorImpl<BlockArgument> &argReplacements) { - // Check if the block was already converted. - // * If the block is mapped in `conversionInfo`, it is a converted block. - // * If the block is detached, conservatively assume that it is going to be - // deleted; it is likely the old block (before it was converted). - if (conversionInfo.count(block) || !block->getParent()) - return block; - // If a converter wasn't provided, and the block wasn't already converted, - // there is nothing we can do. - if (!converter) - return failure(); - - // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(block, converter, *conversion, mapping, - argReplacements); - return failure(); -} - -Block *ArgConverter::applySignatureConversion( - Block *block, const TypeConverter *converter, - TypeConverter::SignatureConversion &signatureConversion, - ConversionValueMapping &mapping, - SmallVectorImpl<BlockArgument> &argReplacements) { - // If no arguments are being changed or added, there is nothing to do. - unsigned origArgCount = block->getNumArguments(); - auto convertedTypes = signatureConversion.getConvertedTypes(); - if (origArgCount == 0 && convertedTypes.empty()) - return block; - - // Split the block at the beginning to get a new block to use for the updated - // signature. - Block *newBlock = block->splitBlock(block->begin()); - block->replaceAllUsesWith(newBlock); - // Unlink the block, but do not erase it yet, so that the change can be rolled - // back. - block->getParent()->getBlocks().remove(block); - - // Map all new arguments to the location of the argument they originate from. - SmallVector<Location> newLocs(convertedTypes.size(), - rewriter.getUnknownLoc()); - for (unsigned i = 0; i < origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap || inputMap->replacementValue) - continue; - Location origLoc = block->getArgument(i).getLoc(); - for (unsigned j = 0; j < inputMap->size; ++j) - newLocs[inputMap->inputNo + j] = origLoc; - } - - SmallVector<Value, 4> newArgRange( - newBlock->addArguments(convertedTypes, newLocs)); - ArrayRef<Value> newArgs(newArgRange); - - // Remap each of the original arguments as determined by the signature - // conversion. - ConvertedBlockInfo info(block, converter); - info.argInfo.resize(origArgCount); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(newBlock); - for (unsigned i = 0; i != origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap) - continue; - BlockArgument origArg = block->getArgument(i); - - // If inputMap->replacementValue is not nullptr, then the argument is - // dropped and a replacement value is provided to be the remappedValue. - if (inputMap->replacementValue) { - assert(inputMap->size == 0 && - "invalid to provide a replacement value when the argument isn't " - "dropped"); - mapping.map(origArg, inputMap->replacementValue); - argReplacements.push_back(origArg); - continue; - } - - // Otherwise, this is a 1->1+ mapping. - auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg; - - // If this is a 1->1 mapping and the types of new and replacement arguments - // match (i.e. it's an identity map), then the argument is mapped to its - // original type. - // FIXME: We simply pass through the replacement argument if there wasn't a - // converter, which isn't great as it allows implicit type conversions to - // appear. We should properly restructure this code to handle cases where a - // converter isn't provided and also to properly handle the case where an - // argument materialization is actually a temporary source materialization - // (e.g. in the case of 1->N). - if (replArgs.size() == 1 && - (!converter || replArgs[0].getType() == origArg.getType())) { - newArg = replArgs.front(); - } else { - Type origOutputType = origArg.getType(); - - // Legalize the argument output type. - Type outputType = origOutputType; - if (Type legalOutputType = converter->convertType(outputType)) - outputType = legalOutputType; - - newArg = buildUnresolvedArgumentMaterialization( - rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, - converter, unresolvedMaterializations); - } - - mapping.map(origArg, newArg); - argReplacements.push_back(origArg); - info.argInfo[i] = - ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); - } - - conversionInfo.insert({newBlock, std::move(info)}); - return newBlock; -} //===----------------------------------------------------------------------===// // IR rewrites //===----------------------------------------------------------------------===// -namespace { /// An IR rewrite that can be committed (upon success) or rolled back (upon /// failure). /// @@ -685,19 +191,29 @@ public: MoveBlock, SplitBlock, BlockTypeConversion, + ReplaceBlockArg, // Operation rewrites MoveOperation, - ModifyOperation + ModifyOperation, + ReplaceOperation, + CreateOperation, + UnresolvedMaterialization }; virtual ~IRRewrite() = default; - /// Roll back the rewrite. + /// Roll back the rewrite. Operations may be erased during rollback. virtual void rollback() = 0; - /// Commit the rewrite. + /// Commit the rewrite. Operations may be unlinked from their blocks during + /// the commit phase, but they must not be erased yet. This is because + /// internal dialect conversion state (such as `mapping`) may still be using + /// them. Operations must be erased during cleanup. virtual void commit() {} + /// Cleanup operations. Cleanup is called after commit. + virtual void cleanup() {} + Kind getKind() const { return kind; } static bool classof(const IRRewrite *rewrite) { return true; } @@ -706,6 +222,14 @@ protected: IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) : kind(kind), rewriterImpl(rewriterImpl) {} + /// Erase the given op (unless it was already erased). + void eraseOp(Operation *op); + + /// Erase the given block (unless it was already erased). + void eraseBlock(Block *block); + + const ConversionConfig &getConfig() const; + const Kind kind; ConversionPatternRewriterImpl &rewriterImpl; }; @@ -718,7 +242,7 @@ public: static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::CreateBlock && - rewrite->getKind() <= Kind::BlockTypeConversion; + rewrite->getKind() <= Kind::ReplaceBlockArg; } protected: @@ -748,8 +272,10 @@ public: auto &blockOps = block->getOperations(); while (!blockOps.empty()) blockOps.remove(blockOps.begin()); - block->dropAllDefinedValueUses(); - block->erase(); + if (block->getParent()) + eraseBlock(block); + else + delete block; } }; @@ -787,6 +313,8 @@ public: void commit() override { // Erase the block. assert(block && "expected block"); + assert(block->empty() && "expected empty block"); + block->dropAllDefinedValueUses(); delete block; block = nullptr; } @@ -885,8 +413,7 @@ public: // Merge back the block that was split out. originalBlock->getOperations().splice(originalBlock->end(), block->getOperations()); - block->dropAllDefinedValueUses(); - block->erase(); + eraseBlock(block); } private: @@ -894,20 +421,80 @@ private: Block *originalBlock; }; +/// This structure contains the information pertaining to an argument that has +/// been converted. +struct ConvertedArgInfo { + ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, + Value castValue = nullptr) + : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} + + /// The start index of in the new argument list that contains arguments that + /// replace the original. + unsigned newArgIdx; + + /// The number of arguments that replaced the original argument. + unsigned newArgSize; + + /// The cast value that was created to cast from the new arguments to the + /// old. This only used if 'newArgSize' > 1. + Value castValue; +}; + /// Block type conversion. This rewrite is partially reflected in the IR. class BlockTypeConversionRewrite : public BlockRewrite { public: - BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block) - : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {} + BlockTypeConversionRewrite( + ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo, + const TypeConverter *converter) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), + origBlock(origBlock), argInfo(argInfo), converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; } - // TODO: Block type conversions are currently committed in - // `ArgConverter::applyRewrites`. This should be done in the "commit" method. + /// Materialize any necessary conversions for converted arguments that have + /// live users, using the provided `findLiveUser` to search for a user that + /// survives the conversion process. + LogicalResult + materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser); + + void commit() override; + void rollback() override; + +private: + /// The original block that was requested to have its signature converted. + Block *origBlock; + + /// The conversion information for each of the arguments. The information is + /// std::nullopt if the argument was dropped during conversion. + SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; + + /// The type converter used to convert the arguments. + const TypeConverter *converter; +}; + +/// Replacing a block argument. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but the actual replacement is delayed +/// until the rewrite is committed. +class ReplaceBlockArgRewrite : public BlockRewrite { +public: + ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block, BlockArgument arg) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceBlockArg; + } + + void commit() override; + + void rollback() override; + +private: + BlockArgument arg; }; /// An operation rewrite. @@ -918,7 +505,7 @@ public: static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::ModifyOperation; + rewrite->getKind() <= Kind::UnresolvedMaterialization; } protected: @@ -953,8 +540,8 @@ private: // The block in which this operation was previously contained. Block *block; - // The original successor of this operation before it was moved. "nullptr" if - // this operation was the only operation in the region. + // The original successor of this operation before it was moved. "nullptr" + // if this operation was the only operation in the region. Operation *insertBeforeOp; }; @@ -1019,6 +606,118 @@ private: SmallVector<Block *, 2> successors; void *propertiesStorage = nullptr; }; + +/// Replacing an operation. Erasing an operation is treated as a special case +/// with "null" replacements. This rewrite is not immediately reflected in the +/// IR. An internal IR mapping is updated, but values are not replaced and the +/// original op is not erased until the rewrite is committed. +class ReplaceOperationRewrite : public OperationRewrite { +public: + ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op, const TypeConverter *converter, + bool changedResults) + : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), + converter(converter), changedResults(changedResults) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceOperation; + } + + void commit() override; + + void rollback() override; + + void cleanup() override; + + const TypeConverter *getConverter() const { return converter; } + + bool hasChangedResults() const { return changedResults; } + +private: + /// An optional type converter that can be used to materialize conversions + /// between the new and old values if necessary. + const TypeConverter *converter; + + /// A boolean flag that indicates whether result types have changed or not. + bool changedResults; +}; + +class CreateOperationRewrite : public OperationRewrite { +public: + CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateOperation; + } + + void rollback() override; +}; + +/// The type of materialization. +enum MaterializationKind { + /// This materialization materializes a conversion for an illegal block + /// argument type, to a legal one. + Argument, + + /// This materialization materializes a conversion from an illegal type to a + /// legal one. + Target +}; + +/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" +/// op. Unresolved materializations are erased at the end of the dialect +/// conversion. +class UnresolvedMaterializationRewrite : public OperationRewrite { +public: + UnresolvedMaterializationRewrite( + ConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, + MaterializationKind kind = MaterializationKind::Target, + Type origOutputType = nullptr) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind), origOutputType(origOutputType) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::UnresolvedMaterialization; + } + + UnrealizedConversionCastOp getOperation() const { + return cast<UnrealizedConversionCastOp>(op); + } + + void rollback() override; + + void cleanup() override; + + /// Return the type converter of this materialization (which may be null). + const TypeConverter *getConverter() const { + return converterAndKind.getPointer(); + } + + /// Return the kind of this materialization. + MaterializationKind getMaterializationKind() const { + return converterAndKind.getInt(); + } + + /// Set the kind of this materialization. + void setMaterializationKind(MaterializationKind kind) { + converterAndKind.setInt(kind); + } + + /// Return the original illegal output type of the input values. + Type getOrigOutputType() const { return origOutputType; } + +private: + /// The corresponding type converter to use when resolving this + /// materialization, and the kind of this materialization. + llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind> + converterAndKind; + + /// The original output type. This is only used for argument conversions. + Type origOutputType; +}; } // namespace /// Return "true" if there is an operation rewrite that matches the specified @@ -1031,23 +730,35 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } +/// Find the single rewrite object of the specified type and block among the +/// given rewrites. In debug mode, asserts that there is mo more than one such +/// object. Return "nullptr" if no object was found. +template <typename RewriteTy, typename R> +static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { + RewriteTy *result = nullptr; + for (auto &rewrite : rewrites) { + auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); + if (rewriteTy && rewriteTy->getBlock() == block) { +#ifndef NDEBUG + assert(!result && "expected single matching rewrite"); + result = rewriteTy; +#else + return rewriteTy; +#endif // NDEBUG + } + } + return result; +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { - explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations), - notifyCallback(nullptr) {} - - /// Cleanup and destroy any generated rewrite operations. This method is - /// invoked when the conversion process fails. - void discardRewrites(); - - /// Apply all requested operation rewrites. This method is invoked when the - /// conversion process succeeds. - void applyRewrites(); + explicit ConversionPatternRewriterImpl(MLIRContext *ctx, + const ConversionConfig &config) + : eraseRewriter(ctx), config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -1056,6 +767,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return the current state of the rewriter. RewriterState getCurrentState(); + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites(); + /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); @@ -1092,11 +807,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Type Conversion //===--------------------------------------------------------------------===// - /// Convert the signature of the given block. + /// Attempt to convert the signature of the given block, if successful a new + /// block is returned containing the new arguments. Returns `block` if it did + /// not require conversion. FailureOr<Block *> convertBlockSignature( Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion = nullptr); + /// Convert the types of non-entry block arguments within the given region. + LogicalResult convertNonEntryRegionTypes( + Region *region, const TypeConverter &converter, + ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); + /// Apply a signature conversion on the given region, using `converter` for /// materializations if not null. Block * @@ -1109,10 +831,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); - /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); + /// Apply the given signature conversion on the given block. The new block + /// containing the updated signature is returned. If no conversions were + /// necessary, e.g. if the block has no arguments, `block` is returned. + /// `converter` is used to generate any necessary cast operations that + /// translate between the origin argument types and those specified in the + /// signature conversion. + Block *applySignatureConversion( + Block *block, const TypeConverter *converter, + TypeConverter::SignatureConversion &signatureConversion); + + //===--------------------------------------------------------------------===// + // Materializations + //===--------------------------------------------------------------------===// + /// Build an unresolved materialization operation given an output type and set + /// of input operands. + Value buildUnresolvedMaterialization(MaterializationKind kind, + Block *insertBlock, + Block::iterator insertPt, Location loc, + ValueRange inputs, Type outputType, + Type origOutputType, + const TypeConverter *converter); + + Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, + ValueRange inputs, + Type origOutputType, + Type outputType, + const TypeConverter *converter); + + Value buildUnresolvedTargetMaterialization(Location loc, Value input, + Type outputType, + const TypeConverter *converter); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1145,28 +894,51 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { function_ref<void(Diagnostic &)> reasonCallback) override; //===--------------------------------------------------------------------===// - // State + // IR Erasure //===--------------------------------------------------------------------===// - // Mapping between replaced values that differ in type. This happens when - // replacing a value with one of a different type. - ConversionValueMapping mapping; + /// A rewriter that keeps track of erased ops and blocks. It ensures that no + /// operation or block is erased multiple times. This rewriter assumes that + /// no new IR is created between calls to `eraseOp`/`eraseBlock`. + struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { + public: + SingleEraseRewriter(MLIRContext *context) + : RewriterBase(context, /*listener=*/this) {} + + /// Erase the given op (unless it was already erased). + void eraseOp(Operation *op) override { + if (erased.contains(op)) + return; + op->dropAllUses(); + RewriterBase::eraseOp(op); + } - /// Utility used to convert block arguments. - ArgConverter argConverter; + /// Erase the given block (unless it was already erased). + void eraseBlock(Block *block) override { + if (erased.contains(block)) + return; + assert(block->empty() && "expected empty block"); + block->dropAllDefinedValueUses(); + RewriterBase::eraseBlock(block); + } - /// Ordered vector of all of the newly created operations during conversion. - SmallVector<Operation *> createdOps; + void notifyOperationErased(Operation *op) override { erased.insert(op); } + void notifyBlockErased(Block *block) override { erased.insert(block); } - /// Ordered vector of all unresolved type conversion materializations during - /// conversion. - SmallVector<UnresolvedMaterialization> unresolvedMaterializations; + /// Pointers to all erased operations and blocks. + SetVector<void *> erased; + }; + + //===--------------------------------------------------------------------===// + // State + //===--------------------------------------------------------------------===// - /// Ordered map of requested operation replacements. - llvm::MapVector<Operation *, OpReplacement> replacements; + /// This rewriter must be used for erasing ops/blocks. + SingleEraseRewriter eraseRewriter; - /// Ordered vector of any requested block argument replacements. - SmallVector<BlockArgument, 4> argReplacements; + // Mapping between replaced values that differ in type. This happens when + // replacing a value with one of a different type. + ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). SmallVector<std::unique_ptr<IRRewrite>> rewrites; @@ -1182,11 +954,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector<Operation *> ignoredOps; - /// A vector of indices into `replacements` of operations that were replaced - /// with values with different result types than the original operation, e.g. - /// 1->N conversion of some kind. - SmallVector<unsigned, 4> operationsWithChangedResults; - /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1195,8 +962,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// converting the arguments of blocks within that region. DenseMap<Region *, const TypeConverter *> regionToConverter; - /// This allows the user to collect the match failure message. - function_ref<void(Diagnostic &)> notifyCallback; + /// Dialect conversion configuration. + const ConversionConfig &config; #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't @@ -1211,154 +978,193 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } // namespace detail } // namespace mlir -void BlockTypeConversionRewrite::rollback() { - // Undo the type conversion. - rewriterImpl.argConverter.discardRewrites(block); -} - -/// Detach any operations nested in the given operation from their parent -/// blocks, and erase the given operation. This can be used when the nested -/// operations are scheduled for erasure themselves, so deleting the regions of -/// the given operation together with their content would result in double-free. -/// This happens, for example, when rolling back op creation in the reverse -/// order and if the nested ops were created before the parent op. This function -/// does not need to collect nested ops recursively because it is expected to -/// also be called for each nested op when it is about to be deleted. -static void detachNestedAndErase(Operation *op) { - for (Region ®ion : op->getRegions()) { - for (Block &block : region.getBlocks()) { - while (!block.getOperations().empty()) - block.getOperations().remove(block.getOperations().begin()); - block.dropAllDefinedValueUses(); +void IRRewrite::eraseOp(Operation *op) { + rewriterImpl.eraseRewriter.eraseOp(op); +} + +void IRRewrite::eraseBlock(Block *block) { + rewriterImpl.eraseRewriter.eraseBlock(block); +} + +const ConversionConfig &IRRewrite::getConfig() const { + return rewriterImpl.config; +} + +void BlockTypeConversionRewrite::commit() { + // Process the remapping for each of the original arguments. + for (auto [origArg, info] : + llvm::zip_equal(origBlock->getArguments(), argInfo)) { + // Handle the case of a 1->0 value mapping. + if (!info) { + if (Value newArg = + rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + origArg.replaceAllUsesWith(newArg); + continue; + } + + // Otherwise this is a 1->1+ value mapping. + Value castValue = info->castValue; + assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); + + // If the argument is still used, replace it with the generated cast. + if (!origArg.use_empty()) { + origArg.replaceAllUsesWith( + rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType())); } } - op->dropAllUses(); - op->erase(); + + assert(origBlock->empty() && "expected empty block"); + origBlock->dropAllDefinedValueUses(); + delete origBlock; + origBlock = nullptr; } -void ConversionPatternRewriterImpl::discardRewrites() { - undoRewrites(); +void BlockTypeConversionRewrite::rollback() { + // Drop all uses of the new block arguments and replace uses of the new block. + for (int i = block->getNumArguments() - 1; i >= 0; --i) + block->getArgument(i).dropAllUses(); + block->replaceAllUsesWith(origBlock); - // Remove any newly created ops. - for (UnresolvedMaterialization &materialization : unresolvedMaterializations) - detachNestedAndErase(materialization.getOp()); - for (auto *op : llvm::reverse(createdOps)) - detachNestedAndErase(op); + // Move the operations back the original block, move the original block back + // into its original location and the delete the new block. + origBlock->getOperations().splice(origBlock->end(), block->getOperations()); + block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); + eraseBlock(block); } -void ConversionPatternRewriterImpl::applyRewrites() { - // Apply all of the rewrites replacements requested during conversion. - for (auto &repl : replacements) { - for (OpResult result : repl.first->getResults()) - if (Value newValue = mapping.lookupOrNull(result, result.getType())) - result.replaceAllUsesWith(newValue); - - // If this operation defines any regions, drop any pending argument - // rewrites. - if (repl.first->getNumRegions()) - argConverter.notifyOpRemoved(repl.first); - } - - // Apply all of the requested argument replacements. - for (BlockArgument arg : argReplacements) { - Value repl = mapping.lookupOrNull(arg, arg.getType()); - if (!repl) +LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( + function_ref<Operation *(Value)> findLiveUser) { + // Process the remapping for each of the original arguments. + for (auto it : llvm::enumerate(origBlock->getArguments())) { + BlockArgument origArg = it.value(); + // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used. + OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl); + builder.setInsertionPointToStart(block); + + // If the type of this argument changed and the argument is still live, we + // need to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) continue; - - if (isa<BlockArgument>(repl)) { - arg.replaceAllUsesWith(repl); + Operation *liveUser = findLiveUser(origArg); + if (!liveUser) continue; + + Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); + bool isDroppedArg = replacementValue == origArg; + if (!isDroppedArg) + builder.setInsertionPointAfterValue(replacementValue); + Value newArg; + if (converter) { + newArg = converter->materializeSourceConversion( + builder, origArg.getLoc(), origArg.getType(), + isDroppedArg ? ValueRange() : ValueRange(replacementValue)); + assert((!newArg || newArg.getType() == origArg.getType()) && + "materialization hook did not provide a value of the expected " + "type"); } + if (!newArg) { + InFlightDiagnostic diag = + emitError(origArg.getLoc()) + << "failed to materialize conversion for block argument #" + << it.index() << " that remained live after conversion, type was " + << origArg.getType(); + if (!isDroppedArg) + diag << ", with target type " << replacementValue.getType(); + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; + return failure(); + } + rewriterImpl.mapping.map(origArg, newArg); + } + return success(); +} - // If the replacement value is an operation, we check to make sure that we - // don't replace uses that are within the parent operation of the - // replacement value. - Operation *replOp = cast<OpResult>(repl).getOwner(); - Block *replBlock = replOp->getBlock(); - arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { - Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); - }); +void ReplaceBlockArgRewrite::commit() { + Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); + if (!repl) + return; + + if (isa<BlockArgument>(repl)) { + arg.replaceAllUsesWith(repl); + return; } - // Drop all of the unresolved materialization operations created during - // conversion. - for (auto &mat : unresolvedMaterializations) { - mat.getOp()->dropAllUses(); - mat.getOp()->erase(); + // If the replacement value is an operation, we check to make sure that we + // don't replace uses that are within the parent operation of the + // replacement value. + Operation *replOp = cast<OpResult>(repl).getOwner(); + Block *replBlock = replOp->getBlock(); + arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + }); +} + +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } + +void ReplaceOperationRewrite::commit() { + for (OpResult result : op->getResults()) + if (Value newValue = + rewriterImpl.mapping.lookupOrNull(result, result.getType())) + result.replaceAllUsesWith(newValue); + if (getConfig().unlegalizedOps) + getConfig().unlegalizedOps->erase(op); + // Do not erase the operation yet. It may still be referenced in `mapping`. + op->getBlock()->getOperations().remove(op); +} + +void ReplaceOperationRewrite::rollback() { + for (auto result : op->getResults()) + rewriterImpl.mapping.erase(result); +} + +void ReplaceOperationRewrite::cleanup() { eraseOp(op); } + +void CreateOperationRewrite::rollback() { + for (Region ®ion : op->getRegions()) { + while (!region.getBlocks().empty()) + region.getBlocks().remove(region.getBlocks().begin()); } + op->dropAllUses(); + eraseOp(op); +} - // In a second pass, erase all of the replaced operations in reverse. This - // allows processing nested operations before their parent region is - // destroyed. Because we process in reverse order, producers may be deleted - // before their users (a pattern deleting a producer and then the consumer) - // so we first drop all uses explicitly. - for (auto &repl : llvm::reverse(replacements)) { - repl.first->dropAllUses(); - repl.first->erase(); +void UnresolvedMaterializationRewrite::rollback() { + if (getMaterializationKind() == MaterializationKind::Target) { + for (Value input : op->getOperands()) + rewriterImpl.mapping.erase(input); } + eraseOp(op); +} - argConverter.applyRewrites(mapping); +void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); } +void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. for (auto &rewrite : rewrites) rewrite->commit(); + for (auto &rewrite : rewrites) + rewrite->cleanup(); } //===----------------------------------------------------------------------===// // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(createdOps.size(), unresolvedMaterializations.size(), - replacements.size(), argReplacements.size(), - rewrites.size(), ignoredOps.size()); + return RewriterState(rewrites.size(), ignoredOps.size(), + eraseRewriter.erased.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Reset any replaced arguments. - for (BlockArgument replacedArg : - llvm::drop_begin(argReplacements, state.numArgReplacements)) - mapping.erase(replacedArg); - argReplacements.resize(state.numArgReplacements); - // Undo any rewrites. undoRewrites(state.numRewrites); - // Reset any replaced operations and undo any saved mappings. - for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) - for (auto result : repl.first->getResults()) - mapping.erase(result); - while (replacements.size() != state.numReplacements) - replacements.pop_back(); - - // Pop all of the newly inserted materializations. - while (unresolvedMaterializations.size() != - state.numUnresolvedMaterializations) { - UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val(); - UnrealizedConversionCastOp op = mat.getOp(); - - // If this was a target materialization, drop the mapping that was inserted. - if (mat.getKind() == UnresolvedMaterialization::Target) { - for (Value input : op->getOperands()) - mapping.erase(input); - } - detachNestedAndErase(op); - } - - // Pop all of the newly created operations. - while (createdOps.size() != state.numCreatedOps) { - detachNestedAndErase(createdOps.back()); - createdOps.pop_back(); - } - // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); - // Reset operations with changed results. - while (!operationsWithChangedResults.empty() && - operationsWithChangedResults.back() >= state.numReplacements) - operationsWithChangedResults.pop_back(); + while (eraseRewriter.erased.size() != state.numErased) + eraseRewriter.erased.pop_back(); } void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { @@ -1413,8 +1219,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( if (currentTypeConverter && desiredType && newOperandType != desiredType) { Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); Value castValue = buildUnresolvedTargetMaterialization( - operandLoc, newOperand, desiredType, currentTypeConverter, - unresolvedMaterializations); + operandLoc, newOperand, desiredType, currentTypeConverter); mapping.map(mapping.lookupOrDefault(newOperand), castValue); newOperand = castValue; } @@ -1425,7 +1230,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation was replaced or its parent ignored. - return replacements.count(op) || ignoredOps.count(op->getParentOp()); + return ignoredOps.count(op->getParentOp()) || + hasRewrite<ReplaceOperationRewrite>(rewrites, op); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { @@ -1447,18 +1253,18 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion) { - FailureOr<Block *> result = - conversion ? argConverter.applySignatureConversion( - block, converter, *conversion, mapping, argReplacements) - : argConverter.convertSignature(block, converter, mapping, - argReplacements); - if (failed(result)) + if (conversion) + return applySignatureConversion(block, converter, *conversion); + + // If a converter wasn't provided, and the block wasn't already converted, + // there is nothing we can do. + if (!converter) return failure(); - if (Block *newBlock = *result) { - if (newBlock != block) - appendRewrite<BlockTypeConversionRewrite>(newBlock); - } - return result; + + // Try to convert the signature for the block with the provided converter. + if (auto conversion = converter->convertBlockSignature(block)) + return applySignatureConversion(block, converter, *conversion); + return failure(); } Block *ConversionPatternRewriterImpl::applySignatureConversion( @@ -1512,6 +1318,145 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( return success(); } +Block *ConversionPatternRewriterImpl::applySignatureConversion( + Block *block, const TypeConverter *converter, + TypeConverter::SignatureConversion &signatureConversion) { + MLIRContext *ctx = eraseRewriter.getContext(); + + // If no arguments are being changed or added, there is nothing to do. + unsigned origArgCount = block->getNumArguments(); + auto convertedTypes = signatureConversion.getConvertedTypes(); + if (llvm::equal(block->getArgumentTypes(), convertedTypes)) + return block; + + // Split the block at the beginning to get a new block to use for the updated + // signature. + Block *newBlock = block->splitBlock(block->begin()); + block->replaceAllUsesWith(newBlock); + // Unlink the block, but do not erase it yet, so that the change can be rolled + // back. + block->getParent()->getBlocks().remove(block); + + // Map all new arguments to the location of the argument they originate from. + SmallVector<Location> newLocs(convertedTypes.size(), + Builder(ctx).getUnknownLoc()); + for (unsigned i = 0; i < origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap || inputMap->replacementValue) + continue; + Location origLoc = block->getArgument(i).getLoc(); + for (unsigned j = 0; j < inputMap->size; ++j) + newLocs[inputMap->inputNo + j] = origLoc; + } + + SmallVector<Value, 4> newArgRange( + newBlock->addArguments(convertedTypes, newLocs)); + ArrayRef<Value> newArgs(newArgRange); + + // Remap each of the original arguments as determined by the signature + // conversion. + SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; + argInfo.resize(origArgCount); + + for (unsigned i = 0; i != origArgCount; ++i) { + auto inputMap = signatureConversion.getInputMapping(i); + if (!inputMap) + continue; + BlockArgument origArg = block->getArgument(i); + + // If inputMap->replacementValue is not nullptr, then the argument is + // dropped and a replacement value is provided to be the remappedValue. + if (inputMap->replacementValue) { + assert(inputMap->size == 0 && + "invalid to provide a replacement value when the argument isn't " + "dropped"); + mapping.map(origArg, inputMap->replacementValue); + appendRewrite<ReplaceBlockArgRewrite>(block, origArg); + continue; + } + + // Otherwise, this is a 1->1+ mapping. + auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); + Value newArg; + + // If this is a 1->1 mapping and the types of new and replacement arguments + // match (i.e. it's an identity map), then the argument is mapped to its + // original type. + // FIXME: We simply pass through the replacement argument if there wasn't a + // converter, which isn't great as it allows implicit type conversions to + // appear. We should properly restructure this code to handle cases where a + // converter isn't provided and also to properly handle the case where an + // argument materialization is actually a temporary source materialization + // (e.g. in the case of 1->N). + if (replArgs.size() == 1 && + (!converter || replArgs[0].getType() == origArg.getType())) { + newArg = replArgs.front(); + } else { + Type origOutputType = origArg.getType(); + + // Legalize the argument output type. + Type outputType = origOutputType; + if (Type legalOutputType = converter->convertType(outputType)) + outputType = legalOutputType; + + newArg = buildUnresolvedArgumentMaterialization( + newBlock, origArg.getLoc(), replArgs, origOutputType, outputType, + converter); + } + + mapping.map(origArg, newArg); + appendRewrite<ReplaceBlockArgRewrite>(block, origArg); + argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); + } + + appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo, + converter); + return newBlock; +} + +//===----------------------------------------------------------------------===// +// Materializations +//===----------------------------------------------------------------------===// + +/// Build an unresolved materialization operation given an output type and set +/// of input operands. +Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( + MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, + Location loc, ValueRange inputs, Type outputType, Type origOutputType, + const TypeConverter *converter) { + // Avoid materializing an unnecessary cast. + if (inputs.size() == 1 && inputs.front().getType() == outputType) + return inputs.front(); + + // Create an unresolved materialization. We use a new OpBuilder to avoid + // tracking the materialization like we do for other operations. + OpBuilder builder(insertBlock, insertPt); + auto convertOp = + builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, + origOutputType); + return convertOp.getResult(0); +} +Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( + Block *block, Location loc, ValueRange inputs, Type origOutputType, + Type outputType, const TypeConverter *converter) { + return buildUnresolvedMaterialization(MaterializationKind::Argument, block, + block->begin(), loc, inputs, outputType, + origOutputType, converter); +} +Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( + Location loc, Value input, Type outputType, + const TypeConverter *converter) { + Block *insertBlock = input.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast<OpResult>(input)) + insertPt = ++inputRes.getOwner()->getIterator(); + + return buildUnresolvedMaterialization(MaterializationKind::Target, + insertBlock, insertPt, loc, input, + outputType, outputType, converter); +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1523,7 +1468,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( }); if (!previous.isSet()) { // This is a newly created op. - createdOps.push_back(op); + appendRewrite<CreateOperationRewrite>(op); return; } Operation *prevOp = previous.getPoint() == previous.getBlock()->end() @@ -1535,7 +1480,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); - assert(!replacements.count(op) && "operation was already replaced"); +#ifndef NDEBUG + for (auto &rewrite : rewrites) + if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get())) + assert(opReplacement->getOperation() != op && + "operation was already replaced"); +#endif // NDEBUG // Track if any of the results changed, e.g. erased and replaced with null. bool resultChanged = false; @@ -1550,11 +1500,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, mapping.map(result, newValue); resultChanged |= (newValue.getType() != result.getType()); } - if (resultChanged) - operationsWithChangedResults.push_back(replacements.size()); - // Record the requested operation replacement. - replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter))); + appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter, + resultChanged); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1594,8 +1542,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; - if (notifyCallback) - notifyCallback(diag); + if (config.notifyCallback) + config.notifyCallback(diag); }); } @@ -1603,9 +1551,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( // ConversionPatternRewriter //===----------------------------------------------------------------------===// -ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) +ConversionPatternRewriter::ConversionPatternRewriter( + MLIRContext *ctx, const ConversionConfig &config) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this)) { + impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { setListener(impl.get()); } @@ -1649,8 +1598,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { } void ConversionPatternRewriter::eraseBlock(Block *block) { - impl->notifyBlockIsBeingErased(block); - // Mark all ops for erasure. for (Operation &op : *block) eraseOp(&op); @@ -1659,6 +1606,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { // object and will be actually destroyed when rewrites are applied. This // allows us to keep the operations in the block live and undo the removal by // re-inserting the block. + impl->notifyBlockIsBeingErased(block); block->getParent()->getBlocks().remove(block); } @@ -1688,7 +1636,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); - impl->argReplacements.push_back(from); + impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } @@ -2022,13 +1970,16 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. - for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); i != e; ++i) { - Operation *cstOp = rewriterImpl.createdOps[i]; - if (failed(legalize(cstOp, rewriter))) { + auto *createOp = + dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get()); + if (!createOp) + continue; + if (failed(legalize(createOp->getOperation(), rewriter))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", - cstOp->getName())); + createOp->getOperation()->getName())); rewriterImpl.resetState(curState); return failure(); } @@ -2054,12 +2005,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op, assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); - if (rewriterImpl.notifyCallback) { + if (rewriterImpl.config.notifyCallback) { Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); diag << "Failed to apply pattern \"" << pattern.getDebugName() << "\" on op:\n" << *op; - rewriterImpl.notifyCallback(diag); + rewriterImpl.config.notifyCallback(diag); } }); rewriterImpl.resetState(curState); @@ -2112,16 +2063,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, #ifndef NDEBUG assert(impl.pendingRootUpdates.empty() && "dangling root updates"); - // Check that the root was either replaced or updated in place. + auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); auto replacedRoot = [&] { - return llvm::any_of( - llvm::drop_begin(impl.replacements, curState.numReplacements), - [op](auto &it) { return it.first == op; }); + return hasRewrite<ReplaceOperationRewrite>(newRewrites, op); }; auto updatedRootInPlace = [&] { - return hasRewrite<ModifyOperationRewrite>( - llvm::drop_begin(impl.rewrites, curState.numRewrites), op); + return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; assert((replacedRoot() || updatedRootInPlace()) && "expected pattern to replace the root operation"); @@ -2154,7 +2102,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( if (!rewrite) continue; Block *block = rewrite->getBlock(); - if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite)) + if (isa<BlockTypeConversionRewrite, EraseBlockRewrite, + ReplaceBlockArgRewrite>(rewrite)) continue; // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); @@ -2177,9 +2126,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // blocks in regions created by this pattern will already be legalized later // on. If we haven't built the set yet, build it now. if (operationsToIgnore.empty()) { - auto createdOps = ArrayRef<Operation *>(impl.createdOps) - .drop_front(state.numCreatedOps); - operationsToIgnore.insert(createdOps.begin(), createdOps.end()); + for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; + ++i) { + auto *createOp = + dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); + if (!createOp) + continue; + operationsToIgnore.insert(createOp->getOperation()); + } } // If this operation should be considered for re-legalization, try it. @@ -2197,8 +2151,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( LogicalResult OperationLegalizer::legalizePatternCreatedOperations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { - for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { - Operation *op = impl.createdOps[i]; + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); + if (!createOp) + continue; + Operation *op = createOp->getOperation(); if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to legalize generated operation '{0}'({1})", @@ -2432,21 +2389,21 @@ enum OpConversionMode { /// applied to the operations on success. Analysis, }; +} // namespace +namespace mlir { // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. struct OperationConverter { explicit OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - OpConversionMode mode, - DenseSet<Operation *> *trackedOps = nullptr) - : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} + const ConversionConfig &config, + OpConversionMode mode) + : opLegalizer(target, patterns), config(config), mode(mode) {} /// Converts the given operations to the conversion target. - LogicalResult - convertOperations(ArrayRef<Operation *> ops, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); + LogicalResult convertOperations(ArrayRef<Operation *> ops); private: /// Converts an operation with the given rewriter. @@ -2483,16 +2440,13 @@ private: /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; + /// Dialect conversion configuration. + ConversionConfig config; + /// The conversion mode to use when legalizing operations. OpConversionMode mode; - - /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, - /// this is populated with ops found to be legalizable to the target. - /// When mode == OpConversionMode::Partial, this is populated with ops found - /// *not* to be legalizable to the target. - DenseSet<Operation *> *trackedOps; }; -} // namespace +} // namespace mlir LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { @@ -2504,28 +2458,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, return op->emitError() << "failed to legalize operation '" << op->getName() << "'"; // Partial conversions allow conversions to fail iff the operation was not - // explicitly marked as illegal. If the user provided a nonlegalizableOps - // set, non-legalizable ops are included. + // explicitly marked as illegal. If the user provided a `unlegalizedOps` + // set, non-legalizable ops are added to that set. if (mode == OpConversionMode::Partial) { if (opLegalizer.isIllegal(op)) return op->emitError() << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; - if (trackedOps) - trackedOps->insert(op); + if (config.unlegalizedOps) + config.unlegalizedOps->insert(op); } } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - trackedOps->insert(op); + if (config.legalizableOps) + config.legalizableOps->insert(op); } return success(); } -LogicalResult OperationConverter::convertOperations( - ArrayRef<Operation *> ops, - function_ref<void(Diagnostic &)> notifyCallback) { +LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { if (ops.empty()) return success(); const ConversionTarget &target = opLegalizer.getTarget(); @@ -2546,33 +2499,25 @@ LogicalResult OperationConverter::convertOperations( } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext()); + ConversionPatternRewriter rewriter(ops.front()->getContext(), config); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - rewriterImpl.notifyCallback = notifyCallback; for (auto *op : toConvert) if (failed(convert(rewriter, op))) - return rewriterImpl.discardRewrites(), failure(); + return rewriterImpl.undoRewrites(), failure(); // Now that all of the operations have been converted, finalize the conversion // process to ensure any lingering conversion artifacts are cleaned up and // legalized. if (failed(finalize(rewriter))) - return rewriterImpl.discardRewrites(), failure(); + return rewriterImpl.undoRewrites(), failure(); // After a successful conversion, apply rewrites if this is not an analysis // conversion. if (mode == OpConversionMode::Analysis) { - rewriterImpl.discardRewrites(); + rewriterImpl.undoRewrites(); } else { rewriterImpl.applyRewrites(); - - // It is possible for a later pattern to erase an op that was originally - // identified as illegal and added to the trackedOps, remove it now after - // replacements have been computed. - if (trackedOps) - for (auto &repl : rewriterImpl.replacements) - trackedOps->erase(repl.first); } return success(); } @@ -2586,21 +2531,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) return failure(); - if (rewriterImpl.operationsWithChangedResults.empty()) - return success(); - // Process requested operation replacements. - for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); - i != e; ++i) { - unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; - auto &repl = *(rewriterImpl.replacements.begin() + replIdx); - for (OpResult result : repl.first->getResults()) { + for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { + auto *opReplacement = + dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get()); + if (!opReplacement || !opReplacement->hasChangedResults()) + continue; + Operation *op = opReplacement->getOperation(); + for (OpResult result : op->getResults()) { Value newValue = rewriterImpl.mapping.lookupOrNull(result); // If the operation result was replaced with null, all of the uses of this // value should be replaced. if (!newValue) { - if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) + if (failed(legalizeErasedResult(op, result, rewriterImpl))) return failure(); continue; } @@ -2614,15 +2558,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { inverseMapping = rewriterImpl.mapping.getInverse(); // Legalize this result. - rewriter.setInsertionPoint(repl.first); - if (failed(legalizeChangedResultType(repl.first, result, newValue, - repl.second.converter, rewriter, - rewriterImpl, *inverseMapping))) + rewriter.setInsertionPoint(op); + if (failed(legalizeChangedResultType( + op, result, newValue, opReplacement->getConverter(), rewriter, + rewriterImpl, *inverseMapping))) return failure(); - - // Update the end iterator for this loop in the case it was updated - // when legalizing generated conversion operations. - e = rewriterImpl.operationsWithChangedResults.size(); } } return success(); @@ -2639,8 +2579,17 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( }); return liveUserIt == val.user_end() ? nullptr : *liveUserIt; }; - return rewriterImpl.argConverter.materializeLiveConversions( - rewriterImpl.mapping, rewriter, findLiveUser); + // Note: `rewrites` may be reallocated as the loop is running. + for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size()); + ++i) { + auto &rewrite = rewriterImpl.rewrites[i]; + if (auto *blockTypeConversionRewrite = + dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) + if (failed(blockTypeConversionRewrite->materializeLiveConversions( + findLiveUser))) + return failure(); + } + return success(); } /// Replace the results of a materialization operation with the given values. @@ -2672,11 +2621,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, /// Compute all of the unresolved materializations that will persist beyond the /// conversion process, and require inserting a proper user materialization for. static void computeNecessaryMaterializations( - DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, + DenseMap<Operation *, UnresolvedMaterializationRewrite *> + &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap<Value, SmallVector<Value>> &inverseMapping, - SetVector<UnresolvedMaterialization *> &necessaryMaterializations) { + SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) { auto isLive = [&](Value value) { auto findFn = [&](Operation *user) { auto matIt = materializationOps.find(user); @@ -2711,14 +2661,17 @@ static void computeNecessaryMaterializations( return Value(); }; - SetVector<UnresolvedMaterialization *> worklist; - for (auto &mat : rewriterImpl.unresolvedMaterializations) { - materializationOps.try_emplace(mat.getOp(), &mat); - worklist.insert(&mat); + SetVector<UnresolvedMaterializationRewrite *> worklist; + for (auto &rewrite : rewriterImpl.rewrites) { + auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get()); + if (!mat) + continue; + materializationOps.try_emplace(mat->getOperation(), mat); + worklist.insert(mat); } while (!worklist.empty()) { - UnresolvedMaterialization *mat = worklist.pop_back_val(); - UnrealizedConversionCastOp op = mat->getOp(); + UnresolvedMaterializationRewrite *mat = worklist.pop_back_val(); + UnrealizedConversionCastOp op = mat->getOperation(); // We currently only handle target materializations here. assert(op->getNumResults() == 1 && "unexpected materialization type"); @@ -2760,7 +2713,7 @@ static void computeNecessaryMaterializations( auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); }; if (llvm::any_of(op->getOperands(), isBlockArg) || llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { - mat->setKind(UnresolvedMaterialization::Argument); + mat->setMaterializationKind(MaterializationKind::Argument); } // If the materialization does not have any live users, we don't need to @@ -2770,7 +2723,7 @@ static void computeNecessaryMaterializations( // value replacement even if the types differ in some cases. When those // patterns are fixed, we can drop the argument special case here. bool isMaterializationLive = isLive(opResult); - if (mat->getKind() == UnresolvedMaterialization::Argument) + if (mat->getMaterializationKind() == MaterializationKind::Argument) isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); if (!isMaterializationLive) continue; @@ -2790,8 +2743,9 @@ static void computeNecessaryMaterializations( /// Legalize the given unresolved materialization. Returns success if the /// materialization was legalized, failure otherise. static LogicalResult legalizeUnresolvedMaterialization( - UnresolvedMaterialization &mat, - DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, + UnresolvedMaterializationRewrite &mat, + DenseMap<Operation *, UnresolvedMaterializationRewrite *> + &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap<Value, SmallVector<Value>> &inverseMapping) { @@ -2811,7 +2765,7 @@ static LogicalResult legalizeUnresolvedMaterialization( return Value(); }; - UnrealizedConversionCastOp op = mat.getOp(); + UnrealizedConversionCastOp op = mat.getOperation(); if (!rewriterImpl.ignoredOps.insert(op)) return success(); @@ -2861,8 +2815,8 @@ static LogicalResult legalizeUnresolvedMaterialization( rewriter.setInsertionPoint(op); Value newMaterialization; - switch (mat.getKind()) { - case UnresolvedMaterialization::Argument: + switch (mat.getMaterializationKind()) { + case MaterializationKind::Argument: // Try to materialize an argument conversion. // FIXME: The current argument materialization hook expects the original // output type, even though it doesn't use that as the actual output type @@ -2879,7 +2833,7 @@ static LogicalResult legalizeUnresolvedMaterialization( // If an argument materialization failed, fallback to trying a target // materialization. [[fallthrough]]; - case UnresolvedMaterialization::Target: + case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), outputType, inputOperands); break; @@ -2907,14 +2861,12 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) { - if (rewriterImpl.unresolvedMaterializations.empty()) - return success(); inverseMapping = rewriterImpl.mapping.getInverse(); // As an initial step, compute all of the inserted materializations that we // expect to persist beyond the conversion process. - DenseMap<Operation *, UnresolvedMaterialization *> materializationOps; - SetVector<UnresolvedMaterialization *> necessaryMaterializations; + DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps; + SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations; computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, *inverseMapping, necessaryMaterializations); @@ -3535,57 +3487,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { //===----------------------------------------------------------------------===// // Partial Conversion -LogicalResult -mlir::applyPartialConversion(ArrayRef<Operation *> ops, - const ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps) { - OperationConverter opConverter(target, patterns, OpConversionMode::Partial, - unconvertedOps); +LogicalResult mlir::applyPartialConversion( + ArrayRef<Operation *> ops, const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Partial); return opConverter.convertOperations(ops); } LogicalResult mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps) { - return applyPartialConversion(llvm::ArrayRef(op), target, patterns, - unconvertedOps); + ConversionConfig config) { + return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); } //===----------------------------------------------------------------------===// // Full Conversion -LogicalResult -mlir::applyFullConversion(ArrayRef<Operation *> ops, - const ConversionTarget &target, - const FrozenRewritePatternSet &patterns) { - OperationConverter opConverter(target, patterns, OpConversionMode::Full); +LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Full); return opConverter.convertOperations(ops); } -LogicalResult -mlir::applyFullConversion(Operation *op, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns) { - return applyFullConversion(llvm::ArrayRef(op), target, patterns); +LogicalResult mlir::applyFullConversion(Operation *op, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config) { + return applyFullConversion(llvm::ArrayRef(op), target, patterns, config); } //===----------------------------------------------------------------------===// // Analysis Conversion -LogicalResult -mlir::applyAnalysisConversion(ArrayRef<Operation *> ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback) { - OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, - &convertedOps); - return opConverter.convertOperations(ops, notifyCallback); +LogicalResult mlir::applyAnalysisConversion( + ArrayRef<Operation *> ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Analysis); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback) { - return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, - convertedOps, notifyCallback); + ConversionConfig config) { + return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config); } diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c index 5a78fac..1817988 100644 --- a/mlir/test/CAPI/llvm.c +++ b/mlir/test/CAPI/llvm.c @@ -15,6 +15,7 @@ #include "mlir-c/Support.h" #include <assert.h> +#include <inttypes.h> #include <math.h> #include <stdio.h> #include <stdlib.h> @@ -105,7 +106,7 @@ static int testStructTypeCreation(MlirContext ctx) { // CHECK: i8 // CHECK: i32 // CHECK: i64 - fprintf(stderr, "num elements: %ld\n", + fprintf(stderr, "num elements: %" PRIdPTR "\n", mlirLLVMStructTypeGetNumElementTypes(literal)); for (intptr_t i = 0; i < 3; ++i) { mlirTypeDump(mlirLLVMStructTypeGetElementType(literal, i)); diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index 3de2f11d..56129db 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) { // ----- +// CHECK-LABEL: func @log1p_scalable_vector( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]] : vector<[4]xf32> + // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<[4]xf32>) -> vector<[4]xf32> + %0 = math.log1p %arg0 : vector<[4]xf32> + func.return %0 : vector<[4]xf32> +} + +// ----- + // CHECK-LABEL: func @expm1( // CHECK-SAME: f32 func.func @expm1(%arg0 : f32) { @@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) { // ----- +// CHECK-LABEL: func @expm1_scalable_vector( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32> + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32> + %0 = math.expm1 %arg0 : vector<[4]xf32> + func.return %0 : vector<[4]xf32> +} + +// ----- + // CHECK-LABEL: func @expm1_vector_fmf( // CHECK-SAME: vector<4xf32> func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) { @@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) { // ----- +// CHECK-LABEL: func @cttz_scalable_vec( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32> +func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32> + %0 = math.cttz %arg0 : vector<[4]xi32> + func.return %0 : vector<[4]xi32> +} + +// ----- + // CHECK-LABEL: func @ctpop( // CHECK-SAME: i32 func.func @ctpop(%arg0 : i32) { @@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) { // ----- +// CHECK-LABEL: func @ctpop_scalable_vector( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32> +func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32> + %0 = math.ctpop %arg0 : vector<[4]xi32> + func.return %0 : vector<[4]xi32> +} + +// ----- + // CHECK-LABEL: func @rsqrt_double( // CHECK-SAME: f64 func.func @rsqrt_double(%arg0 : f64) { @@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) { // ----- +// CHECK-LABEL: func @rsqrt_scalable_vector( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32>{ + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32> + %0 = math.rsqrt %arg0 : vector<[4]xf32> + func.return %0 : vector<[4]xf32> +} + +// ----- + // CHECK-LABEL: func @rsqrt_vector_fmf( // CHECK-SAME: vector<4xf32> func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) { @@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) { // ----- +// CHECK-LABEL: func @rsqrt_scalable_vector_fmf( +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32> + %0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32> + func.return %0 : vector<[4]xf32> +} + +// ----- + // CHECK-LABEL: func @rsqrt_multidim_vector( func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> @@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { // ----- +// CHECK-LABEL: func @rsqrt_multidim_scalable_vector( +func.func @rsqrt_multidim_scalable_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32> + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>> + %0 = math.rsqrt %arg0 : vector<4x[4]xf32> + func.return %0 : vector<4x[4]xf32> +} + +// ----- + // CHECK-LABEL: func @fpowi( // CHECK-SAME: f64 func.func @fpowi(%arg0 : f64, %arg1 : i32) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index febe74e..1fa783f 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () { // ----- +// CHECK-LABEL: @test_i64 +func.func @test_i64(%arg0: tensor<1xi64>) -> () { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG1:.+]]: i64, + // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808 + // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807 + // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]] + // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]] + %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64> + + return +} + +// ----- + // CHECK-LABEL: @test_clamp_f16 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () { // CHECK: linalg.generic diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 7adde31..206d7e9 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten // ----- // CHECK-LABEL: func @linalg_effects( -// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32> -// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32> -// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32> -func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) { +func.func @linalg_effects( + %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>, + %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) { // CHECK-NOT: %{{.*}} = linalg.matmul - %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>) + %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>) outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: linalg.matmul - linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>) - outs(%b : memref<?x?xf32>) + linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>) + outs(%f : memref<?x?xf32>) return } @@ -889,11 +888,11 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) -> // ----- #map = affine_map<(d0) -> (d0)> -func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) { +func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) { linalg.generic { indexing_maps = [#map, #map], iterator_types = ["parallel"] - } ins(%arg0 : tensor<?xf32>) + } ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?xf32>) { ^bb0(%arg2 : f32, %arg3 : f32): linalg.yield %arg2 : f32 @@ -901,14 +900,13 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) { return } -// There was a crash in EraseIdentityGenericOp for generic with mixed semantics. -// For now, check generic remained unchanged. -// CHECK-LABEL: func @identity_mixed -// CHECK-SAME: (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>) +// Do not erase ops with buffer semantics. +// CHECK-LABEL: func @identity_buffer +// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>) // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#map, #map], // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>) +// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>) // CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) { // ----- @@ -916,12 +914,12 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) { // Just make sure that we don't crash. // CHECK-LABEL: func @dedeplicate_regression_test -func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) { +func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) { %36 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} - ins(%1, %1 : memref<4xf32>, memref<4xf32>) + ins(%1, %1 : tensor<4xf32>, tensor<4xf32>) outs(%0 : tensor<4xf32>) { ^bb0(%in: f32, %in_24: f32, %out: f32): linalg.yield %in : f32 @@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) { // ----- -#map = affine_map<(d0) -> (d0)> -func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) { - %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32> - linalg.generic { - indexing_maps = [#map, #map], - iterator_types = ["parallel"] - } ins(%0 : tensor<?xf32>) - outs(%arg1 : memref<?xf32>) { - ^bb0(%arg2 : f32, %arg3 : f32): - linalg.yield %arg2 : f32 - } - return -} - -// We need a mixed linalg as a bridge between tensor and memref worlds. -// CHECK-LABEL: func @cast_producer_mixed -// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>) -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#map, #map], -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>) -// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) { - -// ----- - // CHECK-LABEL: dead_softmax func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { %0 = tensor.empty() : tensor<16x64x256xf32> diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 9d8421c..15a4f6c 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1110,43 +1110,3 @@ module { // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] - -// ----- - -// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @mixed_fusion -func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>) -{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> - %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> - %2 = tensor.empty(%0, %1) : tensor<?x?xf32> - %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) - outs(%2 : tensor<?x?xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %4 = arith.addf %arg3, %arg4 : f32 - linalg.yield %4 : f32 - } -> tensor<?x?xf32> - // CHECK: linalg.generic { - // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}} - linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} - ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) - outs(%arg8 : memref<?x?xf32>) { - // CHECK: ^{{[a-zA-Z0-9_]*}} - // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] - // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] - // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]] - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): - // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]] - // CHECK-NOT: linalg.yield - // CHECK: arith.mulf [[T1]], [[ARG2]] - // CHECK: linalg.yield - %5 = arith.mulf %arg5, %arg6 : f32 - linalg.yield %5 : f32 - } - return -} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 916c04f..44c81c3 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, -> tensor<8x8xf32> return %res : tensor<8x8xf32> } + +// ----- + +func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) { + // expected-error @+1 {{expected to have pure tensor or buffer semantics}} + linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>) + outs(%c: memref<?x?xf32>) + return +} + diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 2fb8029..572d3eb 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -127,3 +127,17 @@ func.func @multiple_chained_ops( // CHECK: return %[[RESHARD3]] : tensor<1xi8> return %7 : tensor<2xi8> } + +// CHECK-LABEL: func @incomplete_sharding +func.func @incomplete_sharding( + // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32> + %arg0: tensor<8x16xf32> +// CHECK-SAME: -> tensor<4x16xf32> { +) -> tensor<8x16xf32> { + %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32> + // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32> + %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32> + %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32> + // CHECK: return %[[RES]] : tensor<4x16xf32> + return %2 : tensor<8x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 3d68464..01b2707 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -376,6 +376,13 @@ func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { } // ----- +// CHECK-LABEL: cos +func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- // CHECK-LABEL: exp func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -425,6 +432,13 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { } // ----- +// CHECK-LABEL: sin +func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- // CHECK-LABEL: select func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 02063a8..94e78ce 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { // CHECK-LABEL: func.func @aligned_extsi( func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> { - // CHECK: arith.shli - // CHECK: arith.shrsi - // CHECK: arith.shrsi - // CHECK: vector.shuffle - // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32> +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> return %0 : vector<8xi32> } +// CHECK-LABEL: func.func @aligned_extsi_2d( +func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + // CHECK-LABEL: func.func @aligned_extsi_base_case( func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> { - // CHECK: arith.shli - // CHECK: arith.shrsi - // CHECK: arith.shrsi - // CHECK: vector.shuffle - // CHECK-NOT: arith.extsi +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> %0 = arith.extsi %a : vector<8xi4> to vector<8xi8> return %0 : vector<8xi8> } // CHECK-LABEL: func.func @aligned_sitofp( func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { - // CHECK: arith.shli - // CHECK: arith.shrsi - // CHECK: arith.shrsi - // CHECK: shuffle - // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32> +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> return %0 : vector<8xf32> } +// CHECK-LABEL: func.func @aligned_sitofp_2d( +func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + // CHECK-LABEL: func.func @i4_transpose( -// CHECK-SAME: %[[A:[0-9a-z]*]] func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> { - // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8> - // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4> +// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> { +// CHECK: %[[EXT:.*]] = vector.interleave +// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> +// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4> %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> return %0 : vector<16x8xi4> } // CHECK-LABEL: func.func @i7_transpose( -// CHECK-SAME: %[[A:[0-9a-z]*]] func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { - // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8> - // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> - // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7> +// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> { +// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8> +// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> +// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7> %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7> return %0 : vector<16x8xi7> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 1775b5f..788ae9a 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -83,7 +83,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( return } -// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)> +// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, @@ -92,7 +92,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 // CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> -// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]] +// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> // CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32> // CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32> @@ -459,3 +459,37 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, // CHECK-128B-LABEL: func @fold_unit_dims_entirely( // CHECK-128B-NOT: memref.collapse_shape + +// ----- + +func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, + %idx0 : index, %idx1 : index) -> vector<2x2xf32> { + %c0 = arith.constant 0 : index + %cst_1 = arith.constant 0.000000e+00 : f32 + %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> + return %8 : vector<2x2xf32> +} + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func.func @regression_non_contiguous_dim_read( +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> +// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() + +// CHECK-128B-LABEL: func @regression_non_contiguous_dim_read( +// CHECK-128B: memref.collapse_shape + +// ----- + +func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>, + %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, + %idx0 : index, %idx1 : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> + return +} + +// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write( +// CHECK-NOT: memref.collapse_shape + +// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write( +// CHECK-128B-NOT: memref.collapse_shape diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir index 44ff1af..12f13e8 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir @@ -1,13 +1,8 @@ // RUN: mlir-opt %s \ -// RUN: -transform-interpreter \ -// RUN: -test-transform-dialect-erase-schedule \ +// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ // RUN: -lower-vector-mask \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \ -// RUN: -test-lower-to-llvm | \ +// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=entry -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir index c781d5e..34c5351 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir @@ -1,12 +1,7 @@ // RUN: mlir-opt %s \ // RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ -// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \ -// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ -// RUN: -convert-arm-sme-to-llvm \ -// RUN: -convert-vector-to-llvm=enable-arm-sve \ -// RUN: -cse -canonicalize -test-lower-to-llvm | \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir index 31c3202..2bfdaa8 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir @@ -1,12 +1,6 @@ // RUN: mlir-opt %s \ // RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ -// RUN: -canonicalize \ -// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ -// RUN: -convert-arm-sme-to-llvm \ -// RUN: -convert-vector-to-llvm=enable-arm-sve \ -// RUN: -cse -canonicalize -test-lower-to-llvm | \ +// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir index d5c3506..e376bdd 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir @@ -1,11 +1,7 @@ // RUN: mlir-opt %s \ // RUN: -transform-interpreter -test-transform-dialect-erase-schedule \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \ -// RUN: -arm-sme-vector-legalization -canonicalize -cse \ -// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \ -// RUN: -test-lower-to-llvm | \ +// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir index 42fe21c..ee3866de 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -1,10 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ -// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \ -// RUN: -canonicalize -test-lower-to-llvm -verify-diagnostics | \ +// RUN: -test-lower-to-arm-sme -test-lower-to-llvm -verify-diagnostics | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir index 59b4a7e..06b1c10 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir @@ -1,9 +1,5 @@ // DEFINE: %{entry_point} = test_load_store_zaq0 -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir index 064141c..27be801 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir @@ -1,9 +1,5 @@ // DEFINE: %{entry_point} = entry -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir index 0827d9b..9d836d9 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir @@ -1,10 +1,4 @@ -// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \ -// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \ -// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ -// RUN: -convert-arm-sme-to-llvm \ -// RUN: -convert-vector-to-llvm=enable-arm-sve \ -// RUN: -cse -canonicalize -test-lower-to-llvm | \ +// RUN: mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=main -entry-point-result=void \ // RUN: -march=aarch64 -mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir index f081838..a06ad37 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir @@ -1,11 +1,7 @@ +// DEFINE: %{opts} = // DEFINE: %{entry} = main -// DEFINE: %{fusion_opts} = -arm-sme-outer-product-fusion // DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{fusion_opts} \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm -o %t +// DEFINE: -test-lower-to-arm-sme=%{opts} -test-lower-to-llvm -o %t // DEFINE: %{run} = %mcr_aarch64_cmd %t \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry} -entry-point-result=void \ @@ -18,7 +14,7 @@ // Check result is the same when outerproducts are not combined into widening // variant. -// REDEFINE: %{fusion_opts} = +// REDEFINE: %{opts} = fuse-outer-products=false // RUN: %{run} | FileCheck %s func.func @main() { diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir index 5f41b37..7e7869d 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir @@ -1,10 +1,6 @@ // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32 // DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm -o %t +// DEFINE: -test-lower-to-arm-sme -test-lower-to-llvm -o %t // DEFINE: %{run} = %mcr_aarch64_cmd %t \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir index a1bb9b7..46bf799 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir @@ -1,10 +1,6 @@ // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64 // DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm -o %t +// DEFINE: -test-lower-to-arm-sme -test-lower-to-llvm -o %t // DEFINE: %{run} = %mcr_aarch64_cmd %t \ // DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir index 1770e57..9a353ec 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir @@ -1,11 +1,5 @@ // DEFINE: %{entry} = main -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// DEFINE: -arm-sme-outer-product-fusion \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir index 6e028d5..52f5688 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir @@ -1,9 +1,5 @@ // DEFINE: %{entry_point} = entry -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir index c0c1f55..710cc66 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir @@ -1,10 +1,5 @@ // DEFINE: %{entry_point} = entry -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir index eee3c56..88bc0d0 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir @@ -1,9 +1,5 @@ // DEFINE: %{entry_point} = entry -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir index 223bc8c..e149174 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir @@ -1,8 +1,4 @@ -// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// RUN: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \ -// RUN: -test-lower-to-llvm | \ +// RUN: mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -march=aarch64 -mattr=+sve,+sme \ // RUN: -e entry -entry-point-result=i32 \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index 2f151e2..b29790db 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -1,9 +1,5 @@ // DEFINE: %{entry_point} = za0_d_f64 -// DEFINE: %{compile} = mlir-opt %s \ -// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ -// DEFINE: -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=i32 \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index f28bf19..c8c401b 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -1,8 +1,5 @@ // DEFINE: %{entry_point} = entry -// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ -// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \ -// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ -// DEFINE: -convert-arm-sme-to-llvm -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=i32 \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir new file mode 100644 index 0000000..07989bd --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -test-lower-to-llvm | \ +// RUN: %mcr_aarch64_cmd -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_arm_runner_utils \ +// RUN: -march=aarch64 -mattr=+sve | \ +// RUN: FileCheck %s + +func.func @entry() { + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %v1 = vector.splat %f1 : vector<[4]xf32> + %v2 = vector.splat %f2 : vector<[4]xf32> + vector.print %v1 : vector<[4]xf32> + vector.print %v2 : vector<[4]xf32> + // + // Test vectors: + // + // CHECK: ( 1, 1, 1, 1 + // CHECK: ( 2, 2, 2, 2 + + %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> + vector.print %v3 : vector<[8]xf32> + // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2 + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir new file mode 100644 index 0000000..0bc78af --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -test-lower-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +func.func @entry() { + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %v1 = vector.splat %f1 : vector<2x4xf32> + %v2 = vector.splat %f2 : vector<2x4xf32> + vector.print %v1 : vector<2x4xf32> + vector.print %v2 : vector<2x4xf32> + // + // Test vectors: + // + // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) ) + // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) ) + + %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> + vector.print %v3 : vector<2x8xf32> + // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) ) + + return +} diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index 0962134..9a4e939 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -59,13 +59,15 @@ define void @unhandled_intrinsic() gc "example" { ; // ----- +; Check that debug intrinsics with an unsupported argument are dropped. + declare void @llvm.dbg.value(metadata, metadata, metadata) ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5 +; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !DIArgList(i64 %{{.*}}, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)) ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5 -define void @dropped_instruction(i64 %arg1) { +; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()) +define void @unsupported_argument(i64 %arg1) { call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5 call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5 ret void @@ -83,6 +85,38 @@ define void @dropped_instruction(i64 %arg1) { ; // ----- +; Check that debug intrinsics that depend on cyclic metadata are dropped. + +declare void @llvm.dbg.value(metadata, metadata, metadata) + +; CHECK: import-failure.ll +; CHECK-SAME: warning: dropped instruction: call void @llvm.dbg.label(metadata !{{.*}}) +; CHECK: import-failure.ll +; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata i64 %{{.*}}, metadata !3, metadata !DIExpression()) +define void @cylic_metadata(i64 %arg1) { + call void @llvm.dbg.value(metadata i64 %arg1, metadata !10, metadata !DIExpression()), !dbg !14 + call void @llvm.dbg.label(metadata !13), !dbg !14 + ret void +} + +!llvm.dbg.cu = !{!1} +!llvm.module.flags = !{!0} +!0 = !{i32 2, !"Debug Info Version", i32 3} +!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2) +!2 = !DIFile(filename: "import-failure.ll", directory: "/") +!3 = !DICompositeType(tag: DW_TAG_array_type, size: 42, baseType: !4) +!4 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !3) +!5 = distinct !DISubprogram(name: "class_method", scope: !2, file: !2, type: !6, spFlags: DISPFlagDefinition, unit: !1) +!6 = !DISubroutineType(types: !7) +!7 = !{!3} +!10 = !DILocalVariable(scope: !5, name: "arg1", file: !2, line: 1, arg: 1, align: 64); +!11 = !DILexicalBlock(scope: !5) +!12 = !DILexicalBlockFile(scope: !11, discriminator: 0) +!13 = !DILabel(scope: !12, name: "label", file: !2, line: 42) +!14 = !DILocation(line: 1, column: 2, scope: !5) + +; // ----- + ; global_dtors with non-null data fields cannot be represented in MLIR. ; CHECK: <unknown> ; CHECK-SAME: error: unhandled global variable: @llvm.global_dtors diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt new file mode 100644 index 0000000..e942c7b --- /dev/null +++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,18 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRArmSMETestPasses + TestLowerToArmSME.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRArithToArmSME + MLIRArmSMEToLLVM + MLIRArmSMEToSCF + MLIRArmSMETransforms + MLIRArmSVETransforms + MLIRIR + MLIRPass + MLIRTransforms + MLIRVectorToArmSME + MLIRVectorToSCF + ) diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp new file mode 100644 index 0000000..48d4a58 --- /dev/null +++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp @@ -0,0 +1,99 @@ +//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the lowering to ArmSME as a +// generally usable sink pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" +#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" +#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" +#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct TestLowerToArmSMEOptions + : public PassPipelineOptions<TestLowerToArmSMEOptions> { + PassOptions::Option<bool> fuseOuterProducts{ + *this, "fuse-outer-products", + llvm::cl::desc("Fuse outer product operations via " + "'-arm-sme-outer-product-fusion' pass"), + llvm::cl::init(true)}; +}; + +void buildTestLowerToArmSME(OpPassManager &pm, + const TestLowerToArmSMEOptions &options) { + // Legalize vector operations so they can be converted to ArmSME. + pm.addPass(arm_sme::createVectorLegalizationPass()); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Passes that convert operations on vectors to ArmSME operations. + + // Convert Arith to ArmSME. + pm.addPass(createArithToArmSMEConversionPass()); + // Convert Vector to ArmSME. + pm.addPass(createConvertVectorToArmSMEPass()); + + // Fuse outer products. + if (options.fuseOuterProducts) + pm.addPass(arm_sme::createOuterProductFusionPass()); + + // Convert operations on high-level vectors to loops. + + // Convert ArmSME to SCF. + pm.addPass(createConvertArmSMEToSCFPass()); + + // Convert Vector to SCF (with full unroll enabled). + pm.addPass(createConvertVectorToSCFPass( + VectorTransferToSCFOptions().enableFullUnroll())); + + // Allocate tiles for ArmSME operations. + // + // Later passes may create further ArmSME ops that implement the + // ArmSMETileOpInterface, but tiles are allocated for root operations, + // all of which should now exist. + pm.addPass(arm_sme::createTileAllocationPass()); + + // Enable streaming-mode and ZA. + pm.addPass(arm_sme::createEnableArmStreamingPass( + arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, + /*onlyIfRequiredByOps=*/true)); + + // Convert ArmSME to LLVM. + pm.addPass(createConvertArmSMEToLLVMPass()); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestLowerToArmSME() { + PassPipelineRegistration<TestLowerToArmSMEOptions>( + "test-lower-to-arm-sme", + "An example pipeline to lower operations on vectors (arith, vector) to " + "LLVM via ArmSME.", + buildTestLowerToArmSME); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index 30a17c2..e20cd44 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(Arith) +add_subdirectory(ArmSME) add_subdirectory(Bufferization) add_subdirectory(ControlFlow) add_subdirectory(DLTI) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 108cfe8..bde4255 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1152,8 +1152,10 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; - if (failed(applyPartialConversion( - getOperation(), target, std::move(patterns), &unlegalizedOps))) { + ConversionConfig config; + config.unlegalizedOps = &unlegalizedOps; + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns), config))) { getOperation()->emitRemark() << "applyPartialConversion failed"; } // Emit remarks for each legalizable operation. @@ -1181,8 +1183,10 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet<Operation *> legalizedOps; + ConversionConfig config; + config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, - std::move(patterns), legalizedOps))) + std::move(patterns), config))) return signalPassFailure(); // Emit remarks for each legalizable operation. @@ -1806,8 +1810,10 @@ struct TestMergeBlocksPatternDriver }); DenseSet<Operation *> unlegalizedOps; + ConversionConfig config; + config.unlegalizedOps = &unlegalizedOps; (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + config); for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; } diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 68aa6ba..701fc46 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestFuncToLLVM MLIRAffineTransformsTestPasses MLIRArithTestPasses + MLIRArmSMETestPasses MLIRBufferizationTestPasses MLIRControlFlowTestPasses MLIRDLTITestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index f11c6b4..4dfa05c 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -109,6 +109,7 @@ void registerTestLoopFusion(); void registerTestCFGLoopInfoPass(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestLowerToArmSME(); void registerTestLowerToLLVM(); void registerTestMakeIsolatedFromAbovePass(); void registerTestMatchReductionPass(); @@ -233,6 +234,7 @@ void registerTestPasses() { mlir::test::registerTestCFGLoopInfoPass(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); + mlir::test::registerTestLowerToArmSME(); mlir::test::registerTestLowerToLLVM(); mlir::test::registerTestMakeIsolatedFromAbovePass(); mlir::test::registerTestMatchReductionPass(); diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index 3a6bcbd..9d2f690 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -77,7 +77,7 @@ protected: } // Inserts an Integer or a Vector of Integers constant of value 'val'. - spirv::ConstantOp AddConstInt(Type type, const APInt &val) { + spirv::ConstantOp addConstInt(Type type, const APInt &val) { OpBuilder builder(module->getRegion()); auto loc = UnknownLoc::get(&context); @@ -181,8 +181,8 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) { APInt signedIntConstVal(signedInt16Type.getWidth(), -1, signedInt16Type.getSignedness()); - AddConstInt(signlessInt16Type, signlessIntConstVal); - AddConstInt(signedInt16Type, signedIntConstVal); + addConstInt(signlessInt16Type, signlessIntConstVal); + addConstInt(signedInt16Type, signedIntConstVal); ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index 2e1309a..16de34c 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -421,7 +421,7 @@ TEST(InterfaceAttachmentTest, PromisedInterfaces) { // Attribute interfaces use the exact same mechanism as types, so just check // that the promise mechanism works for attributes. MLIRContext context; - auto testDialect = context.getOrLoadDialect<test::TestDialect>(); + auto *testDialect = context.getOrLoadDialect<test::TestDialect>(); auto attr = test::SimpleAAttr::get(&context); // `SimpleAAttr` doesn't implement nor promises the diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 8a4f67b..9d75615 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -295,9 +295,9 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) { MLIRContext context; context.getOrLoadDialect<test::TestDialect>(); - auto op1 = createOp(&context); + auto *op1 = createOp(&context); // `op1` has an unknown loc. - auto op2 = createOp(&context); + auto *op2 = createOp(&context); op2->setLoc(NameLoc::get(StringAttr::get(&context, "foo"))); auto getHash = [](Operation *op, OperationEquivalence::Flags flags) { return OperationEquivalence::computeHash( diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index a00ebba..26bfbd5 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -37,7 +37,7 @@ using namespace mlir; class MLIRTargetLLVMNVVM : public ::testing::Test { protected: - virtual void SetUp() { + void SetUp() override { registerBuiltinDialectTranslation(registry); registerLLVMDialectTranslation(registry); registerGPUDialectTranslation(registry); @@ -85,7 +85,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMMToLLVM)) { serializer.serializeToObject(gpuModule, options); // Check that the serializer was successful. ASSERT_TRUE(object != std::nullopt); - ASSERT_TRUE(object->size() > 0); + ASSERT_TRUE(!object->empty()); // Read the serialized module. llvm::MemoryBufferRef buffer(StringRef(object->data(), object->size()), @@ -121,7 +121,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) { serializer.serializeToObject(gpuModule, options); // Check that the serializer was successful. ASSERT_TRUE(object != std::nullopt); - ASSERT_TRUE(object->size() > 0); + ASSERT_TRUE(!object->empty()); ASSERT_TRUE( StringRef(object->data(), object->size()).contains("nvvm_kernel")); @@ -151,6 +151,6 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) { serializer.serializeToObject(gpuModule, options); // Check that the serializer was successful. ASSERT_TRUE(object != std::nullopt); - ASSERT_TRUE(object->size() > 0); + ASSERT_TRUE(!object->empty()); } } |