diff options
Diffstat (limited to 'mlir')
127 files changed, 1562 insertions, 2079 deletions
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md new file mode 100644 index 0000000..7c852ef --- /dev/null +++ b/mlir/Maintainers.md @@ -0,0 +1,63 @@ +# MLIR Maintainers + +This file is a list of the +[maintainers](https://llvm.org/docs/DeveloperPolicy.html#maintainers) for MLIR. + +The following people are the active maintainers for the project. For the sake of +simplicity, responsibility areas are subdivided into broad categories, which are +further subdivided into individual components, such as dialects. Please reach +out to them for code reviews, questions about their area of expertise, or other +assistance. + +## Core + +Core components of MLIR, including core IR, analyses and rewriters, fundamental +dialects, build system and language bindings. + +- Alex Zinenko \ + ftynse@gmail.com (email), + [@ftynse](https://github.com/ftynse) (GitHub), + ftynse (Discourse) +- Jacques Pienaar \ + jpienaar@google.com (email), + [@jpienaar](https://github.com/jpienaar) (GitHub), + jpienaar (Discourse) +- Mehdi Amini \ + joker.eph@gmail.com (email), + [@joker-eph](https://github.com/joker-eph) (GitHub), + mehdi_amini (Discourse) + +## Egress + +MLIR components pertaining to egress flows from MLIR, in particular to LLVM IR. + +- Matthias Springer \ + me@m-sp.org (email), + [@matthias-springer](https://github.com/matthias-springer) (GitHub), + matthias-springer (Discourse) +- Andrzej Warzynski \ + andrzej.warzynski@arm.com (email), + [@banach-space](https://github.com/banach-space) (GitHub), + banach-space (Discourse) +- Tobias Gysi \ + tobias.gysi@nextsilicon.com (email), + [@gysit](https://github.com/gysit) (GitHub), + gysit (Discourse) + +## Tensor Compiler + +MLIR components specific to construction of compilers for tensor algebra, in +particular for machine learning compilers. + +- Renato Golin \ + rengolin@gmail.com (email), + [@rengolin](https://github.com/rengolin) (GitHub), + rengolin (Discourse) +- Jacques Pienaar \ + jpienaar@google.com (email), + [@jpienaar](https://github.com/jpienaar) (GitHub), + jpienaar (Discourse) +- Andrzej Warzynski \ + andrzej.warzynski@arm.com (email), + [@banach-space](https://github.com/banach-space) (GitHub), + banach-space (Discourse) diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md index ebeb0a2..6c8949d 100644 --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -294,7 +294,7 @@ LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g. `llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR operations are prefixed by the `vector.` dialect prefix (e.g. -`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector` +`vector.insert`). Such ops operate exclusively on MLIR `n-D` `vector` types. ### Alternatives For Lowering an n-D Vector Type to LLVM diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index e2288f5..6d09e93 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -18,6 +18,8 @@ The following convention is followed: GCC or Clang. * If `emitc.array` with a dimension of size zero is used, then the code requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html). +* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17 + or C11 is required. * Else the generated code is compatible with C99. These restrictions are neither inherent to the EmitC dialect itself nor to the diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index 1bba269..e9abe36 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -445,7 +445,7 @@ When processing an operation like described, we query if it registered the ```c++ // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; /// We check if an operation has a particular interface by casting. if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) { diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md index c750c07..17cd6bb 100644 --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -91,13 +91,11 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details. After the conversion target has been defined, we can define how to convert the *illegal* operations into *legal* ones. Similarly to the canonicalization framework introduced in [chapter 3](Ch-3.md), the -[`DialectConversion` framework](../../DialectConversion.md) also uses -[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic. -These patterns may be the `RewritePatterns` seen before or a new type of pattern -specific to the conversion framework `ConversionPattern`. `ConversionPatterns` +[`DialectConversion` framework](../../DialectConversion.md) uses a special kind +of `ConversionPattern` to perform the conversion logic. `ConversionPatterns` are different from traditional `RewritePatterns` in that they accept an -additional `operands` parameter containing operands that have been -remapped/replaced. This is used when dealing with type conversions, as the +additional `operands` (or `adaptor`) parameter containing operands that have +been remapped/replaced. This is used when dealing with type conversions, as the pattern will want to operate on values of the new type but match against the old. For our lowering, this invariant will be useful as it translates from the [TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being @@ -106,38 +104,23 @@ look at a snippet of lowering the `toy.transpose` operation: ```c++ /// Lower the `toy.transpose` operation to an affine loop nest. -struct TransposeOpLowering : public mlir::ConversionPattern { - TransposeOpLowering(mlir::MLIRContext *ctx) - : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {} - - /// Match and rewrite the given `toy.transpose` operation, with the given - /// operands that have been remapped from `tensor<...>` to `memref<...>`. - llvm::LogicalResult - matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands, - mlir::ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; - // Call to a helper function that will lower the current operation to a set - // of affine loops. We provide a functor that operates on the remapped - // operands, as well as the loop induction variables for the inner most - // loop body. - lowerOpToLoops( - op, operands, rewriter, - [loc](mlir::PatternRewriter &rewriter, - ArrayRef<mlir::Value> memRefOperands, - ArrayRef<mlir::Value> loopIvs) { - // Generate an adaptor for the remapped operands of the TransposeOp. - // This allows for using the nice named accessors that are generated - // by the ODS. This adaptor is automatically provided by the ODS - // framework. - TransposeOpAdaptor transposeAdaptor(memRefOperands); - mlir::Value input = transposeAdaptor.input(); - - // Transpose the elements by generating a load from the reverse - // indices. - SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return mlir::AffineLoadOp::create(rewriter, loc, input, reverseIvs); - }); + LogicalResult + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, rewriter, + [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, + reverseIvs); + }); return success(); } }; diff --git a/mlir/docs/Tutorials/transform/Ch0.md b/mlir/docs/Tutorials/transform/Ch0.md index ac3989a..dc4b753 100644 --- a/mlir/docs/Tutorials/transform/Ch0.md +++ b/mlir/docs/Tutorials/transform/Ch0.md @@ -46,7 +46,7 @@ When no support is available, such an operation can be transformed into a loop: %c8 = arith.constant 8 : index %init = arith.constant 0.0 : f32 %result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) { - %element = vector.extractelement %0[%i : index] : vector<8xf32> + %element = vector.extract %0[%i] : f32 into vector<8xf32> %updated = arith.addf %partial, %element : f32 scf.yield %updated : f32 } @@ -145,7 +145,7 @@ linalg.generic { %c0 = arith.constant 0.0 : f32 %0 = arith.cmpf ogt %in_one, %c0 : f32 %1 = arith.select %0, %in_one, %c0 : f32 - linalg.yield %1 : f32 + linalg.yield %1 : f32 } ``` @@ -185,7 +185,7 @@ In the case of `linalg.generic` operations, the iteration space is implicit and For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `linalg.generic` expressing the same operation on a `2x8` tensor. ```mlir -// A special "multi-for" loop that supports tensor-insertion semantics +// A special "multi-for" loop that supports tensor-insertion semantics // as opposed to implicit updates. The resulting 8x16 tensor will be produced // by this loop. // The trip count of iterators is computed dividing the original tensor size, @@ -202,9 +202,9 @@ For example, tiling the matrix multiplication presented above with tile sizes `( // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced. %lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1] : tensor<8x10xf32> to tensor<2x10xf32> - %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] + %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] : tensor<10x16xf32> to tensor<10x8xf32> - %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] + %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] : tensor<8x16xf32> to tensor<2x8xf32> // This is exactly the same operation as before, but now operating on smaller @@ -214,7 +214,7 @@ For example, tiling the matrix multiplication presented above with tile sizes `( affine_map<(i, j, k) -> (k, j)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) + } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> { ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): %0 = arith.mulf %lhs_one, %rhs_one : f32 @@ -238,15 +238,15 @@ After materializing loops with tiling, another key code generation transformatio 1. the subset (slice) of the operand that is used by the tile, and 2. the tensor-level structured operation producing the whole tensor that is being sliced. -By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. +By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety. ```mlir // Same loop as before. -%0 = scf.forall (%i, %j) in (4, 2) - shared_outs(%shared = %init) +%0 = scf.forall (%i, %j) in (4, 2) + shared_outs(%shared = %init) -> (tensor<8x16xf32>, tensor<8x16xf32>) { // Scale the loop induction variables by the tile sizes. %1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i) @@ -286,7 +286,7 @@ Let us assume that the matrix multiplication operation is followed by another op indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"] - } ins(%partial : tensor<2x8xf32>) + } ins(%partial : tensor<2x8xf32>) outs(%shared_slice : tensor<2x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.mulf %in, %in : f32 diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index d65c89c..2969d3a 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, ivs); }); @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return LoweredBinaryOp::create(builder, loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return affine::AffineLoadOp::create(builder, loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index d65c89c..2969d3a 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, ivs); }); @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return LoweredBinaryOp::create(builder, loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return affine::AffineLoadOp::create(builder, loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index e0950ef..987dfa1 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -55,19 +55,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns +// ToyToLLVM Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public ConversionPattern { +class PrintOpLowering : public OpConversionPattern<toy::PrintOp> { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern<toy::PrintOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin())); @@ -108,9 +107,8 @@ public: } // Generate a call to printf for the current element of the loop. - auto printOp = cast<toy::PrintOp>(op); auto elementLoad = - memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, ArrayRef<Value>({formatSpecifierCst, elementLoad})); diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index d65c89c..cbe4236 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, + // Call the processing function with the rewriter // and the loop induction variables. This function will return the value // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); + Value valueToStore = processIteration(nestedBuilder, ivs); affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, ivs); }); @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return LoweredBinaryOp::create(builder, loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return affine::AffineLoadOp::create(builder, loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 43a84da..8b48a8f 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -55,19 +55,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns +// ToyToLLVM Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public ConversionPattern { +class PrintOpLowering : public OpConversionPattern<toy::PrintOp> { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern<toy::PrintOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin())); @@ -108,9 +107,8 @@ public: } // Generate a call to printf for the current element of the loop. - auto printOp = cast<toy::PrintOp>(op); auto elementLoad = - memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, ArrayRef<Value>({formatSpecifierCst, elementLoad})); diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index fa0ffc9..2159483 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -13,11 +13,9 @@ #include "MyExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" -#define DEBUG_TYPE_MATCHER "transform-matcher" -#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") -#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) +#define DEBUG_TYPE "transform-matcher" #define GET_OP_CLASSES #include "MyExtension.cpp.inc" @@ -124,9 +122,8 @@ mlir::transform::HasOperandSatisfyingOp::apply( // Report failure-to-match for debugging purposes and stop matching this // operand. assert(diag.isSilenceableFailure()); - DEBUG_MATCHER(DBGS_MATCHER() - << "failed to match operand #" << operand.getOperandNumber() - << ": " << diag.getMessage()); + LDBG() << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage(); (void)diag.silence(); matchSucceeded = false; break; diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 364a70c..b595b6a3 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -8,6 +8,11 @@ #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H +constexpr const char *alignedAllocFunctionName = "aligned_alloc"; +constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *cppStandardLibraryHeader = "cstdlib"; +constexpr const char *cStandardLibraryHeader = "stdlib.h"; + namespace mlir { class DialectRegistry; class RewritePatternSet; diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index eb18160..cf7596c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { // MemRefToEmitC //===----------------------------------------------------------------------===// -def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { +def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> { let summary = "Convert MemRef dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "lowerToCpp", "lower-to-cpp", "bool", + /*default=*/"false", + /*description=*/"Target C++ (true) instead of C (false)">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 2cf801d..09700f8 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -14,7 +14,7 @@ struct LogicalResult; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; namespace bufferization { struct BufferizationStatistics; @@ -23,12 +23,13 @@ struct OneShotBufferizationOptions; class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in -/// `state`. +/// `state`. This operates on any `SymbolTable` op. llvm::LogicalResult -analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, +analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics = nullptr); -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`. +/// This operates on any `SymbolTable` op. /// /// Note: This function does not run One-Shot Analysis. No buffer copies are /// inserted except two cases: @@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter` /// is not empty. The FuncOps it contains were not analyzed. Buffer copies /// will be inserted only to these FuncOps. -llvm::LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, - BufferizationStatistics *statistics = nullptr); +llvm::LogicalResult bufferizeModuleOp( + Operation *moduleOp, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); -/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. -void removeBufferizationAttributesInModule(ModuleOp moduleOp); +/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable +/// op. +void removeBufferizationAttributesInModule(Operation *moduleOp); -/// Run One-Shot Module Bufferization on the given module. Performs a simple -/// function call analysis to determine which function arguments are +/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a +/// simple function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot /// Bufferize. llvm::LogicalResult runOneShotModuleBufferize( - ModuleOp moduleOp, + Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics = nullptr); diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 1dbaf5d..2ed7d38 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1368,12 +1368,14 @@ def GPU_ShuffleOp : GPU_Op< def GPU_RotateOp : GPU_Op< "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>, - Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, + ConfinedAttr<I32Attr, [IntMinValue<0>]>:$offset, + ConfinedAttr<I32Attr, [IntPowerOf2]>:$width)>, Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> { let summary = "Rotate values within a subgroup."; let description = [{ The "rotate" op moves values across lanes in a subgroup (a.k.a., local - invocations) within the same subgroup. The `width` argument specifies the + invocations) within the same subgroup. The `width` attribute specifies the number of lanes that participate in the rotation, and must be uniform across all participating lanes. Further, the first `width` lanes of the subgroup must be active. @@ -1394,9 +1396,7 @@ def GPU_RotateOp : GPU_Op< example: ```mlir - %offset = arith.constant 1 : i32 - %width = arith.constant 16 : i32 - %1, %2 = gpu.rotate %0, %offset, %width : f32 + %1, %2 = gpu.rotate %0, 1, 16 : f32 ``` For lane `k`, returns the value from lane `(k + cst1) % width`. @@ -1406,11 +1406,6 @@ def GPU_RotateOp : GPU_Op< $value `,` $offset `,` $width attr-dict `:` type($value) }]; - let builders = [ - // Helper function that creates a rotate with constant offset/width. - OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)> - ]; - let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index fa57202..f36b41c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ result tensor in the order in which they appear, i.e. `shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`. - The following relationship for the tiled dimensions holds: - `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`. + `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`, + where (⌈/⌉ indicates CeilDiv). + Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of `...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled @@ -150,9 +152,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ `padding_value` specifies a padding value at the boundary on non-perfectly divisible dimensions. Padding is optional: - - If absent, it is UB if the tile does not perfectly divide the dimension. + - If absent, it is assumed that for all inner tiles, + `shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner + tiles divide perfectly the corresponding outer dimension in the result + tensor. It is UB if the tile does not perfectly divide the dimension. - If present, it will pad along high dimensions (high-padding) to make the - tile complete. + tile complete. Note that it is not allowed to have artificial padding that + is not strictly required by linalg.pack (i.e., padding past what is needed + to complete the last tile along each packed dimension). It is UB if extra + padding is requested. + It is not possible to verify the requirements statically with dynamic + shapes, so they are treated as UB. Example: ```mlir @@ -167,6 +177,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // // Note: Only tiled dimensions can be padded. ``` + + Invalid example that has artificial padding: + ```mlir + %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %dest + : tensor<9xf32> -> tensor<3x8xf32> + // \ + // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2 + ``` }]; let arguments = (ins AnyRankedTensor:$source, AnyRankedTensor:$dest, diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h index 5430fd9..c0c6085 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h @@ -119,6 +119,12 @@ public: Status status; }; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CopyMappingInfo &info) { + info.print(os); + return os; +} + } // namespace gpu } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td index 704d0b2..72ce4c6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td @@ -23,6 +23,21 @@ class OpenMP_Attr<string name, string attrMnemonic, list<Trait> traits = [], } //===----------------------------------------------------------------------===// +// AtomicControlAttr +//===----------------------------------------------------------------------===// + +// Atomic control attributes hold information about architectural +// characteristics which are required for lowering atomic operations. +def AtomicControlAttr : OpenMP_Attr<"AtomicControl", "atomic_control"> { + let parameters = + (ins DefaultValuedParameter<"bool", "false">:$ignore_denormal_mode, + DefaultValuedParameter<"bool", "false">:$fine_grained_memory, + DefaultValuedParameter<"bool", "false">:$remote_memory); + + let assemblyFormat = "`<` struct(params) `>`"; +} + +//===----------------------------------------------------------------------===// // DeclareTargetAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8cf18b4..be114ea 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1750,9 +1750,11 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", traits = [ operations. }] # clausesDescription; - let arguments = !con((ins Arg<OpenMP_PointerLikeType, - "Address of variable to be updated", - [MemRead, MemWrite]>:$x), clausesArgs); + let arguments = !con( + (ins Arg<OpenMP_PointerLikeType, + "Address of variable to be updated", [MemRead, MemWrite]>:$x, + OptionalAttr<AtomicControlAttr>:$atomic_control), + clausesArgs); // Override region definition. let regions = (region SizedRegion<1>:$region); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index a534381b..2513e10 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -380,7 +380,7 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> { After: %3 = memref.load %2[] : memref<f32> - %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32> + %4 = vector.insert %3, %cst [0] : f32 into vector<32xf32> %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) { %8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32> %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32> diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 349e8ed..754640d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -151,7 +151,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> : def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>; -def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; +def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>; def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>; // We include unranked tensors as a supported type for all possible tosa diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 0a5c1e5..3885439 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -646,55 +646,6 @@ def Vector_DeinterleaveOp : }]; } -def Vector_ExtractElementOp : - Vector_Op<"extractelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"result type matches element type of vector operand", - "vector", "result", - "::llvm::cast<VectorType>($_self).getElementType()">]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyType:$result)> { - let summary = "extractelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.extract insert. - - Takes a 0-D or 1-D vector and a optional dynamic index position and - extracts the scalar at that position. - - Note that this instruction resembles vector.extract, but is restricted to - 0-D and 1-D vectors. - If the vector is 0-D, the position must be std::nullopt. - - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#extractelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %1 = vector.extractelement %0[%c : i32]: vector<16xf32> - %2 = vector.extractelement %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source)>, - ]; - let extraClassDeclaration = [{ - VectorType getSourceVectorType() { - return ::llvm::cast<VectorType>(getVector().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_ExtractOp : Vector_Op<"extract", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, @@ -890,57 +841,6 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [ let hasCanonicalizer = 1; } -def Vector_InsertElementOp : - Vector_Op<"insertelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"source operand type matches element type of result", - "result", "source", - "::llvm::cast<VectorType>($_self).getElementType()">, - AllTypesMatch<["dest", "result"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyVectorOfAnyRank:$result)> { - let summary = "insertelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.insert instead. - - Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index - position and inserts the source into the destination at the proper position. - - Note that this instruction resembles vector.insert, but is restricted to 0-D - and 1-D vectors. - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#insertelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %f = arith.constant 0.0f : f32 - %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> - %2 = vector.insertelement %f, %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:` - type($result) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source, "Value":$dest)>, - ]; - let extraClassDeclaration = [{ - Type getSourceType() { return getSource().getType(); } - VectorType getDestVectorType() { - return ::llvm::cast<VectorType>(getDest().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_InsertOp : Vector_Op<"insert", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 73f6877..38c217f 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -397,8 +397,8 @@ def DotOp : AVX_LowOp<"dot", [Pure, ```mlir %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> + %1 = vector.extract %0[%i0] : f32 from vector<8xf32> + %2 = vector.extract %0[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 ``` }]; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b3608b4..b5a93a0 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -525,6 +525,11 @@ public: } /// This method erases an operation that is known to have no uses. + /// + /// If the current insertion point is before the erased operation, it is + /// adjusted to the following operation (or the end of the block). If the + /// current insertion point is within the erased operation, the insertion + /// point is left in an invalid state. virtual void eraseOp(Operation *op); /// This method erases all operations in a block. @@ -539,6 +544,9 @@ public: /// somewhere in the middle (or beginning) of the dest block, the source block /// must have no successors. Otherwise, the resulting IR would have /// unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues = {}); @@ -549,6 +557,9 @@ public: /// /// The source block must have no successors. Otherwise, the resulting IR /// would have unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void inlineBlockBefore(Block *source, Operation *op, ValueRange argValues = {}); @@ -558,6 +569,9 @@ public: /// /// The dest block must have no successors. Otherwise, the resulting IR would /// have unreachable operation. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {}); /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0..bbfa308 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Returns true if this symbol has nested visibility.", "bool", "isNested", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Nested; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Private; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private; }] >, InterfaceMethod<"Returns true if this symbol has public visibility.", "bool", "isPublic", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<"Sets the visibility of this symbol.", @@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Sets the visibility of this symbol to be nested.", "void", "setNested", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Nested); + $_op.setVisibility(mlir::SymbolTable::Visibility::Nested); }] >, InterfaceMethod<"Sets the visibility of this symbol to be private.", "void", "setPrivate", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Private); + $_op.setVisibility(mlir::SymbolTable::Visibility::Private); }] >, InterfaceMethod<"Sets the visibility of this symbol to be public.", "void", "setPublic", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Public); + $_op.setVisibility(mlir::SymbolTable::Visibility::Public); }] >, InterfaceMethod<[{ @@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { // By default, base this on the visibility alone. A symbol can be // discarded as long as it is not public. Only public symbols may be // visible from outside of the IR. - return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<[{ diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp index 51fa773..fb5649e 100644 --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #define DEBUG_TYPE "constant-propagation" @@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const { LogicalResult SparseConstantPropagation::visitOperation( Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, ArrayRef<Lattice<ConstantValue> *> results) { - LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + LDBG() << "SCP: Visiting operation: " << *op; // Don't try to simulate the results of a region operation as we can't // guarantee that folding will be out-of-place. We don't allow in-place @@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation( // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = std::get<1>(it); if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { - LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + LDBG() << "Folded to constant: " << attr; propagateIfChanged(lattice, lattice->join(ConstantValue(attr, op->getDialect()))); } else { - LLVM_DEBUG(llvm::dbgs() - << "Folded to value: " << cast<Value>(foldResult) << "\n"); + LDBG() << "Folded to value: " << cast<Value>(foldResult); AbstractSparseForwardDataFlowAnalysis::join( lattice, *getLatticeElement(cast<Value>(foldResult))); } diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 197f97f..509f520 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -294,7 +294,7 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { solver.load<LivenessAnalysis>(symbolTable); LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG() << "Dumping liveness state for op"; + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 176d53e..16f7033 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -14,7 +14,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/Config/abi-breaking.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "dataflow" @@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent, (void)inserted; DATAFLOW_DEBUG({ if (inserted) { - llvm::dbgs() << "Creating dependency between " << debugName << " of " - << anchor << "\nand " << debugName << " on " << dependent - << "\n"; + LDBG() << "Creating dependency between " << debugName << " of " << anchor + << "\nand " << debugName << " on " << dependent; } }); } @@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { // Initialize the analyses. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { - DATAFLOW_DEBUG(llvm::dbgs() - << "Priming analysis: " << analysis.debugName << "\n"); + DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName); if (failed(analysis.initialize(top))) return failure(); } @@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { auto [point, analysis] = worklist.front(); worklist.pop(); - DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName - << "' on: " << point << "\n"); + DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName + << "' on: " << point); if (failed(analysis->visit(point))) return failure(); } @@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state, assert(isRunning && "DataFlowSolver is not running, should not use propagateIfChanged"); if (changed == ChangeResult::Change) { - DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->anchor << "\n" - << "Value: " << *state << "\n"); + DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName + << " of " << state->anchor << "\n" + << "Value: " << *state); state->onUpdate(this); } } diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 9f4a87a..8b14e71 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, nestedPunctuation.pop_back(); return success(); }; + const char *curBufferEnd = state.lex.getBufferEnd(); do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, break; } + if (curBufferEnd == curPtr) { + if (!nestedPunctuation.empty()) + return emitPunctError(); + return emitError("unexpected nul or EOF in pretty dialect name"); + } + char c = *curPtr++; switch (c) { case '\0': diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp index 751bd63..8f53529 100644 --- a/mlir/lib/AsmParser/Lexer.cpp +++ b/mlir/lib/AsmParser/Lexer.cpp @@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, AsmParserCodeCompleteContext *codeCompleteContext) : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); + + // Check to see if the main buffer contains the last buffer, and if so the + // last buffer should be used as main file for parsing. + if (sourceMgr.getNumBuffers() > 1) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + if (main->getBufferStart() <= last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + bufferID = lastFileID; + } + } curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); @@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) { } Token Lexer::lexToken() { + const char *curBufferEnd = curBuffer.end(); while (true) { const char *tokStart = curPtr; @@ -78,6 +91,9 @@ Token Lexer::lexToken() { if (tokStart == codeCompleteLoc) return formToken(Token::code_complete, tokStart); + if (tokStart == curBufferEnd) + return formToken(Token::eof, tokStart); + // Lex the next token. switch (*curPtr++) { default: @@ -102,7 +118,7 @@ Token Lexer::lexToken() { case 0: // This may either be a nul character in the source file or may be the EOF // marker that llvm::MemoryBuffer guarantees will be there. - if (curPtr - 1 == curBuffer.end()) + if (curPtr - 1 == curBufferEnd) return formToken(Token::eof, tokStart); continue; @@ -259,7 +275,11 @@ void Lexer::skipComment() { assert(*curPtr == '/'); ++curPtr; + const char *curBufferEnd = curBuffer.end(); while (true) { + if (curPtr == curBufferEnd) + return; + switch (*curPtr++) { case '\n': case '\r': @@ -267,7 +287,7 @@ void Lexer::skipComment() { return; case 0: // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { + if (curPtr - 1 == curBufferEnd) { --curPtr; return; } @@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); + const char *curBufferEnd = curBuffer.end(); while (true) { // Check to see if there is a code completion location within the string. In // these cases we generate a completion location and place the currently @@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) { case 0: // If this is a random nul character in the middle of a string, just // include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) + if (curPtr - 1 != curBufferEnd) continue; [[fallthrough]]; case '\n': diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h index 4085a9b..670444e 100644 --- a/mlir/lib/AsmParser/Lexer.h +++ b/mlir/lib/AsmParser/Lexer.h @@ -40,6 +40,9 @@ public: /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Returns the end of the buffer. + const char *getBufferEnd() { return curBuffer.end(); } + /// Return the code completion location of the lexer, or nullptr if there is /// none. const char *getCodeCompleteLoc() const { return codeCompleteLoc; } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 75e6563..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 4307bc6..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1070,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1206,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index b1af5f0..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 3c00b32..6265f46 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -15,13 +15,13 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::affine; #define DEBUG_TYPE "decompose-affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") /// Count the number of loops surrounding `operand` such that operand could be /// hoisted above. @@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, return rewriter.notifyMatchFailure( op, "only add or mul binary expr can be reassociated"); - LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n"); + LDBG() << "Start decomposeIntoFinerGrainedOps: " << op; // 2. Iteratively extract the RHS subexpressions while the top-level binary // expr kind remains the same. @@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp); if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { subExpressions.push_back(remainingExp); - LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n"); + LDBG() << "--terminal: " << subExpressions.back(); break; } subExpressions.push_back(currentBinExpr.getRHS()); - LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n"); + LDBG() << "--subExpr: " << subExpressions.back(); remainingExp = currentBinExpr.getLHS(); } @@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) { return getMaxSymbol(e1) < getMaxSymbol(e2); }); - LLVM_DEBUG( - llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: "); - llvm::dbgs() << "\n"); + LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions); // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. auto s0 = getAffineSymbolExpr(0, ctx); @@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, Value tmp = createSubApply(rewriter, op, subExpressions[i]); current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, ValueRange{current, tmp}); - LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); + LDBG() << "--reassociate into: " << current; } // 5. Replace original op. diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index 8493b60..2521512 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -19,11 +19,10 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/IntEqClasses.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "affine-min-max" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::affine; @@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { ValueRange operands = affineOp.getOperands(); static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>; - LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; }); + LDBG() << "analyzing value: `" << affineOp; // Create a `Variable` list with values corresponding to each of the results // in the affine affineMap. @@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { [&](unsigned i) { return Variable(affineMap.getSliceMap(i, 1), operands); }); - LLVM_DEBUG({ - DBGS() << "- constructed variables are: " - << llvm::interleaved_array(llvm::map_range( - variables, [](const Variable &v) { return v.getMap(); })) - << "`\n"; - }); + LDBG() << "- constructed variables are: " + << llvm::interleaved_array(llvm::map_range( + variables, [](const Variable &v) { return v.getMap(); })); // Get the comparison operation. ComparisonOperator cmpOp = @@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Initialize the bound. Variable *bound = &v; - LLVM_DEBUG({ - DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() - << "`\n"; - }); + LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() + << "`\n"; // Check against the other variables. for (size_t j = i + 1; j < variables.size(); ++j) { @@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Get the bound of the equivalence class or itself. Variable *nv = bounds.lookup_or(jEqClass, &variables[j]); - LLVM_DEBUG({ - DBGS() << "- comparing with variable: #" << jEqClass - << ", with map: " << nv->getMap() << "\n"; - }); + LDBG() << "- comparing with variable: #" << jEqClass + << ", with map: " << nv->getMap(); // Compare the variables. FailureOr<bool> cmpResult = @@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // The variables cannot be compared. if (failed(cmpResult)) { - LLVM_DEBUG({ - DBGS() << "-- classes: #" << i << ", #" << jEqClass - << " cannot be merged\n"; - }); + LDBG() << "-- classes: #" << i << ", #" << jEqClass + << " cannot be merged"; continue; } // Join the equivalent classes and update the bound if necessary. - LLVM_DEBUG({ - DBGS() << "-- merging classes: #" << i << ", #" << jEqClass - << ", is cmp(lhs, rhs): " << *cmpResult << "`\n"; - }); + LDBG() << "-- merging classes: #" << i << ", #" << jEqClass + << ", is cmp(lhs, rhs): " << *cmpResult << "`"; if (*cmpResult) { boundedClasses.join(eqClass, jEqClass); } else { @@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Return if there's no simplification. if (bounds.size() >= affineMap.getNumResults()) { - LLVM_DEBUG( - { DBGS() << "- the affine operation couldn't get simplified\n"; }); + LDBG() << "- the affine operation couldn't get simplified"; return false; } @@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { for (auto [k, bound] : bounds) results.push_back(bound->getMap().getResult(0)); - LLVM_DEBUG({ - DBGS() << "- starting from map: " << affineMap << "\n"; - DBGS() << "- creating new map with: \n"; - DBGS() << "--- dims: " << affineMap.getNumDims() << "\n"; - DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n"; - DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n"; - }); + LDBG() << "- starting from map: " << affineMap; + LDBG() << "- creating new map with:"; + LDBG() << "--- dims: " << affineMap.getNumDims(); + LDBG() << "--- syms: " << affineMap.getNumSymbols(); + LDBG() << "--- res: " << llvm::interleaved_array(results); affineMap = AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(), @@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Update the affine op. rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); - LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; }); + LDBG() << "- simplified affine op: `" << affineOp << "`"; return true; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d1d1062..aa53f94 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,12 +9,13 @@ // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. +// implementations for FuncOp, CallOp and ReturnOp. Although it is named +// Module Bufferization, it may operate on any SymbolTable. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis -// and bufferization: Functions that are called are processed before their -// respective callers. +// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp, +// ...)`. This function analyzes the given op and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. @@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, + Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, +mlir::bufferization::analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, } void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps<func::FuncOp>()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); + Operation *moduleOp) { + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(moduleOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index f999c93..a6159ee 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics))) + if (failed(analyzeModuleOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d186a48..5a72ef1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, // RotateOp //===----------------------------------------------------------------------===// -void RotateOp::build(OpBuilder &builder, OperationState &result, Value value, - int32_t offset, int32_t width) { - build(builder, result, value, - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(offset)), - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(width))); -} - LogicalResult RotateOp::verify() { - auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>(); - if (!offsetConstOp) - return emitOpError() << "offset is not a constant value"; - - auto offsetIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue()); - - auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>(); - if (!widthConstOp) - return emitOpError() << "width is not a constant value"; - - auto widthIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue()); - - llvm::APInt offsetValue = offsetIntAttr.getValue(); - llvm::APInt widthValue = widthIntAttr.getValue(); - - if (!widthValue.isPowerOf2()) - return emitOpError() << "width must be a power of two"; + uint32_t offset = getOffset(); + uint32_t width = getWidth(); - if (offsetValue.sge(widthValue) || offsetValue.slt(0)) { - int64_t widthValueInt = widthValue.getSExtValue(); - return emitOpError() << "offset must be in the range [0, " << widthValueInt - << ")"; + if (offset >= width) { + return emitOpError() << "offset must be in the range [0, " << width << ")"; } return success(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 27b6617..b56a212 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, }); } -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool areAllInBound(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; - }); -} - template <typename OpTy> static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, @@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // represents full tiles. RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } + if (failed(verifyCompatibleShape(expectedPackedType.getShape(), + packedType.getShape()))) { + return op->emitError("expected ") + << expectedPackedType << " for the packed domain value, got " + << packedType; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp index c926dfb..5c8c2de 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" @@ -21,7 +22,6 @@ using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") static Attribute linearId0(MLIRContext *ctx) { return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0); @@ -81,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, this->threadMapping = llvm::to_vector(ArrayRef(allThreadMappings) .take_back(this->smallestBoundingTileSizes.size())); - LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n"); + LDBG() << *this; } int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 2fe72a3..d4a3e5f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -15,14 +15,13 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InterleavedRange.h" using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") //===----------------------------------------------------------------------===// // StructuredMatchOp @@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( return emitSilenceableError() << "expected a Linalg op"; } // If errors are suppressed, succeed and set all results to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); + LDBG() << "optional nested matcher expected a Linalg op"; results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); return DiagnosedSilenceableFailure::success(); } @@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( // When they are defined in this block, we additionally check if we have // already applied the operation that defines them. If not, the // corresponding results will be set to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() - << "\n"); + LDBG() << "optional nested matcher failed: " << diag.getMessage(); (void)diag.silence(); SmallVector<OpOperand *> undefinedOperands; for (OpOperand &terminatorOperand : diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 277e50b..9d7f4e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index dad3526..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0170837..793eec7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1913,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), sourceShape.end()); - ReifiedRankedShapedTypeDims reifiedRetShapes; - LogicalResult status = - cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) - .reifyResultShapes(rewriter, reifiedRetShapes); - if (status.failed()) { - LDBG() << "Unable to reify result shapes of " << unpackOp; - return failure(); - } Location loc = unpackOp->getLoc(); auto padValue = arith::ConstantOp::create( diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 106c3b4..cce80db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), sourceOp.getMixedStrides(), op.getMixedSizes())) { - // We only support static sizes. - if (isa<Value>(opSize)) { - return failure(); - } sizes.push_back(opSize); Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), sourceOffsetAttr = diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 0e96b59..869d27a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -115,8 +115,7 @@ public: bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, + if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ecd93ff..3cafb19 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } +static void printInitializationList(OpAsmPrinter &parser, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + parser << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), parser, + [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); + parser << ")"; +} + // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. @@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); - auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + + if (parser.parseOperand(cond)) return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) + + SmallVector<OpAsmParser::Argument, 4> regionArgs; + SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; + + // Parse the optional block arguments + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.has_value() && failed(listResult.value())) return failure(); + + // Parse a colon. + if (failed(parser.parseColon())) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Parse the type of the condition operand + Type condType; + if (failed(parser.parseType(condType))) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Resolve operand with provided type + if (failed(parser.resolveOperand(cond, condType, result.operands))) + return failure(); + + // Parse optional block arg types + if (listResult.has_value()) { + FunctionType functionType; + + if (failed(parser.parseType(functionType))) + return parser.emitError(parser.getCurrentLocation()) + << "expected list of types for block arguments " + << "followed by arrow type and list of return types"; + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + } else { + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } + // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; + + printInitializationList(p, getThenGraph().front().getArguments(), + getInputList(), " "); + p << " : "; + p << getCondition().getType(); + + if (!getInputList().empty()) { + p << " ("; + llvm::interleaveComma(getInputList().getTypes(), p); + p << ")"; } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printArrowTypeList(getResultTypes()); + p << " "; + + p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOptionalAttrDictWithKeyword(result.attributes)); } -static void printInitializationList(OpAsmPrinter &parser, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - parser << prefix << '('; - llvm::interleaveComma( - llvm::zip(blocksArgs, initializers), parser, - [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); - parser << ")"; -} - void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 32b5fb6..8ec7765 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) { // }) // // Simplified: - // %0 = tosa.cond_if %arg2 { - // tosa.yield %arg0 + // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 // } else { - // tosa.yield %arg1 + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 // } - // - // Unfortunately, the simplified syntax does not encapsulate values - // used in then/else regions (see 'simplified' example above), so it - // must be rewritten to use the generic syntax in order to be conformant - // to the specification. + return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index c0d20d4..14a4fdf 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -21,17 +21,8 @@ #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "transform-dialect" -#define DEBUG_TYPE_FULL "transform-dialect-full" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") -#ifndef NDEBUG -#define FULL_LDBG(X) \ - DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL) -#else -#define FULL_LDBG(X) \ - for (bool _c = false; _c; _c = false) \ - ::llvm::nulls() -#endif +#define FULL_LDBG() LDBG(4) using namespace mlir; @@ -818,16 +809,14 @@ void transform::TransformState::compactOpHandles() { DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { - LLVM_DEBUG({ - DBGS() << "applying: "; - transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"; - }); + LDBG() << "applying: " + << OpWithFlags(transform, OpPrintingFlags().skipRegions()); FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel(); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; - LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( - llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); + LDBG() << "Failing Top-level payload:\n" + << OpWithFlags(getTopLevel(), + OpPrintingFlags().printGenericOpForm()); }); // Set current transform op. @@ -995,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { - DBGS() << "Top-level payload:\n"; - getTopLevel()->print(llvm::dbgs()); + LDBG() << "Top-level payload:\n" << *getTopLevel(); }); return result; } @@ -1273,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - DBGS() << "Match Failure : " << diag.str(); + LDBG() << "Match Failure : " << diag.str(); }); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bce358d..8789f55 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5c98417..9332f55 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present<Listener>(listener); + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0db9808..7094c8e 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; - pipelineKey = pipelineInitializationKey; + pipelineInitializationKey = pipelineKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp index 01ad910..304253c 100644 --- a/mlir/lib/Support/TypeID.cpp +++ b/mlir/lib/Support/TypeID.cpp @@ -27,9 +27,6 @@ namespace { struct ImplicitTypeIDRegistry { /// Lookup or insert a TypeID for the given type name. TypeID lookupOrInsert(StringRef typeName) { - LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert(" - << typeName << ")\n"); - // Perform a heuristic check to see if this type is in an anonymous // namespace. String equality is not valid for anonymous types, so we try to // abort whenever we see them. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 58e5353..a8a2b2e 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type, LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector<StringRef> &serializationCtx) { + + // Map unsigned integer types to singless integer types. + // This is needed otherwise the generated spirv assembly will contain + // twice a type declaration (like OpTypeInt 32 0) which is no permitted and + // such module fails validation. Indeed at MLIR level the two types are + // different and lookup in the cache below misses. + // Note: This conversion needs to happen here before the type is looked up in + // the cache. + if (type.isUnsignedInteger()) { + type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + typeID = getTypeID(type); if (typeID) return success(); diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 5650de2..4ccb83f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -60,7 +60,6 @@ #include <vector> #define DEBUG_TYPE "remove-dead-values" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") namespace mlir { #define GEN_PASS_DEF_REMOVEDEADVALUES diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 7502dc6..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -508,9 +509,11 @@ private: class MoveBlockRewrite : public BlockRewrite { public: MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Region *region, Block *insertBeforeBlock) - : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), - insertBeforeBlock(insertBeforeBlock) {} + Region *previousRegion, Region::iterator previousIt) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(previousRegion), + insertBeforeBlock(previousIt == previousRegion->end() ? nullptr + : &*previousIt) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveBlock; @@ -617,9 +620,12 @@ protected: class MoveOperationRewrite : public OperationRewrite { public: MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, Block *block, Operation *insertBeforeOp) - : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), - insertBeforeOp(insertBeforeOp) {} + Operation *op, OpBuilder::InsertPoint previous) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(previous.getBlock()), + insertBeforeOp(previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint()) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveOperation; @@ -1588,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + // If no previous insertion point is provided, the op used to be detached. + bool wasDetached = !previous.isSet(); LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; + logger.startLine() << "** Insert : '" << op->getName() << "' (" << op + << ")"; + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); assert(!wasOpReplaced(op->getParentOp()) && "attempting to insert into a block within a replaced/erased op"); - if (!previous.isSet()) { - // This is a newly created op. + if (wasDetached) { + // If the op was detached, it is most likely a newly created op. + // TODO: If the same op is inserted multiple times from a detached state, + // the rollback mechanism may erase the same op multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateOperationRewrite>(op); patternNewOps.insert(op); return; } - Operation *prevOp = previous.getPoint() == previous.getBlock()->end() - ? nullptr - : &*previous.getPoint(); - appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); + + // The op was moved from one place to another. + appendRewrite<MoveOperationRewrite>(op, previous); } void ConversionPatternRewriterImpl::replaceOp( @@ -1669,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + // If no previous insertion point is provided, the block used to be detached. + bool wasDetached = !previous; + Operation *newParentOp = block->getParentOp(); LLVM_DEBUG( { - Operation *parent = block->getParentOp(); + Operation *parent = newParentOp; if (parent) { logger.startLine() << "** Insert Block into : '" << parent->getName() - << "'(" << parent << ")\n"; + << "' (" << parent << ")"; } else { logger.startLine() - << "** Insert Block into detached Region (nullptr parent op)'\n"; + << "** Insert Block into detached Region (nullptr parent op)"; } + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); + assert(!wasOpReplaced(newParentOp) && + "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); - if (!previous) { - // This is a newly created block. + if (wasDetached) { + // If the block was detached, it is most likely a newly created block. + // TODO: If the same block is inserted multiple times from a detached state, + // the rollback mechanism may erase the same block multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateBlockRewrite>(block); return; } - Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); + + // The block was moved from one place to another. + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1736,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1751,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1759,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1865,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -1996,6 +2043,7 @@ private: /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks); @@ -2193,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2222,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } @@ -2241,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *checkOp; + std::optional<OperationFingerPrint> topLevelFingerPrint; + if (!rewriterImpl.config.allowPatternRollback) { + // The op may be getting erased, so we have to check the parent op. + // (In rare cases, a pattern may even erase the parent op, which will cause + // a crash here. Expensive checks are "best effort".) Skip the check if the + // op does not have a parent op. + if ((checkOp = op->getParentOp())) { + if (!op->getContext()->isMultithreadingEnabled()) { + topLevelFingerPrint = OperationFingerPrint(checkOp); + } else { + // Another thread may be modifying a sibling operation. Therefore, the + // fingerprinting mechanism of the parent op works only in + // single-threaded mode. + LLVM_DEBUG({ + rewriterImpl.logger.startLine() + << "WARNING: Multi-threadeding is enabled. Some dialect " + "conversion expensive checks are skipped in multithreading " + "mode!\n"; + }); + } + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { bool canApply = canApplyPattern(op, pattern, rewriter); @@ -2253,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (!rewriterImpl.config.allowPatternRollback) { + // Returning "failure" after modifying IR is not allowed. + if (checkOp) { + OperationFingerPrint fingerPrintAfterPattern(checkOp); + if (fingerPrintAfterPattern != *topLevelFingerPrint) + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' returned failure but IR did change"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2281,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, newOps, + auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2324,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const SetVector<Operation *> &newOps, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { auto &impl = rewriter.getImpl(); @@ -2340,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult( return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error("expected pattern to replace the root operation"); + llvm::report_fatal_error( + "expected pattern to replace the root operation or modify it in place"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index ac8b44f5..89568e7 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -68,6 +68,7 @@ endif() llvm_canonicalize_cmake_booleans( LLVM_BUILD_EXAMPLES LLVM_HAS_NVPTX_TARGET + LLVM_INCLUDE_SPIRV_TOOLS_TESTS MLIR_ENABLE_BINDINGS_PYTHON MLIR_ENABLE_CUDA_RUNNER MLIR_ENABLE_ROCM_CONVERSIONS @@ -217,6 +218,11 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) ) endif() +if (LLVM_INCLUDE_SPIRV_TOOLS_TESTS) + list(APPEND MLIR_TEST_DEPENDS spirv-as) + list(APPEND MLIR_TEST_DEPENDS spirv-val) +endif() + # This target can be used to just build the dependencies # for the check-mlir target without executing the tests. # This is useful for bots when splitting the build step diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir index 00bbd1c..96ad107 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir @@ -85,11 +85,10 @@ module attributes { // CHECK: spirv.Load "StorageBuffer" %val = memref.load %arg0[%idx0] : memref<2xi32> // CHECK: spirv.CompositeInsert - %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32> + %vec = vector.insert %val, %vec0[%idx0] : i32 into vector<2xi32> // CHECK: spirv.VectorShuffle %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32> - // CHECK: spirv.CompositeExtract - %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> + %res = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> // CHECK: spirv.AccessChain // CHECK: spirv.Store "StorageBuffer" memref.store %res, %arg1[%idx0]: memref<4xi32> @@ -102,9 +101,9 @@ module attributes { // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32> // CHECK: arith.constant // CHECK: memref.load - // CHECK: vector.insertelement + // CHECK: vector.insert // CHECK: vector.shuffle - // CHECK: vector.extractelement + // CHECK: vector.extract // CHECK: memref.store // CHECK: gpu.return } diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index fb14feb..eb9feaa 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -51,108 +51,6 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME:(%[[S:.+]]: f32, -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32> // CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32> diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir index b96dd37..c71d220 100644 --- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir @@ -10,16 +10,14 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate() gpu.func @rotate() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %{{.+}} = spirv.Constant true - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 16 : f32 gpu.return } } @@ -38,18 +36,16 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size() gpu.func @rotate_width_less_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 8 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] // CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]] - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 8 : f32 gpu.return } } @@ -67,34 +63,10 @@ module attributes { gpu.module @kernels { gpu.func @rotate_with_bigger_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 32 : i32 %val = arith.constant 42.0 : f32 // expected-error @+1 {{failed to legalize operation 'gpu.rotate'}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 - gpu.return - } -} - -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, - #spirv.resource_limits<subgroup_size = 16>> -} { - -gpu.module @kernels { - gpu.func @rotate_non_const_width(%width: i32) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %val = arith.constant 42.0 : f32 - - // expected-error @+1 {{'gpu.rotate' op width is not a constant value}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 32 : f32 gpu.return } } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir new file mode 100644 index 0000000..e391a89 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @alloc() { + %alloc = memref.alloc() : memref<999xi32> + return +} + +// CPP: module { +// CPP-NEXT: emitc.include <"cstdlib"> +// CPP-LABEL: alloc() +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"stdlib.h"> +// NOCPP-LABEL: alloc() +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + +func.func @alloc_aligned() { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32> + return +} + +// CPP-LABEL: alloc_aligned +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// CPP-NEXT: return + +// NOCPP-LABEL: alloc_aligned +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// NOCPP-NEXT: return + +func.func @allocating_multi() { + %alloc_5 = memref.alloc() : memref<7x999xi32> + return +} + +// CPP-LABEL: allocating_multi +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void"> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP-LABEL: allocating_multi +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index fa7a91c..b6f2383 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) { func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) { // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // CHECK: scf.yield [[ARG0]] tosa.yield %arg0 : tensor<f32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 8c135d5..31e17fb 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -274,73 +274,6 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3 // ----- //===----------------------------------------------------------------------===// -// vector.extractelement -//===----------------------------------------------------------------------===// - -func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 { - %1 = vector.extractelement %arg0[] : vector<f32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_0d_f32 -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- -func.func @extractelement_from_vec_1d_f32_idx_as_index(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_index_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -//===----------------------------------------------------------------------===// // vector.extract //===----------------------------------------------------------------------===// @@ -592,81 +525,6 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : // ----- //===----------------------------------------------------------------------===// -// vector.insertelement -//===----------------------------------------------------------------------===// - -func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> { - %1 = vector.insertelement %arg0, %arg1[] : vector<f32> - return %1 : vector<f32> -} -// CHECK-LABEL: @insertelement_into_vec_0d_f32 -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : -// CHECK: vector<f32> to vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -//===----------------------------------------------------------------------===// // vector.insert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f43a41a..8918f91 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -400,67 +400,6 @@ func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size5_vector -func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { - // CHECK: vector.extractelement - %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME: (%[[S:.+]]: f32 -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - // CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]], %[[ARG]] : vector<4xf32>, vector<4xf32> -> vector<2xf32> @@ -473,67 +412,6 @@ func.func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector // ----- -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size5_vector -func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { - // CHECK: vector.insertelement - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - return %0 : vector<5xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32> - // CHECK: return %[[V]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32> - // CHECK: return %[[V]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]], %[[PART]] : vector<4xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir new file mode 100644 index 0000000..e2ab876 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor<?xf32> + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor<?xf32> + // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32 + return %0, %1 : tensor<?xf32>, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @call_func_with_non_tensor_return( + %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + return %1, %0 : f32, tensor<?xf32> + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 162ff06..35381da 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a // ----- func.func @rotate_mismatching_type(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1) return } // ----- func.func @rotate_unsupported_type(%arg0 : index) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : index + %rotate, %valid = gpu.rotate %arg0, 4, 16 : index return } @@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) { %offset = arith.constant 4 : i32 %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32> + %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32> return } // ----- func.func @rotate_unsupported_width(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 15 : i32 - // expected-error@+1 {{op width must be a power of two}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset(%arg0 : f32) { - %offset = arith.constant 16 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset_minus(%arg0 : f32) { - %offset = arith.constant -1 : i32 - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) { - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) { - %offset = arith.constant 0 : i32 - // expected-error@+1 {{op width is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 2aef80f..ee1fdfa 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -140,9 +140,8 @@ module attributes {gpu.container_module} { // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32 %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32 - // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32 - %rotate_width = arith.constant 16 : i32 - %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32 + // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32 + %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32 "gpu.barrier"() : () -> () diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cbb56e4..39a7b1b 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1387,42 +1387,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) { // CHECK-LABEL: @recursive_effect // CHECK: linalg.map +// ----- + //===----------------------------------------------------------------------===// // linalg.pack //===----------------------------------------------------------------------===// // CHECK-LABEL: func @fold_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 1.000000e-01 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst padding_value(%pad : f32) outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } - // ----- // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1430,8 +1431,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 6fc8d9f..cc26fa4 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( // ----- -func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { - %empty = tensor.empty() : tensor<8x4x16x8xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> - %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> - return %pack : tensor<8x4x16x8xf32> -} -// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] -// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> -// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]] -// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> -// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> - -// ----- - func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { %6 = tensor.empty(%dim) : tensor<?x256xf32> %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index da1dfc7..40bf4d1 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf } // ----- + func.func @pack_mismatch_inner_tile_size_and_output_shape( %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> { // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} @@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t // ----- +func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> { + %cst = arith.constant 0.0 : f32 + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %output + : tensor<9xf32> -> tensor<3x8xf32> + return %0 : tensor<3x8xf32> +} + +// ----- + // The outer dims in the output tensor are incorrectly/unexpectedly transposed. // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose). func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}} + // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}} %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32> return %0 : tensor<4x16x32x16xf32> } // ----- -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> +func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> { + // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}} + %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32> + return %0 : tensor<8x7x16x32xf32> +} + +// ----- + +func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output + : tensor<3x8xf32> -> tensor<9xf32> + return %0 : tensor<9xf32> } // ----- -func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> +func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 81fd7a8..9e7681d 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_with_pad( -func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>) - -> tensor<265x16x16x1xf32> { +func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>) + -> tensor<265x12x16x1xf32> { // CHECK: tensor.pad {{.*}} low[0, 0] - // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32> + // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32> // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]] - // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32> + // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32> // CHECK: linalg.transpose - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>) // CHECK-SAME: permutation = [0, 2, 1, 3] %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %dest - : tensor<4225x12xf32> -> tensor<265x16x16x1xf32> - return %0 : tensor<265x16x16x1xf32> + : tensor<4225x12xf32> -> tensor<265x12x16x1xf32> + return %0 : tensor<265x12x16x1xf32> } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 98e8f50..d41d861 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32> func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> { // CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32> -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32> -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 // CHECK: %[[C01:.*]] = arith.constant 0 // CHECK: %[[C02:.*]] = arith.constant 0 -// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32> -// CHECK: %[[CNST14:.*]] = arith.constant 1 -// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32> +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32> // CHECK: %[[CNST16:.*]] = arith.constant 16 : index // CHECK: %[[CNST2:.*]] = arith.constant 2 : index -// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> +// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4c50ed3..8c846cd 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, // CHECK-NEXT: (%[[XVAL:.*]]: i1): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1 // CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1) - // } + // CHECK-NEXT: } omp.atomic.update %xBool : memref<i1> { ^bb0(%xval: i1): %newval = llvm.icmp "eq" %xval, %exprBool : i1 @@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, omp.yield(%newval : i32) } + // CHECK: omp.atomic.update %[[X]] : memref<i32> { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%{{.+}} : i32) + // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} + omp.atomic.update %x : memref<i32> { + ^bb0(%xval:i32): + omp.yield(%const:i32) + } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} return } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 0176fc2..6398161 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000..06312c7 --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// ----- + +func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index eb25011..fad1bec 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<f32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 716362e..b90d6f5 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: // CHECK-LABEL: test_mul_non_scalar_shift_2d func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso // CHECK-LABEL: test_mul_non_scalar_shift_1d func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 5630c33..3154f54 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0dddf26..cbe0056 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { - %1 = tosa.cond_if %arg3 -> (tensor<f32>) { - %2 = tosa.cond_if %arg2 -> (tensor<f32>) { - %3 = tosa.cond_if %arg3 -> (tensor<f32>) { - %4 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> { + %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { + %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor<f32>) { + %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %res : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ef51197e..30361a8 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir index 38ac8d8..e957bdd 100644 --- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { // CHECK-LABEL: test_regions // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8> func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> { - // CHECK: tosa.cond_if %arg2 -> (tensor<i8>) + // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>): // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 9d43f89..7b8fc24 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: // ----- +// CHECK-LABEL: @test_unranked_scalar_i8_tensor +func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> { + // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8> + %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8> + // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> @@ -1166,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32> // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { tosa.yield %a : tensor<f32> } else { tosa.yield %b : tensor<f32> @@ -1180,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<?xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) { + // CHECK: -> tensor<?xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1194,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1208,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index b305236..2a937b0 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar // ----- +func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso // ----- +// CHECK-LABEL: cond_if_cond_type +func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected ':'}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}} + %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + tosa.yield %arg0 : tensor<f32> + } else { + tosa.yield %arg1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) { %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32> // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30..9cfebd5 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2562,118 +2562,6 @@ func.func @insert_2d_splat_constant() // ----- -// CHECK-LABEL: func @insert_element_fold -// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> -// CHECK: return %[[V]] -func.func @insert_element_fold() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @insert_element_invalid_fold -func.func @insert_element_invalid_fold() -> vector<1xf32> { - // Out-of-bound index here. - %c26 = arith.constant 26 : index - %cst_2 = arith.constant 1.60215309E+9 : f32 - %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32> -// CHECK: vector.insertelement - %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32> - return %46 : vector<1xf32> -} - - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold1 -// CHECK: vector.insertelement -func.func @insert_poison_fold1() -> vector<4xi32> { - %v = ub.poison : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold2 -// CHECK: vector.insertelement -func.func @insert_poison_fold2() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = ub.poison : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold3 -// CHECK: vector.insertelement -func.func @insert_poison_fold3() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = ub.poison : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_element_fold -// CHECK: %[[C:.+]] = arith.constant 5 : i32 -// CHECK: return %[[C]] -func.func @extract_element_fold() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold1 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold1() -> i32 { - %v = ub.poison : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold2 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold2() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = ub.poison : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: func @reduce_one_element_vector_extract // CHECK-SAME: (%[[V:.+]]: vector<1xf32>) // CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32> @@ -2933,18 +2821,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // ----- -// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( -// CHECK-SAME: %[[f:.*]]: f32 -// CHECK: return %[[f]] -func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { - %0 = vector.broadcast %f : f32 to vector<15xf32> - %c5 = arith.constant 5 : index - %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 { // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32> diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 0263193..2563b48 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -60,16 +60,6 @@ func.func @vector_extract() -> index { func.return %2 : index } -// CHECK-LABEL: func @vector_extractelement -// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} -func.func @vector_extractelement() -> index { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> - %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> - %2 = test.reflect_bounds %1 : index - func.return %2 : index -} - // CHECK-LABEL: func @vector_add // CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} func.func @vector_add() -> vector<4xindex> { @@ -90,17 +80,6 @@ func.func @vector_insert() -> vector<4xindex> { func.return %3 : vector<4xindex> } -// CHECK-LABEL: func @vector_insertelement -// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} -func.func @vector_insertelement() -> vector<4xindex> { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> - %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index - %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> - %3 = test.reflect_bounds %2 : vector<4xindex> - func.return %3 : vector<4xindex> -} - // CHECK-LABEL: func @test_loaded_vector_extract // No bounds // CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ca837d3..c21de56 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- -func.func @extract_element(%arg0: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %1 = vector.extractelement %arg0[%c : i32] : vector<f32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %1 = vector.extractelement %arg0[] : vector<4xf32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> -} - -// ----- - func.func @extract_vector_type(%arg0: index) { // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}} %1 = vector.extract %arg0[] : index from index @@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @insert_element(%arg0: f32, %arg1: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> -} - -// ----- - -func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}} - %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>) -} - -// ----- - func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 6a56116..625ffc1 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4 return %1 : vector<4xf32> } -// CHECK-LABEL: @extract_element_0d -func.func @extract_element_0d(%a: vector<f32>) -> f32 { - // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32> - %1 = vector.extractelement %a[] : vector<f32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_element -func.func @extract_element(%a: vector<16xf32>) -> f32 { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.extractelement %a[%c : i32] : vector<16xf32> - return %1 : f32 -} - // CHECK-LABEL: @extract_const_idx func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { @@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { return %0 : f32 } -// CHECK-LABEL: @insert_element_0d -func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> { - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32> - %1 = vector.insertelement %a, %b[] : vector<f32> - return %1 : vector<f32> -} - -// CHECK-LABEL: @insert_element -func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32> - return %1 : vector<16xf32> -} - // CHECK-LABEL: @insert_const_idx func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { diff --git a/mlir/test/IR/test-pattern-logging-listener.mlir b/mlir/test/IR/test-pattern-logging-listener.mlir index c521110..d3d42e3 100644 --- a/mlir/test/IR/test-pattern-logging-listener.mlir +++ b/mlir/test/IR/test-pattern-logging-listener.mlir @@ -8,15 +8,15 @@ // {anonymous_namespace} vs `anonymous_namespace` (and maybe others?) on the // various platforms. -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationInserted | test.new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationReplaced (with values) | test.replace_with_new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationErased | test.replace_with_new_op func.func @replace_with_new_op() -> i32 { %a = "test.replace_with_new_op"() : () -> (i32) diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir index 05e6782..a7bb039 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir @@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { %zero = arith.constant 0 : i32 - %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> + %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32> %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> - %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32> + %C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32> // Pack matrices - %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32> + %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32> %B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32> - %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32> + %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32> // MMT4D - %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32> + %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32> // Unpack output %C_out_empty = tensor.empty() : tensor<7x13xi32> - %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32> + %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32> return %C_out_unpack : tensor<7x13xi32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir index 6e2a82b..6ec1031 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir @@ -4,14 +4,14 @@ // RUN: FileCheck %s func.func @extract_element_0d(%a: vector<f32>) { - %1 = vector.extractelement %a[] : vector<f32> + %1 = vector.extract %a[] : f32 from vector<f32> // CHECK: 42 vector.print %1: f32 return } func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) { - %1 = vector.insertelement %a, %b[] : vector<f32> + %1 = vector.insert %a, %b[] : f32 into vector<f32> return %1: vector<f32> } @@ -58,9 +58,9 @@ func.func @broadcast_0d(%a: f32) { func.func @bitcast_0d() { %0 = arith.constant 42 : i32 %1 = arith.constant dense<0> : vector<i32> - %2 = vector.insertelement %0, %1[] : vector<i32> + %2 = vector.insert %0, %1[] : i32 into vector<i32> %3 = vector.bitcast %2 : vector<i32> to vector<f32> - %4 = vector.extractelement %3[] : vector<f32> + %4 = vector.extract %3[] : f32 from vector<f32> %5 = arith.bitcast %4 : f32 to i32 // CHECK: 42 vector.print %5: i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index b69a200..eb99886 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -72,7 +72,7 @@ func.func @za0_d_f64() -> i32 { %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64> %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { - %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %t = vector.extract %row[%offset] : f64 from vector<[2]xf64> %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 scf.yield %inner_add_reduce_next : f64 } @@ -102,7 +102,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -125,7 +125,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index 697fb90..ad8e321 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -36,7 +36,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -64,7 +64,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir index 53a7282..aff272c2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir @@ -11,8 +11,8 @@ func.func @entry() -> i32 { %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %r = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32> + %1 = vector.extract %r[%i0] : f32 from vector<8xf32> + %2 = vector.extract %r[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir index bf1caaa..1c56990 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir @@ -196,13 +196,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { %v_A = vector.transfer_read %m_A[%a], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { %v_C = vector.transfer_read %m_C[%b], %index_padding : memref<?xi64>, vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { @@ -273,10 +273,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>, %v_C = vector.transfer_read %m_C[%b1], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { @@ -370,8 +370,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64 -> f64 %r2 = arith.addf %r1, %subresult : f64 - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 %cond_a_i64 = arith.extui %cond_a : i1 to i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir index e9a66cc..1683fa5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -49,7 +48,7 @@ func.func @entry() { memref.store %z, %A[%i] : memref<?xf32> %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir index 2dc00df..826da53 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -53,7 +52,7 @@ func.func @entry() { iter_args(%v_iter = %v) -> (vector<16xf32>) { %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir index 54b6e69..22b5eef 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir @@ -21,8 +21,7 @@ func.func @printmem8(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c8 step %c1 iter_args(%m_iter = %m) -> (vector<8xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<8xf32> scf.yield %m_new : vector<8xf32> } vector.print %mem : vector<8xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir index 2393bd1..639eed4 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir @@ -200,7 +200,7 @@ func.func @entry() { // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. + // Generates a loop with vector.insert. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> () // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir index e665653..731bd5a 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %interleave = vector.interleave %lhs1, %rhs1 : vector<2xi32> -> vector<4xi32> - %res0 = vector.extractelement %interleave[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %interleave[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %interleave[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %interleave[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %interleave[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %interleave[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %interleave[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %interleave[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir index dc53fe3..c1b7dba 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %shuffle = vector.shuffle %lhs1, %rhs1[2, 1, 3, 3] : vector<2xi32>, vector<2xi32> - %res0 = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %shuffle[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %shuffle[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %shuffle[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %shuffle[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %shuffle[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %shuffle[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index cdbca72..7888462 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -595,16 +595,17 @@ module attributes {transform.with_named_sequence} { // ----- -// It is valid to fuse the pack op with padding semantics if the tiled -// dimensions do not need padding. +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> } } %1 = tensor.empty() : tensor<22x2x3x16xf32> @@ -621,109 +622,39 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] -// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] -// CHECK-SAME: into %[[TILED_PACK_DEST]] -// CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// It is valid to fuse the pack if the dimension is not tiled even when it needs -// extra padding. - -func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<33x2x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32> - return %pack : tensor<33x2x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32> -// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM:.*]] = linalg.exp -// CHECK-SAME: ins(%[[ELEM_SRC]] -// CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// If the dimension is tiled and it needs extra padding, do not fuse the pack -// op. - -func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<23x32x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> - return %pack : tensor<23x32x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // ----- diff --git a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir b/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir deleted file mode 100644 index d889ef4..0000000 --- a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir +++ /dev/null @@ -1,121 +0,0 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s - -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { - omp.private {type = private} @_QFEi_private_i32 : i32 loc(#loc1) - omp.declare_reduction @add_reduction_i32 : i32 init { - ^bb0(%arg0: i32 loc("test.f90":8:7)): - %0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } combiner { - ^bb0(%arg0: i32 loc("test.f90":8:7), %arg1: i32 loc("test.f90":8:7)): - %0 = llvm.add %arg0, %arg1 : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } loc(#loc2) - llvm.func @_QQmain() { - %0 = llvm.mlir.constant(1 : i64) : i64 loc(#loc4) - %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr<5> loc(#loc4) - %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr loc(#loc4) - %3 = llvm.mlir.constant(1 : i64) : i64 loc(#loc1) - %4 = llvm.alloca %3 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr<5> loc(#loc1) - %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr loc(#loc1) - %6 = llvm.mlir.constant(8191 : index) : i64 loc(#loc5) - %7 = llvm.mlir.constant(0 : index) : i64 loc(#loc5) - %8 = llvm.mlir.constant(1 : index) : i64 loc(#loc5) - %9 = llvm.mlir.constant(0 : i32) : i32 loc(#loc5) - %10 = llvm.mlir.constant(8192 : index) : i64 loc(#loc5) - %11 = llvm.mlir.addressof @_QFEarr : !llvm.ptr<1> loc(#loc6) - %12 = llvm.addrspacecast %11 : !llvm.ptr<1> to !llvm.ptr loc(#loc6) - llvm.store %9, %2 : i32, !llvm.ptr loc(#loc7) - %15 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"} loc(#loc4) - %16 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"} loc(#loc7) - %17 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) extent(%10 : i64) stride(%8 : i64) start_idx(%8 : i64) loc(#loc7) - %18 = omp.map.info var_ptr(%12 : !llvm.ptr, !llvm.array<8192 x i32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%17) -> !llvm.ptr {name = "arr"} loc(#loc7) - omp.target map_entries(%15 -> %arg0, %16 -> %arg1, %18 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { - %19 = llvm.mlir.constant(8192 : i32) : i32 loc(#loc5) - %20 = llvm.mlir.constant(1 : i32) : i32 loc(#loc5) - %21 = llvm.mlir.constant(8192 : index) : i64 loc(#loc6) - omp.teams reduction(@add_reduction_i32 %arg0 -> %arg3 : !llvm.ptr) { - omp.parallel private(@_QFEi_private_i32 %arg1 -> %arg4 : !llvm.ptr) { - omp.distribute { - omp.wsloop reduction(@add_reduction_i32 %arg3 -> %arg5 : !llvm.ptr) { - omp.loop_nest (%arg6) : i32 = (%20) to (%19) inclusive step (%20) { - llvm.store %arg6, %arg4 : i32, !llvm.ptr loc(#loc2) - %22 = llvm.load %arg5 : !llvm.ptr -> i32 loc(#loc8) - %23 = llvm.load %arg4 : !llvm.ptr -> i32 loc(#loc8) - %34 = llvm.add %22, %23 : i32 loc(#loc8) - llvm.store %34, %arg5 : i32, !llvm.ptr loc(#loc8) - omp.yield loc(#loc2) - } loc(#loc2) - } {omp.composite} loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc13) - llvm.return loc(#loc9) - } loc(#loc12) - llvm.mlir.global internal @_QFEarr() {addr_space = 1 : i32} : !llvm.array<8192 x i32> { - %0 = llvm.mlir.zero : !llvm.array<8192 x i32> loc(#loc6) - llvm.return %0 : !llvm.array<8192 x i32> loc(#loc6) - } loc(#loc6) -} loc(#loc) - -#loc = loc("test.f90":4:18) -#loc1 = loc("test.f90":4:18) -#loc2 = loc("test.f90":8:7) -#loc3 = loc("test.f90":1:7) -#loc4 = loc("test.f90":3:18) -#loc5 = loc(unknown) -#loc6 = loc("test.f90":5:18) -#loc7 = loc("test.f90":6:7) -#loc8 = loc("test.f90":10:7) -#loc9 = loc("test.f90":16:7) - -#di_file = #llvm.di_file<"target7.f90" in ""> -#di_null_type = #llvm.di_null_type -#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, - sourceLanguage = DW_LANG_Fortran95, file = #di_file, producer = "flang", - isOptimized = false, emissionKind = LineTablesOnly> -#di_subroutine_type = #llvm.di_subroutine_type< - callingConvention = DW_CC_program, types = #di_null_type> -#di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, - compileUnit = #di_compile_unit, scope = #di_file, name = "main", - file = #di_file, subprogramFlags = "Definition|MainSubprogram", - type = #di_subroutine_type> -#di_subprogram1 = #llvm.di_subprogram<compileUnit = #di_compile_unit, - name = "target", file = #di_file, subprogramFlags = "Definition", - type = #di_subroutine_type> - - -#loc12 = loc(fused<#di_subprogram>[#loc3]) -#loc13 = loc(fused<#di_subprogram1>[#loc2]) - -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @"__omp_offloading_{{.*}}__QQmain_l8_omp$reduction$reduction_func.1" -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func.2 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func.3 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_reduce_func -// CHECK-NOT: !dbg -// CHECK: } diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 76d34c2..6aca11e 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -1,6 +1,7 @@ // RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ mlir-translate -no-implicit-module --split-input-file -serialize-spirv %s | spirv-val %} -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, CooperativeMatrixKHR], [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix]> { // CHECK-LABEL: @bool_const spirv.func @bool_const() -> () "None" { // CHECK: spirv.Constant true @@ -305,6 +306,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> } + + spirv.EntryPoint "GLCompute" @bool_const } // ----- diff --git a/mlir/test/Target/SPIRV/lit.local.cfg b/mlir/test/Target/SPIRV/lit.local.cfg new file mode 100644 index 0000000..6d44394 --- /dev/null +++ b/mlir/test/Target/SPIRV/lit.local.cfg @@ -0,0 +1,4 @@ +if config.spirv_tools_tests: + config.available_features.add("spirv-tools") + config.substitutions.append(("spirv-as", os.path.join(config.llvm_tools_dir, "spirv-as"))) + config.substitutions.append(("spirv-val", os.path.join(config.llvm_tools_dir, "spirv-val"))) diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 53fbb8a..d6fa442 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>> %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>> @@ -12,9 +12,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>> %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>> %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> @@ -24,12 +24,12 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -38,13 +38,13 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_128 = arith.constant 128 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], %[[C384]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -53,9 +53,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { +// CHECK-SAME: %[[input:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> + // CHECK: {{.*}} = memref.subview %[[input]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>> @@ -64,9 +64,9 @@ func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { +// CHECK-SAME: %[[input:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> + // CHECK: {{.*}} = memref.subview %[[input]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>> %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>> %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>> @@ -76,13 +76,13 @@ func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_64 = arith.constant 64 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], %[[C384]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> @@ -91,13 +91,39 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> } + +// ----- + +// CHECK-LABEL: func.func @single_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE_1:.*]]: index) -> memref<8x?xf32> { +func.func @single_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<8x?xf32>{ + %subview = memref.subview %input[0, 0][8, %size0][1, 1] : memref<256x?xf32> to memref<8x?xf32> + %subview_1 = memref.subview %subview[0, 0][8, %size1][1, 1] : memref<8x?xf32> to memref<8x?xf32> + // CHECK: %{{.*}} = memref.subview %[[input]][0, 0] [8, %[[SIZE_1]]] [1, 1] : memref<256x?xf32> to memref<8x?xf32> + return %subview_1 : memref<8x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @all_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE1:.*]]: index) -> memref<?x?xf32> { +func.func @all_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<?x?xf32>{ + %subview = memref.subview %input[0, 0][%size0, %size0][1, 1] : memref<256x?xf32> to memref<?x?xf32> + %subview_1 = memref.subview %subview[0, 0][%size1, %size1][1, 1] : memref<?x?xf32> to memref<?x?xf32> + // CHECK: {{.*}} = memref.subview %[[input]][0, 0] {{\[}}%[[SIZE1]], %[[SIZE1]]] [1, 1] : memref<256x?xf32> to memref<?x?xf32> + return %subview_1 : memref<?x?xf32> +} diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index db8bd0f..9bffe92 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() { "test.signature_conversion_no_converter"() ({ // expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}} ^bb0(%arg0: f32): - "test.type_consumer"(%arg0) : (f32) -> () // expected-note@below{{see existing live user here}} + "test.type_consumer"(%arg0) : (f32) -> () "test.return"(%arg0) : (f32) -> () }) : () -> () return diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir index 19a1310..5b07055 100644 --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s // expected-remark@-2 {{op 'builtin.module' is legalizable}} // expected-remark@+1 {{op 'func.func' is legalizable}} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 5f1148c..dcd0172 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: func @multi_level_mapping func.func @multi_level_mapping() { diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index 226e0bb..2ee3222 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestOneShotModuleBufferize.cpp TestTensorCopyInsertion.cpp TestTensorLikeAndBufferLike.cpp diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp new file mode 100644 index 0000000..1e2d4a7 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -0,0 +1,57 @@ +//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneShotModuleBufferizePass + : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) + + TestOneShotModuleBufferizePass() = default; + TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<bufferization::BufferizationDialect>(); + } + StringRef getArgument() const final { + return "test-one-shot-module-bufferize"; + } + StringRef getDescription() const final { + return "Pass to test One Shot Module Bufferization"; + } + + void runOnOperation() override { + + llvm::errs() << "Running TestOneShotModuleBufferize on: " + << getOperation()->getName() << "\n"; + bufferization::OneShotBufferizationOptions opt; + + opt.bufferizeFunctionBoundaries = true; + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, + bufferizationState))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestOneShotModuleBufferizePass() { + PassRegistration<TestOneShotModuleBufferizePass>(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index f79e2cf..53055fe 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -18,6 +18,32 @@ using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// +// OverridenSymbolVisibilityOp +//===----------------------------------------------------------------------===// + +SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() { + return SymbolTable::Visibility::Private; +} + +static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) { + switch (visibility) { + case SymbolTable::Visibility::Private: + return "private"; + case SymbolTable::Visibility::Nested: + return "nested"; + case SymbolTable::Visibility::Public: + return "public"; + } +} + +void OverriddenSymbolVisibilityOp::setVisibility( + SymbolTable::Visibility visibility) { + + emitOpError("cannot change visibility of symbol to ") + << getVisibilityString(visibility); +} + +//===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a7c6cd6..2eaad55 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -119,12 +119,28 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ + DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, +]> { + let summary = "operation overridden symbol visibility accessors"; + let arguments = (ins StrAttr:$sym_name); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; let regions = (region SizedRegion<1>:$region); } +def SymbolScopeIsolatedOp + : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable, + SingleBlockImplicitTerminator< + "TerminatorOp">]> { + let summary = + "operation which defines a new symbol table that is IsolatedFromAbove"; + let regions = (region SizedRegion<1>:$region); +} + def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 5fcd92e..eda618f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1361,6 +1361,10 @@ struct TestLegalizePatternDriver : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) + TestLegalizePatternDriver() = default; + TestLegalizePatternDriver(const TestLegalizePatternDriver &other) + : PassWrapper(other) {} + StringRef getArgument() const final { return "test-legalize-patterns"; } StringRef getDescription() const final { return "Run test dialect legalization patterns"; @@ -1368,8 +1372,6 @@ struct TestLegalizePatternDriver /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; - TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<func::FuncDialect, test::TestDialect>(); } @@ -1498,24 +1500,19 @@ struct TestLegalizePatternDriver op->emitRemark() << "op '" << op->getName() << "' is legalizable"; } - /// The mode of conversion to use. - ConversionMode mode; + Option<ConversionMode> mode{ + *this, "test-legalize-mode", + llvm::cl::desc("The legalization mode to use with the test driver"), + llvm::cl::init(ConversionMode::Partial), + llvm::cl::values( + clEnumValN(ConversionMode::Analysis, "analysis", + "Perform an analysis conversion"), + clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), + clEnumValN(ConversionMode::Partial, "partial", + "Perform a partial conversion"))}; }; } // namespace -static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> - legalizerConversionMode( - "test-legalize-mode", - llvm::cl::desc("The legalization mode to use with the test driver"), - llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), - llvm::cl::values( - clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, - "analysis", "Perform an analysis conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", - "Perform a full conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, - "partial", "Perform a partial conversion"))); - //===----------------------------------------------------------------------===// // ConversionPatternRewriter::getRemappedValue testing. This method is used // to get the remapped value of an original value that was replaced using @@ -2201,9 +2198,7 @@ void registerPatternsTestPass() { PassRegistration<TestStrictPatternDriver>(); PassRegistration<TestWalkPatternDriver>(); - PassRegistration<TestLegalizePatternDriver>([] { - return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); - }); + PassRegistration<TestLegalizePatternDriver>(); PassRegistration<TestRemappedValue>(); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 233fef8..feaf5fb 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -343,7 +343,6 @@ if config.enable_assertions: else: config.available_features.add("noasserts") - def have_host_jit_feature_support(feature_name): mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 132aabe..b1185e1 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -5,6 +5,7 @@ import sys config.target_triple = "@LLVM_TARGET_TRIPLE@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@") +config.spirv_tools_tests = @LLVM_INCLUDE_SPIRV_TOOLS_TESTS@ config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@")) config.python_executable = "@Python3_EXECUTABLE@" @@ -41,7 +42,7 @@ config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@ # This is a workaround for the fact that LIT's: # %if <cond> -# requires <cond> to be in the set of available features. +# requires <cond> to be in the set of available features. # TODO: Update LIT's TestRunner so that this is not required. if config.mlir_run_arm_sve_tests: config.available_features.add("mlir_arm_sve_tests") diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2c09753..14714c45 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -135,6 +135,7 @@ void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); +void registerTestOneShotModuleBufferizePass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); @@ -281,6 +282,7 @@ void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); + mlir::test::registerTestOneShotModuleBufferizePass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion(); diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index cfc3fe0..4b3545b 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { }); } +TEST(SymbolOpInterface, Visibility) { + DialectRegistry registry; + ::test::registerTestDialect(registry); + MLIRContext context(registry); + + constexpr static StringLiteral kInput = R"MLIR( + "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> () + )MLIR"; + OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(kInput, &context); + auto symOp = cast<SymbolOpInterface>(module->getBody()->front()); + + ASSERT_TRUE(symOp.isPrivate()); + ASSERT_FALSE(symOp.isPublic()); + ASSERT_FALSE(symOp.isNested()); + ASSERT_TRUE(symOp.canDiscardOnUseEmpty()); + + std::string diagStr; + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { diagStr += diag.str(); }); + + std::string expectedDiag; + symOp.setPublic(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to public"; + symOp.setNested(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to nested"; + symOp.setPrivate(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to private"; + + ASSERT_EQ(diagStr, expectedDiag); +} + } // namespace |