diff options
Diffstat (limited to 'mlir')
344 files changed, 4851 insertions, 4503 deletions
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md new file mode 100644 index 0000000..7c852ef --- /dev/null +++ b/mlir/Maintainers.md @@ -0,0 +1,63 @@ +# MLIR Maintainers + +This file is a list of the +[maintainers](https://llvm.org/docs/DeveloperPolicy.html#maintainers) for MLIR. + +The following people are the active maintainers for the project. For the sake of +simplicity, responsibility areas are subdivided into broad categories, which are +further subdivided into individual components, such as dialects. Please reach +out to them for code reviews, questions about their area of expertise, or other +assistance. + +## Core + +Core components of MLIR, including core IR, analyses and rewriters, fundamental +dialects, build system and language bindings. + +- Alex Zinenko \ + ftynse@gmail.com (email), + [@ftynse](https://github.com/ftynse) (GitHub), + ftynse (Discourse) +- Jacques Pienaar \ + jpienaar@google.com (email), + [@jpienaar](https://github.com/jpienaar) (GitHub), + jpienaar (Discourse) +- Mehdi Amini \ + joker.eph@gmail.com (email), + [@joker-eph](https://github.com/joker-eph) (GitHub), + mehdi_amini (Discourse) + +## Egress + +MLIR components pertaining to egress flows from MLIR, in particular to LLVM IR. + +- Matthias Springer \ + me@m-sp.org (email), + [@matthias-springer](https://github.com/matthias-springer) (GitHub), + matthias-springer (Discourse) +- Andrzej Warzynski \ + andrzej.warzynski@arm.com (email), + [@banach-space](https://github.com/banach-space) (GitHub), + banach-space (Discourse) +- Tobias Gysi \ + tobias.gysi@nextsilicon.com (email), + [@gysit](https://github.com/gysit) (GitHub), + gysit (Discourse) + +## Tensor Compiler + +MLIR components specific to construction of compilers for tensor algebra, in +particular for machine learning compilers. + +- Renato Golin \ + rengolin@gmail.com (email), + [@rengolin](https://github.com/rengolin) (GitHub), + rengolin (Discourse) +- Jacques Pienaar \ + jpienaar@google.com (email), + [@jpienaar](https://github.com/jpienaar) (GitHub), + jpienaar (Discourse) +- Andrzej Warzynski \ + andrzej.warzynski@arm.com (email), + [@banach-space](https://github.com/banach-space) (GitHub), + banach-space (Discourse) diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md index ebeb0a2..6c8949d 100644 --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -294,7 +294,7 @@ LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g. `llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR operations are prefixed by the `vector.` dialect prefix (e.g. -`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector` +`vector.insert`). Such ops operate exclusively on MLIR `n-D` `vector` types. ### Alternatives For Lowering an n-D Vector Type to LLVM diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index e2288f5..6d09e93 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -18,6 +18,8 @@ The following convention is followed: GCC or Clang. * If `emitc.array` with a dimension of size zero is used, then the code requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html). +* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17 + or C11 is required. * Else the generated code is compatible with C99. These restrictions are neither inherent to the EmitC dialect itself nor to the diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index bf590ac..7e1c5fe 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -563,7 +563,7 @@ def MyInterface : OpInterface<"MyInterface"> { template <typename ConcreteOp> struct Model : public Concept { Operation *create(OpBuilder &builder, Location loc) const override { - return builder.create<ConcreteOp>(loc); + return ConcreteOp::create(builder, loc); } } }; @@ -574,7 +574,7 @@ def MyInterface : OpInterface<"MyInterface"> { }], "Operation *", "create", (ins "OpBuilder &":$builder, "Location":$loc), /*methodBody=*/[{ - return builder.create<ConcreteOp>(loc); + return ConcreteOp::create(builder, loc); }]>, InterfaceMethod<[{ diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md index 9839d1d..c6e352f 100644 --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1483,7 +1483,7 @@ be defined by specifying a string code block after the rewrite declaration: ```pdll Rewrite BuildOp(value: Value) -> (foo: Op<my_dialect.foo>, bar: Op<my_dialect.bar>) [{ - return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()}; + return {my_dialect::FooOp::create(rewriter, value), my_dialect::BarOp::create(rewriter)}; }]; Pattern { @@ -1508,7 +1508,7 @@ translated into: ```c++ std::tuple<my_dialect::FooOp, my_dialect::BarOp> BuildOp(Value value) { - return {rewriter.create<my_dialect::FooOp>(value), rewriter.create<my_dialect::BarOp>()}; + return {my_dialect::FooOp::create(rewriter, value), my_dialect::BarOp::create(rewriter)}; } ``` @@ -1530,7 +1530,7 @@ below describes the various result translation scenarios: ```pdll Rewrite createOp() [{ - rewriter.create<my_dialect::FooOp>(); + my_dialect::FooOp::create(rewriter); }]; ``` @@ -1538,7 +1538,7 @@ In the case where a native `Rewrite` has no results, the native function returns ```c++ void createOp(PatternRewriter &rewriter) { - rewriter.create<my_dialect::FooOp>(); + my_dialect::FooOp::create(rewriter); } ``` @@ -1546,7 +1546,7 @@ void createOp(PatternRewriter &rewriter) { ```pdll Rewrite createOp() -> Op<my_dialect.foo> [{ - return rewriter.create<my_dialect::FooOp>(); + return my_dialect::FooOp::create(rewriter); }]; ``` @@ -1555,7 +1555,7 @@ native type for that single result: ```c++ my_dialect::FooOp createOp(PatternRewriter &rewriter) { - return rewriter.create<my_dialect::FooOp>(); + return my_dialect::FooOp::create(rewriter); } ``` diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md index 0c89065..cbb6f03 100644 --- a/mlir/docs/Tutorials/QuickstartRewrites.md +++ b/mlir/docs/Tutorials/QuickstartRewrites.md @@ -130,7 +130,7 @@ def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), ```c++ static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, Value operand, Attribute attr) { - return rewriter.create<mlir::TFL::LeakyReluOp>( + return mlir::TFL::LeakyReluOp::create(rewriter, op->getLoc(), operands[0].getType(), /*arg=*/operands[0], /*alpha=*/cast<FloatAttr>(attrs[0])); } @@ -194,10 +194,10 @@ LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { // mul(x, c) -> shl(x, log2(c)), where c is a power of two. if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && value.isPowerOf2()) { - auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(), + auto shift = rtl::ConstantOp::create(rewriter, op.getLoc(), op.getType(), value.exactLogBase2()); auto shlOp = - rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift); + comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift); rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), ArrayRef<Value>(shlOp)); return success(); diff --git a/mlir/docs/Tutorials/Toy/Ch-2.md b/mlir/docs/Tutorials/Toy/Ch-2.md index 039417c..81e4161 100644 --- a/mlir/docs/Tutorials/Toy/Ch-2.md +++ b/mlir/docs/Tutorials/Toy/Ch-2.md @@ -521,7 +521,7 @@ def ConstantOp : Toy_Op<"constant"> { // Add custom build methods for the constant operation. These methods populate // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md index 1275d36..e9abe36 100644 --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -300,7 +300,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; ``` @@ -445,7 +445,7 @@ When processing an operation like described, we query if it registered the ```c++ // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; /// We check if an operation has a particular interface by casting. if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) { diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md index d483cd8..17cd6bb 100644 --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -91,13 +91,11 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details. After the conversion target has been defined, we can define how to convert the *illegal* operations into *legal* ones. Similarly to the canonicalization framework introduced in [chapter 3](Ch-3.md), the -[`DialectConversion` framework](../../DialectConversion.md) also uses -[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic. -These patterns may be the `RewritePatterns` seen before or a new type of pattern -specific to the conversion framework `ConversionPattern`. `ConversionPatterns` +[`DialectConversion` framework](../../DialectConversion.md) uses a special kind +of `ConversionPattern` to perform the conversion logic. `ConversionPatterns` are different from traditional `RewritePatterns` in that they accept an -additional `operands` parameter containing operands that have been -remapped/replaced. This is used when dealing with type conversions, as the +additional `operands` (or `adaptor`) parameter containing operands that have +been remapped/replaced. This is used when dealing with type conversions, as the pattern will want to operate on values of the new type but match against the old. For our lowering, this invariant will be useful as it translates from the [TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being @@ -106,38 +104,23 @@ look at a snippet of lowering the `toy.transpose` operation: ```c++ /// Lower the `toy.transpose` operation to an affine loop nest. -struct TransposeOpLowering : public mlir::ConversionPattern { - TransposeOpLowering(mlir::MLIRContext *ctx) - : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {} - - /// Match and rewrite the given `toy.transpose` operation, with the given - /// operands that have been remapped from `tensor<...>` to `memref<...>`. - llvm::LogicalResult - matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands, - mlir::ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; - // Call to a helper function that will lower the current operation to a set - // of affine loops. We provide a functor that operates on the remapped - // operands, as well as the loop induction variables for the inner most - // loop body. - lowerOpToLoops( - op, operands, rewriter, - [loc](mlir::PatternRewriter &rewriter, - ArrayRef<mlir::Value> memRefOperands, - ArrayRef<mlir::Value> loopIvs) { - // Generate an adaptor for the remapped operands of the TransposeOp. - // This allows for using the nice named accessors that are generated - // by the ODS. This adaptor is automatically provided by the ODS - // framework. - TransposeOpAdaptor transposeAdaptor(memRefOperands); - mlir::Value input = transposeAdaptor.input(); - - // Transpose the elements by generating a load from the reverse - // indices. - SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs); - }); + LogicalResult + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops(op, rewriter, + [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, + reverseIvs); + }); return success(); } }; diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md index e8a68b5..529de55 100644 --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -47,7 +47,7 @@ static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); + LLVM::LLVMFuncOp::create(rewriter, module.getLoc(), "printf", llvmFnType); return SymbolRefAttr::get("printf", context); } ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md index dce3490..0f50c49 100644 --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -488,9 +488,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Type type, mlir::Location loc) { if (isa<StructType>(type)) - return builder.create<StructConstantOp>(loc, type, + return StructConstantOp::create(builder, loc, type, cast<mlir::ArrayAttr>(value)); - return builder.create<ConstantOp>(loc, type, + return ConstantOp::create(builder, loc, type, cast<mlir::DenseElementsAttr>(value)); } ``` diff --git a/mlir/docs/Tutorials/transform/Ch0.md b/mlir/docs/Tutorials/transform/Ch0.md index ac3989a..dc4b753 100644 --- a/mlir/docs/Tutorials/transform/Ch0.md +++ b/mlir/docs/Tutorials/transform/Ch0.md @@ -46,7 +46,7 @@ When no support is available, such an operation can be transformed into a loop: %c8 = arith.constant 8 : index %init = arith.constant 0.0 : f32 %result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) { - %element = vector.extractelement %0[%i : index] : vector<8xf32> + %element = vector.extract %0[%i] : f32 into vector<8xf32> %updated = arith.addf %partial, %element : f32 scf.yield %updated : f32 } @@ -145,7 +145,7 @@ linalg.generic { %c0 = arith.constant 0.0 : f32 %0 = arith.cmpf ogt %in_one, %c0 : f32 %1 = arith.select %0, %in_one, %c0 : f32 - linalg.yield %1 : f32 + linalg.yield %1 : f32 } ``` @@ -185,7 +185,7 @@ In the case of `linalg.generic` operations, the iteration space is implicit and For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `linalg.generic` expressing the same operation on a `2x8` tensor. ```mlir -// A special "multi-for" loop that supports tensor-insertion semantics +// A special "multi-for" loop that supports tensor-insertion semantics // as opposed to implicit updates. The resulting 8x16 tensor will be produced // by this loop. // The trip count of iterators is computed dividing the original tensor size, @@ -202,9 +202,9 @@ For example, tiling the matrix multiplication presented above with tile sizes `( // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced. %lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1] : tensor<8x10xf32> to tensor<2x10xf32> - %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] + %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] : tensor<10x16xf32> to tensor<10x8xf32> - %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] + %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] : tensor<8x16xf32> to tensor<2x8xf32> // This is exactly the same operation as before, but now operating on smaller @@ -214,7 +214,7 @@ For example, tiling the matrix multiplication presented above with tile sizes `( affine_map<(i, j, k) -> (k, j)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) + } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> { ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): %0 = arith.mulf %lhs_one, %rhs_one : f32 @@ -238,15 +238,15 @@ After materializing loops with tiling, another key code generation transformatio 1. the subset (slice) of the operand that is used by the tile, and 2. the tensor-level structured operation producing the whole tensor that is being sliced. -By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. +By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety. ```mlir // Same loop as before. -%0 = scf.forall (%i, %j) in (4, 2) - shared_outs(%shared = %init) +%0 = scf.forall (%i, %j) in (4, 2) + shared_outs(%shared = %init) -> (tensor<8x16xf32>, tensor<8x16xf32>) { // Scale the loop induction variables by the tile sizes. %1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i) @@ -286,7 +286,7 @@ Let us assume that the matrix multiplication operation is followed by another op indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"] - } ins(%partial : tensor<2x8xf32>) + } ins(%partial : tensor<2x8xf32>) outs(%shared_slice : tensor<2x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.mulf %in, %in : f32 diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md index fa788d1..eeab770 100644 --- a/mlir/docs/Tutorials/transform/Ch3.md +++ b/mlir/docs/Tutorials/transform/Ch3.md @@ -139,7 +139,21 @@ void MyExtension::init() { ``` This type is now directly available in the Transform dialect and can be used in operations. +In the previous tablegen definition, the type of `$call` must be `Transform_ConcreteOp<“func.call”>`, +By adding `CallOpInterfaceHandle` as an allowed type for `$call`, the corresponding handle +is allowed to be to any op implementing the interface. +```tablegen +def ChangeCallTargetOp : ... { + let arguments = (ins + // Allow the handle to be to concrete `func.call` ops as well as any op implementing + // the `CallOpInterface`. + AnyTypeOf<[Transform_ConcreteOpType<"func.call">, CallOpInterfaceHandle]>:$call, + StrAttr:$new_target); +} +``` + +We can then add the following code to `sequence.mlir` and run it with the interpreter. ```mlir // Cast to our new type. @@ -172,7 +186,7 @@ def CallToOp : Op<Transform_Dialect, "my.call_to_op", let results = (outs TransformHandleTypeInterface:$transformed); // Provide nice syntax. - let assemblyFormat = "$call attr-dict `:` functional-type(inputs, outputs)"; + let assemblyFormat = "$call attr-dict `:` functional-type(operands, results)"; // Declare the function implementing the interface for a single payload operation. let extraClassDeclaration = [{ diff --git a/mlir/examples/standalone/standalone-opt/CMakeLists.txt b/mlir/examples/standalone/standalone-opt/CMakeLists.txt index 27f8128..4b38de7 100644 --- a/mlir/examples/standalone/standalone-opt/CMakeLists.txt +++ b/mlir/examples/standalone/standalone-opt/CMakeLists.txt @@ -1,12 +1,10 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS - ${dialect_libs} - ${conversion_libs} - MLIRArithDialect - MLIROptLib - MLIRStandalone - ) + MLIRArithDialect + MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRStandalone + ) add_llvm_executable(standalone-opt standalone-opt.cpp) llvm_update_compile_flags(standalone-opt) diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp index e39fa96..eebfcb7 100644 --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index ef65c9c..91bf83a 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -70,7 +70,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 96925be..39ae6a0 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -121,8 +121,8 @@ private: llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{})); auto funcType = builder.getFunctionType(argTypes, {}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -202,9 +202,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -235,8 +235,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -280,7 +280,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -325,13 +325,13 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create<GenericCallOp>(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -341,13 +341,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -391,8 +391,8 @@ private: // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index 4859804..027b076 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -69,7 +69,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index c8cba82..0573af6 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -121,8 +121,8 @@ private: llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{})); auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -202,9 +202,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -235,8 +235,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -280,7 +280,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -325,13 +325,13 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create<GenericCallOp>(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -341,13 +341,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -391,8 +391,8 @@ private: // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 0b32b1b..6c6b739 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 076a75a..1e5e672 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); + auto resultType = + llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); if (!resultType) return success(); @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || + if (inputType == resultType || + llvm::isa<mlir::UnrankedTensorType>(inputType) || llvm::isa<mlir::UnrankedTensorType>(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 9371815..7d676f1 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -121,8 +121,8 @@ private: llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{})); auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -284,7 +284,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +329,13 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create<GenericCallOp>(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +345,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +395,8 @@ private: // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt index f4f0fec..454ca56 100644 --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -27,12 +27,8 @@ add_toy_chapter(toyc-ch5 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch5 PRIVATE - ${dialect_libs} - ${extension_libs} MLIRAnalysis MLIRCallInterfaces MLIRCastInterfaces @@ -40,6 +36,9 @@ target_link_libraries(toyc-ch5 MLIRIR MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions MLIRSideEffectInterfaces MLIRSupport - MLIRTransforms) + MLIRTransforms + ) diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index d11d18dc..6a136ec 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp index fb7c742..69fb69f 100644 --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); + auto resultType = + llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); if (!resultType) return success(); @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || + if (inputType == resultType || + llvm::isa<mlir::UnrankedTensorType>(inputType) || llvm::isa<mlir::UnrankedTensorType>(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index bf2bc43..2969d3a 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create<memref::AllocOp>(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc, - ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create<LoweredBinaryOp>(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { if (!valueShape.empty()) { for (auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create<affine::AffineStoreOp>( - loc, rewriter.create<arith::ConstantOp>(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { } // Create a new non-toy function, with the same region. - auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return builder.create<affine::AffineLoadOp>(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 9371815..7d676f1 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -121,8 +121,8 @@ private: llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{})); auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -284,7 +284,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +329,13 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create<GenericCallOp>(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +345,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +395,8 @@ private: // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 6a0c631..afdf782 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "toy/AST.h" #include "toy/Dialect.h" diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt index 283b895..73df602 100644 --- a/mlir/examples/toy/Ch6/CMakeLists.txt +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -37,14 +37,8 @@ add_toy_chapter(toyc-ch6 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch6 PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces @@ -58,8 +52,11 @@ target_link_libraries(toyc-ch6 MLIRMemRefDialect MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSideEffectInterfaces MLIRSupport MLIRTargetLLVMIRExport MLIRTransforms - ) + ) diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index 63950f4..897b36d 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp index fb7c742..69fb69f 100644 --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); + auto resultType = + llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType()); if (!resultType) return success(); @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || + if (inputType == resultType || + llvm::isa<mlir::UnrankedTensorType>(inputType) || llvm::isa<mlir::UnrankedTensorType>(resultType)) return mlir::success(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index bf2bc43..2969d3a 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create<memref::AllocOp>(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc, - ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create<LoweredBinaryOp>(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { if (!valueShape.empty()) { for (auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create<affine::AffineStoreOp>( - loc, rewriter.create<arith::ConstantOp>(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { } // Create a new non-toy function, with the same region. - auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return builder.create<affine::AffineLoadOp>(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 22f75e0..987dfa1 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -55,19 +55,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns +// ToyToLLVM Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public ConversionPattern { +class PrintOpLowering : public OpConversionPattern<toy::PrintOp> { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern<toy::PrintOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin())); @@ -86,12 +85,12 @@ public: // Create a loop for each of the dimensions within the shape. SmallVector<Value, 4> loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = - rewriter.create<arith::ConstantIndexOp>(loc, memRefShape[i]); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(rewriter, loc, memRefShape[i]); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loop = - rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); for (Operation &nested : make_early_inc_range(*loop.getBody())) rewriter.eraseOp(&nested); loopIvs.push_back(loop.getInductionVar()); @@ -101,19 +100,17 @@ public: // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef, - newLineCst); - rewriter.create<scf::YieldOp>(loc); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + newLineCst); + scf::YieldOp::create(rewriter, loc); rewriter.setInsertionPointToStart(loop.getBody()); } // Generate a call to printf for the current element of the loop. - auto printOp = cast<toy::PrintOp>(op); auto elementLoad = - rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs); - rewriter.create<LLVM::CallOp>( - loc, getPrintfType(context), printfRef, - ArrayRef<Value>({formatSpecifierCst, elementLoad})); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + ArrayRef<Value>({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -142,8 +139,8 @@ private: // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", - getPrintfType(context)); + LLVM::LLVMFuncOp::create(rewriter, module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -159,19 +156,19 @@ private: builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); + global = LLVM::GlobalOp::create(builder, loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); } // Get the pointer to the first character in the global string. - Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); - Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getIndexAttr(0)); - return builder.create<LLVM::GEPOp>( - loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), - globalPtr, ArrayRef<Value>({cst0, cst0})); + Value globalPtr = LLVM::AddressOfOp::create(builder, loc, global); + Value cst0 = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return LLVM::GEPOp::create( + builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), + global.getType(), globalPtr, ArrayRef<Value>({cst0, cst0})); } }; } // namespace diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp index 9371815..7d676f1 100644 --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -121,8 +121,8 @@ private: llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(), getType(VarType{})); auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -284,7 +284,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +329,13 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create<GenericCallOp>(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +345,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +395,8 @@ private: // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index dccab91..4a5e109 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt index 362ab51..a489ae5 100644 --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -36,14 +36,8 @@ add_toy_chapter(toyc-ch7 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch7 PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces @@ -56,7 +50,10 @@ target_link_libraries(toyc-ch7 MLIRMemRefDialect MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSideEffectInterfaces MLIRTargetLLVMIRExport MLIRTransforms - ) + ) diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index bdf8ad0b..9151396 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -93,7 +93,7 @@ def ConstantOp : Toy_Op<"constant", // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create<ConstantOp>(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 52881db..4d2f063 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -97,7 +97,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create<CastOp>(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -429,7 +429,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) || + if (inputType == resultType || + llvm::isa<mlir::UnrankedTensorType>(inputType) || llvm::isa<mlir::UnrankedTensorType>(resultType)) return mlir::success(); @@ -657,8 +658,8 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Type type, mlir::Location loc) { if (llvm::isa<StructType>(type)) - return builder.create<StructConstantOp>(loc, type, - llvm::cast<mlir::ArrayAttr>(value)); - return builder.create<ConstantOp>(loc, type, - llvm::cast<mlir::DenseElementsAttr>(value)); + return StructConstantOp::create(builder, loc, type, + llvm::cast<mlir::ArrayAttr>(value)); + return ConstantOp::create(builder, loc, type, + llvm::cast<mlir::DenseElementsAttr>(value)); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index bf2bc43..cbe4236 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create<memref::AllocOp>(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, + // Call the processing function with the rewriter // and the loop induction variables. This function will return the value // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc, - ivs); + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create<affine::AffineLoadOp>( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create<LoweredBinaryOp>(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { if (!valueShape.empty()) { for (auto i : llvm::seq<int64_t>(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create<affine::AffineStoreOp>( - loc, rewriter.create<arith::ConstantOp>(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { } // Create a new non-toy function, with the same region. - auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return builder.create<affine::AffineLoadOp>(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 54eeb27..8b48a8f 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -55,19 +55,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns +// ToyToLLVM Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public ConversionPattern { +class PrintOpLowering : public OpConversionPattern<toy::PrintOp> { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern<toy::PrintOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin())); @@ -86,12 +85,12 @@ public: // Create a loop for each of the dimensions within the shape. SmallVector<Value, 4> loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = - rewriter.create<arith::ConstantIndexOp>(loc, memRefShape[i]); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(rewriter, loc, memRefShape[i]); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loop = - rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); for (Operation &nested : make_early_inc_range(*loop.getBody())) rewriter.eraseOp(&nested); loopIvs.push_back(loop.getInductionVar()); @@ -101,19 +100,17 @@ public: // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create<LLVM::CallOp>(loc, getPrintfType(context), printfRef, - newLineCst); - rewriter.create<scf::YieldOp>(loc); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + newLineCst); + scf::YieldOp::create(rewriter, loc); rewriter.setInsertionPointToStart(loop.getBody()); } // Generate a call to printf for the current element of the loop. - auto printOp = cast<toy::PrintOp>(op); auto elementLoad = - rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs); - rewriter.create<LLVM::CallOp>( - loc, getPrintfType(context), printfRef, - ArrayRef<Value>({formatSpecifierCst, elementLoad})); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + ArrayRef<Value>({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -142,8 +139,8 @@ private: // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", - getPrintfType(context)); + LLVM::LLVMFuncOp::create(rewriter, module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -159,19 +156,19 @@ private: builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); + global = LLVM::GlobalOp::create(builder, loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); } // Get the pointer to the first character in the global string. - Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); - Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getIndexAttr(0)); - return builder.create<LLVM::GEPOp>( - loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), - globalPtr, ArrayRef<Value>({cst0, cst0})); + Value globalPtr = LLVM::AddressOfOp::create(builder, loc, global); + Value cst0 = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return LLVM::GEPOp::create( + builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), + global.getType(), globalPtr, ArrayRef<Value>({cst0, cst0})); } }; } // namespace diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 2490f17..75dbc91 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -183,8 +183,8 @@ private: argTypes.push_back(type); } auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return builder.create<mlir::toy::FuncOp>(location, proto.getName(), - funcType); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -227,7 +227,7 @@ private: if (!entryBlock.empty()) returnOp = dyn_cast<ReturnOp>(entryBlock.back()); if (!returnOp) { - builder.create<ReturnOp>(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -333,7 +333,7 @@ private: emitError(location, "invalid access into struct expression"); return nullptr; } - return builder.create<StructAccessOp>(location, lhs, *accessIndex); + return StructAccessOp::create(builder, location, lhs, *accessIndex); } // Otherwise, this is a normal binary op. @@ -345,9 +345,9 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create<AddOp>(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create<MulOp>(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -378,8 +378,8 @@ private: } // Otherwise, this return operation has zero operands. - builder.create<ReturnOp>(location, - expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef<mlir::Value>()); return mlir::success(); } @@ -464,7 +464,7 @@ private: // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Emit a struct literal. It will be emitted as an array of @@ -477,7 +477,8 @@ private: // Build the MLIR op `toy.struct_constant`. This invokes the // `StructConstantOp::build` method. - return builder.create<StructConstantOp>(loc(lit.loc()), dataType, dataAttr); + return StructConstantOp::create(builder, loc(lit.loc()), dataType, + dataAttr); } /// Recursive helper function to accumulate the data that compose an array @@ -522,7 +523,7 @@ private: "does not accept multiple arguments"); return nullptr; } - return builder.create<TransposeOp>(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to @@ -534,8 +535,9 @@ private: return nullptr; } mlir::toy::FuncOp calledFunc = calledFuncIt->second; - return builder.create<GenericCallOp>( - location, calledFunc.getFunctionType().getResult(0), callee, operands); + return GenericCallOp::create(builder, location, + calledFunc.getFunctionType().getResult(0), + callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -545,13 +547,13 @@ private: if (!arg) return mlir::failure(); - builder.create<PrintOp>(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create<ConstantOp>(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -613,8 +615,8 @@ private: // declared with specific shape, we emit a "reshape" operation. It will // get optimized out later as needed. } else if (!varType.shape.empty()) { - value = builder.create<ReshapeOp>(loc(vardecl.loc()), - getType(varType.shape), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(varType.shape), value); } // Register the value in the symbol table. diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index dd86265..32208ecca 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" diff --git a/mlir/examples/transform-opt/CMakeLists.txt b/mlir/examples/transform-opt/CMakeLists.txt index 8e23555..07d58f6 100644 --- a/mlir/examples/transform-opt/CMakeLists.txt +++ b/mlir/examples/transform-opt/CMakeLists.txt @@ -1,18 +1,14 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) - set(LIBS MLIRAnalysis MLIRIR MLIRParser + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSupport MLIRTransformDialect MLIRTransformDialectTransforms MLIRTransforms - ${dialect_libs} - ${conversion_libs} - ${extension_libs} ) add_mlir_tool(mlir-transform-opt diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp index 1a29913..4b12e76 100644 --- a/mlir/examples/transform-opt/mlir-transform-opt.cpp +++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp @@ -22,6 +22,7 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include <cstdlib> diff --git a/mlir/examples/transform/Ch3/include/MyExtension.td b/mlir/examples/transform/Ch3/include/MyExtension.td index 5a78186..49874a7 100644 --- a/mlir/examples/transform/Ch3/include/MyExtension.td +++ b/mlir/examples/transform/Ch3/include/MyExtension.td @@ -46,9 +46,9 @@ def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target", // We use a string attribute as the symbol may not exist in the transform IR so the // verification may fail. let arguments = (ins - // Specify the type constraint on the input accepting only `func.call` payload - // operations. - Transform_ConcreteOpType<"func.call">:$call, + // Allow the handle to be to concrete func.call ops as well as any op implementing + // the CallOpInterface. + AnyTypeOf<[Transform_ConcreteOpType<"func.call">, CallOpInterfaceHandle]>:$call, StrAttr:$new_target); // The results are empty as the transformation does not produce any new payload. diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index fa0ffc9..2159483 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -13,11 +13,9 @@ #include "MyExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" -#define DEBUG_TYPE_MATCHER "transform-matcher" -#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") -#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) +#define DEBUG_TYPE "transform-matcher" #define GET_OP_CLASSES #include "MyExtension.cpp.inc" @@ -124,9 +122,8 @@ mlir::transform::HasOperandSatisfyingOp::apply( // Report failure-to-match for debugging purposes and stop matching this // operand. assert(diag.isSilenceableFailure()); - DEBUG_MATCHER(DBGS_MATCHER() - << "failed to match operand #" << operand.getOperandNumber() - << ": " << diag.getMessage()); + LDBG() << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage(); (void)diag.silence(); matchSucceeded = false; break; diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 364a70c..b595b6a3 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -8,6 +8,11 @@ #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H +constexpr const char *alignedAllocFunctionName = "aligned_alloc"; +constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *cppStandardLibraryHeader = "cstdlib"; +constexpr const char *cStandardLibraryHeader = "stdlib.h"; + namespace mlir { class DialectRegistry; class RewritePatternSet; diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index eb18160..cf7596c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { // MemRefToEmitC //===----------------------------------------------------------------------===// -def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { +def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> { let summary = "Convert MemRef dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "lowerToCpp", "lower-to-cpp", "bool", + /*default=*/"false", + /*description=*/"Target C++ (true) instead of C (false)">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 2cf801d..09700f8 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -14,7 +14,7 @@ struct LogicalResult; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; namespace bufferization { struct BufferizationStatistics; @@ -23,12 +23,13 @@ struct OneShotBufferizationOptions; class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in -/// `state`. +/// `state`. This operates on any `SymbolTable` op. llvm::LogicalResult -analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, +analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics = nullptr); -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`. +/// This operates on any `SymbolTable` op. /// /// Note: This function does not run One-Shot Analysis. No buffer copies are /// inserted except two cases: @@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter` /// is not empty. The FuncOps it contains were not analyzed. Buffer copies /// will be inserted only to these FuncOps. -llvm::LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, - BufferizationStatistics *statistics = nullptr); +llvm::LogicalResult bufferizeModuleOp( + Operation *moduleOp, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); -/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. -void removeBufferizationAttributesInModule(ModuleOp moduleOp); +/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable +/// op. +void removeBufferizationAttributesInModule(Operation *moduleOp); -/// Run One-Shot Module Bufferization on the given module. Performs a simple -/// function call analysis to determine which function arguments are +/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a +/// simple function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot /// Bufferize. llvm::LogicalResult runOneShotModuleBufferize( - ModuleOp moduleOp, + Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics = nullptr); diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td index 74c4913..1893c10 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -51,11 +51,6 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { ``` }]; let dependentDialects = ["emitc::EmitCDialect"]; - let options = [Option< - "namedAttribute", "named-attribute", "std::string", - /*default=*/"", - "Attribute key used to extract field names from function argument's " - "dictionary attributes">]; } #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h index a4e8fe1..bdf6d09 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -29,8 +29,7 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder); void populateExpressionPatterns(RewritePatternSet &patterns); /// Populates 'patterns' with func-related patterns. -void populateFuncPatterns(RewritePatternSet &patterns, - StringRef namedAttribute); +void populateFuncPatterns(RewritePatternSet &patterns); } // namespace emitc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 1dbaf5d..2ed7d38 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1368,12 +1368,14 @@ def GPU_ShuffleOp : GPU_Op< def GPU_RotateOp : GPU_Op< "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>, - Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, + ConfinedAttr<I32Attr, [IntMinValue<0>]>:$offset, + ConfinedAttr<I32Attr, [IntPowerOf2]>:$width)>, Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> { let summary = "Rotate values within a subgroup."; let description = [{ The "rotate" op moves values across lanes in a subgroup (a.k.a., local - invocations) within the same subgroup. The `width` argument specifies the + invocations) within the same subgroup. The `width` attribute specifies the number of lanes that participate in the rotation, and must be uniform across all participating lanes. Further, the first `width` lanes of the subgroup must be active. @@ -1394,9 +1396,7 @@ def GPU_RotateOp : GPU_Op< example: ```mlir - %offset = arith.constant 1 : i32 - %width = arith.constant 16 : i32 - %1, %2 = gpu.rotate %0, %offset, %width : f32 + %1, %2 = gpu.rotate %0, 1, 16 : f32 ``` For lane `k`, returns the value from lane `(k + cst1) % width`. @@ -1406,11 +1406,6 @@ def GPU_RotateOp : GPU_Op< $value `,` $offset `,` $width attr-dict `:` type($value) }]; - let builders = [ - // Helper function that creates a rotate with constant offset/width. - OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)> - ]; - let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index e355bb8..f3bd5c0 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -32,11 +32,6 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/PointerEmbeddedInt.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" namespace llvm { class Type; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index caba614..8c6f1ee 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -555,8 +555,6 @@ def LLVM_AssumeOp let builders = [ OpBuilder<(ins "Value":$cond)>, - OpBuilder<(ins "Value":$cond, - "ArrayRef<llvm::OperandBundleDefT<Value>>":$opBundles)>, OpBuilder<(ins "Value":$cond, "llvm::StringRef":$tag, "ValueRange":$args)>, OpBuilder<(ins "Value":$cond, "AssumeAlignTag":$tag, "Value":$ptr, "Value":$align)>, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index fa57202..f36b41c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ result tensor in the order in which they appear, i.e. `shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`. - The following relationship for the tiled dimensions holds: - `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`. + `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`, + where (⌈/⌉ indicates CeilDiv). + Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of `...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled @@ -150,9 +152,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ `padding_value` specifies a padding value at the boundary on non-perfectly divisible dimensions. Padding is optional: - - If absent, it is UB if the tile does not perfectly divide the dimension. + - If absent, it is assumed that for all inner tiles, + `shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner + tiles divide perfectly the corresponding outer dimension in the result + tensor. It is UB if the tile does not perfectly divide the dimension. - If present, it will pad along high dimensions (high-padding) to make the - tile complete. + tile complete. Note that it is not allowed to have artificial padding that + is not strictly required by linalg.pack (i.e., padding past what is needed + to complete the last tile along each packed dimension). It is UB if extra + padding is requested. + It is not possible to verify the requirements statically with dynamic + shapes, so they are treated as UB. Example: ```mlir @@ -167,6 +177,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // // Note: Only tiled dimensions can be padded. ``` + + Invalid example that has artificial padding: + ```mlir + %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %dest + : tensor<9xf32> -> tensor<3x8xf32> + // \ + // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2 + ``` }]; let arguments = (ins AnyRankedTensor:$source, AnyRankedTensor:$dest, diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h index 5430fd9..c0c6085 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h @@ -119,6 +119,12 @@ public: Status status; }; +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CopyMappingInfo &info) { + info.print(os); + return os; +} + } // namespace gpu } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8d45c40..61ce23f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac iteration domain induces a padding of the operands that is consistent across the op semantics and, unlike for simple elementwise ops, may not be trivially deducible or specifiable on operands only (e.g. convolutions). + Currently, only a limited set of projected permutation maps are supported. The specification of `padding_sizes` follows that of `tile_sizes` during tiling: the value "0" on a particular iterator encode "no padding". Like in diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e625eef..d4ffe0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// affine.apply operations. /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and /// provides a gentle portability path for Linalg-like ops with affine maps. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permuation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 96b9adc..e1e99c3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr< let printBitEnumPrimaryGroups = 1; } +// These are parallelism determination modes for `acc loop`. +// In the enum names, we use the "loop_" prefix because "auto" is +// a language keyword - and thus for consistency all other cases +// do the same. +def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>; +def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>; +def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>; + +def OpenACC_LoopParMode : I32EnumAttr< + "LoopParMode", + "Encodes the options for loop parallelism determination mode", + [ + OpenACC_LoopAuto, OpenACC_LoopIndependent, + OpenACC_LoopSeq]> { + let cppNamespace = "::mlir::acc"; + let genSpecializedAttr = 0; +} + // Type used in operation below. def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; @@ -2373,6 +2391,11 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", // Return whether this LoopOp has a gang, worker, or vector applying to the // 'default'/None device-type. bool hasDefaultGangWorkerVector(); + + // Used to obtain the parallelism mode for the requested device type. + // This first checks if the mode is set for the device_type requested. + // And if not, it returns the non-device_type mode. + LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType); }]; let hasCustomAssemblyFormat = 1; @@ -2404,6 +2427,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", }]; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "::mlir::ValueRange":$lowerbounds, + "::mlir::ValueRange":$upperbounds, + "::mlir::ValueRange":$steps, + "LoopParMode":$parMode), [{ + auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get( + $_builder.getContext(), mlir::acc::DeviceType::None); + auto arrOfDeviceNone = mlir::ArrayAttr::get( + $_builder.getContext(), deviceNoneAttr); + build($_builder, $_state, + /*results=*/{}, + /*lowerbound=*/lowerbounds, + /*upperbound=*/upperbounds, + /*step=*/steps, + /*inclusiveUpperbound=*/nullptr, + /*collapse=*/nullptr, + /*collapseDeviceType=*/nullptr, + /*gangOperands=*/{}, + /*gangOperandsArgType=*/nullptr, + /*gangOperandsSegments=*/nullptr, + /*gangOperandsDeviceType=*/nullptr, + /*workerNumOperands=*/{}, + /*workerNumOperandsDeviceType=*/nullptr, + /*vectorOperands=*/{}, + /*vectorOperandsDeviceType=*/nullptr, + /*seq=*/parMode == LoopParMode::loop_seq ? + arrOfDeviceNone : nullptr, + /*independent=*/parMode == LoopParMode::loop_independent ? + arrOfDeviceNone : nullptr, + /*auto_=*/parMode == LoopParMode::loop_auto ? + arrOfDeviceNone : nullptr, + /*gang=*/nullptr, + /*worker=*/nullptr, + /*vector=*/nullptr, + /*tileOperands=*/{}, + /*tileOperandsSegments=*/nullptr, + /*tileOperandsDeviceType=*/nullptr, + /*cacheOperands=*/{}, + /*privateOperands=*/{}, + /*privatizationRecipes=*/nullptr, + /*reductionOperands=*/{}, + /*reductionRecipes=*/nullptr, + /*combined=*/nullptr); + }] + > + ]; } // Yield operation for the acc.loop and acc.parallel operations. diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td index 704d0b2..72ce4c6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPAttrDefs.td @@ -23,6 +23,21 @@ class OpenMP_Attr<string name, string attrMnemonic, list<Trait> traits = [], } //===----------------------------------------------------------------------===// +// AtomicControlAttr +//===----------------------------------------------------------------------===// + +// Atomic control attributes hold information about architectural +// characteristics which are required for lowering atomic operations. +def AtomicControlAttr : OpenMP_Attr<"AtomicControl", "atomic_control"> { + let parameters = + (ins DefaultValuedParameter<"bool", "false">:$ignore_denormal_mode, + DefaultValuedParameter<"bool", "false">:$fine_grained_memory, + DefaultValuedParameter<"bool", "false">:$remote_memory); + + let assemblyFormat = "`<` struct(params) `>`"; +} + +//===----------------------------------------------------------------------===// // DeclareTargetAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8cf18b4..be114ea 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1750,9 +1750,11 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", traits = [ operations. }] # clausesDescription; - let arguments = !con((ins Arg<OpenMP_PointerLikeType, - "Address of variable to be updated", - [MemRead, MemWrite]>:$x), clausesArgs); + let arguments = !con( + (ins Arg<OpenMP_PointerLikeType, + "Address of variable to be updated", [MemRead, MemWrite]>:$x, + OptionalAttr<AtomicControlAttr>:$atomic_control), + clausesArgs); // Override region definition. let regions = (region SizedRegion<1>:$region); diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index a534381b..2513e10 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -380,7 +380,7 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> { After: %3 = memref.load %2[] : memref<f32> - %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32> + %4 = vector.insert %3, %cst [0] : f32 into vector<32xf32> %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) { %8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32> %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32> diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 349e8ed..754640d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -151,7 +151,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> : def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>; -def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; +def Tosa_ScalarInt8Tensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int8]>, TosaScalarTensorOf<[Tosa_Int8], [1]>]>; def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>; // We include unranked tensors as a supported type for all possible tosa diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 0a5c1e5..3885439 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -646,55 +646,6 @@ def Vector_DeinterleaveOp : }]; } -def Vector_ExtractElementOp : - Vector_Op<"extractelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"result type matches element type of vector operand", - "vector", "result", - "::llvm::cast<VectorType>($_self).getElementType()">]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyType:$result)> { - let summary = "extractelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.extract insert. - - Takes a 0-D or 1-D vector and a optional dynamic index position and - extracts the scalar at that position. - - Note that this instruction resembles vector.extract, but is restricted to - 0-D and 1-D vectors. - If the vector is 0-D, the position must be std::nullopt. - - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#extractelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %1 = vector.extractelement %0[%c : i32]: vector<16xf32> - %2 = vector.extractelement %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source)>, - ]; - let extraClassDeclaration = [{ - VectorType getSourceVectorType() { - return ::llvm::cast<VectorType>(getVector().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_ExtractOp : Vector_Op<"extract", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, @@ -890,57 +841,6 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [ let hasCanonicalizer = 1; } -def Vector_InsertElementOp : - Vector_Op<"insertelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"source operand type matches element type of result", - "result", "source", - "::llvm::cast<VectorType>($_self).getElementType()">, - AllTypesMatch<["dest", "result"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyVectorOfAnyRank:$result)> { - let summary = "insertelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.insert instead. - - Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index - position and inserts the source into the destination at the proper position. - - Note that this instruction resembles vector.insert, but is restricted to 0-D - and 1-D vectors. - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#insertelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %f = arith.constant 0.0f : f32 - %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> - %2 = vector.insertelement %f, %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:` - type($result) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source, "Value":$dest)>, - ]; - let extraClassDeclaration = [{ - Type getSourceType() { return getSource().getType(); } - VectorType getDestVectorType() { - return ::llvm::cast<VectorType>(getDest().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_InsertOp : Vector_Op<"insert", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 73f6877..38c217f 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -397,8 +397,8 @@ def DotOp : AVX_LowOp<"dot", [Pure, ```mlir %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> + %1 = vector.extract %0[%i0] : f32 from vector<8xf32> + %2 = vector.extract %0[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 ``` }]; diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 4ed0423..7ff718a 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -639,6 +639,10 @@ public: /// verified correctly, failure otherwise. LogicalResult verify(); + /// Register this handler with the given context. This is intended for use + /// with the splitAndProcessBuffer function. + void registerInContext(MLIRContext *ctx); + private: /// Process a single diagnostic. void process(Diagnostic &diag); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index fa8a487..edc8ab4 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -1102,6 +1102,29 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) { return os; } +/// A wrapper class that allows for printing an operation with a set of flags, +/// useful to act as a "stream modifier" to customize printing an operation +/// with a stream using the operator<< overload, e.g.: +/// llvm::dbgs() << OpWithFlags(op, OpPrintingFlags().skipRegions()); +class OpWithFlags { +public: + OpWithFlags(Operation *op, OpPrintingFlags flags = {}) + : op(op), theFlags(flags) {} + OpPrintingFlags &flags() { return theFlags; } + const OpPrintingFlags &flags() const { return theFlags; } + +private: + Operation *op; + OpPrintingFlags theFlags; + friend raw_ostream &operator<<(raw_ostream &os, const OpWithFlags &op); +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const OpWithFlags &opWithFlags) { + opWithFlags.op->print(os, opWithFlags.flags()); + return os; +} + } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b3608b4..b5a93a0 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -525,6 +525,11 @@ public: } /// This method erases an operation that is known to have no uses. + /// + /// If the current insertion point is before the erased operation, it is + /// adjusted to the following operation (or the end of the block). If the + /// current insertion point is within the erased operation, the insertion + /// point is left in an invalid state. virtual void eraseOp(Operation *op); /// This method erases all operations in a block. @@ -539,6 +544,9 @@ public: /// somewhere in the middle (or beginning) of the dest block, the source block /// must have no successors. Otherwise, the resulting IR would have /// unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues = {}); @@ -549,6 +557,9 @@ public: /// /// The source block must have no successors. Otherwise, the resulting IR /// would have unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void inlineBlockBefore(Block *source, Operation *op, ValueRange argValues = {}); @@ -558,6 +569,9 @@ public: /// /// The dest block must have no successors. Otherwise, the resulting IR would /// have unreachable operation. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {}); /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 2162a74..8959dab 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -200,7 +200,7 @@ public: // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get<ConcreteT>(ctx, args...); + return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...); } /// Get an instance of the concrete type from a void pointer. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0..bbfa308 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Returns true if this symbol has nested visibility.", "bool", "isNested", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Nested; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Private; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private; }] >, InterfaceMethod<"Returns true if this symbol has public visibility.", "bool", "isPublic", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<"Sets the visibility of this symbol.", @@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Sets the visibility of this symbol to be nested.", "void", "setNested", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Nested); + $_op.setVisibility(mlir::SymbolTable::Visibility::Nested); }] >, InterfaceMethod<"Sets the visibility of this symbol to be private.", "void", "setPrivate", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Private); + $_op.setVisibility(mlir::SymbolTable::Visibility::Private); }] >, InterfaceMethod<"Sets the visibility of this symbol to be public.", "void", "setPublic", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Public); + $_op.setVisibility(mlir::SymbolTable::Visibility::Public); }] >, InterfaceMethod<[{ @@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { // By default, base this on the visibility alone. A symbol can be // discarded as long as it is not public. Only public symbols may be // visible from outside of the IR. - return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<[{ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 856170e..7628171 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,200 +14,15 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" -#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/IRDL/IR/IRDL.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/LLVMIR/XeVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/MPI/IR/MPI.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" -#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/SMT/IR/SMTDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Shard/IR/ShardDialect.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/Target/LLVM/ROCDL/Target.h" -#include "mlir/Target/SPIRV/Target.h" - namespace mlir { +class DialectRegistry; +class MLIRContext; /// Add all the MLIR dialects to the provided registry. -inline void registerAllDialects(DialectRegistry ®istry) { - // clang-format off - registry.insert<acc::OpenACCDialect, - affine::AffineDialect, - amdgpu::AMDGPUDialect, - amx::AMXDialect, - arith::ArithDialect, - arm_neon::ArmNeonDialect, - arm_sme::ArmSMEDialect, - arm_sve::ArmSVEDialect, - async::AsyncDialect, - bufferization::BufferizationDialect, - cf::ControlFlowDialect, - complex::ComplexDialect, - DLTIDialect, - emitc::EmitCDialect, - func::FuncDialect, - gpu::GPUDialect, - index::IndexDialect, - irdl::IRDLDialect, - linalg::LinalgDialect, - LLVM::LLVMDialect, - math::MathDialect, - memref::MemRefDialect, - shard::ShardDialect, - ml_program::MLProgramDialect, - mpi::MPIDialect, - nvgpu::NVGPUDialect, - NVVM::NVVMDialect, - omp::OpenMPDialect, - pdl::PDLDialect, - pdl_interp::PDLInterpDialect, - ptr::PtrDialect, - quant::QuantDialect, - ROCDL::ROCDLDialect, - scf::SCFDialect, - shape::ShapeDialect, - smt::SMTDialect, - sparse_tensor::SparseTensorDialect, - spirv::SPIRVDialect, - tensor::TensorDialect, - tosa::TosaDialect, - transform::TransformDialect, - ub::UBDialect, - vector::VectorDialect, - x86vector::X86VectorDialect, - xegpu::XeGPUDialect, - xevm::XeVMDialect>(); - // clang-format on - - // Register all external models. - affine::registerValueBoundsOpInterfaceExternalModels(registry); - arith::registerBufferDeallocationOpInterfaceExternalModels(registry); - arith::registerBufferizableOpInterfaceExternalModels(registry); - arith::registerBufferViewFlowOpInterfaceExternalModels(registry); - arith::registerShardingInterfaceExternalModels(registry); - arith::registerValueBoundsOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - registry); - builtin::registerCastOpInterfaceExternalModels(registry); - cf::registerBufferizableOpInterfaceExternalModels(registry); - cf::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerValueBoundsOpInterfaceExternalModels(registry); - LLVM::registerInlinerInterface(registry); - NVVM::registerInlinerInterface(registry); - linalg::registerAllDialectInterfaceImplementations(registry); - linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerAllocationOpInterfaceExternalModels(registry); - memref::registerBufferViewFlowOpInterfaceExternalModels(registry); - memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerValueBoundsOpInterfaceExternalModels(registry); - memref::registerMemorySlotExternalModels(registry); - ml_program::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerBufferDeallocationOpInterfaceExternalModels(registry); - scf::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerValueBoundsOpInterfaceExternalModels(registry); - shape::registerBufferizableOpInterfaceExternalModels(registry); - sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); - tensor::registerInferTypeOpInterfaceExternalModels(registry); - tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - tensor::registerSubsetOpInterfaceExternalModels(registry); - tensor::registerTilingInterfaceExternalModels(registry); - tensor::registerValueBoundsOpInterfaceExternalModels(registry); - tosa::registerShardingInterfaceExternalModels(registry); - vector::registerBufferizableOpInterfaceExternalModels(registry); - vector::registerSubsetOpInterfaceExternalModels(registry); - vector::registerValueBoundsOpInterfaceExternalModels(registry); - NVVM::registerNVVMTargetInterfaceExternalModels(registry); - ROCDL::registerROCDLTargetInterfaceExternalModels(registry); - spirv::registerSPIRVTargetInterfaceExternalModels(registry); -} +void registerAllDialects(DialectRegistry ®istry); /// Append all the MLIR dialects to the registry contained in the given context. -inline void registerAllDialects(MLIRContext &context) { - DialectRegistry registry; - registerAllDialects(registry); - context.appendDialectRegistry(registry); -} +void registerAllDialects(MLIRContext &context); } // namespace mlir diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c..a7f64d9 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,110 +14,15 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ -#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" -#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" -#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" -#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" -#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" -#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" -#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" -#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" -#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" -#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" -#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" -#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" -#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" -#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" -#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" - -#include <cstdlib> - namespace mlir { +class DialectRegistry; /// This function may be called to register all MLIR dialect extensions with the /// provided registry. /// If you're building a compiler, you generally shouldn't use this: you would /// individually register the specific extensions that are useful for the /// pipelines and transformations you are using. -inline void registerAllExtensions(DialectRegistry ®istry) { - // Register all conversions to LLVM extensions. - registerConvertArithToEmitCInterface(registry); - arith::registerConvertArithToLLVMInterface(registry); - registerConvertComplexToLLVMInterface(registry); - cf::registerConvertControlFlowToLLVMInterface(registry); - func::registerAllExtensions(registry); - tensor::registerAllExtensions(registry); - registerConvertFuncToEmitCInterface(registry); - registerConvertFuncToLLVMInterface(registry); - index::registerConvertIndexToLLVMInterface(registry); - registerConvertMathToLLVMInterface(registry); - mpi::registerConvertMPIToLLVMInterface(registry); - registerConvertMemRefToEmitCInterface(registry); - registerConvertMemRefToLLVMInterface(registry); - registerConvertNVVMToLLVMInterface(registry); - registerConvertOpenMPToLLVMInterface(registry); - registerConvertSCFToEmitCInterface(registry); - ub::registerConvertUBToLLVMInterface(registry); - registerConvertAMXToLLVMInterface(registry); - gpu::registerConvertGpuToLLVMInterface(registry); - NVVM::registerConvertGpuToNVVMInterface(registry); - vector::registerConvertVectorToLLVMInterface(registry); - registerConvertXeVMToLLVMInterface(registry); - - // Register all transform dialect extensions. - affine::registerTransformDialectExtension(registry); - bufferization::registerTransformDialectExtension(registry); - dlti::registerTransformDialectExtension(registry); - func::registerTransformDialectExtension(registry); - gpu::registerTransformDialectExtension(registry); - linalg::registerTransformDialectExtension(registry); - memref::registerTransformDialectExtension(registry); - nvgpu::registerTransformDialectExtension(registry); - scf::registerTransformDialectExtension(registry); - sparse_tensor::registerTransformDialectExtension(registry); - tensor::registerTransformDialectExtension(registry); - transform::registerDebugExtension(registry); - transform::registerIRDLExtension(registry); - transform::registerLoopExtension(registry); - transform::registerPDLExtension(registry); - transform::registerTuneExtension(registry); - vector::registerTransformDialectExtension(registry); - arm_neon::registerTransformDialectExtension(registry); - arm_sve::registerTransformDialectExtension(registry); - - // Translation extensions need to be registered by calling - // `registerAllToLLVMIRTranslations` (see All.h). -} +void registerAllExtensions(DialectRegistry ®istry); } // namespace mlir diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 002ff61..4554290 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -1,4 +1,4 @@ -//===- LinkAllPassesAndDialects.h - MLIR Registration -----------*- C++ -*-===// +//===- InitAllPasses.h - MLIR Registration ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,50 +6,14 @@ // //===----------------------------------------------------------------------===// // -// This file defines a helper to trigger the registration of all dialects and -// passes to the system. +// This file defines a helper to trigger the registration of all passes to the +// system. // //===----------------------------------------------------------------------===// #ifndef MLIR_INITALLPASSES_H_ #define MLIR_INITALLPASSES_H_ -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/Dialect/ArmSME/Transforms/Passes.h" -#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/EmitC/Transforms/Passes.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/GPU/Pipelines/Passes.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MLProgram/Transforms/Passes.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/NVGPU/Transforms/Passes.h" -#include "mlir/Dialect/OpenACC/Transforms/Passes.h" -#include "mlir/Dialect/Quant/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Dialect/SPIRV/Transforms/Passes.h" -#include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Dialect/Shard/Transforms/Passes.h" -#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Transform/Transforms/Passes.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" -#include "mlir/Dialect/XeGPU/Transforms/Passes.h" -#include "mlir/Transforms/Passes.h" - -#include <cstdlib> - namespace mlir { // This function may be called to register the MLIR passes with the @@ -59,49 +23,7 @@ namespace mlir { // registry, since it would already be calling the creation routine of the // individual passes. // The global registry is interesting to interact with the command-line tools. -inline void registerAllPasses() { - // General passes - registerTransformsPasses(); - - // Conversion passes - registerConversionPasses(); - - // Dialect passes - acc::registerOpenACCPasses(); - affine::registerAffinePasses(); - amdgpu::registerAMDGPUPasses(); - registerAsyncPasses(); - arith::registerArithPasses(); - bufferization::registerBufferizationPasses(); - func::registerFuncPasses(); - registerGPUPasses(); - registerLinalgPasses(); - registerNVGPUPasses(); - registerSparseTensorPasses(); - LLVM::registerLLVMPasses(); - math::registerMathPasses(); - memref::registerMemRefPasses(); - shard::registerShardPasses(); - ml_program::registerMLProgramPasses(); - quant::registerQuantPasses(); - registerSCFPasses(); - registerShapePasses(); - spirv::registerSPIRVPasses(); - tensor::registerTensorPasses(); - tosa::registerTosaOptPasses(); - transform::registerTransformPasses(); - vector::registerVectorPasses(); - arm_sme::registerArmSMEPasses(); - arm_sve::registerArmSVEPasses(); - emitc::registerEmitCPasses(); - xegpu::registerXeGPUPasses(); - - // Dialect pipelines - bufferization::registerBufferizationPipelines(); - sparse_tensor::registerSparseTensorPipelines(); - tosa::registerTosaToLinalgPipelines(); - gpu::registerGPUToNVVMPipeline(); -} +void registerAllPasses(); } // namespace mlir diff --git a/mlir/include/mlir/Support/ToolUtilities.h b/mlir/include/mlir/Support/ToolUtilities.h index cb6ba29..657f117 100644 --- a/mlir/include/mlir/Support/ToolUtilities.h +++ b/mlir/include/mlir/Support/ToolUtilities.h @@ -21,10 +21,16 @@ namespace llvm { class MemoryBuffer; +class MemoryBufferRef; } // namespace llvm namespace mlir { +// A function that processes a chunk of a buffer and writes the result to an +// output stream. using ChunkBufferHandler = function_ref<LogicalResult( + std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &sourceBuffer, raw_ostream &os)>; +using NoSourceChunkBufferHandler = function_ref<LogicalResult( std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os)>; extern inline const char *const kDefaultSplitMarker = "// -----"; @@ -45,6 +51,15 @@ splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker = kDefaultSplitMarker, llvm::StringRef outputSplitMarker = ""); + +/// Same as above, but for case where the original buffer is not used while +/// processing the chunk. +LogicalResult +splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, + llvm::StringRef inputSplitMarker = kDefaultSplitMarker, + llvm::StringRef outputSplitMarker = ""); } // namespace mlir #endif // MLIR_SUPPORT_TOOLUTILITIES_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index 60615cf6..e4670cb 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -28,6 +28,7 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); @@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry ®istry) { registerNVVMDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h new file mode 100644 index 0000000..b4f6750 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for XeVM dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the XeVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerXeVMDialectTranslation(mlir::DialectRegistry ®istry); + +/// Register the XeVM dialect and the translation from it in the registry +/// associated with the given context. +void registerXeVMDialectTranslation(mlir::MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index c484072..17ef8e4 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -19,6 +19,7 @@ #include "mlir/Target/LLVMIR/Import.h" #include "mlir/Target/LLVMIR/LLVMImportInterface.h" #include "mlir/Target/LLVMIR/TypeFromLLVM.h" +#include "llvm/IR/Module.h" namespace llvm { class BasicBlock; diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp index 51fa773..fb5649e 100644 --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #define DEBUG_TYPE "constant-propagation" @@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const { LogicalResult SparseConstantPropagation::visitOperation( Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, ArrayRef<Lattice<ConstantValue> *> results) { - LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + LDBG() << "SCP: Visiting operation: " << *op; // Don't try to simulate the results of a region operation as we can't // guarantee that folding will be out-of-place. We don't allow in-place @@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation( // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = std::get<1>(it); if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { - LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + LDBG() << "Folded to constant: " << attr; propagateIfChanged(lattice, lattice->join(ConstantValue(attr, op->getDialect()))); } else { - LLVM_DEBUG(llvm::dbgs() - << "Folded to value: " << cast<Value>(foldResult) << "\n"); + LDBG() << "Folded to value: " << cast<Value>(foldResult); AbstractSparseForwardDataFlowAnalysis::join( lattice, *getLatticeElement(cast<Value>(foldResult))); } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 1abdfcb..10874fd 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -23,12 +23,11 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> #define DEBUG_TYPE "dead-code-analysis" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::dataflow; @@ -127,7 +126,8 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) } LogicalResult DeadCodeAnalysis::initialize(Operation *top) { - LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName()); + LDBG() << "Initializing DeadCodeAnalysis for top-level op: " + << top->getName(); // Mark the top-level blocks as executable. for (Region ®ion : top->getRegions()) { if (region.empty()) @@ -135,7 +135,7 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked entry block live for region in op: " << top->getName()); + LDBG() << "Marked entry block live for region in op: " << top->getName(); } // Mark as overdefined the predecessors of symbol callables with potentially @@ -146,18 +146,18 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { } void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { - LDBG("[init] Entering initializeSymbolCallables for top-level op: " - << top->getName()); + LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " + << top->getName(); analysisScope = top; auto walkFn = [&](Operation *symTable, bool allUsesVisible) { - LDBG("[init] Processing symbol table op: " << symTable->getName()); + LDBG() << "[init] Processing symbol table op: " << symTable->getName(); Region &symbolTableRegion = symTable->getRegion(0); Block *symbolTableBlock = &symbolTableRegion.front(); bool foundSymbolCallable = false; for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { - LDBG("[init] Found CallableOpInterface: " - << callable.getOperation()->getName()); + LDBG() << "[init] Found CallableOpInterface: " + << callable.getOperation()->getName(); Region *callableRegion = callable.getCallableRegion(); if (!callableRegion) continue; @@ -171,8 +171,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Marked callable as having unknown predecessors: " - << callable.getOperation()->getName()); + LDBG() << "[init] Marked callable as having unknown predecessors: " + << callable.getOperation()->getName(); } foundSymbolCallable = true; } @@ -187,15 +187,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { if (!uses) { // If we couldn't gather the symbol uses, conservatively assume that // we can't track information for any nested symbols. - LDBG("[init] Could not gather symbol uses, conservatively marking " - "all nested callables as having unknown predecessors"); + LDBG() << "[init] Could not gather symbol uses, conservatively marking " + "all nested callables as having unknown predecessors"; return top->walk([&](CallableOpInterface callable) { auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Marked nested callable as " - "having unknown predecessors: " - << callable.getOperation()->getName()); + LDBG() << "[init] Marked nested callable as " + "having unknown predecessors: " + << callable.getOperation()->getName(); }); } @@ -209,15 +209,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { continue; auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Found non-call use for symbol, " - "marked as having unknown predecessors: " - << symbol->getName()); + LDBG() << "[init] Found non-call use for symbol, " + "marked as having unknown predecessors: " + << symbol->getName(); } }; SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), walkFn); - LDBG("[init] Finished initializeSymbolCallables for top-level op: " - << top->getName()); + LDBG() << "[init] Finished initializeSymbolCallables for top-level op: " + << top->getName(); } /// Returns true if the operation is a returning terminator in region @@ -229,14 +229,14 @@ static bool isRegionOrCallableReturn(Operation *op) { } LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { - LDBG("[init] Entering initializeRecursively for op: " << op->getName() - << " at " << op); + LDBG() << "[init] Entering initializeRecursively for op: " << op->getName() + << " at " << op; // Initialize the analysis by visiting every op with control-flow semantics. if (op->getNumRegions() || op->getNumSuccessors() || isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) { - LDBG("[init] Visiting op with control-flow semantics: " << *op); - // When the liveness of the parent block changes, make sure to re-invoke the - // analysis on the op. + LDBG() << "[init] Visiting op with control-flow semantics: " << *op; + // When the liveness of the parent block changes, make sure to + // re-invoke the analysis on the op. if (op->getBlock()) getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) ->blockContentSubscribe(this); @@ -246,21 +246,21 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { } // Recurse on nested operations. for (Region ®ion : op->getRegions()) { - LDBG("[init] Recursing into region of op: " << op->getName()); + LDBG() << "[init] Recursing into region of op: " << op->getName(); for (Operation &nestedOp : region.getOps()) { - LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at " - << &nestedOp); + LDBG() << "[init] Recursing into nested op: " << nestedOp.getName() + << " at " << &nestedOp; if (failed(initializeRecursively(&nestedOp))) return failure(); } } - LDBG("[init] Finished initializeRecursively for op: " << op->getName() - << " at " << op); + LDBG() << "[init] Finished initializeRecursively for op: " << op->getName() + << " at " << op; return success(); } void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { - LDBG("Marking edge live from block " << from << " to block " << to); + LDBG() << "Marking edge live from block " << from << " to block " << to; auto *state = getOrCreate<Executable>(getProgramPointBefore(to)); propagateIfChanged(state, state->setToLive()); auto *edgeState = @@ -269,35 +269,35 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { } void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { - LDBG("Marking entry blocks live for op: " << op->getName()); + LDBG() << "Marking entry blocks live for op: " << op->getName(); for (Region ®ion : op->getRegions()) { if (region.empty()) continue; auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked entry block live for region in op: " << op->getName()); + LDBG() << "Marked entry block live for region in op: " << op->getName(); } } LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { - LDBG("Visiting program point: " << point << " " << *point); + LDBG() << "Visiting program point: " << point << " " << *point; if (point->isBlockStart()) return success(); Operation *op = point->getPrevOp(); - LDBG("Visiting operation: " << *op); + LDBG() << "Visiting operation: " << *op; // If the parent block is not executable, there is nothing to do. if (op->getBlock() != nullptr && !getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) ->isLive()) { - LDBG("Parent block not live, skipping op: " << *op); + LDBG() << "Parent block not live, skipping op: " << *op; return success(); } // We have a live call op. Add this as a live predecessor of the callee. if (auto call = dyn_cast<CallOpInterface>(op)) { - LDBG("Visiting call operation: " << *op); + LDBG() << "Visiting call operation: " << *op; visitCallOperation(call); } @@ -305,12 +305,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumRegions()) { // Check if we can reason about the region control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - LDBG("Visiting region branch operation: " << *op); + LDBG() << "Visiting region branch operation: " << *op; visitRegionBranchOperation(branch); // Check if this is a callable operation. } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { - LDBG("Visiting callable operation: " << *op); + LDBG() << "Visiting callable operation: " << *op; const auto *callsites = getOrCreateFor<PredecessorState>( getProgramPointAfter(op), getProgramPointAfter(callable)); @@ -322,19 +322,19 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { // Otherwise, conservatively mark all entry blocks as executable. } else { - LDBG("Marking all entry blocks live for op: " << *op); + LDBG() << "Marking all entry blocks live for op: " << *op; markEntryBlocksLive(op); } } if (isRegionOrCallableReturn(op)) { if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { - LDBG("Visiting region terminator: " << *op); + LDBG() << "Visiting region terminator: " << *op; // Visit the exiting terminator of a region. visitRegionTerminator(op, branch); } else if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - LDBG("Visiting callable terminator: " << *op); + LDBG() << "Visiting callable terminator: " << *op; // Visit the exiting terminator of a callable. visitCallableTerminator(op, callable); } @@ -343,12 +343,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumSuccessors()) { // Check if we can reason about the control-flow. if (auto branch = dyn_cast<BranchOpInterface>(op)) { - LDBG("Visiting branch operation: " << *op); + LDBG() << "Visiting branch operation: " << *op; visitBranchOperation(branch); // Otherwise, conservatively mark all successors as exectuable. } else { - LDBG("Marking all successors live for op: " << *op); + LDBG() << "Marking all successors live for op: " << *op; for (Block *successor : op->getSuccessors()) markEdgeLive(op->getBlock(), successor); } @@ -358,7 +358,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - LDBG("visitCallOperation: " << call.getOperation()->getName()); + LDBG() << "visitCallOperation: " << call.getOperation()->getName(); Operation *callableOp = call.resolveCallableInTable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. @@ -381,15 +381,15 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { auto *callsites = getOrCreate<PredecessorState>(getProgramPointAfter(callableOp)); propagateIfChanged(callsites, callsites->join(call)); - LDBG("Added callsite as predecessor for callable: " - << callableOp->getName()); + LDBG() << "Added callsite as predecessor for callable: " + << callableOp->getName(); } else { // Mark this call op's predecessors as overdefined. auto *predecessors = getOrCreate<PredecessorState>(getProgramPointAfter(call)); propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); - LDBG("Marked call op's predecessors as unknown for: " - << call.getOperation()->getName()); + LDBG() << "Marked call op's predecessors as unknown for: " + << call.getOperation()->getName(); } } @@ -421,7 +421,7 @@ DeadCodeAnalysis::getOperandValues(Operation *op) { } void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { - LDBG("visitBranchOperation: " << branch.getOperation()->getName()); + LDBG() << "visitBranchOperation: " << branch.getOperation()->getName(); // Try to deduce a single successor for the branch. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -429,18 +429,18 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { if (Block *successor = branch.getSuccessorForOperands(*operands)) { markEdgeLive(branch->getBlock(), successor); - LDBG("Branch has single successor: " << successor); + LDBG() << "Branch has single successor: " << successor; } else { // Otherwise, mark all successors as executable and outgoing edges. for (Block *successor : branch->getSuccessors()) markEdgeLive(branch->getBlock(), successor); - LDBG("Branch has multiple/all successors live"); + LDBG() << "Branch has multiple/all successors live"; } } void DeadCodeAnalysis::visitRegionBranchOperation( RegionBranchOpInterface branch) { - LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName()); + LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName(); // Try to deduce which regions are executable. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -457,19 +457,19 @@ void DeadCodeAnalysis::visitRegionBranchOperation( // Mark the entry block as executable. auto *state = getOrCreate<Executable>(point); propagateIfChanged(state, state->setToLive()); - LDBG("Marked region successor live: " << point); + LDBG() << "Marked region successor live: " << point; // Add the parent op as a predecessor. auto *predecessors = getOrCreate<PredecessorState>(point); propagateIfChanged( predecessors, predecessors->join(branch, successor.getSuccessorInputs())); - LDBG("Added region branch as predecessor for successor: " << point); + LDBG() << "Added region branch as predecessor for successor: " << point; } } void DeadCodeAnalysis::visitRegionTerminator(Operation *op, RegionBranchOpInterface branch) { - LDBG("visitRegionTerminator: " << *op); + LDBG() << "visitRegionTerminator: " << *op; std::optional<SmallVector<Attribute>> operands = getOperandValues(op); if (!operands) return; @@ -488,7 +488,7 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion->front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked region entry block live for region: " << region); + LDBG() << "Marked region entry block live for region: " << region; predecessors = getOrCreate<PredecessorState>( getProgramPointBefore(®ion->front())); } else { @@ -498,14 +498,14 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, } propagateIfChanged(predecessors, predecessors->join(op, successor.getSuccessorInputs())); - LDBG("Added region terminator as predecessor for successor: " - << (successor.getSuccessor() ? "region entry" : "parent op")); + LDBG() << "Added region terminator as predecessor for successor: " + << (successor.getSuccessor() ? "region entry" : "parent op"); } } void DeadCodeAnalysis::visitCallableTerminator(Operation *op, CallableOpInterface callable) { - LDBG("visitCallableTerminator: " << *op); + LDBG() << "visitCallableTerminator: " << *op; // Add as predecessors to all callsites this return op. auto *callsites = getOrCreateFor<PredecessorState>( getProgramPointAfter(op), getProgramPointAfter(callable)); @@ -516,15 +516,15 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op, getOrCreate<PredecessorState>(getProgramPointAfter(predecessor)); if (canResolve) { propagateIfChanged(predecessors, predecessors->join(op)); - LDBG("Added callable terminator as predecessor for callsite: " - << predecessor->getName()); + LDBG() << "Added callable terminator as predecessor for callsite: " + << predecessor->getName(); } else { // If the terminator is not a return-like, then conservatively assume we // can't resolve the predecessor. propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); - LDBG("Could not resolve callable terminator for callsite: " - << predecessor->getName()); + LDBG() << "Could not resolve callable terminator for callsite: " + << predecessor->getName(); } } } diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 6a12fe3..509f520 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -10,7 +10,7 @@ #include <cassert> #include <mlir/Analysis/DataFlow/LivenessAnalysis.h> -#include <llvm/Support/Debug.h> +#include <llvm/Support/DebugLog.h> #include <mlir/Analysis/DataFlow/SparseAnalysis.h> #include <mlir/Analysis/DataFlow/Utils.h> #include <mlir/Analysis/DataFlowFramework.h> @@ -21,8 +21,6 @@ #include <mlir/Support/LLVM.h> #define DEBUG_TYPE "liveness-analysis" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::dataflow; @@ -81,16 +79,15 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) { LogicalResult LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands, ArrayRef<const Liveness *> results) { - LLVM_DEBUG(DBGS() << "[visitOperation] Enter: "; - op->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"); + LDBG() << "[visitOperation] Enter: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // This marks values of type (1.a) and (4) liveness as "live". if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) { - LDBG("[visitOperation] Operation has memory effects or is " - "return-like, marking operands live"); + LDBG() << "[visitOperation] Operation has memory effects or is " + "return-like, marking operands live"; for (auto *operand : operands) { - LDBG(" [visitOperation] Marking operand live: " - << operand << " (" << operand->isLive << ")"); + LDBG() << " [visitOperation] Marking operand live: " << operand << " (" + << operand->isLive << ")"; propagateIfChanged(operand, operand->markLive()); } } @@ -99,28 +96,28 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands, bool foundLiveResult = false; for (const Liveness *r : results) { if (r->isLive && !foundLiveResult) { - LDBG("[visitOperation] Found live result, " - "meeting all operands with result: " - << r); + LDBG() << "[visitOperation] Found live result, " + "meeting all operands with result: " + << r; // It is assumed that each operand is used to compute each result of an // op. Thus, if at least one result is live, each operand is live. for (Liveness *operand : operands) { - LDBG(" [visitOperation] Meeting operand: " << operand - << " with result: " << r); + LDBG() << " [visitOperation] Meeting operand: " << operand + << " with result: " << r; meet(operand, *r); } foundLiveResult = true; } - LDBG("[visitOperation] Adding dependency for result: " << r << " after op: " - << *op); + LDBG() << "[visitOperation] Adding dependency for result: " << r + << " after op: " << *op; addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op)); } return success(); } void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { - LDBG("Visiting branch operand: " << operand.get() - << " in op: " << *operand.getOwner()); + LDBG() << "Visiting branch operand: " << operand.get() + << " in op: " << *operand.getOwner(); // We know (at the moment) and assume (for the future) that `operand` is a // non-forwarded branch operand of a `RegionBranchOpInterface`, // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op. @@ -152,9 +149,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Value result : op->getResults()) { if (getLatticeElement(result)->isLive) { mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch " - "operand may be live due to live result: " - << result); + LDBG() << "[visitBranchOperand] Non-forwarded branch " + "operand may be live due to live result: " + << result; break; } } @@ -174,8 +171,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // Therefore, we conservatively consider the non-forwarded operand of the // branch operation may live. mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch operand may " - "be live due to branch op interface"); + LDBG() << "[visitBranchOperand] Non-forwarded branch operand may " + "be live due to branch op interface"; } else { Operation *parentOp = op->getParentOp(); assert(isa<RegionBranchOpInterface>(parentOp) && @@ -191,9 +188,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Value result : parentOp->getResults()) { if (getLatticeElement(result)->isLive) { mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch " - "operand may be live due to parent live result: " - << result); + LDBG() << "[visitBranchOperand] Non-forwarded branch " + "operand may be live due to parent live result: " + << result; break; } } @@ -214,9 +211,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Operation &nestedOp : *block) { if (!isMemoryEffectFree(&nestedOp)) { mayLive = true; - LDBG("Non-forwarded branch operand may be " - "live due to memory effect in block: " - << block); + LDBG() << "Non-forwarded branch operand may be " + "live due to memory effect in block: " + << block; break; } } @@ -224,7 +221,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { if (mayLive) { Liveness *operandLiveness = getLatticeElement(operand.get()); - LDBG("Marking branch operand live: " << operand.get()); + LDBG() << "Marking branch operand live: " << operand.get(); propagateIfChanged(operandLiveness, operandLiveness->markLive()); } @@ -236,7 +233,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { SmallVector<const Liveness *, 4> resultsLiveness; for (const Value result : op->getResults()) resultsLiveness.push_back(getLatticeElement(result)); - LDBG("Visiting operation for non-forwarded branch operand: " << *op); + LDBG() << "Visiting operation for non-forwarded branch operand: " << *op; (void)visitOperation(op, operandLiveness, resultsLiveness); // We also visit the parent op with the parent's results and this operand if @@ -249,14 +246,14 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { SmallVector<const Liveness *, 4> parentResultsLiveness; for (const Value parentResult : parentOp->getResults()) parentResultsLiveness.push_back(getLatticeElement(parentResult)); - LDBG("Visiting parent operation for non-forwarded branch operand: " - << *parentOp); + LDBG() << "Visiting parent operation for non-forwarded branch operand: " + << *parentOp; (void)visitOperation(parentOp, operandLiveness, parentResultsLiveness); } void LivenessAnalysis::visitCallOperand(OpOperand &operand) { - LDBG("Visiting call operand: " << operand.get() - << " in op: " << *operand.getOwner()); + LDBG() << "Visiting call operand: " << operand.get() + << " in op: " << *operand.getOwner(); // We know (at the moment) and assume (for the future) that `operand` is a // non-forwarded call operand of an op implementing `CallOpInterface`. assert(isa<CallOpInterface>(operand.getOwner()) && @@ -269,18 +266,18 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) { // This marks values of type (1.c) liveness as "live". A non-forwarded // call operand is live. Liveness *operandLiveness = getLatticeElement(operand.get()); - LDBG("Marking call operand live: " << operand.get()); + LDBG() << "Marking call operand live: " << operand.get(); propagateIfChanged(operandLiveness, operandLiveness->markLive()); } void LivenessAnalysis::setToExitState(Liveness *lattice) { - LDBG("setToExitState for lattice: " << lattice); + LDBG() << "setToExitState for lattice: " << lattice; if (lattice->isLive) { - LDBG("Lattice already live, nothing to do"); + LDBG() << "Lattice already live, nothing to do"; return; } // This marks values of type (2) liveness as "live". - LDBG("Marking lattice live due to exit state"); + LDBG() << "Marking lattice live due to exit state"; (void)lattice->markLive(); propagateIfChanged(lattice, ChangeResult::Change); } @@ -290,14 +287,14 @@ void LivenessAnalysis::setToExitState(Liveness *lattice) { //===----------------------------------------------------------------------===// RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { - LDBG("Constructing RunLivenessAnalysis for op: " << op->getName()); + LDBG() << "Constructing RunLivenessAnalysis for op: " << op->getName(); SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); solver.load<LivenessAnalysis>(symbolTable); - LDBG("Initializing and running solver"); + LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG("Dumping liveness state for op"); + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 176d53e..16f7033 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -14,7 +14,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/Config/abi-breaking.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "dataflow" @@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent, (void)inserted; DATAFLOW_DEBUG({ if (inserted) { - llvm::dbgs() << "Creating dependency between " << debugName << " of " - << anchor << "\nand " << debugName << " on " << dependent - << "\n"; + LDBG() << "Creating dependency between " << debugName << " of " << anchor + << "\nand " << debugName << " on " << dependent; } }); } @@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { // Initialize the analyses. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { - DATAFLOW_DEBUG(llvm::dbgs() - << "Priming analysis: " << analysis.debugName << "\n"); + DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName); if (failed(analysis.initialize(top))) return failure(); } @@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { auto [point, analysis] = worklist.front(); worklist.pop(); - DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName - << "' on: " << point << "\n"); + DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName + << "' on: " << point); if (failed(analysis->visit(point))) return failure(); } @@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state, assert(isRunning && "DataFlowSolver is not running, should not use propagateIfChanged"); if (changed == ChangeResult::Change) { - DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->anchor << "\n" - << "Value: " << *state << "\n"); + DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName + << " of " << state->anchor << "\n" + << "Value: " << *state); state->onUpdate(this); } } diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 9f4a87a..8b14e71 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, nestedPunctuation.pop_back(); return success(); }; + const char *curBufferEnd = state.lex.getBufferEnd(); do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, break; } + if (curBufferEnd == curPtr) { + if (!nestedPunctuation.empty()) + return emitPunctError(); + return emitError("unexpected nul or EOF in pretty dialect name"); + } + char c = *curPtr++; switch (c) { case '\0': diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp index 751bd63..8f53529 100644 --- a/mlir/lib/AsmParser/Lexer.cpp +++ b/mlir/lib/AsmParser/Lexer.cpp @@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, AsmParserCodeCompleteContext *codeCompleteContext) : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); + + // Check to see if the main buffer contains the last buffer, and if so the + // last buffer should be used as main file for parsing. + if (sourceMgr.getNumBuffers() > 1) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + if (main->getBufferStart() <= last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + bufferID = lastFileID; + } + } curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); @@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) { } Token Lexer::lexToken() { + const char *curBufferEnd = curBuffer.end(); while (true) { const char *tokStart = curPtr; @@ -78,6 +91,9 @@ Token Lexer::lexToken() { if (tokStart == codeCompleteLoc) return formToken(Token::code_complete, tokStart); + if (tokStart == curBufferEnd) + return formToken(Token::eof, tokStart); + // Lex the next token. switch (*curPtr++) { default: @@ -102,7 +118,7 @@ Token Lexer::lexToken() { case 0: // This may either be a nul character in the source file or may be the EOF // marker that llvm::MemoryBuffer guarantees will be there. - if (curPtr - 1 == curBuffer.end()) + if (curPtr - 1 == curBufferEnd) return formToken(Token::eof, tokStart); continue; @@ -259,7 +275,11 @@ void Lexer::skipComment() { assert(*curPtr == '/'); ++curPtr; + const char *curBufferEnd = curBuffer.end(); while (true) { + if (curPtr == curBufferEnd) + return; + switch (*curPtr++) { case '\n': case '\r': @@ -267,7 +287,7 @@ void Lexer::skipComment() { return; case 0: // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { + if (curPtr - 1 == curBufferEnd) { --curPtr; return; } @@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); + const char *curBufferEnd = curBuffer.end(); while (true) { // Check to see if there is a code completion location within the string. In // these cases we generate a completion location and place the currently @@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) { case 0: // If this is a random nul character in the middle of a string, just // include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) + if (curPtr - 1 != curBufferEnd) continue; [[fallthrough]]; case '\n': diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h index 4085a9b..670444e 100644 --- a/mlir/lib/AsmParser/Lexer.h +++ b/mlir/lib/AsmParser/Lexer.h @@ -40,6 +40,9 @@ public: /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Returns the end of the buffer. + const char *getBufferEnd() { return curBuffer.end(); } + /// Return the code completion location of the lexer, or nullptr if there is /// none. const char *getCodeCompleteLoc() const { return codeCompleteLoc; } diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 8b9a395..ccda668 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,19 +1,16 @@ # Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC - ${dialect_libs} ${translation_libs} - ${conversion_libs} - ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index d25c84a..191b5ab6 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -20,3 +20,37 @@ add_subdirectory(Target) add_subdirectory(Tools) add_subdirectory(Transforms) add_subdirectory(ExecutionEngine) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +add_mlir_library(MLIRRegisterAllDialects + RegisterAllDialects.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ) + +add_mlir_library(MLIRRegisterAllPasses + RegisterAllPasses.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} # Some passes are part of the dialect libs + ${conversion_libs} + ) + +add_mlir_library(MLIRRegisterAllExtensions + RegisterAllExtensions.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b6f6167..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -481,16 +481,16 @@ struct MemoryCounterWaitOpLowering if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional<int> ds = adaptor.getDs()) - rewriter.create<ROCDL::WaitDscntOp>(loc, *ds); + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional<int> load = adaptor.getLoad()) - rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load); + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional<int> store = adaptor.getStore()) - rewriter.create<ROCDL::WaitStorecntOp>(loc, *store); + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional<int> exp = adaptor.getExp()) - rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp); + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 59b3fe2..515fe5c 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -402,8 +402,8 @@ public: Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); // Actual cast (may change bitwidth) - auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), - castDestType, actualOp); + auto cast = + emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp); // Cast to the expected output type auto result = adaptValueType(cast, rewriter, opReturnType); @@ -507,8 +507,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -547,8 +547,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -748,8 +748,8 @@ public: } Value fpCastOperand = adaptor.getIn(); if (actualOperandType != operandType) { - fpCastOperand = rewriter.template create<emitc::CastOp>( - castOp.getLoc(), actualOperandType, fpCastOperand); + fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(), + actualOperandType, fpCastOperand); } rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 30a7170..3edcbb8 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { scf::YieldOp::create(rewriter, loc, acc); }; - auto size = rewriter - .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), - loopBody) + auto size = scf::ForOp::create(rewriter, loc, zero, rank, one, + ValueRange(one), loopBody) .getResult(0); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index c8311eb..5ac838c 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; - return builder - .create<func::ReturnOp>( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) + return func::ReturnOp::create( + builder, loc, + llvm::map_to_vector( + funcOp.getResultTypes(), + [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 63eb6c58..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); return LLVM::CallOp::create(builder, loc, function, arguments); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a19194e..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -559,8 +561,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index ecd5b63..2568044 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = - toDynamic - ? builder - .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 5b68eb8..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index b09afd9..855c582 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -698,7 +698,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { scf::IfOp ifOp = scf::IfOp::create(builder, elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); - ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); + auto thenBuilder = ifOp.getThenBodyBuilder(); + scf::YieldOp::create(thenBuilder, loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 53a1912..6ba5bfe4 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -575,8 +575,8 @@ private: Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); - return rewriter - .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) + return LLVM::LoadOp::create(rewriter, loc, + getTypeConverter()->getIndexType(), sizePtr) .getResult(); } diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 240491a..807be7e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); @@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index aae3271..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1493,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1505,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 8525543..fd40e7c 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -177,9 +177,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto type = RankedTensorType::get({nSplits, 2}, i64); Value resHaloSizes = haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) + ? tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64) .getResult() : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); @@ -306,13 +305,11 @@ public: auto ctx = op.getContext(); Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); + auto rank = mpi::CommRankOp::create( + rewriter, loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -703,10 +700,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); + sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + value) + .getResult(); } // most of the offset/size/stride data is the same for all dims @@ -758,9 +754,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex, - splitAxes) + auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid, + myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... Value neighbourIDs[2] = { diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 5c7c027..0e3de06 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -569,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // to UIToFP. if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { auto unrealizedCast = - rewriter - .create<UnrealizedConversionCastOp>( - loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), - args[0]) + UnrealizedConversionCastOp::create( + rewriter, loc, + rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); @@ -868,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op auto resultTensor = - opBuilder - .create<linalg::GenericOp>( - loc, outputTensor.getType(), operand, outputTensor, affineMaps, - getNParallelLoopsAttrs(rank), - [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { - // Emit 'linalg.yield' op - linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); - }) + linalg::GenericOp::create( + opBuilder, loc, outputTensor.getType(), operand, outputTensor, + affineMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + // Emit 'linalg.yield' op + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); + }) .getResult(0); // Cast to original operand type if necessary @@ -1155,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), - dynDims) - .getResult(); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -1167,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, op, "No initial value found for reduction operation"); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; @@ -1186,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto trueAttr = rewriter.getBoolAttr(true); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), - dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{trueValue}, - ValueRange{emptyBoolTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we @@ -1261,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{nanValue}, - ValueRange{emptyNanTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1503,12 +1494,11 @@ public: Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>( - nestedLoc, - nestedBuilder.getIntegerType( - valueTy.getIntOrFloatBitWidth()), - value) + value = UnrealizedConversionCastOp::create( + nestedBuilder, nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { @@ -1557,9 +1547,8 @@ public: } if (outIntType.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>(nestedLoc, - outIntType, value) + value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc, + outIntType, value) .getResult(0); } linalg::YieldOp::create(nestedBuilder, loc, value); @@ -2095,10 +2084,9 @@ public: Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef<Value>({dynDims})) + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputTy.getShape(), + inputTy.getElementType(), ArrayRef<Value>({dynDims})) .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2241,23 +2229,22 @@ public: } // First fill the output buffer for the index. - auto emptyTensorIdx = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - outElementTy, dynDims) - .getResult(); + auto emptyTensorIdx = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); auto fillValueIdx = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, - ValueRange{emptyTensorIdx}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx}, + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto emptyTensorMax = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - inElementTy, dynDims) - .getResult(); + auto emptyTensorMax = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy, + dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2268,9 +2255,8 @@ public: auto fillValueMax = arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, - ValueRange{emptyTensorMax}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax}, + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2371,9 +2357,8 @@ public: auto loc = op.getLoc(); auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, - dynamicDims) + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynamicDims) .getResult(); SmallVector<AffineMap, 2> affineMaps = { @@ -2448,10 +2433,10 @@ public: } } - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - resultElementTy, dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), @@ -2585,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); return filledTensor; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 3a20524..da1fb20 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef<AffineMap> indexingMaps) { ShapedType resultTy = cast<ShapedType>(conv.getType()); - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); - } - Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); - linalg::YieldOp::create(builder, loc, added); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + arith::ExtSIOp::create(builder, loc, resType, biasVal); + } + Value added = + arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); + }) .getResult(0); } @@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); // Build the broadcast-like operation as a linalg.generic. - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({source}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = - resultTy.getElementType().isFloat() - ? arith::ExtFOp::create(builder, loc, resType, biasVal) - .getResult() - : arith::ExtSIOp::create(builder, loc, resType, biasVal) - .getResult(); - } - linalg::YieldOp::create(builder, loc, biasVal); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({source}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + resultTy.getElementType().isFloat() + ? arith::ExtFOp::create(builder, loc, resType, biasVal) + .getResult() + : arith::ExtSIOp::create(builder, loc, resType, + biasVal) + .getResult(); + } + linalg::YieldOp::create(builder, loc, biasVal); + }) .getResult(0); } @@ -397,21 +398,19 @@ public: auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); - Value conv = - rewriter - .create<LinalgConvQOp>( - loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) - ->getResult(0); + Value conv = LinalgConvQOp::create( + rewriter, loc, resultTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) + ->getResult(0); rewriter.replaceOp(op, conv); return success(); } - Value conv = rewriter - .create<LinalgConvOp>( - loc, accTy, ValueRange{input, weight}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) + Value conv = LinalgConvOp::create( + rewriter, loc, accTy, ValueRange{input, weight}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -529,9 +528,8 @@ public: Value emptyTensor = tensor::EmptyOp::create( rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); Value biasEmptyTensor = tensor::EmptyOp::create( @@ -544,10 +542,9 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); if (hasNullZps) { - Value conv = rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmOp>( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) + Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create( + rewriter, loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -565,22 +562,20 @@ public: rewriter, loc, resultTy, conv, reassociationMap); Value result = - rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added; - if (llvm::isa<FloatType>(inputETy)) - added = arith::AddFOp::create(nestedBuilder, loc, args[0], - args[1]); - else - added = arith::AddIOp::create(nestedBuilder, loc, args[0], - args[1]); - linalg::YieldOp::create(nestedBuilder, nestedLoc, added); - }) + linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, convReshape}), + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added; + if (llvm::isa<FloatType>(inputETy)) + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); + else + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); + }) .getResult(0); rewriter.replaceOp(op, result); } else { @@ -588,12 +583,11 @@ public: IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); - Value conv = - rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( - loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) - .getResult(0); + Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create( + rewriter, loc, linalgConvTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = tensor::CollapseShapeOp::create( @@ -639,9 +633,8 @@ public: auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); @@ -910,20 +903,18 @@ public: rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{poolEmptyTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{initialValue}, + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. - Value poolingOp = rewriter - .create<linalg::PoolingNhwcSumOp>( - loc, ArrayRef<Type>{accTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr) + Value poolingOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, ArrayRef<Type>{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each @@ -1050,10 +1041,9 @@ public: Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = - rewriter - .create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), poolVal, multiplier, shift, - rewriter.getStringAttr("SINGLE_ROUND")) + tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, + shift, rewriter.getStringAttr("SINGLE_ROUND")) .getResult(); // If we have quantization information we need to apply output diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 77aab85..a425eff 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -482,14 +482,12 @@ struct CombineTransferReadOpTranspose final permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); - Value result = - rewriter - .create<vector::TransferReadOp>( - loc, resultType, transferReadOp.getBase(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()) - .getResult(); + Value result = vector::TransferReadOp::create( + rewriter, loc, resultType, transferReadOp.getBase(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); // Fuse through the integer extend op. if (extOp) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9cd491c..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -29,7 +29,9 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APFloat.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" + #include <optional> using namespace mlir; @@ -1068,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1204,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index b1af5f0..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 748ff1e..6f3110c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -96,9 +96,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return rewriter - .create<LLVM::MulOp>(loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)) .getResult(); } // Use direct constant for static stride. diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 3c00b32..6265f46 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -15,13 +15,13 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::affine; #define DEBUG_TYPE "decompose-affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") /// Count the number of loops surrounding `operand` such that operand could be /// hoisted above. @@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, return rewriter.notifyMatchFailure( op, "only add or mul binary expr can be reassociated"); - LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n"); + LDBG() << "Start decomposeIntoFinerGrainedOps: " << op; // 2. Iteratively extract the RHS subexpressions while the top-level binary // expr kind remains the same. @@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp); if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { subExpressions.push_back(remainingExp); - LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n"); + LDBG() << "--terminal: " << subExpressions.back(); break; } subExpressions.push_back(currentBinExpr.getRHS()); - LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n"); + LDBG() << "--subExpr: " << subExpressions.back(); remainingExp = currentBinExpr.getLHS(); } @@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) { return getMaxSymbol(e1) < getMaxSymbol(e2); }); - LLVM_DEBUG( - llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: "); - llvm::dbgs() << "\n"); + LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions); // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. auto s0 = getAffineSymbolExpr(0, ctx); @@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, Value tmp = createSubApply(rewriter, op, subExpressions[i]); current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, ValueRange{current, tmp}); - LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); + LDBG() << "--reassociate into: " << current; } // 5. Replace original op. diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index 8493b60..2521512 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -19,11 +19,10 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/IntEqClasses.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "affine-min-max" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::affine; @@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { ValueRange operands = affineOp.getOperands(); static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>; - LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; }); + LDBG() << "analyzing value: `" << affineOp; // Create a `Variable` list with values corresponding to each of the results // in the affine affineMap. @@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { [&](unsigned i) { return Variable(affineMap.getSliceMap(i, 1), operands); }); - LLVM_DEBUG({ - DBGS() << "- constructed variables are: " - << llvm::interleaved_array(llvm::map_range( - variables, [](const Variable &v) { return v.getMap(); })) - << "`\n"; - }); + LDBG() << "- constructed variables are: " + << llvm::interleaved_array(llvm::map_range( + variables, [](const Variable &v) { return v.getMap(); })); // Get the comparison operation. ComparisonOperator cmpOp = @@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Initialize the bound. Variable *bound = &v; - LLVM_DEBUG({ - DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() - << "`\n"; - }); + LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() + << "`\n"; // Check against the other variables. for (size_t j = i + 1; j < variables.size(); ++j) { @@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Get the bound of the equivalence class or itself. Variable *nv = bounds.lookup_or(jEqClass, &variables[j]); - LLVM_DEBUG({ - DBGS() << "- comparing with variable: #" << jEqClass - << ", with map: " << nv->getMap() << "\n"; - }); + LDBG() << "- comparing with variable: #" << jEqClass + << ", with map: " << nv->getMap(); // Compare the variables. FailureOr<bool> cmpResult = @@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // The variables cannot be compared. if (failed(cmpResult)) { - LLVM_DEBUG({ - DBGS() << "-- classes: #" << i << ", #" << jEqClass - << " cannot be merged\n"; - }); + LDBG() << "-- classes: #" << i << ", #" << jEqClass + << " cannot be merged"; continue; } // Join the equivalent classes and update the bound if necessary. - LLVM_DEBUG({ - DBGS() << "-- merging classes: #" << i << ", #" << jEqClass - << ", is cmp(lhs, rhs): " << *cmpResult << "`\n"; - }); + LDBG() << "-- merging classes: #" << i << ", #" << jEqClass + << ", is cmp(lhs, rhs): " << *cmpResult << "`"; if (*cmpResult) { boundedClasses.join(eqClass, jEqClass); } else { @@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Return if there's no simplification. if (bounds.size() >= affineMap.getNumResults()) { - LLVM_DEBUG( - { DBGS() << "- the affine operation couldn't get simplified\n"; }); + LDBG() << "- the affine operation couldn't get simplified"; return false; } @@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { for (auto [k, bound] : bounds) results.push_back(bound->getMap().getResult(0)); - LLVM_DEBUG({ - DBGS() << "- starting from map: " << affineMap << "\n"; - DBGS() << "- creating new map with: \n"; - DBGS() << "--- dims: " << affineMap.getNumDims() << "\n"; - DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n"; - DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n"; - }); + LDBG() << "- starting from map: " << affineMap; + LDBG() << "- creating new map with:"; + LDBG() << "--- dims: " << affineMap.getNumDims(); + LDBG() << "--- syms: " << affineMap.getNumSymbols(); + LDBG() << "--- res: " << llvm::interleaved_array(results); affineMap = AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(), @@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Update the affine op. rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); - LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; }); + LDBG() << "- simplified affine op: `" << affineOp << "`"; return true; } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 45b896d..1aa8064 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -145,8 +145,8 @@ protected: return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc, lhs, rhs); case MMLA::Bfloat: - return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs, - rhs); + return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs, + rhs); case MMLA::Nop: llvm_unreachable("Uninitialized operation type"); } @@ -226,8 +226,9 @@ public: // Initial accumulator for the final result. This is the un-tiled result if // tiling is done. - Value result = rewriter.create<arith::ConstantOp>( - loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); + Value result = + arith::ConstantOp::create(rewriter, loc, op.getResultType(), + rewriter.getZeroAttr(op.getResultType())); SmallVector<int64_t, 3> loopOrder = {0, 1}; if (iterationBounds.size() == 3) @@ -263,8 +264,9 @@ public: if (dimM == 1) { auto expandRowVector = [&](Value tiledOperand, VectorType expandedTypeType) { - auto emptyOperand = rewriter.create<arith::ConstantOp>( - loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); + auto emptyOperand = + arith::ConstantOp::create(rewriter, loc, expandedTypeType, + rewriter.getZeroAttr(expandedTypeType)); SmallVector<int64_t> offsets( cast<ShapedType>(emptyOperand.getType()).getRank(), 0); SmallVector<int64_t> strides( @@ -280,8 +282,8 @@ public: // using the instruction for unsigned by signed multiplication with // reversed operands. if (swapOperands) - tiledAcc = rewriter.create<vector::TransposeOp>( - loc, tiledAcc, ArrayRef<int64_t>({1, 0})); + tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc, + ArrayRef<int64_t>({1, 0})); // Collapse tiled operands to 1D vectors required by the ArmNeon ops auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( @@ -309,8 +311,8 @@ public: // Because of the reversed operands the result is obtained transposed. // Transpose it back, if (swapOperands) - tiledRes = rewriter.create<vector::TransposeOp>( - loc, tiledRes, ArrayRef<int64_t>({1, 0})); + tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes, + ArrayRef<int64_t>({1, 0})); // With vecmat, only one row of tiled ACC can be inserted into the final // result diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp index fcfeb9c..35b0bd1 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -214,13 +214,13 @@ Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter, switch (mmlaOp) { case MMLA::SignedInt: - return rewriter.create<arm_sve::SmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::UnsignedInt: - return rewriter.create<arm_sve::UmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::MixedInt: - return rewriter.create<arm_sve::UsmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::Bfloat: - return rewriter.create<arm_sve::BfmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); default: llvm_unreachable("Uninitialized operation kind"); } @@ -316,62 +316,63 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the LHS tile. auto r0 = - rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i}); + vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i}); auto r1 = - rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1}); + vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1}); // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile. SmallVector<int64_t> shuffleIdx(2 * K); std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0); - auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx); + auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx); // Turn it into a scalable vector. - auto s = rewriter.create<vector::ScalableInsertOp>( - loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0); + auto s = vector::ScalableInsertOp::create( + rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0); // Replicate the sub-tile VSCALE times to fill the entire vector. - auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0); + auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0); lhsTile.push_back(r); } // "Flatten" the RHS tile from <[N]xK> to <[N*K]>. - auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(), - flatRhsTileType, this->rhs); + auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(), + flatRhsTileType, this->rhs); // Extract the RHS sub-tiles with logical shape <Kx[2]>. SmallVector<Value> rhsTile; for (int64_t j = 0; j < N; j += 2) - rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>( - loc, flatRhsType, rhs, j * K)); + rhsTile.push_back(vector::ScalableExtractOp::create( + rewriter, loc, flatRhsType, rhs, j * K)); // Extract and pack the ACC sub-tiles. SmallVector<Value> accTile; for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the accumulator tile. - auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), - ArrayRef<int64_t>{i}); - auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), - ArrayRef<int64_t>{i + 1}); + auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), + ArrayRef<int64_t>{i}); + auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), + ArrayRef<int64_t>{i + 1}); Value accTileVec; if (swapOperands) { // We are performing the operation with swapped LHS and RHS we need to // transpose each individual 2x2 tile of the accumulator and (later) the // final result. - accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1); + accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1); } else { // Bitcast accumulator rows to double-width integer elements, so // subsequent interleave/deinterleave work on pairs of elements. - auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0); - auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1); + auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0); + auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1); // Interleave the rows, effectively flattening each 2x2 tile into 4 // consecutive elements. - auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64); + auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64); // Bitcast back to original element type. - accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64); + accTileVec = + vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64); } // Extract ACC sub-tiles. for (int64_t j = 0; j < N; j += 2) - accTile.push_back(rewriter.create<vector::ScalableExtractOp>( - loc, flatAccType, accTileVec, j * 2)); + accTile.push_back(vector::ScalableExtractOp::create( + rewriter, loc, flatAccType, accTileVec, j * 2)); } // Emit sub-tile matrix multiplications. @@ -384,13 +385,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, } // Unpack the OUT sub-tiles and insert into the result. - Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType()); + Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType()); for (int64_t i = 0; i < M / 2; ++i) { // Collect a number of sub-tiles in a row. - Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty); + Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty); for (int64_t j = 0; j < N / 2; ++j) - row = rewriter.create<vector::ScalableInsertOp>( - loc, outTile[i * N / 2 + j], row, j * 4); + row = vector::ScalableInsertOp::create( + rewriter, loc, outTile[i * N / 2 + j], row, j * 4); // Unpack the row to obtain two rows of the output. If we have the out // sub-tiles transposed we obtain two consecutive output rows by @@ -398,22 +399,22 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Otherwise, the interleave is by pairs. Value out0, out1; if (swapOperands) { - auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row); + auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row); out0 = tmp.getRes1(); out1 = tmp.getRes2(); } else { // Deinterleave by pairs. - auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row); - auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64); + auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row); + auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64); // Bitcast back into original element type and insert into the result. - out0 = - rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1()); - out1 = - rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2()); + out0 = vector::BitCastOp::create(rewriter, loc, accRowTy, + deintr64.getRes1()); + out1 = vector::BitCastOp::create(rewriter, loc, accRowTy, + deintr64.getRes2()); } - result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2); - result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1); + result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2); + result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1); } return result; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 825f63e..f7b0b87 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,7 +18,6 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// // BufferizableOpInterface @@ -35,8 +34,6 @@ namespace bufferization { MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState) #define DEBUG_TYPE "bufferizable-op-interface" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X)) using namespace mlir; using namespace bufferization; @@ -691,8 +688,8 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, if (failed(bufferType)) return failure(); ensureToBufferOpIsValid(value, *bufferType); - return rewriter - .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value) + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), + *bufferType, value) .getResult(); } @@ -775,9 +772,8 @@ FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, // Default bufferallocation via AllocOp. if (bufferAlignment != 0) - return b - .create<memref::AllocOp>(loc, type, dynShape, - b.getI64IntegerAttr(bufferAlignment)) + return memref::AllocOp::create(b, loc, type, dynShape, + b.getI64IntegerAttr(bufferAlignment)) .getResult(); return memref::AllocOp::create(b, loc, type, dynShape).getResult(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp index f0d65b0..e9ad13f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp @@ -482,10 +482,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( // Build the first for loop that computes aliasing with retained // memrefs. - Value noRetainAlias = - builder - .create<scf::ForOp>( - loc, c0, toRetainSize, c1, trueValue, + Value + noRetainAlias = + scf::ForOp::create( + builder, loc, c0, toRetainSize, c1, trueValue, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { Value retainValue = memref::LoadOp::create( @@ -512,14 +512,14 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( builder, loc, iterArgs[0], doesntAlias); scf::YieldOp::create(builder, loc, yieldValue); }) - .getResult(0); + .getResult(0); // Build the second for loop that adds aliasing with previously // deallocated memrefs. - Value noAlias = - builder - .create<scf::ForOp>( - loc, c0, outerIter, c1, noRetainAlias, + Value + noAlias = + scf::ForOp::create( + builder, loc, c0, outerIter, c1, noRetainAlias, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { Value prevDeallocValue = memref::LoadOp::create( @@ -531,7 +531,7 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( builder, loc, iterArgs[0], doesntAlias); scf::YieldOp::create(builder, loc, yieldValue); }) - .getResult(0); + .getResult(0); Value shouldDealoc = arith::AndIOp::create(builder, loc, noAlias, cond); memref::StoreOp::create(builder, loc, shouldDealoc, deallocCondsMemref, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d1d1062..aa53f94 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,12 +9,13 @@ // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. +// implementations for FuncOp, CallOp and ReturnOp. Although it is named +// Module Bufferization, it may operate on any SymbolTable. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis -// and bufferization: Functions that are called are processed before their -// respective callers. +// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp, +// ...)`. This function analyzes the given op and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. @@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, + Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, +mlir::bufferization::analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, } void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps<func::FuncOp>()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); + Operation *moduleOp) { + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(moduleOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp index 605a487..b8ddee6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp @@ -18,11 +18,9 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "optimize-allocation-liveness" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir { namespace bufferization { @@ -65,8 +63,8 @@ Operation *findUserWithFreeSideEffect(Value value) { for (const auto &effect : effects) { if (isa<MemoryEffects::Free>(effect.getEffect())) { if (freeOpUser) { - LDBG("Multiple users with free effect found: " << *freeOpUser - << " and " << *user); + LDBG() << "Multiple users with free effect found: " << *freeOpUser + << " and " << *user; return nullptr; } freeOpUser = user; @@ -121,7 +119,7 @@ public: return WalkResult::advance(); auto allocOp = memEffectOp; - LDBG("Checking alloc op: " << allocOp); + LDBG() << "Checking alloc op: " << allocOp; SmallVector<OpResult> allocationResults = collectAllocations(allocOp); // Multiple allocations from a single op are not considered here yet. @@ -129,7 +127,7 @@ public: return WalkResult::advance(); OpResult allocResult = allocationResults[0]; - LDBG("On allocation result: " << allocResult); + LDBG() << "On allocation result: " << allocResult; auto *deallocOp = findUserWithFreeSideEffect(allocResult); if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) { @@ -159,12 +157,12 @@ public: if (lastUser == nullptr) { return WalkResult::advance(); } - LDBG("Last user found: " << *lastUser); + LDBG() << "Last user found: " << *lastUser; assert(lastUser->getBlock() == allocOp->getBlock()); assert(lastUser->getBlock() == deallocOp->getBlock()); // Move the dealloc op after the last user. deallocOp->moveAfter(lastUser); - LDBG("Moved dealloc op after: " << *lastUser); + LDBG() << "Moved dealloc op after: " << *lastUser; return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 64c178d..725fa24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -750,17 +750,16 @@ Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership( // Insert a runtime check and only clone if we still don't have ownership at // runtime. - Value maybeClone = builder - .create<scf::IfOp>( - memref.getLoc(), condition, - [&](OpBuilder &builder, Location loc) { - scf::YieldOp::create(builder, loc, newMemref); - }, - [&](OpBuilder &builder, Location loc) { - Value clone = bufferization::CloneOp::create( - builder, loc, newMemref); - scf::YieldOp::create(builder, loc, clone); - }) + Value maybeClone = scf::IfOp::create( + builder, memref.getLoc(), condition, + [&](OpBuilder &builder, Location loc) { + scf::YieldOp::create(builder, loc, newMemref); + }, + [&](OpBuilder &builder, Location loc) { + Value clone = bufferization::CloneOp::create( + builder, loc, newMemref); + scf::YieldOp::create(builder, loc, clone); + }) .getResult(0); Value trueVal = buildBoolValue(builder, memref.getLoc(), true); state.updateOwnership(maybeClone, trueVal); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index f999c93..a6159ee 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics))) + if (failed(analyzeModuleOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 612e809..fa05ad8 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -31,7 +31,7 @@ struct WrapFuncInClassPass Operation *rootOp = getOperation(); RewritePatternSet patterns(&getContext()); - populateFuncPatterns(patterns, namedAttribute); + populateFuncPatterns(patterns); walkAndApplyPatterns(rootOp, std::move(patterns)); } @@ -43,8 +43,8 @@ struct WrapFuncInClassPass class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> { public: - WrapFuncInClass(MLIRContext *context, StringRef attrName) - : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {} + WrapFuncInClass(MLIRContext *context) + : OpRewritePattern<emitc::FuncOp>(context) {} LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override { @@ -101,12 +101,8 @@ public: rewriter.replaceOp(funcOp, newClassOp); return success(); } - -private: - StringRef attributeName; }; -void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns, - StringRef namedAttribute) { - patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute); +void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) { + patterns.add<WrapFuncInClass>(patterns.getContext()); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d186a48..5a72ef1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, // RotateOp //===----------------------------------------------------------------------===// -void RotateOp::build(OpBuilder &builder, OperationState &result, Value value, - int32_t offset, int32_t width) { - build(builder, result, value, - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(offset)), - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(width))); -} - LogicalResult RotateOp::verify() { - auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>(); - if (!offsetConstOp) - return emitOpError() << "offset is not a constant value"; - - auto offsetIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue()); - - auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>(); - if (!widthConstOp) - return emitOpError() << "width is not a constant value"; - - auto widthIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue()); - - llvm::APInt offsetValue = offsetIntAttr.getValue(); - llvm::APInt widthValue = widthIntAttr.getValue(); - - if (!widthValue.isPowerOf2()) - return emitOpError() << "width must be a power of two"; + uint32_t offset = getOffset(); + uint32_t width = getWidth(); - if (offsetValue.sge(widthValue) || offsetValue.slt(0)) { - int64_t widthValueInt = widthValue.getSExtValue(); - return emitOpError() << "offset must be in the range [0, " << widthValueInt - << ")"; + if (offset >= width) { + return emitOpError() << "offset must be in the range [0, " << width << ")"; } return success(); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 1d8279c..21cb2f6 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -39,7 +39,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/LogicalResult.h" @@ -51,11 +51,6 @@ using namespace mlir::transform; using namespace mlir::transform::gpu; #define DEBUG_TYPE "gpu-transforms" -#define DEBUG_TYPE_ALIAS "gpu-transforms-alias" - -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") //===----------------------------------------------------------------------===// // Apply...ConversionPatternsOp @@ -471,7 +466,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { - LDBG("--start rewriteOneForallCommonImpl"); + LDBG() << "--start rewriteOneForallCommonImpl"; // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. auto numParallelIterations = @@ -506,14 +501,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( // Otherwise, we have a new insertion without a size -> use size 1. tmpMappingSizes.push_back(1); } - LDBG("----tmpMappingSizes extracted from scf.forall op: " - << llvm::interleaved(tmpMappingSizes)); + LDBG() << "----tmpMappingSizes extracted from scf.forall op: " + << llvm::interleaved(tmpMappingSizes); // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator); - LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes)); - LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs)); + LDBG() << "----forallMappingSizes: " << llvm::interleaved(forallMappingSizes); + LDBG() << "----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs); // Step 3. Generate the mappingIdOps using the provided generator. Location loc = forallOp.getLoc(); @@ -522,24 +517,24 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( SmallVector<int64_t> originalBasis(availableMappingSizes); bool originalBasisWasProvided = !originalBasis.empty(); if (!originalBasisWasProvided) { - LDBG("----originalBasis was not provided, deriving it and there will be no " - "predication"); + LDBG() << "----originalBasis was not provided, deriving it and there will " + "be no " + "predication"; originalBasis = forallMappingSizes; while (originalBasis.size() < 3) originalBasis.push_back(1); } else { - LDBG("----originalBasis was provided, using it, there will be predication"); + LDBG() << "----originalBasis was provided, using it, there will be " + "predication"; } - LLVM_DEBUG( - llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: "); - llvm::dbgs() << "\n"); + LDBG() << "------originalBasis: " << llvm::interleaved(originalBasis); IdBuilderResult builderResult = gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); if (!builderResult.errorMsg.empty()) return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg); - LLVM_DEBUG(DBGS() << builderResult); + LDBG() << builderResult; // Step 4. Map the induction variables to the mappingIdOps, this may involve // a permutation. @@ -550,7 +545,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; - LDBG("----map: " << iv << " to " << peIdOp); + LDBG() << "----map: " << iv << " to " << peIdOp; bvm.map(iv, peIdOp); } @@ -596,9 +591,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( // Step 8. Erase old op. rewriter.eraseOp(forallOp); - LDBG("----result forallMappingSizes: " - << llvm::interleaved(forallMappingSizes)); - LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps)); + LDBG() << "----result forallMappingSizes: " + << llvm::interleaved(forallMappingSizes); + LDBG() << "----result mappingIdOps: " << llvm::interleaved(mappingIdOps); result = ForallRewriteResult{forallMappingSizes, mappingIdOps}; return DiagnosedSilenceableFailure::success(); @@ -612,7 +607,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, const GpuIdBuilder &gpuIdBuilder) { - LDBG("Start mapForallToBlocksImpl"); + LDBG() << "Start mapForallToBlocksImpl"; { // GPU-specific verifications. There is no better place to anchor @@ -893,7 +888,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, bool syncAfterDistribute) { - LDBG("Start mapNestedForallToThreadsImpl"); + LDBG() << "Start mapNestedForallToThreadsImpl"; if (blockDims.size() != 3) { return definiteFailureHelper(transformOp, target, "requires size-3 thread mapping"); diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp index 2fba09b..05bd917 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp @@ -27,7 +27,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::gpu; @@ -36,10 +37,6 @@ using namespace mlir::transform::gpu; #define DEBUG_TYPE "gpu-transforms" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") - /// Build predicates to filter execution by only the activeIds. Along each /// dimension, 3 cases appear: /// 1. activeMappingSize > availableMappingSize: this is an unsupported case @@ -54,15 +51,9 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds, ArrayRef<int64_t> activeMappingSizes, ArrayRef<int64_t> availableMappingSizes, std::string &errorMsg) { - // clang-format off - LLVM_DEBUG( - llvm::interleaveComma( - activeMappingSizes, DBGS() << "----activeMappingSizes: "); - DBGS() << "\n"; - llvm::interleaveComma( - availableMappingSizes, DBGS() << "----availableMappingSizes: "); - DBGS() << "\n";); - // clang-format on + LDBG() << "----activeMappingSizes: " << llvm::interleaved(activeMappingSizes); + LDBG() << "----availableMappingSizes: " + << llvm::interleaved(availableMappingSizes); SmallVector<Value> predicateOps; for (auto [activeId, activeMappingSize, availableMappingSize] : @@ -88,10 +79,8 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds, template <typename ThreadOrBlockIdOp> static Value buildLinearId(RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> originalBasisOfr) { - LLVM_DEBUG(llvm::interleaveComma( - originalBasisOfr, - DBGS() << "----buildLinearId with originalBasisOfr: "); - llvm::dbgs() << "\n"); + LDBG() << "----buildLinearId with originalBasisOfr: " + << llvm::interleaved(originalBasisOfr); assert(originalBasisOfr.size() == 3 && "expected 3 sizes"); IndexType indexType = rewriter.getIndexType(); AffineExpr tx, ty, tz, bdx, bdy; @@ -157,7 +146,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1, mask.createLogicalLinearMappingId(rewriter, scaledLinearIdI64); scaledLinearId = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getIndexType(), logicalLinearIdI64); - LDBG("------adjusting linearId with mask: " << scaledLinearId); + LDBG() << "------adjusting linearId with mask: " << scaledLinearId; } // 3. Compute remapped indices. @@ -179,7 +168,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1, if (mask) { Value isActiveIdPredicate = mask.createIsActiveIdPredicate(rewriter, scaledLinearIdI64); - LDBG("------adjusting predicate with mask: " << isActiveIdPredicate); + LDBG() << "------adjusting predicate with mask: " << isActiveIdPredicate; predicateOps.push_back(isActiveIdPredicate); } else { // 4.b. Otherwise, handle predicates using physicalLinearId. diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp index d88f4d5..8e05436 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp @@ -60,14 +60,12 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> { // Shuffle the values. ValueRange loRes = - rewriter - .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(), - op.getWidth(), op.getMode()) + gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(), + op.getWidth(), op.getMode()) .getResults(); ValueRange hiRes = - rewriter - .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(), - op.getWidth(), op.getMode()) + gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(), + op.getWidth(), op.getMode()) .getResults(); // Convert lo back to i64. diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index b9e2dd5..b45fdf3 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -197,10 +197,9 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, // Parallel reduction using butterfly shuffles. for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; i <<= 1) { - Value shuffled = builder - .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, - /*width=*/ci.subgroupSize, - /*mode=*/gpu::ShuffleMode::XOR) + Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i, + /*width=*/ci.subgroupSize, + /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = vector::makeArithReduction(builder, loc, gpu::convertReductionKind(mode), diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index d987b72..ff55f17 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -21,10 +21,7 @@ add_mlir_dialect_library(MLIRLLVMDialect intrinsics_gen LINK_COMPONENTS - AsmParser BinaryFormat - BitReader - BitWriter Core LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index d42ce96..422039f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -26,8 +26,7 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" +#include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" #include <numeric> @@ -4064,28 +4063,9 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, } void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, - Value cond, - ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) { - SmallVector<ValueRange> opBundleOperands; - SmallVector<Attribute> opBundleTags; - opBundleOperands.reserve(opBundles.size()); - opBundleTags.reserve(opBundles.size()); - - for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) { - opBundleOperands.emplace_back(bundle.inputs()); - opBundleTags.push_back( - StringAttr::get(builder.getContext(), bundle.getTag())); - } - - auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags); - return build(builder, state, cond, opBundleOperands, opBundleTagsAttr); -} - -void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, Value cond, llvm::StringRef tag, ValueRange args) { - llvm::OperandBundleDefT<Value> opBundle( - tag.str(), SmallVector<Value>(args.begin(), args.end())); - return build(builder, state, cond, opBundle); + return build(builder, state, cond, ArrayRef<ValueRange>(args), + builder.getStrArrayAttr(tag)); } void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b12..cffe310 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -30,15 +30,8 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <optional> diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 1a9ccf5..17371ec 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/IR/Type.h" using namespace mlir; using namespace ROCDL; diff --git a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp index bd9d3528..1d4a0af 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp @@ -20,11 +20,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace vcix; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f49d9a1..73ae029 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, SmallVector<unsigned, 2>(ac.begin(), ac.end()), SmallVector<unsigned, 2>(bc.begin(), bc.end()), SmallVector<unsigned, 2>(ra.begin(), ra.end())}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.m.begin(), dimensions.m.end()); - llvm::sort(dimensions.n.begin(), dimensions.n.end()); - llvm::sort(dimensions.k.begin(), dimensions.k.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.m); + llvm::sort(dimensions.n); + llvm::sort(dimensions.k); return dimensions; } @@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp, SmallVector<unsigned, 2>(depth.begin(), depth.end()), /*strides=*/SmallVector<int64_t, 2>{}, /*dilations=*/SmallVector<int64_t, 2>{}}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end()); - llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end()); - llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end()); - llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end()); - llvm::sort(dimensions.depth.begin(), dimensions.depth.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.outputImage); + llvm::sort(dimensions.outputChannel); + llvm::sort(dimensions.filterLoop); + llvm::sort(dimensions.inputChannel); + llvm::sort(dimensions.depth); // Use the op carried strides/dilations attribute if present. auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 4fee81a..34c63d3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -791,9 +792,8 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(), padOp.getResultType().getElementType()); Value replacement = - rewriter - .create<FillOp>(fillOp.getLoc(), ValueRange{padValue}, - ValueRange{emptyTensor}) + FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue}, + ValueRange{emptyTensor}) .getResult(0); if (replacement.getType() != padOp.getResultType()) { replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(), @@ -2154,9 +2154,8 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> { // Create broadcast(transpose(input)). Value transposeResult = - rewriter - .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit, - resultPerms) + TransposeOp::create(rewriter, loc, broadcastOp.getInput(), + transposeInit, resultPerms) ->getResult(0); rewriter.replaceOpWithNewOp<BroadcastOp>( transposeOp, transposeResult, transposeOp.getInit(), resultDimensions); @@ -2294,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +/// Fold back-to-back broadcasts together. +struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> { + using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>(); + if (!defBroadcastOp) + return failure(); + ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions(); + ArrayRef<int64_t> dimensions = broadcastOp.getDimensions(); + SmallVector<int64_t> foldedDims(dimensions); + Value init = broadcastOp.getInit(); + int64_t initRank = cast<ShapedType>(init.getType()).getRank(); + // Mapping from input dims to init dims. + SmallVector<int64_t> dimMap; + for (auto dim : llvm::seq<int64_t>(0, initRank)) { + if (!llvm::is_contained(dimensions, dim)) + dimMap.push_back(dim); + } + for (auto dim : defDimensions) + foldedDims.push_back(dimMap[dim]); + + llvm::sort(foldedDims); + rewriter.replaceOpWithNewOp<BroadcastOp>( + broadcastOp, defBroadcastOp.getInput(), init, foldedDims); + return success(); + } +}; + void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); + results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context); } //===----------------------------------------------------------------------===// @@ -4624,22 +4653,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, }); } -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool areAllInBound(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; - }); -} - template <typename OpTy> static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, @@ -4698,11 +4711,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // represents full tiles. RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -4719,6 +4727,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } + if (failed(verifyCompatibleShape(expectedPackedType.getShape(), + packedType.getShape()))) { + return op->emitError("expected ") + << expectedPackedType << " for the packed domain value, got " + << packedType; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp index ce1b1b9..5c8c2de 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -21,8 +22,6 @@ using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") static Attribute linearId0(MLIRContext *ctx) { return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0); @@ -43,9 +42,8 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, assert(!copySizes.empty() && copySizes.size() <= 3 && "only 1,2,3-D copies are supported for now"); - LDBG("START CopyMappingInfo, favorPredication: " << favorPredication); - LLVM_DEBUG(DBGS() << "--copy shape: " << llvm::interleaved(copySizes) - << "\n"); + LDBG() << "START CopyMappingInfo, favorPredication: " << favorPredication; + LDBG() << "--copy shape: " << llvm::interleaved(copySizes); // Greedily find the largest vector size that can be used to copy the most // minor dimension: we are in the business of filling kMaxVectorLoadBitWidth @@ -53,20 +51,19 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer( desiredBitAlignment, copySizes.back(), elementalBitwidth); - LDBG("--greedily determined vectorSize: " - << desiredVectorSize << " elements of " << elementalBitwidth - << "b each -> " << (desiredVectorSize * elementalBitwidth) - << "b total out of a max of " << kMaxVectorLoadBitWidth << "b"); + LDBG() << "--greedily determined vectorSize: " << desiredVectorSize + << " elements of " << elementalBitwidth << "b each -> " + << (desiredVectorSize * elementalBitwidth) + << "b total out of a max of " << kMaxVectorLoadBitWidth << "b"; status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize, favorPredication); if (status == Status::Invalid) return; - LLVM_DEBUG(DBGS() << "--copy: " << llvm::interleaved(copySizes) << "\n" - << "--numThreads: " << llvm::interleaved(this->numThreads) - << "\n" - << "--vectorSize: " << this->vectorSize << "\n"); + LDBG() << "--copy: " << llvm::interleaved(copySizes) << "\n" + << "--numThreads: " << llvm::interleaved(this->numThreads) << "\n" + << "--vectorSize: " << this->vectorSize; assert(this->numThreads.size() == copySizes.size() && "compute copy mapping expected same number of threads and copy sizes"); @@ -84,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, this->threadMapping = llvm::to_vector(ArrayRef(allThreadMappings) .take_back(this->smallestBoundingTileSizes.size())); - LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n"); + LDBG() << *this; } int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer( @@ -140,7 +137,7 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, "currentIndex out of bounds"); std::string indent(2 * currentIndex, '-'); if (static_cast<size_t>(currentIndex) == sizes.size() - 1) { - LDBG(indent << "mandated globalBest: " << sizes[currentIndex]); + LDBG() << indent << "mandated globalBest: " << sizes[currentIndex]; return SmallVector<int64_t>{sizes[currentIndex]}; } @@ -149,16 +146,16 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, SmallVector<int64_t> factors = getFactors(s); SmallVector<int64_t> localThreadsPerDim; localThreadsPerDim.reserve(sizes.size()); - LDBG(indent << "maximizeNumThreads in " << s - << " with limit: " << maxNumThreads); + LDBG() << indent << "maximizeNumThreads in " << s + << " with limit: " << maxNumThreads; for (auto factor : factors) { auto nestedThreadsPerDim = maximizeNumThreads(sizes, currentIndex + 1, maxNumThreads / factor); int64_t localBest = factor * product(nestedThreadsPerDim); if (localBest > best && localBest <= maxNumThreads) { - LDBG(indent << "new localBest: " << localBest); - LDBG(indent << "nestedThreadsPerDim: " - << llvm::interleaved(nestedThreadsPerDim)); + LDBG() << indent << "new localBest: " << localBest; + LDBG() << indent << "nestedThreadsPerDim: " + << llvm::interleaved(nestedThreadsPerDim); localThreadsPerDim.clear(); localThreadsPerDim.push_back(factor); llvm::append_range(localThreadsPerDim, nestedThreadsPerDim); @@ -166,8 +163,8 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, } } - LDBG(indent << "found globalBest: " << best); - LDBG(indent << "numThreads: " << llvm::interleaved(localThreadsPerDim)); + LDBG() << indent << "found globalBest: " << best; + LDBG() << indent << "numThreads: " << llvm::interleaved(localThreadsPerDim); return localThreadsPerDim; } @@ -192,8 +189,8 @@ transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads, if (status == Status::Success || status == Status::Invalid) return status; - LDBG("requires predication, try reducing vector size to " - << (localVectorSize / 2)); + LDBG() << "requires predication, try reducing vector size to " + << (localVectorSize / 2); } } @@ -210,8 +207,8 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl( assert(sizes.back() % desiredVectorSize == 0 && "most-minor size not divisible by actualVectorSize"); - LDBG("inferNumThreadsImpl with totalNumThreads: " - << totalNumThreads << " and vectorSize: " << desiredVectorSize); + LDBG() << "inferNumThreadsImpl with totalNumThreads: " << totalNumThreads + << " and vectorSize: " << desiredVectorSize; // Scale the most minor size to account for the chosen vector size and // maximize the number of threads without exceeding the total number of @@ -219,22 +216,22 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl( SmallVector<int64_t> scaledSizes(sizes); scaledSizes.back() /= desiredVectorSize; if (scaledSizes.back() > totalNumThreads) { - LDBG("--Too few threads given the required vector size -> FAIL"); + LDBG() << "--Too few threads given the required vector size -> FAIL"; return Status::Invalid; } SmallVector<int64_t> inferredNumThreads = maximizeNumThreads(scaledSizes, 0, totalNumThreads); - LDBG("inferred numThreads: " << llvm::interleaved(inferredNumThreads)); - LDBG("computed actualVectorSize: " << desiredVectorSize); + LDBG() << "inferred numThreads: " << llvm::interleaved(inferredNumThreads); + LDBG() << "computed actualVectorSize: " << desiredVectorSize; // Corner case: we cannot use more threads than available. If the dimension of // the copy is so bad it is because higher-level tiling did not do its job, we // do not try to recover from it here. int64_t totalNumThreadsUsed = product(inferredNumThreads); - LDBG("--totalNumThreadsUsed: " << totalNumThreadsUsed); + LDBG() << "--totalNumThreadsUsed: " << totalNumThreadsUsed; if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) { - LDBG("--Too few threads given the required vector size -> FAIL"); + LDBG() << "--Too few threads given the required vector size -> FAIL"; return Status::Invalid; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 2fe72a3..d4a3e5f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -15,14 +15,13 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InterleavedRange.h" using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") //===----------------------------------------------------------------------===// // StructuredMatchOp @@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( return emitSilenceableError() << "expected a Linalg op"; } // If errors are suppressed, succeed and set all results to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); + LDBG() << "optional nested matcher expected a Linalg op"; results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); return DiagnosedSilenceableFailure::success(); } @@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( // When they are defined in this block, we additionally check if we have // already applied the operation that defines them. If not, the // corresponding results will be set to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() - << "\n"); + LDBG() << "optional nested matcher failed: " << diag.getMessage(); (void)diag.silence(); SmallVector<OpOperand *> undefinedOperands; for (OpOperand &terminatorOperand : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9f523e9d..bdfc8d0 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include <type_traits> @@ -49,9 +49,6 @@ using namespace mlir::linalg; using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` @@ -773,7 +770,7 @@ static bool sameOrEquivalentIterArg(Value src, Value dst) { static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); + LDBG() << "Try to fuse a direct extract use"; auto tileableProducer = dyn_cast<TilingInterface>(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) @@ -838,7 +835,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Tile the producer. int64_t resultNumber = cast<OpResult>(sliceOpToTile.getSource()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets(); SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes(); @@ -855,7 +852,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, #ifndef NDEBUG for (auto *tiledOp : tileAndFuseResult->tiledOps) { - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); + LDBG() << "tiledProducer: " << *tiledOp; } #endif @@ -894,7 +891,7 @@ static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); + LDBG() << "Try to fuse an extract use through block argument"; auto tileableProducer = dyn_cast<TilingInterface>(producerOp); if (!tileableProducer) { @@ -947,7 +944,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; // Gather destination tensors. SmallVector<Value> destinationTensors; @@ -996,7 +993,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n"); + LDBG() << "Try to fuse an use by cloning"; // Gather all uses inside the containing op. SmallVector<OpOperand *> uses; @@ -1030,7 +1027,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); @@ -1113,7 +1110,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, auto [tiledOps, newContainingOp] = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (!tiledOps.empty()) { - LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); + LDBG() << "\nFused a direct extract use\n" << *containingOp; fusedOps.append(tiledOps); if (newContainingOp) { // Update handles associated with the containing op so we don't need to @@ -1139,8 +1136,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (!tiledContainingOpOperand.empty()) { - LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" - << *containingOp); + LDBG() << "\nFused an extract use through block argument\n" + << *containingOp; fusedOps.append(tiledContainingOpOperand); continue; } @@ -1148,7 +1145,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { - LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp); + LDBG() << "\nFused an use by cloning\n" << *containingOp; fusedOps.push_back(cloned); continue; } @@ -4136,9 +4133,8 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, Value extracted = tensor::ExtractSliceOp::create( rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); - Value copied = rewriter - .create<linalg::CopyOp>(target.getLoc(), - target.getSource(), extracted) + Value copied = linalg::CopyOp::create(rewriter, target.getLoc(), + target.getSource(), extracted) .getResult(0); // Reset the insertion point. rewriter.setInsertionPoint(target); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 91a297f..0a9c176 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -1143,10 +1143,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, // Insert an unPackOp right after the packed generic. Value unPackOpRes = - rewriter - .create<linalg::UnPackOp>(genericOp.getLoc(), newResult, - destPack.getSource(), innerDimsPos, - mixedTiles, outerDimsPerm) + linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult, + destPack.getSource(), innerDimsPos, mixedTiles, + outerDimsPerm) .getResult(); return std::make_tuple(newGenericOp, unPackOpRes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 745a40db..7f9ba1b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -267,8 +267,8 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, assert(rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); - return rewriter - .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation) + return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result, + reassociation) .getResult(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 4a66b8b..3bd763e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1572,12 +1572,12 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op, // Insert a reshape to collapse the dimensions. if (isa<MemRefType>(operand.getType())) { - return builder - .create<memref::CollapseShapeOp>(loc, operand, operandReassociation) + return memref::CollapseShapeOp::create(builder, loc, operand, + operandReassociation) .getResult(); } - return builder - .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation) + return tensor::CollapseShapeOp::create(builder, loc, operand, + operandReassociation) .getResult(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index a45a4e3..9d7f4e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -81,9 +82,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { ArrayRef<ReassociationIndices> reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter - .create<tensor::ExpandShapeOp>(loc, newOperandType, operand, - reassociation) + return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand, + reassociation) .getResult(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2c62cb6..2e62523 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, return paddingSizes; } +/// Extracts the constant multiplier from an affine expression of the form +/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an +/// AffineConstantExpr. Returns 1 if the expression is not a simple +/// multiplication of a dimension and a constant. +static int64_t extractConstantMultiplier(AffineExpr expr) { + if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) { + if (binOp.getKind() == AffineExprKind::Mul) { + auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS()); + auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS()); + if (lhsD && rhsC) { + return rhsC.getValue(); + } + auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS()); + auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS()); + if (lhsC && rhsD) { + return lhsC.getValue(); + } + } + } + return 1; +} + /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases @@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permutation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> linalg::computePaddedShape( @@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( /*compressDims=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. + OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); - terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); - terms.push_back(paddingDimOfr); } + // Adjust for the maximum accessed index, which is (paddingSize - 1) * + // multiplier. + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); + AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); + OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( + rewriter, loc, subtractMap, {paddingDimOfr}); + terms.push_back(maxAccessIdx); + LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } @@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; - OpFoldResult paddedDimOfr = - affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms); + // Add 1 to the maximum accessed index and get the final padded size. + OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index b5c5aea..dd84379 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -333,17 +333,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, for (auto it : llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) { if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { - replacements.push_back(rewriter - .create<linalg::CopyOp>(loc, std::get<0>(it), - std::get<1>(it).get()) + replacements.push_back(linalg::CopyOp::create(rewriter, loc, + std::get<0>(it), + std::get<1>(it).get()) .getResult(0)); } else if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp:: BufferizationMaterializeInDestination) { replacements.push_back( - rewriter - .create<bufferization::MaterializeInDestinationOp>( - loc, std::get<0>(it), std::get<1>(it).get()) + bufferization::MaterializeInDestinationOp::create( + rewriter, loc, std::get<0>(it), std::get<1>(it).get()) ->getResult(0)); } else { llvm_unreachable("unsupported copy back op"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index dad3526..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 1f1e617..bb725f2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -947,9 +947,9 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) return val; - return rewriter - .create<arith::ConstantIndexOp>( - padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) + return arith::ConstantIndexOp::create( + rewriter, padOp.getLoc(), + cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) .getResult(); }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp index 99fb8c7..35453e2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp @@ -70,9 +70,8 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy) .getResult(); } else { - input = rewriter - .create<memref::AllocOp>( - loc, MemRefType::get(newFilterShape, elementTy)) + input = memref::AllocOp::create(rewriter, loc, + MemRefType::get(newFilterShape, elementTy)) .getResult(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 78c6bd1..793eec7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -38,7 +38,8 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -48,9 +49,6 @@ using namespace mlir::linalg; #define DEBUG_TYPE "linalg-vectorization" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - /// Try to vectorize `convOp` as a convolution. static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, @@ -403,12 +401,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter, scalableVecDims.append(linalgOp.getNumLoops(), false); } - LDBG("Canonical vector shape: "); - LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Scalable vector dims: "); - LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape); + LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims); if (ShapedType::isDynamicShape(canonicalVecShape)) return failure(); @@ -452,14 +446,14 @@ Value VectorizationState::getOrCreateMaskFor( : AffineMap::getMultiDimIdentityMap( linalgOp.getNumLoops(), rewriter.getContext()); - LDBG("Masking map: " << maskingMap << "\n"); + LDBG() << "Masking map: " << maskingMap; // Return the active mask for the masking map of this operation if it was // already created. auto activeMaskIt = activeMaskCache.find(maskingMap); if (activeMaskIt != activeMaskCache.end()) { Value mask = activeMaskIt->second; - LDBG("Reusing mask: " << mask << "\n"); + LDBG() << "Reusing mask: " << mask; return mask; } @@ -474,12 +468,10 @@ Value VectorizationState::getOrCreateMaskFor( auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); auto maskShape = maskType.getShape(); - LDBG("Mask shape: "); - LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Mask shape: " << llvm::interleaved(maskShape); if (permutedStaticSizes == maskShape) { - LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); + LDBG() << "Masking is not needed for masking map: " << maskingMap; activeMaskCache[maskingMap] = Value(); return Value(); } @@ -494,8 +486,9 @@ Value VectorizationState::getOrCreateMaskFor( ? true : std::get<0>(it) == std::get<1>(it); })) { - LDBG("Dynamic + static dimensions match vector sizes, masking is not " - "required.\n"); + LDBG() + << "Dynamic + static dimensions match vector sizes, masking is not " + "required."; activeMaskCache[maskingMap] = Value(); return Value(); } @@ -510,7 +503,7 @@ Value VectorizationState::getOrCreateMaskFor( // Create the mask based on the dimension values. Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(), maskType, upperBounds); - LDBG("Creating new mask: " << mask << "\n"); + LDBG() << "Creating new mask: " << mask; activeMaskCache[maskingMap] = mask; return mask; } @@ -519,7 +512,7 @@ Operation * VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional<AffineMap> maybeIndexingMap) { - LDBG("Trying to mask: " << *opToMask << "\n"); + LDBG() << "Trying to mask: " << *opToMask; std::optional<AffineMap> maybeMaskingMap = std::nullopt; if (maybeIndexingMap) @@ -530,7 +523,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap); if (!mask) { - LDBG("No mask required\n"); + LDBG() << "No mask required"; return opToMask; } @@ -544,7 +537,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), maskOpTerminator); - LDBG("Masked operation: " << *maskOp << "\n"); + LDBG() << "Masked operation: " << *maskOp; return maskOp; } @@ -748,7 +741,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); } - LDBG("vectorized op: " << *write << "\n"); + LDBG() << "vectorized op: " << *write; if (!write->getResults().empty()) return write->getResult(0); return Value(); @@ -1090,7 +1083,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, } if (!leadingIdxsLoopInvariant) { - LDBG("Found gather load: " << extractOp); + LDBG() << "Found gather load: " << extractOp; return VectorMemoryAccessKind::Gather; } @@ -1104,7 +1097,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, // If the trailing index is loop invariant then this is a scalar load. if (leadingIdxsLoopInvariant && isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) { - LDBG("Found scalar broadcast load: " << extractOp); + LDBG() << "Found scalar broadcast load: " << extractOp; return VectorMemoryAccessKind::ScalarBroadcast; } @@ -1122,12 +1115,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, isContiguousLoad &= (foundIndexOp && isRowVector); if (isContiguousLoad) { - LDBG("Found contigous load: " << extractOp); + LDBG() << "Found contigous load: " << extractOp; return VectorMemoryAccessKind::Contiguous; } // 4. Fallback case - gather load. - LDBG("Found gather load: " << extractOp); + LDBG() << "Found gather load: " << extractOp; return VectorMemoryAccessKind::Gather; } @@ -1171,7 +1164,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, maskConstantOp, passThruConstantOp); gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); - LDBG("Vectorised as gather load: " << extractOp << "\n"); + LDBG() << "Vectorised as gather load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp}; } @@ -1235,7 +1228,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, auto *maskedReadOp = mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); - LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n"); + LDBG() << "Vectorised as scalar broadcast load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, maskedReadOp}; } @@ -1262,7 +1255,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); - LDBG("Vectorised as contiguous load: " << extractOp); + LDBG() << "Vectorised as contiguous load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, transferReadOp}; } @@ -1310,7 +1303,7 @@ static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef<CustomVectorizationHook> customVectorizationHooks) { - LDBG("vectorize op " << *op << "\n"); + LDBG() << "vectorize op " << *op; // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -1425,7 +1418,7 @@ static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { - LDBG("Vectorizing operation as linalg generic\n"); + LDBG() << "Vectorizing operation as linalg generic/n"; Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -1490,8 +1483,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, readValue = vector::ExtractOp::create(rewriter, loc, readValue, ArrayRef<int64_t>()); - LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue - << "\n"); + LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber() + << "): " << readValue; bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -1523,13 +1516,13 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, VectorizationHookResult result = vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); if (result.status == VectorizationHookStatus::Failure) { - LDBG("failed to vectorize: " << op << "\n"); + LDBG() << "failed to vectorize: " << op; return failure(); } if (result.status == VectorizationHookStatus::NewOp) { Operation *maybeMaskedOp = state.maskOperation(rewriter, result.newOp, linalgOp); - LDBG("New vector op: " << *maybeMaskedOp << "\n"); + LDBG() << "New vector op: " << *maybeMaskedOp; bvm.map(op.getResults(), maybeMaskedOp->getResults()); } } @@ -1920,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), sourceShape.end()); - ReifiedRankedShapedTypeDims reifiedRetShapes; - LogicalResult status = - cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) - .reifyResultShapes(rewriter, reifiedRetShapes); - if (status.failed()) { - LDBG("Unable to reify result shapes of " << unpackOp); - return failure(); - } Location loc = unpackOp->getLoc(); auto padValue = arith::ConstantOp::create( @@ -2010,7 +1995,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator\n"); + LDBG() << "reduction precondition failed: no reduction iterator"; return failure(); } for (OpOperand &opOperand : op.getDpsInitsMutable()) { @@ -2020,7 +2005,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) { Operation *reduceOp = matchLinalgReduction(&opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed\n"); + LDBG() << "reduction precondition failed: reduction detection failed"; return failure(); } } @@ -2031,13 +2016,13 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv) { if (flatten1DDepthwiseConv) { - LDBG("Vectorization of flattened convs with dynamic shapes is not " - "supported\n"); + LDBG() << "Vectorization of flattened convs with dynamic shapes is not " + "supported"; return failure(); } if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) { - LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n"); + LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported"; return failure(); } @@ -2047,8 +2032,8 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape(); auto shapeWithoutCh = lhsShape.drop_back(1); if (ShapedType::isDynamicShape(shapeWithoutCh)) { - LDBG("Dynamically-shaped op vectorization precondition failed: only " - "channel dim can be dynamic\n"); + LDBG() << "Dynamically-shaped op vectorization precondition failed: only " + "channel dim can be dynamic"; return failure(); } @@ -2071,7 +2056,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, op.getOperation())) return failure(); - LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); + LDBG() << "Dynamically-shaped op meets vectorization pre-conditions"; return success(); } @@ -2083,7 +2068,7 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { return !getConstantIntValue(res).has_value(); })) { - LDBG("Inner-tiles must be constant: " << unpackOp << "\n"); + LDBG() << "Inner-tiles must be constant: " << unpackOp; return failure(); } ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); @@ -2123,7 +2108,7 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, !sourceType.hasStaticShape() && inputVectorSizes.empty(); if (!padValue && isOutOfBoundsRead) { - LDBG("Failed to get a pad value for out-of-bounds read access\n"); + LDBG() << "Failed to get a pad value for out-of-bounds read access"; return failure(); } return success(); @@ -2153,7 +2138,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, Operation *reduceOp = matchLinalgReduction(outOperand); auto maybeKind = getCombinerOpKind(reduceOp); if (!maybeKind) { - LDBG("Failed to determine contraction combining kind.\n"); + LDBG() << "Failed to determine contraction combining kind."; return failure(); } @@ -2163,7 +2148,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) { - LDBG("Contractions with broadcasts are not supported.\n"); + LDBG() << "Contractions with broadcasts are not supported."; return failure(); } @@ -2198,8 +2183,8 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, } // Create contraction. - Operation *contractOp = rewriter.create<vector::ContractionOp>( - loc, /*lhs=*/vecOperands[0], + Operation *contractOp = vector::ContractionOp::create( + rewriter, loc, /*lhs=*/vecOperands[0], /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); contractOp = state.maskOperation(rewriter, contractOp, linalgOp); @@ -2355,7 +2340,7 @@ static LogicalResult vectorizeLinalgOpPrecondition( if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition( linalgOp, flatten1DDepthwiseConv))) { - LDBG("Dynamically-shaped op failed vectorization pre-conditions\n"); + LDBG() << "Dynamically-shaped op failed vectorization pre-conditions"; return failure(); } @@ -2397,11 +2382,11 @@ static LogicalResult vectorizeLinalgOpPrecondition( // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. if (!allIndexingsAreProjectedPermutation(linalgOp)) { - LDBG("precondition failed: not projected permutations\n"); + LDBG() << "precondition failed: not projected permutations"; return failure(); } if (failed(reductionPreconditions(linalgOp))) { - LDBG("precondition failed: reduction preconditions\n"); + LDBG() << "precondition failed: reduction preconditions"; return failure(); } return success(); @@ -2413,7 +2398,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, auto padValue = packOp.getPaddingValue(); Attribute cstAttr; if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { - LDBG("pad value is not constant: " << packOp << "\n"); + LDBG() << "pad value is not constant: " << packOp; return failure(); } ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); @@ -2433,7 +2418,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { return !getConstantIntValue(v).has_value(); })) { - LDBG("inner_tiles must be constant: " << packOp << "\n"); + LDBG() << "inner_tiles must be constant: " << packOp; return failure(); } @@ -2445,7 +2430,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef<int64_t> inputVectorSizes) { auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { - LDBG("pad value is not constant: " << padOp << "\n"); + LDBG() << "pad value is not constant: " << padOp; return failure(); } @@ -2472,7 +2457,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return (!pad.has_value() || pad.value() != 0) && resultTensorShape[pos] != 1; })) { - LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n"); + LDBG() << "low pad must all be zero for all non unit dims: " << padOp; return failure(); } @@ -2541,13 +2526,14 @@ vectorizeScalableVectorPrecondition(Operation *op, case utils::IteratorType::reduction: { // Check 3. above is met. if (iterators.size() != inputVectorSizes.size()) { - LDBG("Non-trailing reduction dim requested for scalable " - "vectorization\n"); + LDBG() << "Non-trailing reduction dim requested for scalable " + "vectorization"; return failure(); } if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { - LDBG("Scalable vectorization of the reduction dim in Matmul-like ops " - "is not supported\n"); + LDBG() + << "Scalable vectorization of the reduction dim in Matmul-like ops " + "is not supported"; return failure(); } break; @@ -2555,8 +2541,8 @@ vectorizeScalableVectorPrecondition(Operation *op, case utils::IteratorType::parallel: { // Check 1. and 2. above are met. if (seenNonUnitParallel) { - LDBG("Inner parallel dim not requested for scalable " - "vectorization\n"); + LDBG() << "Inner parallel dim not requested for scalable " + "vectorization"; return failure(); } break; @@ -2572,8 +2558,9 @@ vectorizeScalableVectorPrecondition(Operation *op, // * iterators = [..., parallel, reduction] // * scalable flags = [..., true, true] if (iterators.back() == utils::IteratorType::reduction) { - LDBG("Higher dim than the trailing reduction dim requested for scalable " - "vectorization\n"); + LDBG() << "Higher dim than the trailing reduction dim requested for " + "scalable " + "vectorizatio"; return failure(); } scalableFlags.pop_back(); @@ -2656,18 +2643,15 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes, bool createNamedContraction) { - LDBG("Attempting to vectorize:\n" << *op << "\n"); - LDBG("Input vector sizes: "); - LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Input scalable vector dims: "); - LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Attempting to vectorize: " << *op; + LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes); + LDBG() << "Input scalable vector dims: " + << llvm::interleaved(inputScalableVecDims); if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract, flatten1DDepthwiseConv))) { - LDBG("Vectorization pre-conditions failed\n"); + LDBG() << "Vectorization pre-conditions failed"; return failure(); } @@ -2677,7 +2661,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, assumeDynamicDimsMatchVecSizes))) { - LDBG("Vectorization state couldn't be initialized\n"); + LDBG() << "Vectorization state couldn't be initialized"; return failure(); } } @@ -2698,7 +2682,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( return success(); } - LDBG("Unsupported convolution can't be vectorized.\n"); + LDBG() << "Unsupported convolution can't be vectorized."; return failure(); } @@ -2707,8 +2691,9 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( return vectorizeAsLinalgContraction(rewriter, state, linalgOp, results); - LDBG("Vectorize generic by broadcasting to the canonical vector " - "shape\n"); + LDBG() + << "Vectorize generic by broadcasting to the canonical vector " + "shape"; // Pre-process before proceeding. convertAffineApply(rewriter, linalgOp); @@ -2739,7 +2724,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( .Default([](auto) { return failure(); }); if (failed(vectorizeResult)) { - LDBG("Vectorization failed\n"); + LDBG() << "Vectorization failed"; return failure(); } @@ -3244,8 +3229,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values) { if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { - LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp << "\n"); + LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp + << ", second op: " << *secondOp; return true; } for (auto v : values) { @@ -3257,8 +3242,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, if (owner->getBlock() == firstOp->getBlock() && (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; - LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp << "\n"); + LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp + << ", second op: " << *secondOp; return true; } } @@ -3721,8 +3706,8 @@ struct Conv1DGenerator } } - return rewriter - .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding) + return vector::TransferWriteOp::create(rewriter, loc, res, resShaped, + resPadding) .getOperation(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 669fefc..b80b27f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -398,10 +398,9 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, retRows = GMatrix.rows; auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); @@ -422,10 +421,9 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, auto matmulType = RankedTensorType::get({retRows, GTMatrix.cols}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); @@ -547,10 +545,9 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, retRows = BTMatrix.rows; auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); @@ -572,10 +569,9 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, retCols = BMatrix.cols; auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value B = @@ -661,9 +657,8 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, {inputShape[0] * inputShape[1], inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, outputElementType); - Value empty = rewriter - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - outputElementType) + Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(), + outputElementType) .getResult(); Value zero = arith::ConstantOp::create( rewriter, loc, rewriter.getZeroAttr(outputElementType)); @@ -782,9 +777,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); Value init = outInitVal; if (rightTransform || scalarFactor != 1) { - auto empty = builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - elementType) + auto empty = tensor::EmptyOp::create(builder, loc, + matmulType.getShape(), elementType) .getResult(); init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } @@ -802,9 +796,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, RankedTensorType::get({retRows, AMatrix.cols}, elementType); Value init = outInitVal; if (scalarFactor != 1) { - auto empty = builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - elementType) + auto empty = tensor::EmptyOp::create(builder, loc, + matmulType.getShape(), elementType) .getResult(); init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } @@ -827,23 +820,21 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap}; matmulRetValue = - rewriter - .create<linalg::GenericOp>( - loc, matmulType, - ValueRange{scalarFactorValue, matmulRetValue}, - ValueRange{outInitVal}, affineMaps, - llvm::ArrayRef<utils::IteratorType>{ - utils::IteratorType::parallel, - utils::IteratorType::parallel}, - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc, - args[0], args[1]); - auto addf = arith::AddFOp::create( - nestedBuilder, nestedLoc, mulf.getResult(), args[2]); - linalg::YieldOp::create(nestedBuilder, nestedLoc, - addf.getResult()); - }) + linalg::GenericOp::create( + rewriter, loc, matmulType, + ValueRange{scalarFactorValue, matmulRetValue}, + ValueRange{outInitVal}, affineMaps, + llvm::ArrayRef<utils::IteratorType>{ + utils::IteratorType::parallel, utils::IteratorType::parallel}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc, + args[0], args[1]); + auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc, + mulf.getResult(), args[2]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + addf.getResult()); + }) .getResult(0); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 106c3b4..cce80db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), sourceOp.getMixedStrides(), op.getMixedSizes())) { - // We only support static sizes. - if (isa<Value>(opSize)) { - return failure(); - } sizes.push_back(opSize); Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), sourceOffsetAttr = diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 66c1aa6..d5e2b97 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -56,9 +56,8 @@ FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, // Create a memref::SubViewOp. SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); - return b - .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), - strides) + return SubViewOp::create(b, loc, newAllocaOp, offsets, + allocaOp.getMixedSizes(), strides) .getResult(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 1f03e9a..d3a77c0 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -185,9 +185,8 @@ struct CopyOpInterface int64_t dim) -> Value { return type.isDynamicDim(dim) ? DimOp::create(builder, loc, memRef, dim).getResult() - : builder - .create<arith::ConstantIndexOp>(loc, - type.getDimSize(dim)) + : arith::ConstantIndexOp::create(builder, loc, + type.getDimSize(dim)) .getResult(); }; Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i); diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 97fe3cb..5af46a4 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -237,8 +237,8 @@ LogicalResult resolveSourceIndicesExpandShape( llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; }); SmallVector<Value> groupIndices = llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; }); - Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>( - loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); + Value collapsedIndex = affine::AffineLinearizeIndexOp::create( + rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); sourceIndices.push_back(collapsedIndex); } return success(); @@ -250,8 +250,8 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, ValueRange indices, SmallVectorImpl<Value> &sourceIndices) { // Note: collapse_shape requires a strided memref, we can do this. - auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>( - loc, collapseShapeOp.getSrc()); + auto metadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, collapseShapeOp.getSrc()); SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes(); for (auto [index, group] : llvm::zip(indices, collapseShapeOp.getReassociationIndices())) { @@ -265,8 +265,8 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, SmallVector<OpFoldResult> basis = llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); - auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>( - loc, index, basis, /*hasOuterBound=*/true); + auto delinearize = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, index, basis, /*hasOuterBound=*/true); llvm::append_range(sourceIndices, delinearize.getResults()); } if (collapseShapeOp.getReassociationIndices().empty()) { diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index f5f0bfa..bc3e8b2 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -38,9 +38,6 @@ using namespace mlir::NVVM; using namespace mlir::transform; #define DEBUG_TYPE "nvgpu-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") //===----------------------------------------------------------------------===// // Apply...ConversionPatternsOp diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e73bdd3..9d5dfc1 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() { getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static); } +acc::LoopParMode +acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) { + if (hasSeq(deviceType)) + return LoopParMode::loop_seq; + if (hasAuto(deviceType)) + return LoopParMode::loop_auto; + if (hasIndependent(deviceType)) + return LoopParMode::loop_independent; + if (hasSeq()) + return LoopParMode::loop_seq; + if (hasAuto()) + return LoopParMode::loop_auto; + assert(hasIndependent() && + "loop must have default auto, seq, or independent"); + return LoopParMode::loop_independent; +} + void acc::LoopOp::addGangOperands( MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes, llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) { diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 58cd160..9e37bc5 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -148,16 +148,14 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis); auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1); auto shapeLeft = - builder - .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType}, - inputShape, axisValue) + shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType}, + inputShape, axisValue) .getResult(0); auto sizeLeft = shape::NumElementsOp::create(builder, loc, indexType, shapeLeft); auto shapeRight = - builder - .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType}, - inputShape, axisNextValue) + shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType}, + inputShape, axisNextValue) .getResult(1); auto sizeRight = shape::NumElementsOp::create(builder, loc, indexType, shapeRight); @@ -557,25 +555,24 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, SmallVector<AffineMap> indexingMaps{ builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap, channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)}; - auto result = builder - .create<linalg::GenericOp>( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, iteratorTypes, - [&](OpBuilder &builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto input = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = - convertRanked(builder, loc, op, input, {}, scale, - zeroPoint, quantizedType); - - linalg::YieldOp::create(builder, loc, result); - }) + auto result = linalg::GenericOp::create( + builder, loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + linalg::YieldOp::create(builder, loc, result); + }) .getResult(0); return result; @@ -660,25 +657,24 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, SmallVector<AffineMap> indexingMaps{ builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap, builder.getMultiDimIdentityMap(inputRank)}; - auto result = builder - .create<linalg::GenericOp>( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, iteratorTypes, - [&](OpBuilder &builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto input = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = - convertRanked(builder, loc, op, input, {}, scale, - zeroPoint, quantizedType); - - linalg::YieldOp::create(builder, loc, result); - }) + auto result = linalg::GenericOp::create( + builder, loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + linalg::YieldOp::create(builder, loc, result); + }) .getResult(0); return result; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 64c4d60..f8799c5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -497,10 +497,10 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, size_t idx = it.index(); Value val = it.value(); if (tensorIndices.contains(idx)) { - result.push_back(rewriter - .create<bufferization::ToTensorOp>( - val.getLoc(), oldBbArgs[idx].getType(), val) - .getResult()); + result.push_back( + bufferization::ToTensorOp::create(rewriter, val.getLoc(), + oldBbArgs[idx].getType(), val) + .getResult()); } else { result.push_back(val); } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 5982856..1130538 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -19,12 +19,10 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" #define DEBUG_TYPE "scf-loop-pipelining" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::scf; @@ -100,7 +98,7 @@ public: bool LoopPipelinerInternal::initializeLoopInfo( ForOp op, const PipeliningOption &options) { - LDBG("Start initializeLoopInfo"); + LDBG() << "Start initializeLoopInfo"; forOp = op; ub = forOp.getUpperBound(); lb = forOp.getLowerBound(); @@ -109,7 +107,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( std::vector<std::pair<Operation *, unsigned>> schedule; options.getScheduleFn(forOp, schedule); if (schedule.empty()) { - LDBG("--empty schedule -> BAIL"); + LDBG() << "--empty schedule -> BAIL"; return false; } @@ -126,7 +124,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( auto stepCst = getConstantIntValue(step); if (!upperBoundCst || !lowerBoundCst || !stepCst) { if (!options.supportDynamicLoops) { - LDBG("--dynamic loop not supported -> BAIL"); + LDBG() << "--dynamic loop not supported -> BAIL"; return false; } } else { @@ -134,21 +132,21 @@ bool LoopPipelinerInternal::initializeLoopInfo( int64_t lbImm = lowerBoundCst.value(); int64_t stepImm = stepCst.value(); if (stepImm <= 0) { - LDBG("--invalid loop step -> BAIL"); + LDBG() << "--invalid loop step -> BAIL"; return false; } int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); if (numIteration >= maxStage) { dynamicLoop = false; } else if (!options.supportDynamicLoops) { - LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + LDBG() << "--fewer loop iterations than pipeline stages -> BAIL"; return false; } } peelEpilogue = options.peelEpilogue; predicateFn = options.predicateFn; if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { - LDBG("--no epilogue or predicate set -> BAIL"); + LDBG() << "--no epilogue or predicate set -> BAIL"; return false; } @@ -156,13 +154,13 @@ bool LoopPipelinerInternal::initializeLoopInfo( for (Operation &op : forOp.getBody()->without_terminator()) { if (!stages.contains(&op)) { op.emitOpError("not assigned a pipeline stage"); - LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + LDBG() << "--op not assigned a pipeline stage: " << op << " -> BAIL"; return false; } } if (!verifySchedule()) { - LDBG("--invalid schedule: " << op << " -> BAIL"); + LDBG() << "--invalid schedule: " << op << " -> BAIL"; return false; } @@ -173,15 +171,16 @@ bool LoopPipelinerInternal::initializeLoopInfo( (void)stageNum; if (op == forOp.getBody()->getTerminator()) { op->emitError("terminator should not be assigned a stage"); - LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + LDBG() << "--terminator should not be assigned stage: " << *op + << " -> BAIL"; return false; } if (op->getBlock() != forOp.getBody()) { op->emitOpError("the owning Block of all operations assigned a stage " "should be the loop body block"); - LDBG("--the owning Block of all operations assigned a stage " - "should be the loop body block: " - << *op << " -> BAIL"); + LDBG() << "--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"; return false; } } @@ -196,8 +195,8 @@ bool LoopPipelinerInternal::initializeLoopInfo( return !def || (!stages.contains(def) && forOp->isAncestor(def)); })) { - LDBG("--only support loop carried dependency with a distance of 1 or " - "defined outside of the loop -> BAIL"); + LDBG() << "--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"; return false; } annotateFn = options.annotateFn; diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 4025ec6..5731795 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -24,14 +24,12 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cstdint> using namespace mlir; #define DEBUG_TYPE "scf-utils" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, @@ -525,13 +523,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, // If any control operand of any inner loop of `forOp` is defined within // `forOp`, no unroll jam. if (!areInnerBoundsInvariant(forOp)) { - LDBG("failed to unroll and jam: inner bounds are not invariant"); + LDBG() << "failed to unroll and jam: inner bounds are not invariant"; return failure(); } // Currently, for operations with results are not supported. if (forOp->getNumResults() > 0) { - LDBG("failed to unroll and jam: unsupported loop with results"); + LDBG() << "failed to unroll and jam: unsupported loop with results"; return failure(); } @@ -540,16 +538,17 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, std::optional<uint64_t> tripCount = getConstantTripCount(forOp); if (!tripCount.has_value()) { // If the trip count is dynamic, do not unroll & jam. - LDBG("failed to unroll and jam: trip count could not be determined"); + LDBG() << "failed to unroll and jam: trip count could not be determined"; return failure(); } if (unrollJamFactor > *tripCount) { - LDBG("unroll and jam factor is greater than trip count, set factor to trip " - "count"); + LDBG() << "unroll and jam factor is greater than trip count, set factor to " + "trip " + "count"; unrollJamFactor = *tripCount; } else if (*tripCount % unrollJamFactor != 0) { - LDBG("failed to unroll and jam: unsupported trip count that is not a " - "multiple of unroll jam factor"); + LDBG() << "failed to unroll and jam: unsupported trip count that is not a " + "multiple of unroll jam factor"; return failure(); } @@ -828,9 +827,8 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, productOf = v; } if (!productOf) { - productOf = rewriter - .create<arith::ConstantOp>( - loc, rewriter.getOneAttr(getType(values.front()))) + productOf = arith::ConstantOp::create( + rewriter, loc, rewriter.getOneAttr(getType(values.front()))) .getResult(); } return productOf.value(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index f2f7f70..9bee200 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -92,11 +92,13 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { - OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); + auto builder = OpBuilder(op); + spirv::BranchOp::create(builder, op->getLoc(), newDest); op->erase(); } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { - OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, - retValOp->getOperands()); + auto builder = OpBuilder(op); + spirv::BranchOp::create(builder, retValOp->getLoc(), newDest, + retValOp->getOperands()); op->erase(); } } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 81365b4..3911ec0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } auto varPtrType = cast<spirv::PointerType>(varType); - auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType()); + Type pointeeType = varPtrType.getPointeeType(); + + // Images are an opaque type and so we can just return a pointer to an image. + // Note that currently only sampled images are supported in the SPIR-V + // lowering. + if (isa<spirv::SampledImageType>(pointeeType)) + return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType, + varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); + + auto varPointeeType = cast<spirv::StructType>(pointeeType); // Set the offset information. varPointeeType = diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index e24f0f8..5ba8289 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1702,9 +1702,8 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> { return failure(); Location loc = op.getLoc(); Value constShape = - rewriter - .create<ConstShapeOp>(loc, - rewriter.getIndexTensorAttr(type.getShape())) + ConstShapeOp::create(rewriter, loc, + rewriter.getIndexTensorAttr(type.getShape())) .getResult(); if (constShape.getType() != op.getResult().getType()) constShape = tensor::CastOp::create(rewriter, loc, diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5fe5566..3e3d476 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -70,10 +70,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, TypedValue<ShapedType> sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - builder - .create<AllSliceOp>(sourceShard, grid, - ArrayRef<GridAxis>(splitGridAxis), - splitTensorAxis) + AllSliceOp::create(builder, sourceShard, grid, + ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) .getResult()); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); @@ -420,16 +418,15 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, // Finally update the halo. auto updateHaloResult = - builder - .create<UpdateHaloOp>( - sourceShard.getLoc(), - RankedTensorType::get(outShape, - sourceShard.getType().getElementType()), - initOprnd, grid.getSymName(), - GridAxesArrayAttr::get(builder.getContext(), - sourceSharding.getSplitAxes()), - targetSharding.getDynamicHaloSizes(), - targetSharding.getStaticHaloSizes()) + UpdateHaloOp::create( + builder, sourceShard.getLoc(), + RankedTensorType::get(outShape, + sourceShard.getType().getElementType()), + initOprnd, grid.getSymName(), + GridAxesArrayAttr::get(builder.getContext(), + sourceSharding.getSplitAxes()), + targetSharding.getDynamicHaloSizes(), + targetSharding.getStaticHaloSizes()) .getResult(); return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index a52872d..3b4140e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -931,10 +931,9 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, ny, args.drop_back(nTrailingP), createPartitionFunc); - Value p = builder - .create<func::CallOp>(loc, partitionFunc, - TypeRange{IndexType::get(context)}, - args.drop_back(nTrailingP)) + Value p = func::CallOp::create(builder, loc, partitionFunc, + TypeRange{IndexType::get(context)}, + args.drop_back(nTrailingP)) .getResult(0); Value lenLow = arith::SubIOp::create(builder, loc, p, lo); @@ -1028,9 +1027,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, xPerm, ny, operands, createBinarySearchFunc); - Value p = builder - .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, - operands) + Value p = func::CallOp::create(builder, loc, searchFunc, + TypeRange{c1.getType()}, operands) .getResult(0); // Move the value at data[i] to a temporary location. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index a317abd..0bd1d34 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -98,10 +98,10 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, Value numT = constantIndex(builder, loc, numThreads); gpu::KernelDim3 gridSize = {one, one, one}; gpu::KernelDim3 blckSize = {numT, one, one}; - return builder - .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, - /*dynSharedMemSz*/ none, args, - builder.getType<gpu::AsyncTokenType>(), tokens) + return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize, + /*dynSharedMemSz*/ none, args, + builder.getType<gpu::AsyncTokenType>(), + tokens) .getAsyncToken(); } @@ -1168,7 +1168,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; ForallRewriter(MLIRContext *context, unsigned nT) - : OpRewritePattern(context), numThreads(nT){}; + : OpRewritePattern(context), numThreads(nT) {}; LogicalResult matchAndRewrite(scf::ParallelOp forallOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index dfb1274..9cd4896 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -443,8 +443,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, ValueRange inputs, Location loc) -> Value { - return builder - .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) + return UnrealizedConversionCastOp::create(builder, loc, TypeRange(spTp), + inputs) .getResult(0); }); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 70795e2..7a26cd3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -412,13 +412,13 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, if (memTp.getRank() > 1) return mem; // Truncate linear memrefs to given size. - return builder - .create<memref::SubViewOp>( - loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), - mem, ValueRange{}, ValueRange{sz}, ValueRange{}, - ArrayRef<int64_t>{0}, // static offset - ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size - ArrayRef<int64_t>{1}) // static stride + return memref::SubViewOp::create( + builder, loc, + MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), + mem, ValueRange{}, ValueRange{sz}, ValueRange{}, + ArrayRef<int64_t>{0}, // static offset + ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size + ArrayRef<int64_t>{1}) // static stride .getResult(); } @@ -449,7 +449,7 @@ class SparseInsertGenerator public: SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, bool genCall) - : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; + : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {}; /// Generates code along an insertion path without the need for a "cursor". /// This current insertion strategy comes at the expense of some testing diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index b444ac5..79f4e7f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -904,9 +904,8 @@ public: dstTp->withoutDimToLvl(), !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); SmallVector<Value> dynSizes; - Value buffer = rewriter - .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), - nnz, Attribute()) + Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes, + Value(), nnz, Attribute()) .getResult(); // Convert src coordinates to dst coordinates by first collapsing it to 1D @@ -1013,9 +1012,8 @@ public: !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); Value buffer = - rewriter - .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(), - /*sizeHint=*/nnz, Attribute()) + AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(), + /*sizeHint=*/nnz, Attribute()) .getResult(); // Implement the sparse2sparse reshape as follows: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 0e96b59..869d27a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -115,8 +115,7 @@ public: bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, + if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bc11e56..c3356c1 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -784,8 +784,8 @@ struct PadOpInterface auto toValue = [&](OpFoldResult ofr) { if (auto value = dyn_cast<Value>(ofr)) return value; - return rewriter - .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) + return arith::ConstantIndexOp::create(rewriter, loc, + *getConstantIntValue(ofr)) .getResult(); }; @@ -919,9 +919,8 @@ struct ReshapeOpInterface auto memrefType = MemRefType::get( srcType.getShape(), srcType.getElementType(), AffineMap(), cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); - srcBuffer = rewriter - .create<bufferization::ToBufferOp>( - op->getLoc(), memrefType, *tensorAlloc) + srcBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(), + memrefType, *tensorAlloc) .getResult(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 43d9d70..9fd27d3 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -130,8 +130,7 @@ FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b, // Create a tensor::ExtractSliceOp. SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); - return b - .create<ExtractSliceOp>(loc, newEmptyOp, offsets, emptyOp.getMixedSizes(), - strides) + return ExtractSliceOp::create(b, loc, newEmptyOp, offsets, + emptyOp.getMixedSizes(), strides) .getResult(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index e0af2f7..2ec23e1 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -385,10 +385,9 @@ struct BubbleUpExpandShapeThroughExtractSlice return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); }); OpFoldResult collapsedOffset = - rewriter - .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, - reassocGroupSizes, - /*disjoint=*/true) + affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals, + reassocGroupSizes, + /*disjoint=*/true) .getResult(); collapsedOffsets.push_back(collapsedOffset); collapsedSizes.push_back(collapsedSize); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 1ad2c80..6d2cbb5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -707,9 +707,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> { auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); replaceWithSlice = - rewriter - .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(), - input, start_op, size_op) + tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(), + input, start_op, size_op) .getResult(); break; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ecd93ff..3cafb19 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } +static void printInitializationList(OpAsmPrinter &parser, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + parser << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), parser, + [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); + parser << ")"; +} + // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. @@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); - auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + + if (parser.parseOperand(cond)) return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) + + SmallVector<OpAsmParser::Argument, 4> regionArgs; + SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; + + // Parse the optional block arguments + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.has_value() && failed(listResult.value())) return failure(); + + // Parse a colon. + if (failed(parser.parseColon())) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Parse the type of the condition operand + Type condType; + if (failed(parser.parseType(condType))) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Resolve operand with provided type + if (failed(parser.resolveOperand(cond, condType, result.operands))) + return failure(); + + // Parse optional block arg types + if (listResult.has_value()) { + FunctionType functionType; + + if (failed(parser.parseType(functionType))) + return parser.emitError(parser.getCurrentLocation()) + << "expected list of types for block arguments " + << "followed by arrow type and list of return types"; + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + } else { + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } + // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; + + printInitializationList(p, getThenGraph().front().getArguments(), + getInputList(), " "); + p << " : "; + p << getCondition().getType(); + + if (!getInputList().empty()) { + p << " ("; + llvm::interleaveComma(getInputList().getTypes(), p); + p << ")"; } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printArrowTypeList(getResultTypes()); + p << " "; + + p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOptionalAttrDictWithKeyword(result.attributes)); } -static void printInitializationList(OpAsmPrinter &parser, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - parser << prefix << '('; - llvm::interleaveComma( - llvm::zip(blocksArgs, initializers), parser, - [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); - parser << ")"; -} - void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 9474299..0bec0da 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -81,9 +81,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { dyn_cast<RankedTensorType>(input.getType()).getElementType()); auto revisedInputShapeValue = getTosaConstShape(rewriter, op.getLoc(), revisedInputShape); - input = rewriter - .create<tosa::ReshapeOp>(op.getLoc(), inputType, input, - revisedInputShapeValue) + input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input, + revisedInputShapeValue) .getResult(); Type resultETy = resultType.getElementType(); @@ -162,9 +161,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); Value constZero = tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr); - Value mulValue = rewriter - .create<tosa::MulOp>(op.getLoc(), mulShapeType, input, - weight, constZero) + Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType, + input, weight, constZero) .getResult(); // Reshape output to [N, H, W, C * M]. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 32b5fb6..8ec7765 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) { // }) // // Simplified: - // %0 = tosa.cond_if %arg2 { - // tosa.yield %arg0 + // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 // } else { - // tosa.yield %arg1 + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 // } - // - // Unfortunately, the simplified syntax does not encapsulate values - // used in then/else regions (see 'simplified' example above), so it - // must be rewritten to use the generic syntax in order to be conformant - // to the specification. + return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 4662836..14a4fdf 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -16,15 +16,13 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/iterator.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "transform-dialect" -#define DEBUG_TYPE_FULL "transform-dialect-full" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X)) -#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X))) +#define FULL_LDBG() LDBG(4) using namespace mlir; @@ -486,24 +484,20 @@ void transform::TransformState::recordOpHandleInvalidationOne( newlyInvalidated.count(otherHandle)) return; - FULL_LDBG("--recordOpHandleInvalidationOne\n"); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "--ancestors: " - << llvm::interleaved(llvm::make_pointee_range(potentialAncestors)) - << "\n"); - }); + FULL_LDBG() << "--recordOpHandleInvalidationOne"; + FULL_LDBG() << "--ancestors: " + << llvm::interleaved( + llvm::make_pointee_range(potentialAncestors)); Operation *owner = consumingHandle.getOwner(); unsigned operandNo = consumingHandle.getOperandNumber(); for (Operation *ancestor : potentialAncestors) { // clang-format off - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----of payload with name: " - << payloadOp->getName().getIdentifier() << "\n"); }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----of payload: " << *payloadOp << "\n"); }); + FULL_LDBG() << "----handle one ancestor: " << *ancestor;; + + FULL_LDBG() << "----of payload with name: " + << payloadOp->getName().getIdentifier(); + FULL_LDBG() << "----of payload: " << *payloadOp; // clang-format on if (!ancestor->isAncestor(payloadOp)) continue; @@ -609,10 +603,8 @@ void transform::TransformState::recordOpHandleInvalidation( transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { if (potentialAncestors.empty()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "----recording invalidation for empty handle: " << handle.get() - << "\n"); - }); + FULL_LDBG() << "----recording invalidation for empty handle: " + << handle.get(); Operation *owner = handle.getOwner(); unsigned operandNo = handle.getOperandNumber(); @@ -709,7 +701,7 @@ void transform::TransformState::recordValueHandleInvalidation( LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( transform::TransformOpInterface transform, transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { - FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); + FULL_LDBG() << "--Start checkAndRecordHandleInvalidation"; auto memoryEffectsIface = cast<MemoryEffectOpInterface>(transform.getOperation()); SmallVector<MemoryEffects::EffectInstance> effects; @@ -717,9 +709,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( transform::TransformMappingResource::get(), effects); for (OpOperand &target : transform->getOpOperands()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "----iterate on handle: " << target.get() << "\n"); - }); + FULL_LDBG() << "----iterate on handle: " << target.get(); // If the operand uses an invalidated handle, report it. If the operation // allows handles to point to repeated payload operations, only report // pre-existing invalidation errors. Otherwise, also report invalidations @@ -727,14 +717,14 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( auto it = invalidatedHandles.find(target.get()); auto nit = newlyInvalidated.find(target.get()); if (it != invalidatedHandles.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " - "invalidated -> FAILURE\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already " + "invalidated -> FAILURE"; return it->getSecond()(transform->getLoc()), failure(); } if (!transform.allowsRepeatedHandleOperands() && nit != newlyInvalidated.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " - "invalidated (by this op) -> FAILURE\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly " + "invalidated (by this op) -> FAILURE"; return nit->getSecond()(transform->getLoc()), failure(); } @@ -745,27 +735,28 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( effect.getValue() == target.get(); }; if (llvm::any_of(effects, consumesTarget)) { - FULL_LDBG("----found consume effect\n"); + FULL_LDBG() << "----found consume effect"; if (llvm::isa<transform::TransformHandleTypeInterface>( target.get().getType())) { - FULL_LDBG("----recordOpHandleInvalidation\n"); + FULL_LDBG() << "----recordOpHandleInvalidation"; SmallVector<Operation *> payloadOps = llvm::to_vector(getPayloadOps(target.get())); recordOpHandleInvalidation(target, payloadOps, nullptr, newlyInvalidated); } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( target.get().getType())) { - FULL_LDBG("----recordValueHandleInvalidation\n"); + FULL_LDBG() << "----recordValueHandleInvalidation"; recordValueHandleInvalidation(target, newlyInvalidated); } else { - FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); + FULL_LDBG() + << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR"; } } else { - FULL_LDBG("----no consume effect -> SKIP\n"); + FULL_LDBG() << "----no consume effect -> SKIP"; } } - FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS"; return success(); } @@ -818,18 +809,14 @@ void transform::TransformState::compactOpHandles() { DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { - LLVM_DEBUG({ - DBGS() << "applying: "; - transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"; - }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - DBGS() << "Top-level payload before application:\n" - << *getTopLevel() << "\n"); + LDBG() << "applying: " + << OpWithFlags(transform, OpPrintingFlags().skipRegions()); + FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel(); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; - LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( - llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); + LDBG() << "Failing Top-level payload:\n" + << OpWithFlags(getTopLevel(), + OpPrintingFlags().printGenericOpForm()); }); // Set current transform op. @@ -837,47 +824,45 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // Expensive checks to detect invalid transform IR. if (options.getExpensiveChecksEnabled()) { - FULL_LDBG("ExpensiveChecksEnabled\n"); + FULL_LDBG() << "ExpensiveChecksEnabled"; if (failed(checkAndRecordHandleInvalidation(transform))) return DiagnosedSilenceableFailure::definiteFailure(); for (OpOperand &operand : transform->getOpOperands()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "iterate on handle: " << operand.get() << "\n"); - }); + FULL_LDBG() << "iterate on handle: " << operand.get(); if (!isHandleConsumed(operand.get(), transform)) { - FULL_LDBG("--handle not consumed -> SKIP\n"); + FULL_LDBG() << "--handle not consumed -> SKIP"; continue; } if (transform.allowsRepeatedHandleOperands()) { - FULL_LDBG("--op allows repeated handles -> SKIP\n"); + FULL_LDBG() << "--op allows repeated handles -> SKIP"; continue; } - FULL_LDBG("--handle is consumed\n"); + FULL_LDBG() << "--handle is consumed"; Type operandType = operand.get().getType(); if (llvm::isa<TransformHandleTypeInterface>(operandType)) { - FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); + FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*"; DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand<Operation *>( getPayloadOpsView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { - FULL_LDBG("----FAILED\n"); + FULL_LDBG() << "----FAILED"; return check; } } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) { - FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); + FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value"; DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand<Value>( getPayloadValuesView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { - FULL_LDBG("----FAILED\n"); + FULL_LDBG() << "----FAILED"; return check; } } else { - FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); + FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR"; } } } @@ -999,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { - DBGS() << "Top-level payload:\n"; - getTopLevel()->print(llvm::dbgs()); + LDBG() << "Top-level payload:\n" << *getTopLevel(); }); return result; } @@ -1277,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - DBGS() << "Match Failure : " << diag.str() << "\n"; + LDBG() << "Match Failure : " << diag.str(); }); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 4e9f93b..8789f55 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -372,9 +372,8 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc, llvm::transform(foldResults, std::back_inserter(values), [&](OpFoldResult foldResult) { if (auto attr = dyn_cast<Attribute>(foldResult)) - return builder - .create<arith::ConstantIndexOp>( - loc, cast<IntegerAttr>(attr).getInt()) + return arith::ConstantIndexOp::create( + builder, loc, cast<IntegerAttr>(attr).getInt()) .getResult(); return cast<Value>(foldResult); @@ -1259,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -3185,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 2484670..e062f55 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -248,11 +248,10 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { scf::YieldOp::create(b, loc, result); }; - result = - rewriter - .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder, + result = scf::IfOp::create(rewriter, loc, condition, + /*thenBuilder=*/loadBuilder, /*elseBuilder=*/passThruBuilder) - .getResult(0); + .getResult(0); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index e910932..4baeb11 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -142,8 +142,8 @@ struct TransferReadPermutationLowering // Transpose result of transfer_read. SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); - return rewriter - .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm) + return vector::TransposeOp::create(rewriter, op.getLoc(), newRead, + transposePerm) .getResult(); } }; @@ -371,8 +371,8 @@ struct TransferOpReduceRank rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); - return rewriter - .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) + return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType, + newRead) .getVector(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 58e94ea..bb0f339 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -451,10 +451,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { } SmallVector<Value> delinearized; if (map.getNumResults() > 1) { - delinearized = rewriter - .create<mlir::affine::AffineDelinearizeIndexOp>( - newWarpOp.getLoc(), newWarpOp.getLaneid(), - delinearizedIdSizes) + delinearized = mlir::affine::AffineDelinearizeIndexOp::create( + rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(), + delinearizedIdSizes) .getResults(); } else { // If there is only one map result, we can elide the delinearization @@ -1538,19 +1537,18 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); Value newResult = - rewriter - .create<scf::IfOp>( - loc, isInsertingLane, - /*thenBuilder=*/ - [&](OpBuilder &builder, Location loc) { - Value newInsert = vector::InsertOp::create( - builder, loc, newSource, distributedVec, newPos); - scf::YieldOp::create(builder, loc, newInsert); - }, - /*elseBuilder=*/ - [&](OpBuilder &builder, Location loc) { - scf::YieldOp::create(builder, loc, distributedVec); - }) + scf::IfOp::create( + rewriter, loc, isInsertingLane, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + Value newInsert = vector::InsertOp::create( + builder, loc, newSource, distributedVec, newPos); + scf::YieldOp::create(builder, loc, newInsert); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + scf::YieldOp::create(builder, loc, distributedVec); + }) .getResult(0); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); @@ -1661,10 +1659,9 @@ struct WarpOpInsert : public WarpDistributionPattern { auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { scf::YieldOp::create(builder, loc, distributedDest); }; - newResult = rewriter - .create<scf::IfOp>(loc, isInsertingLane, - /*thenBuilder=*/insertingBuilder, - /*elseBuilder=*/nonInsertingBuilder) + newResult = scf::IfOp::create(rewriter, loc, isInsertingLane, + /*thenBuilder=*/insertingBuilder, + /*elseBuilder=*/nonInsertingBuilder) .getResult(0); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 73388a5..9889d7f2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -466,9 +466,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); } - return rewriter - .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], - newOp->getResults()[0]) + return vector::BroadcastOp::create(rewriter, loc, + contractOp->getResultTypes()[0], + newOp->getResults()[0]) .getResult(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index e6bb96f..f78e579 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -32,7 +32,7 @@ #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <cstdint> @@ -41,9 +41,6 @@ using namespace mlir; #define DEBUG_TYPE "vector-narrow-type-emulation" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using VectorValue = TypedValue<VectorType>; using MemRefValue = TypedValue<MemRefType>; @@ -135,17 +132,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, return vector::CreateMaskOp::create(rewriter, loc, newMaskType, newMaskOperands); }) - .Case<vector::ConstantMaskOp>( - [&](auto constantMaskOp) -> std::optional<Operation *> { - // Take the shape of mask, compress its trailing dimension: - SmallVector<int64_t> maskDimSizes( - constantMaskOp.getMaskDimSizes()); - int64_t &maskIndex = maskDimSizes.back(); - maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, - numSrcElemsPerDest); - return vector::ConstantMaskOp::create( - rewriter, loc, newMaskType, maskDimSizes); - }) + .Case<vector::ConstantMaskOp>([&](auto constantMaskOp) + -> std::optional<Operation *> { + // Take the shape of mask, compress its trailing dimension: + SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes()); + int64_t &maskIndex = maskDimSizes.back(); + maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, + numSrcElemsPerDest); + return vector::ConstantMaskOp::create(rewriter, loc, newMaskType, + maskDimSizes); + }) .Case<arith::ConstantOp>([&](auto constantOp) -> std::optional<Operation *> { // TODO: Support multiple dimensions. @@ -232,9 +228,8 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, auto resultVectorType = VectorType::get({numElemsToExtract}, vectorType.getElementType()); - return rewriter - .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src, - offsets, sizes, strides) + return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType, + src, offsets, sizes, strides) ->getResult(0); } @@ -1526,11 +1521,11 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, "requires -D non-scalable vector type"); int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth(); int64_t mostMinorSourceDim = sourceVectorType.getShape().back(); - LDBG("sourceVectorType: " << sourceVectorType); + LDBG() << "sourceVectorType: " << sourceVectorType; int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth(); int64_t mostMinorTargetDim = targetVectorType.getShape().back(); - LDBG("targetVectorType: " << targetVectorType); + LDBG() << "targetVectorType: " << targetVectorType; int64_t bitwidth = targetBitWidth * mostMinorTargetDim; (void)mostMinorSourceDim; @@ -1555,7 +1550,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType) : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) { - LDBG("\n" << enumerator.sourceElementRanges); + LDBG() << "\n" << enumerator.sourceElementRanges; } /// Verify that the precondition type meets the common preconditions for any diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 2676d25..48d680c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -330,8 +330,8 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, } reducedOperands.push_back(operand); } - return rewriter - .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) + return vector::CreateMaskOp::create(rewriter, loc, reducedType, + reducedOperands) .getResult(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index 05b0074..5e12dc4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -348,24 +348,23 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, Location loc = xferOp.getLoc(); Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b - .create<scf::IfOp>( - loc, inBoundsCond, - [&](OpBuilder &b, Location loc) { - Value res = - castToCompatibleMemRefType(b, memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - llvm::append_range(viewAndIndices, xferOp.getIndices()); - scf::YieldOp::create(b, loc, viewAndIndices); - }, - [&](OpBuilder &b, Location loc) { - Value casted = - castToCompatibleMemRefType(b, alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), - xferOp.getTransferRank(), zero); - scf::YieldOp::create(b, loc, viewAndIndices); - }) + return scf::IfOp::create( + b, loc, inBoundsCond, + [&](OpBuilder &b, Location loc) { + Value res = + castToCompatibleMemRefType(b, memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + llvm::append_range(viewAndIndices, xferOp.getIndices()); + scf::YieldOp::create(b, loc, viewAndIndices); + }, + [&](OpBuilder &b, Location loc) { + Value casted = + castToCompatibleMemRefType(b, alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), + xferOp.getTransferRank(), zero); + scf::YieldOp::create(b, loc, viewAndIndices); + }) ->getResults(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 73ca327..c51c7b7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -410,9 +410,8 @@ FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp, oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count()); VectorType maskOpType = VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims); - mask = rewriter - .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType, - maskingOp.getMask()) + mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(), + maskOpType, maskingOp.getMask()) .getResult(); } @@ -1006,26 +1005,39 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } - // Get the type of the lhs operand - auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); - if (!lhsBcastOrSplat || - !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) + // Get the type of the first non-constant operand + Operation *firstBroadcastOrSplat = nullptr; + for (Value operand : op->getOperands()) { + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + if (definingOp->hasTrait<OpTrait::ConstantLike>()) + continue; + if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp)) + return failure(); + firstBroadcastOrSplat = definingOp; + break; + } + if (!firstBroadcastOrSplat) return failure(); - auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); + Type firstBroadcastOrSplatType = + firstBroadcastOrSplat->getOperand(0).getType(); // Make sure that all operands are broadcast from identical types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { - auto bcast = val.getDefiningOp<vector::BroadcastOp>(); - if (bcast) - return (bcast.getOperand().getType() == lhsBcastOrSplatType); - auto splat = val.getDefiningOp<vector::SplatOp>(); - if (splat) - return (splat.getOperand().getType() == lhsBcastOrSplatType); - return false; - })) { + if (!llvm::all_of( + op->getOperands(), [&firstBroadcastOrSplatType](Value val) { + if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>()) + return (bcastOp.getOperand().getType() == + firstBroadcastOrSplatType); + if (auto splatOp = val.getDefiningOp<vector::SplatOp>()) + return (splatOp.getOperand().getType() == + firstBroadcastOrSplatType); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); + })) { return failure(); } @@ -1033,13 +1045,28 @@ struct ReorderElementwiseOpsOnBroadcast final SmallVector<Value> srcValues; srcValues.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + SplatElementsAttr splatConst; + if (matchPattern(operand, m_Constant(&splatConst))) { + Attribute newConst; + if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) { + newConst = splatConst.resizeSplat(shapedTy); + } else { + newConst = splatConst.getSplatValue<Attribute>(); + } + Operation *newConstOp = + operand.getDefiningOp()->getDialect()->materializeConstant( + rewriter, newConst, firstBroadcastOrSplatType, + operand.getLoc()); + srcValues.push_back(newConstOp->getResult(0)); + } else { + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + } } // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - lhsBcastOrSplatType, op->getAttrs()); + firstBroadcastOrSplatType, op->getAttrs()); // Replace the original Op with the elementwise Op auto vectorType = op->getResultTypes()[0]; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 062c51f..501abec 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -16,13 +16,11 @@ #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include <optional> #define DEBUG_TYPE "vector-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::vector; @@ -90,10 +88,9 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, /// std::nullopt if the op shouldn't be or cannot be unrolled. static std::optional<SmallVector<int64_t>> getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { - LDBG(""); - LDBG("Get unroll shape for op " << op->getName().getStringRef()); + LDBG() << "Get unroll shape for op " << op->getName().getStringRef(); if (options.filterConstraint && failed(options.filterConstraint(op))) { - LDBG("--no filter constraint -> BAIL"); + LDBG() << "--no filter constraint -> BAIL"; return std::nullopt; } assert(options.nativeShape && @@ -101,33 +98,33 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { "shape call back function to be set"); auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); if (!unrollableVectorOp) { - LDBG("--not an unrollable op -> BAIL"); + LDBG() << "--not an unrollable op -> BAIL"; return std::nullopt; } auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) { - LDBG("--could not get shape of op " << *op << " -> BAIL"); + LDBG() << "--could not get shape of op " << *op << " -> BAIL"; return std::nullopt; } - LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape)); + LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape); std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); if (!targetShape) { - LDBG("--no unrolling target shape defined " << *op << "-> SKIP"); + LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP"; return std::nullopt; } - LDBG("--target shape: " << llvm::interleaved(*targetShape)); + LDBG() << "--target shape: " << llvm::interleaved(*targetShape); auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio) { - LDBG("--could not compute integral shape ratio -> BAIL"); + LDBG() << "--could not compute integral shape ratio -> BAIL"; return std::nullopt; } if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { - LDBG("--no unrolling needed -> SKIP"); + LDBG() << "--no unrolling needed -> SKIP"; return std::nullopt; } - LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); + LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS"; return targetShape; } diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c045063..10ed2bc 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -27,13 +27,11 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "vector-utils" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - using namespace mlir; /// Helper function that creates a memref::DimOp or tensor::DimOp depending on @@ -369,14 +367,14 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, LogicalResult vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, ArrayRef<int64_t> inputVectorSizes) { - LDBG("Iteration space static sizes:" << llvm::interleaved(shape)); + LDBG() << "Iteration space static sizes:" << llvm::interleaved(shape); if (inputVectorSizes.size() != shape.size()) { - LDBG("Input vector sizes don't match the number of loops"); + LDBG() << "Input vector sizes don't match the number of loops"; return failure(); } if (ShapedType::isDynamicShape(inputVectorSizes)) { - LDBG("Input vector sizes can't have dynamic dimensions"); + LDBG() << "Input vector sizes can't have dynamic dimensions"; return failure(); } if (!llvm::all_of(llvm::zip(shape, inputVectorSizes), @@ -386,8 +384,9 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, return ShapedType::isDynamic(staticSize) || staticSize <= inputSize; })) { - LDBG("Input vector sizes must be greater than or equal to iteration space " - "static sizes"); + LDBG() << "Input vector sizes must be greater than or equal to iteration " + "space " + "static sizes"; return failure(); } return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 4656f11..d82c541 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { @@ -26,8 +27,6 @@ namespace xegpu { } // namespace mlir #define DEBUG_TYPE "xegpu-blocking" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; @@ -53,7 +52,7 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) { // We only interest in the case where all inputs and outputs have the // identical VectorTypes if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) { - LDBG("skip unrealized conversion cast op not emulating pack/unpack."); + LDBG() << "skip unrealized conversion cast op not emulating pack/unpack."; return; } @@ -149,7 +148,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); } - LDBG("failed to getTileShape for: " << value); + LDBG() << "failed to getTileShape for: " << value; return std::nullopt; } @@ -214,7 +213,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { return layout && layout.isWgLayout(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { - LDBG("skip unrolling for op with workgroup level layout: " << *op); + LDBG() << "skip unrolling for op with workgroup level layout: " << *op; return false; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a6208b4..ec8fad4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { @@ -27,8 +27,6 @@ namespace xegpu { } // namespace mlir #define DEBUG_TYPE "xegpu-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; @@ -44,11 +42,10 @@ protected: /// Return the target shape for the given `op`. Return std::nullopt if the /// op shouldn't be or cannot be unrolled. std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const { - LDBG(""); - LDBG("Get unroll shape for: " << *op); + LDBG() << "Get unroll shape for: " << *op; if (options.filterConstraint && failed(options.filterConstraint(op))) { - LDBG("--no filter constraint -> BAIL"); + LDBG() << "--no filter constraint -> BAIL"; return std::nullopt; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 229a289..850f70c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -207,7 +207,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { // Subtract startOfRange from the original subgroup id to get the adjusted // sg id Value startOfRangeVal = - rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); + arith::ConstantIndexOp::create(rewriter, loc, startOfRange); adjustedSgId = rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); } @@ -431,8 +431,8 @@ struct WgToSgVectorBroadcastOp SmallVector<Value> newBroadcastOps; for (auto operand : adaptor.getOperands().front()) { - auto newBroadcast = rewriter.create<vector::BroadcastOp>( - op.getLoc(), newResultType, operand); + auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), + newResultType, operand); xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); @@ -563,8 +563,8 @@ struct WgToSgConvertLayoutOp if (input && target) { // keep the ConvertLayoutOp for rest fields, e.g., inst_data. for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { - auto newOp = rewriter.create<xegpu::ConvertLayoutOp>( - op.getLoc(), src.getType(), src, input, target); + auto newOp = xegpu::ConvertLayoutOp::create( + rewriter, op.getLoc(), src.getType(), src, input, target); newOps[i] = newOp; } } diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 3e33795..776b5c6 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i) (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1)); - // Register a handler to verify the diagnostics. - setHandler([&](Diagnostic &diag) { - // Process the main diagnostics. - process(diag); - - // Process each of the notes. - for (auto ¬e : diag.getNotes()) - process(note); - }); + registerInContext(ctx); } SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( @@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() { return impl->status; } +void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) { + ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) { + // Process the main diagnostics. + process(diag); + + // Process each of the notes. + for (auto ¬e : diag.getNotes()) + process(note); + }); +} + /// Process a single diagnostic. void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) { return process(diag.getLocation(), diag.str(), diag.getSeverity()); diff --git a/mlir/lib/IR/PatternLoggingListener.cpp b/mlir/lib/IR/PatternLoggingListener.cpp index ce2123a..0db13ab 100644 --- a/mlir/lib/IR/PatternLoggingListener.cpp +++ b/mlir/lib/IR/PatternLoggingListener.cpp @@ -1,50 +1,48 @@ #include "mlir/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "pattern-logging-listener" -#define DBGS() (llvm::dbgs() << "[" << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; void RewriterBase::PatternLoggingListener::notifyOperationInserted( Operation *op, InsertPoint previous) { - LDBG(patternName << " | notifyOperationInserted" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationInserted" + << " | " << op->getName(); ForwardingListener::notifyOperationInserted(op, previous); } void RewriterBase::PatternLoggingListener::notifyOperationModified( Operation *op) { - LDBG(patternName << " | notifyOperationModified" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationModified" + << " | " << op->getName(); ForwardingListener::notifyOperationModified(op); } void RewriterBase::PatternLoggingListener::notifyOperationReplaced( Operation *op, Operation *newOp) { - LDBG(patternName << " | notifyOperationReplaced (with op)" - << " | " << op->getName() << " | " << newOp->getName()); + LDBG() << patternName << " | notifyOperationReplaced (with op)" + << " | " << op->getName() << " | " << newOp->getName(); ForwardingListener::notifyOperationReplaced(op, newOp); } void RewriterBase::PatternLoggingListener::notifyOperationReplaced( Operation *op, ValueRange replacement) { - LDBG(patternName << " | notifyOperationReplaced (with values)" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationReplaced (with values)" + << " | " << op->getName(); ForwardingListener::notifyOperationReplaced(op, replacement); } void RewriterBase::PatternLoggingListener::notifyOperationErased( Operation *op) { - LDBG(patternName << " | notifyOperationErased" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationErased" + << " | " << op->getName(); ForwardingListener::notifyOperationErased(op); } void RewriterBase::PatternLoggingListener::notifyPatternBegin( const Pattern &pattern, Operation *op) { - LDBG(patternName << " | notifyPatternBegin" - << " | " << op->getName()); + LDBG() << patternName << " | notifyPatternBegin" + << " | " << op->getName(); ForwardingListener::notifyPatternBegin(pattern, op); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5c98417..9332f55 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present<Listener>(listener); + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0db9808..7094c8e 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; - pipelineKey = pipelineInitializationKey; + pipelineInitializationKey = pipelineKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp new file mode 100644 index 0000000..7a345ed --- /dev/null +++ b/mlir/lib/RegisterAllDialects.cpp @@ -0,0 +1,207 @@ +//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialects and +// passes to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllDialects.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVM/ROCDL/Target.h" +#include "mlir/Target/SPIRV/Target.h" + +/// Add all the MLIR dialects to the provided registry. +void mlir::registerAllDialects(DialectRegistry ®istry) { + // clang-format off + registry.insert<acc::OpenACCDialect, + affine::AffineDialect, + amdgpu::AMDGPUDialect, + amx::AMXDialect, + arith::ArithDialect, + arm_neon::ArmNeonDialect, + arm_sme::ArmSMEDialect, + arm_sve::ArmSVEDialect, + async::AsyncDialect, + bufferization::BufferizationDialect, + cf::ControlFlowDialect, + complex::ComplexDialect, + DLTIDialect, + emitc::EmitCDialect, + func::FuncDialect, + gpu::GPUDialect, + index::IndexDialect, + irdl::IRDLDialect, + linalg::LinalgDialect, + LLVM::LLVMDialect, + math::MathDialect, + memref::MemRefDialect, + shard::ShardDialect, + ml_program::MLProgramDialect, + mpi::MPIDialect, + nvgpu::NVGPUDialect, + NVVM::NVVMDialect, + omp::OpenMPDialect, + pdl::PDLDialect, + pdl_interp::PDLInterpDialect, + ptr::PtrDialect, + quant::QuantDialect, + ROCDL::ROCDLDialect, + scf::SCFDialect, + shape::ShapeDialect, + smt::SMTDialect, + sparse_tensor::SparseTensorDialect, + spirv::SPIRVDialect, + tensor::TensorDialect, + tosa::TosaDialect, + transform::TransformDialect, + ub::UBDialect, + vector::VectorDialect, + x86vector::X86VectorDialect, + xegpu::XeGPUDialect, + xevm::XeVMDialect>(); + // clang-format on + + // Register all external models. + affine::registerValueBoundsOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + arith::registerShardingInterfaceExternalModels(registry); + arith::registerValueBoundsOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + builtin::registerCastOpInterfaceExternalModels(registry); + cf::registerBufferizableOpInterfaceExternalModels(registry); + cf::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerValueBoundsOpInterfaceExternalModels(registry); + LLVM::registerInlinerInterface(registry); + NVVM::registerInlinerInterface(registry); + linalg::registerAllDialectInterfaceImplementations(registry); + linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerAllocationOpInterfaceExternalModels(registry); + memref::registerBufferViewFlowOpInterfaceExternalModels(registry); + memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerValueBoundsOpInterfaceExternalModels(registry); + memref::registerMemorySlotExternalModels(registry); + ml_program::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferDeallocationOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerValueBoundsOpInterfaceExternalModels(registry); + shape::registerBufferizableOpInterfaceExternalModels(registry); + sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); + tosa::registerShardingInterfaceExternalModels(registry); + vector::registerBufferizableOpInterfaceExternalModels(registry); + vector::registerSubsetOpInterfaceExternalModels(registry); + vector::registerValueBoundsOpInterfaceExternalModels(registry); + NVVM::registerNVVMTargetInterfaceExternalModels(registry); + ROCDL::registerROCDLTargetInterfaceExternalModels(registry); + spirv::registerSPIRVTargetInterfaceExternalModels(registry); +} + +/// Append all the MLIR dialects to the registry contained in the given context. +void mlir::registerAllDialects(MLIRContext &context) { + DialectRegistry registry; + registerAllDialects(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp new file mode 100644 index 0000000..8f7c67c --- /dev/null +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -0,0 +1,115 @@ +//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialect +// extensions to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllExtensions.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" +#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" +#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" + +/// This function may be called to register all MLIR dialect extensions with the +/// provided registry. +/// If you're building a compiler, you generally shouldn't use this: you would +/// individually register the specific extensions that are useful for the +/// pipelines and transformations you are using. +void mlir::registerAllExtensions(DialectRegistry ®istry) { + // Register all conversions to LLVM extensions. + registerConvertArithToEmitCInterface(registry); + arith::registerConvertArithToLLVMInterface(registry); + registerConvertComplexToLLVMInterface(registry); + cf::registerConvertControlFlowToLLVMInterface(registry); + func::registerAllExtensions(registry); + tensor::registerAllExtensions(registry); + registerConvertFuncToEmitCInterface(registry); + registerConvertFuncToLLVMInterface(registry); + index::registerConvertIndexToLLVMInterface(registry); + registerConvertMathToLLVMInterface(registry); + mpi::registerConvertMPIToLLVMInterface(registry); + registerConvertMemRefToEmitCInterface(registry); + registerConvertMemRefToLLVMInterface(registry); + registerConvertNVVMToLLVMInterface(registry); + registerConvertOpenMPToLLVMInterface(registry); + registerConvertSCFToEmitCInterface(registry); + ub::registerConvertUBToLLVMInterface(registry); + registerConvertAMXToLLVMInterface(registry); + gpu::registerConvertGpuToLLVMInterface(registry); + NVVM::registerConvertGpuToNVVMInterface(registry); + vector::registerConvertVectorToLLVMInterface(registry); + registerConvertXeVMToLLVMInterface(registry); + + // Register all transform dialect extensions. + affine::registerTransformDialectExtension(registry); + bufferization::registerTransformDialectExtension(registry); + dlti::registerTransformDialectExtension(registry); + func::registerTransformDialectExtension(registry); + gpu::registerTransformDialectExtension(registry); + linalg::registerTransformDialectExtension(registry); + memref::registerTransformDialectExtension(registry); + nvgpu::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); + sparse_tensor::registerTransformDialectExtension(registry); + tensor::registerTransformDialectExtension(registry); + transform::registerDebugExtension(registry); + transform::registerIRDLExtension(registry); + transform::registerLoopExtension(registry); + transform::registerPDLExtension(registry); + transform::registerTuneExtension(registry); + vector::registerTransformDialectExtension(registry); + arm_neon::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); + + // Translation extensions need to be registered by calling + // `registerAllToLLVMIRTranslations` (see All.h). +} diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp new file mode 100644 index 0000000..1ed3a37 --- /dev/null +++ b/mlir/lib/RegisterAllPasses.cpp @@ -0,0 +1,99 @@ +//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all passes to the +// system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllPasses.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" + +// This function may be called to register the MLIR passes with the +// global registry. +// If you're building a compiler, you likely don't need this: you would build a +// pipeline programmatically without the need to register with the global +// registry, since it would already be calling the creation routine of the +// individual passes. +// The global registry is interesting to interact with the command-line tools. +void mlir::registerAllPasses() { + // General passes + registerTransformsPasses(); + + // Conversion passes + registerConversionPasses(); + + // Dialect passes + acc::registerOpenACCPasses(); + affine::registerAffinePasses(); + amdgpu::registerAMDGPUPasses(); + registerAsyncPasses(); + arith::registerArithPasses(); + bufferization::registerBufferizationPasses(); + func::registerFuncPasses(); + registerGPUPasses(); + registerLinalgPasses(); + registerNVGPUPasses(); + registerSparseTensorPasses(); + LLVM::registerLLVMPasses(); + math::registerMathPasses(); + memref::registerMemRefPasses(); + shard::registerShardPasses(); + ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); + registerSCFPasses(); + registerShapePasses(); + spirv::registerSPIRVPasses(); + tensor::registerTensorPasses(); + tosa::registerTosaOptPasses(); + transform::registerTransformPasses(); + vector::registerVectorPasses(); + arm_sme::registerArmSMEPasses(); + arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); + xegpu::registerXeGPUPasses(); + + // Dialect pipelines + bufferization::registerBufferizationPipelines(); + sparse_tensor::registerSparseTensorPipelines(); + tosa::registerTosaToLinalgPipelines(); + gpu::registerGPUToNVVMPipeline(); +} diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp index 748f928..2cf33eb 100644 --- a/mlir/lib/Support/ToolUtilities.cpp +++ b/mlir/lib/Support/ToolUtilities.cpp @@ -14,6 +14,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include <string> +#include <utility> using namespace mlir; @@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker, llvm::StringRef outputSplitMarker) { + llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef(); // If splitting is disabled, we process the full input buffer. if (inputSplitMarker.empty()) - return processChunkBuffer(std::move(originalBuffer), os); + return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os); const int inputSplitMarkerLen = inputSplitMarker.size(); - auto *origMemBuffer = originalBuffer.get(); SmallVector<StringRef, 8> rawSourceBuffers; const int checkLen = 2; // Split dropping the last checkLen chars to enable flagging near misses. - origMemBuffer->getBuffer().split(rawSourceBuffers, - inputSplitMarker.drop_back(checkLen)); + originalBufferRef.getBuffer().split(rawSourceBuffers, + inputSplitMarker.drop_back(checkLen)); if (rawSourceBuffers.empty()) return success(); @@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, auto interleaveFn = [&](StringRef subBuffer) { auto splitLoc = SMLoc::getFromPointer(subBuffer.data()); unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first; - auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy( - subBuffer, Twine("within split at ") + - origMemBuffer->getBufferIdentifier() + ":" + - Twine(splitLine) + " offset "); - if (failed(processChunkBuffer(std::move(subMemBuffer), os))) + std::string name((Twine("within split at ") + + originalBufferRef.getBufferIdentifier() + ":" + + Twine(splitLine) + " offset ") + .str()); + // Use MemoryBufferRef to avoid copying the buffer & keep at same location + // relative to the original buffer. + auto subMemBuffer = + llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name), + /*RequiresNullTerminator=*/false); + if (failed( + processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os))) hadFailure = true; }; llvm::interleave(sourceBuffers, os, interleaveFn, @@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, // If any fails, then return a failure of the tool. return failure(hadFailure); } + +LogicalResult +mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, llvm::StringRef inputSplitMarker, + llvm::StringRef outputSplitMarker) { + auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &, raw_ostream &os) { + return processChunkBuffer(std::move(chunkBuffer), os); + }; + return splitAndProcessBuffer(std::move(originalBuffer), process, os, + inputSplitMarker, outputSplitMarker); +} diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp index 01ad910..304253c 100644 --- a/mlir/lib/Support/TypeID.cpp +++ b/mlir/lib/Support/TypeID.cpp @@ -27,9 +27,6 @@ namespace { struct ImplicitTypeIDRegistry { /// Lookup or insert a TypeID for the given type name. TypeID lookupOrInsert(StringRef typeName) { - LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert(" - << typeName << ")\n"); - // Perform a heuristic check to see if this type is in an anonymous // namespace. String equality is not valid for anonymous types, so we try to // abort whenever we see them. diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index 7c6fc37..f6e44c6 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRTargetLLVM intrinsics_gen LINK_COMPONENTS + BitWriter Core IPO IRReader diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index b4d53d4..55c8a64 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/Target/LLVM/NVVM/Utils.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index af22a7f..9ea5c683 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRROCDLToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation + MLIRXeVMToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index f030fa7..86c731a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) +add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9f18199..49e1e55 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, llvm::SmallVector<size_t> indices(indexAttr.size()); std::iota(indices.begin(), indices.end(), 0); - llvm::sort(indices.begin(), indices.end(), - [&](const size_t a, const size_t b) { - auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); - auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); - for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { - int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); - int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); - - if (aIndex == bIndex) - continue; - - if (aIndex < bIndex) - return first; - - if (aIndex > bIndex) - return !first; - } - - // Iterated the up until the end of the smallest member and - // they were found to be equal up to that point, so select - // the member with the lowest index count, so the "parent" - return memberIndicesA.size() < memberIndicesB.size(); - }); + llvm::sort(indices, [&](const size_t a, const size_t b) { + auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); + auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); + for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { + int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); + int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); + + if (aIndex == bIndex) + continue; + + if (aIndex < bIndex) + return first; + + if (aIndex > bIndex) + return !first; + } + + // Iterated the up until the end of the smallest member and + // they were found to be equal up to that point, so select + // the member with the lowest index count, so the "parent" + return memberIndicesA.size() < memberIndicesB.size(); + }); return llvm::cast<omp::MapInfoOp>( mapInfo.getMembers()[indices.front()].getDefiningOp()); diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt new file mode 100644 index 0000000..6308d7e --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + XeVMToLLVMIRTranslation.cpp +) + +add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation + XeVMToLLVMIRTranslation.cpp + + DEPENDS + MLIRXeVMConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRLLVMDialect + MLIRXeVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..73b166d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -0,0 +1,103 @@ +//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR XeVM dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the XeVM dialect to LLVM IR. +class XeVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(instructions.front(), + cacheControlsArray.getValue()); + } + auto func = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!func) + return failure(); + + return success(); + } + +private: + static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst, + ArrayRef<Attribute> attrs) { + SmallVector<llvm::Metadata *> decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue(); + std::array<llvm::Metadata *, 4> metadata; + llvm::transform( + valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + i32Ty, cast<IntegerAttr>(valueAttr).getValue())); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } +}; +} // namespace + +void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry ®istry) { + registry.insert<xevm::XeVMDialect>(); + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) { + DialectRegistry registry; + registerXeVMDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 94db7f8..58e3c44 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -142,6 +142,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder, // TODO: Implement the `convertInstruction` hooks in the // `LLVMDialectLLVMIRImportInterface` and move the following include there. #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" + return failure(); } @@ -1626,12 +1627,11 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { // Convert dso_local_equivalent. if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) { Type type = convertType(dsoLocalEquivalent->getType()); - return builder - .create<DSOLocalEquivalentOp>( - loc, type, - FlatSymbolRefAttr::get( - builder.getContext(), - dsoLocalEquivalent->getGlobalValue()->getName())) + return DSOLocalEquivalentOp::create( + builder, loc, type, + FlatSymbolRefAttr::get( + builder.getContext(), + dsoLocalEquivalent->getGlobalValue()->getName())) .getResult(); } @@ -1736,9 +1736,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName()); auto blockTag = BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber()); - return builder - .create<BlockAddressOp>(loc, convertType(blockAddr->getType()), - BlockAddressAttr::get(context, fnSym, blockTag)) + return BlockAddressOp::create( + builder, loc, convertType(blockAddr->getType()), + BlockAddressAttr::get(context, fnSym, blockTag)) .getRes(); } @@ -2228,17 +2228,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (!resultTy) return failure(); ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst); - return builder - .create<InlineAsmOp>( - loc, resultTy, *operands, - builder.getStringAttr(asmI->getAsmString()), - builder.getStringAttr(asmI->getConstraintString()), - asmI->hasSideEffects(), asmI->isAlignStack(), - convertTailCallKindFromLLVM(callInst->getTailCallKind()), - AsmDialectAttr::get( - mlirModule.getContext(), - convertAsmDialectFromLLVM(asmI->getDialect())), - operandAttrs) + return InlineAsmOp::create( + builder, loc, resultTy, *operands, + builder.getStringAttr(asmI->getAsmString()), + builder.getStringAttr(asmI->getConstraintString()), + asmI->hasSideEffects(), asmI->isAlignStack(), + convertTailCallKindFromLLVM(callInst->getTailCallKind()), + AsmDialectAttr::get( + mlirModule.getContext(), + convertAsmDialectFromLLVM(asmI->getDialect())), + operandAttrs) .getOperation(); } bool isIncompatibleCall; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 58e5353..a8a2b2e 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type, LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector<StringRef> &serializationCtx) { + + // Map unsigned integer types to singless integer types. + // This is needed otherwise the generated spirv assembly will contain + // twice a type declaration (like OpTypeInt 32 0) which is no permitted and + // such module fails validation. Indeed at MLIR level the two types are + // different and lookup in the cache below misses. + // Note: This conversion needs to happen here before the type is looked up in + // the cache. + if (type.isUnsignedInteger()) { + type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + typeID = getTypeID(type); if (typeID) return success(); diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 8f78590..bdcdaa4 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -508,13 +508,20 @@ performActions(raw_ostream &os, /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult processBuffer(raw_ostream &os, - std::unique_ptr<MemoryBuffer> ownedBuffer, - const MlirOptMainConfig &config, - DialectRegistry ®istry, - llvm::ThreadPoolInterface *threadPool) { +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, + llvm::MemoryBufferRef sourceBuffer, + const MlirOptMainConfig &config, DialectRegistry ®istry, + SourceMgrDiagnosticVerifierHandler *verifyHandler, + llvm::ThreadPoolInterface *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared<SourceMgr>(); + // Add the original buffer to the source manager to use for determining + // locations. + sourceMgr->AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(sourceBuffer, + /*RequiresNullTerminator=*/false), + SMLoc()); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation @@ -522,6 +529,8 @@ static LogicalResult processBuffer(raw_ostream &os, MLIRContext context(registry, MLIRContext::Threading::DISABLED); if (threadPool) context.setThreadPool(*threadPool); + if (verifyHandler) + verifyHandler->registerInContext(&context); StringRef irdlFile = config.getIrdlFile(); if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context))) @@ -545,17 +554,12 @@ static LogicalResult processBuffer(raw_ostream &os, return performActions(os, sourceMgr, &context, config); } - SourceMgrDiagnosticVerifierHandler sourceMgrHandler( - *sourceMgr, &context, config.verifyDiagnosticsLevel()); - // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, sourceMgr, &context, config); - // Verify the diagnostic handler to make sure that each of the diagnostics - // matched. - return sourceMgrHandler.verify(); + return success(); } std::pair<std::string, std::string> @@ -624,14 +628,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, if (threadPoolCtx.isMultithreadingEnabled()) threadPool = &threadPoolCtx.getThreadPool(); + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + SMLoc()); + // Note: this creates a verifier handler independent of the the flag set, as + // internally if the flag is not set, a new scoped diagnostic handler is + // created which would intercept the diagnostics and verify them. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler( + sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel()); auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer, - raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), config, registry, - threadPool); + llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) { + return processBuffer( + os, std::move(chunkBuffer), sourceBuffer, config, registry, + config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr, + threadPool); }; - return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, - config.inputSplitMarker(), - config.outputSplitMarker()); + LogicalResult status = splitAndProcessBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + chunkFn, outputStream, config.inputSplitMarker(), + config.outputSplitMarker()); + if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify())) + status = failure(); + return status; } LogicalResult mlir::MlirOptMain(int argc, char **argv, diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index c11cb8d..e1c8afb 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -135,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer, raw_ostream &os) { + // Many of the translations expect a null-terminated buffer while splitting + // the buffer does not guarantee null-termination. Make a copy of the buffer + // to ensure null-termination. + if (!ownedBuffer->getBuffer().ends_with('\0')) { + ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy( + ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier()); + } // Temporary buffers for chained translation processing. std::string dataIn; std::string dataOut; diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index ddd5f2b..4ccb83f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -36,6 +36,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" @@ -51,6 +52,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <cstddef> #include <memory> @@ -58,8 +60,6 @@ #include <vector> #define DEBUG_TYPE "remove-dead-values" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir { #define GEN_PASS_DEF_REMOVEDEADVALUES @@ -119,21 +119,21 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, RunLivenessAnalysis &la) { for (Value value : values) { if (nonLiveSet.contains(value)) { - LDBG("Value " << value << " is already marked non-live (dead)"); + LDBG() << "Value " << value << " is already marked non-live (dead)"; continue; } const Liveness *liveness = la.getLiveness(value); if (!liveness) { - LDBG("Value " << value - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value + << " has no liveness info, conservatively considered live"; return true; } if (liveness->isLive) { - LDBG("Value " << value << " is live according to liveness analysis"); + LDBG() << "Value " << value << " is live according to liveness analysis"; return true; } else { - LDBG("Value " << value << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " is dead according to liveness analysis"; } } return false; @@ -148,8 +148,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, for (auto [index, value] : llvm::enumerate(values)) { if (nonLiveSet.contains(value)) { lives.reset(index); - LDBG("Value " << value << " is already marked non-live (dead) at index " - << index); + LDBG() << "Value " << value + << " is already marked non-live (dead) at index " << index; continue; } @@ -161,17 +161,17 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, // (because they weren't erased) and also their liveness is null because // liveness analysis ran before their creation. if (!liveness) { - LDBG("Value " << value << " at index " << index - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value << " at index " << index + << " has no liveness info, conservatively considered live"; continue; } if (!liveness->isLive) { lives.reset(index); - LDBG("Value " << value << " at index " << index - << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is dead according to liveness analysis"; } else { - LDBG("Value " << value << " at index " << index - << " is live according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is live according to liveness analysis"; } } @@ -187,8 +187,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, if (!nonLive[index]) continue; nonLiveSet.insert(result); - LDBG("Marking value " << result << " as non-live (dead) at index " - << index); + LDBG() << "Marking value " << result << " as non-live (dead) at index " + << index; } } @@ -258,16 +258,18 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing simple op: " << *op); + LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG("Simple op is not memory effect free or has live results, skipping: " - << *op); + LDBG() + << "Simple op is not memory effect free or has live results, skipping: " + << *op; return; } - LDBG("Simple op has all dead results and is memory effect free, scheduling " - "for removal: " - << *op); + LDBG() + << "Simple op has all dead results and is memory effect free, scheduling " + "for removal: " + << *op; cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -286,10 +288,10 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, static void processFuncOp(FunctionOpInterface funcOp, Operation *module, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing function op: " << funcOp.getOperation()->getName()); + LDBG() << "Processing function op: " << funcOp.getOperation()->getName(); if (funcOp.isPublic() || funcOp.isExternal()) { - LDBG("Function is public or external, skipping: " - << funcOp.getOperation()->getName()); + LDBG() << "Function is public or external, skipping: " + << funcOp.getOperation()->getName(); return; } @@ -409,9 +411,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"); + LDBG() << "Processing region branch op: " + << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions()); // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); @@ -697,7 +698,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing branch op: " << *branchOp); + LDBG() << "Processing branch op: " << *branchOp; unsigned numSuccessors = branchOp->getNumSuccessors(); for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d224f73..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -14,8 +14,10 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -130,11 +132,6 @@ struct ConversionValueMapping { /// value. ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup the given value within the map, or return an empty vector if the - /// value is not mapped. If it is mapped, this follows the same behavior - /// as `lookupOrDefault`. - ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -237,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from, return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -ValueVector ConversionValueMapping::lookupOrNull(Value from, - TypeRange desiredTypes) const { - ValueVector result = lookupOrDefault(from, desiredTypes); - if (result == ValueVector{from} || - (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) - return {}; - return result; -} - //===----------------------------------------------------------------------===// // Rewriter and Translation State //===----------------------------------------------------------------------===// @@ -521,9 +509,11 @@ private: class MoveBlockRewrite : public BlockRewrite { public: MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Region *region, Block *insertBeforeBlock) - : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), - insertBeforeBlock(insertBeforeBlock) {} + Region *previousRegion, Region::iterator previousIt) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(previousRegion), + insertBeforeBlock(previousIt == previousRegion->end() ? nullptr + : &*previousIt) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveBlock; @@ -630,9 +620,12 @@ protected: class MoveOperationRewrite : public OperationRewrite { public: MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, Block *block, Operation *insertBeforeOp) - : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), - insertBeforeOp(insertBeforeOp) {} + Operation *op, OpBuilder::InsertPoint previous) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(previous.getBlock()), + insertBeforeOp(previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint()) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveOperation; @@ -926,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return "true" if the given operation was replaced or erased. bool wasOpReplaced(Operation *op) const; + /// Lookup the most recently mapped values with the desired types in the + /// mapping. + /// + /// Special cases: + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; + //===--------------------------------------------------------------------===// // IR Rewrites / Type Conversion //===--------------------------------------------------------------------===// @@ -1248,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// +ValueVector +ConversionPatternRewriterImpl::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + return mapping.lookupOrDefault(from, desiredTypes); +} + +ValueVector +ConversionPatternRewriterImpl::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; + return result; +} + RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } @@ -1295,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped values. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back(lookupOrDefault(operand)); continue; } @@ -1314,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( continue; } - ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + ValueVector repl = lookupOrDefault(operand, legalTypes); if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { // Mapped values have the correct type or there is an existing // materialization. Or the operand is not mapped at all and has the @@ -1324,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = mapping.lookupOrDefault(operand); + repl = lookupOrDefault(operand); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1519,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. - ValueVector repl = mapping.lookupOrNull(value, value.getType()); + ValueVector repl = lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1535,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull(value); + repl = lookupOrNull(value); if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, @@ -1568,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + // If no previous insertion point is provided, the op used to be detached. + bool wasDetached = !previous.isSet(); LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; + logger.startLine() << "** Insert : '" << op->getName() << "' (" << op + << ")"; + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); assert(!wasOpReplaced(op->getParentOp()) && "attempting to insert into a block within a replaced/erased op"); - if (!previous.isSet()) { - // This is a newly created op. + if (wasDetached) { + // If the op was detached, it is most likely a newly created op. + // TODO: If the same op is inserted multiple times from a detached state, + // the rollback mechanism may erase the same op multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateOperationRewrite>(op); patternNewOps.insert(op); return; } - Operation *prevOp = previous.getPoint() == previous.getBlock()->end() - ? nullptr - : &*previous.getPoint(); - appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); + + // The op was moved from one place to another. + appendRewrite<MoveOperationRewrite>(op, previous); } void ConversionPatternRewriterImpl::replaceOp( @@ -1649,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + // If no previous insertion point is provided, the block used to be detached. + bool wasDetached = !previous; + Operation *newParentOp = block->getParentOp(); LLVM_DEBUG( { - Operation *parent = block->getParentOp(); + Operation *parent = newParentOp; if (parent) { logger.startLine() << "** Insert Block into : '" << parent->getName() - << "'(" << parent << ")\n"; + << "' (" << parent << ")"; } else { logger.startLine() - << "** Insert Block into detached Region (nullptr parent op)'\n"; + << "** Insert Block into detached Region (nullptr parent op)"; } + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); + assert(!wasOpReplaced(newParentOp) && + "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); - if (!previous) { - // This is a newly created block. + if (wasDetached) { + // If the block was detached, it is most likely a newly created block. + // TODO: If the same block is inserted multiple times from a detached state, + // the rollback mechanism may erase the same block multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateBlockRewrite>(block); return; } - Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); + + // The block was moved from one place to another. + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1716,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1731,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1739,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1845,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -1976,6 +2043,7 @@ private: /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks); @@ -2092,8 +2160,9 @@ OperationLegalizer::legalize(Operation *op, // If the operation has no regions, just print it here. if (!isIgnored && op->getNumRegions() == 0) { - op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); - logger.getOStream() << "\n\n"; + logger.startLine() << OpWithFlags(op, + OpPrintingFlags().printGenericOpForm()) + << "\n"; } }); @@ -2172,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2201,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } @@ -2220,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *checkOp; + std::optional<OperationFingerPrint> topLevelFingerPrint; + if (!rewriterImpl.config.allowPatternRollback) { + // The op may be getting erased, so we have to check the parent op. + // (In rare cases, a pattern may even erase the parent op, which will cause + // a crash here. Expensive checks are "best effort".) Skip the check if the + // op does not have a parent op. + if ((checkOp = op->getParentOp())) { + if (!op->getContext()->isMultithreadingEnabled()) { + topLevelFingerPrint = OperationFingerPrint(checkOp); + } else { + // Another thread may be modifying a sibling operation. Therefore, the + // fingerprinting mechanism of the parent op works only in + // single-threaded mode. + LLVM_DEBUG({ + rewriterImpl.logger.startLine() + << "WARNING: Multi-threadeding is enabled. Some dialect " + "conversion expensive checks are skipped in multithreading " + "mode!\n"; + }); + } + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { bool canApply = canApplyPattern(op, pattern, rewriter); @@ -2232,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (!rewriterImpl.config.allowPatternRollback) { + // Returning "failure" after modifying IR is not allowed. + if (checkOp) { + OperationFingerPrint fingerPrintAfterPattern(checkOp); + if (fingerPrintAfterPattern != *topLevelFingerPrint) + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' returned failure but IR did change"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2260,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, newOps, + auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2303,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const SetVector<Operation *> &newOps, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { auto &impl = rewriter.getImpl(); @@ -2319,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult( return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error("expected pattern to replace the root operation"); + llvm::report_fatal_error( + "expected pattern to replace the root operation or modify it in place"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index ac8b44f5..89568e7 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -68,6 +68,7 @@ endif() llvm_canonicalize_cmake_booleans( LLVM_BUILD_EXAMPLES LLVM_HAS_NVPTX_TARGET + LLVM_INCLUDE_SPIRV_TOOLS_TESTS MLIR_ENABLE_BINDINGS_PYTHON MLIR_ENABLE_CUDA_RUNNER MLIR_ENABLE_ROCM_CONVERSIONS @@ -217,6 +218,11 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) ) endif() +if (LLVM_INCLUDE_SPIRV_TOOLS_TESTS) + list(APPEND MLIR_TEST_DEPENDS spirv-as) + list(APPEND MLIR_TEST_DEPENDS spirv-val) +endif() + # This target can be used to just build the dependencies # for the check-mlir target without executing the tests. # This is useful for bots when splitting the build step diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index bae7c59..ae59f28 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -2,8 +2,26 @@ // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32 +// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64> // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32> // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64> //CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { @@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { return %rf, %rd : f32, f64 } +//CHECK-LABEL: @angle_caller +func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { + // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}}) + %af = complex.angle %f : complex<f32> + // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}}) + %ad = complex.angle %d : complex<f64> + // CHECK: return %[[AF]], %[[AD]] + return %af, %ad : f32, f64 +} + +//CHECK-LABEL: @cos_caller +func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}}) + %cf = complex.cos %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}}) + %cd = complex.cos %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf, %cd : complex<f32>, complex<f64> +} + //CHECK-LABEL: @exp_caller func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) @@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp // CHECK: return %[[EF]], %[[ED]] return %ef, %ed : complex<f32>, complex<f64> } + +//CHECK-LABEL: @log_caller +func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}}) + %lf = complex.log %f : complex<f32> + // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}}) + %ld = complex.log %d : complex<f64> + // CHECK: return %[[LF]], %[[LD]] + return %lf, %ld : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @conj_caller +func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}}) + %cf2 = complex.conj %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}}) + %cd2 = complex.conj %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf2, %cd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @pow_caller +func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}}) + %pf = complex.pow %f, %f : complex<f32> + // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}}) + %pd = complex.pow %d, %d : complex<f64> + // CHECK: return %[[PF]], %[[PD]] + return %pf, %pd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sin_caller +func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) + %sf2 = complex.sin %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}}) + %sd2 = complex.sin %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf2, %sd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sqrt_caller +func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}}) + %sf = complex.sqrt %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}}) + %sd = complex.sqrt %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf, %sd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tan_caller +func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}}) + %tf2 = complex.tan %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}}) + %td2 = complex.tan %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf2, %td2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tanh_caller +func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}}) + %tf = complex.tanh %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}}) + %td = complex.tanh %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf, %td : complex<f32>, complex<f64> +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir index 00bbd1c..96ad107 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir @@ -85,11 +85,10 @@ module attributes { // CHECK: spirv.Load "StorageBuffer" %val = memref.load %arg0[%idx0] : memref<2xi32> // CHECK: spirv.CompositeInsert - %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32> + %vec = vector.insert %val, %vec0[%idx0] : i32 into vector<2xi32> // CHECK: spirv.VectorShuffle %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32> - // CHECK: spirv.CompositeExtract - %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> + %res = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> // CHECK: spirv.AccessChain // CHECK: spirv.Store "StorageBuffer" memref.store %res, %arg1[%idx0]: memref<4xi32> @@ -102,9 +101,9 @@ module attributes { // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32> // CHECK: arith.constant // CHECK: memref.load - // CHECK: vector.insertelement + // CHECK: vector.insert // CHECK: vector.shuffle - // CHECK: vector.extractelement + // CHECK: vector.extract // CHECK: memref.store // CHECK: gpu.return } diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index fb14feb..eb9feaa 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -51,108 +51,6 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME:(%[[S:.+]]: f32, -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32> // CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32> diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir index b96dd37..c71d220 100644 --- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir @@ -10,16 +10,14 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate() gpu.func @rotate() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %{{.+}} = spirv.Constant true - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 16 : f32 gpu.return } } @@ -38,18 +36,16 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size() gpu.func @rotate_width_less_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 8 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] // CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]] - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 8 : f32 gpu.return } } @@ -67,34 +63,10 @@ module attributes { gpu.module @kernels { gpu.func @rotate_with_bigger_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 32 : i32 %val = arith.constant 42.0 : f32 // expected-error @+1 {{failed to legalize operation 'gpu.rotate'}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 - gpu.return - } -} - -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, - #spirv.resource_limits<subgroup_size = 16>> -} { - -gpu.module @kernels { - gpu.func @rotate_non_const_width(%width: i32) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %val = arith.constant 42.0 : f32 - - // expected-error @+1 {{'gpu.rotate' op width is not a constant value}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 32 : f32 gpu.return } } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir new file mode 100644 index 0000000..e391a89 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @alloc() { + %alloc = memref.alloc() : memref<999xi32> + return +} + +// CPP: module { +// CPP-NEXT: emitc.include <"cstdlib"> +// CPP-LABEL: alloc() +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"stdlib.h"> +// NOCPP-LABEL: alloc() +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + +func.func @alloc_aligned() { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32> + return +} + +// CPP-LABEL: alloc_aligned +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// CPP-NEXT: return + +// NOCPP-LABEL: alloc_aligned +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// NOCPP-NEXT: return + +func.func @allocating_multi() { + %alloc_5 = memref.alloc() : memref<7x999xi32> + return +} + +// CPP-LABEL: allocating_multi +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void"> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP-LABEL: allocating_multi +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index fa7a91c..b6f2383 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -36,7 +36,7 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) { func.func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) { // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // CHECK: scf.yield [[ARG0]] tosa.yield %arg0 : tensor<f32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 8c135d5..31e17fb 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -274,73 +274,6 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3 // ----- //===----------------------------------------------------------------------===// -// vector.extractelement -//===----------------------------------------------------------------------===// - -func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 { - %1 = vector.extractelement %arg0[] : vector<f32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_0d_f32 -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- -func.func @extractelement_from_vec_1d_f32_idx_as_index(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_index_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -//===----------------------------------------------------------------------===// // vector.extract //===----------------------------------------------------------------------===// @@ -592,81 +525,6 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : // ----- //===----------------------------------------------------------------------===// -// vector.insertelement -//===----------------------------------------------------------------------===// - -func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> { - %1 = vector.insertelement %arg0, %arg1[] : vector<f32> - return %1 : vector<f32> -} -// CHECK-LABEL: @insertelement_into_vec_0d_f32 -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : -// CHECK: vector<f32> to vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -//===----------------------------------------------------------------------===// // vector.insert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f43a41a..8918f91 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -400,67 +400,6 @@ func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size5_vector -func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { - // CHECK: vector.extractelement - %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME: (%[[S:.+]]: f32 -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - // CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]], %[[ARG]] : vector<4xf32>, vector<4xf32> -> vector<2xf32> @@ -473,67 +412,6 @@ func.func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector // ----- -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size5_vector -func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { - // CHECK: vector.insertelement - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - return %0 : vector<5xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32> - // CHECK: return %[[V]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32> - // CHECK: return %[[V]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]], %[[PART]] : vector<4xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir new file mode 100644 index 0000000..e2ab876 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor<?xf32> + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor<?xf32> + // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32 + return %0, %1 : tensor<?xf32>, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @call_func_with_non_tensor_return( + %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + return %1, %0 : f32, tensor<?xf32> + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir index c67a0c1..029fa78 100644 --- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s +// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s module attributes { } { emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 162ff06..35381da 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a // ----- func.func @rotate_mismatching_type(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1) return } // ----- func.func @rotate_unsupported_type(%arg0 : index) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : index + %rotate, %valid = gpu.rotate %arg0, 4, 16 : index return } @@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) { %offset = arith.constant 4 : i32 %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32> + %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32> return } // ----- func.func @rotate_unsupported_width(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 15 : i32 - // expected-error@+1 {{op width must be a power of two}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset(%arg0 : f32) { - %offset = arith.constant 16 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset_minus(%arg0 : f32) { - %offset = arith.constant -1 : i32 - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) { - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) { - %offset = arith.constant 0 : i32 - // expected-error@+1 {{op width is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 2aef80f..ee1fdfa 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -140,9 +140,8 @@ module attributes {gpu.container_module} { // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32 %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32 - // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32 - %rotate_width = arith.constant 16 : i32 - %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32 + // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32 + %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32 "gpu.barrier"() : () -> () diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cbb56e4..5c5f7e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose @@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) { // CHECK-LABEL: @recursive_effect // CHECK: linalg.map +// ----- + //===----------------------------------------------------------------------===// // linalg.pack //===----------------------------------------------------------------------===// // CHECK-LABEL: func @fold_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 1.000000e-01 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst padding_value(%pad : f32) outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } - // ----- // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 6fc8d9f..cc26fa4 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( // ----- -func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { - %empty = tensor.empty() : tensor<8x4x16x8xf32> - %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> - %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> - return %pack : tensor<8x4x16x8xf32> -} -// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] -// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> -// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]] -// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] -// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> -// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> - -// ----- - func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { %6 = tensor.empty(%dim) : tensor<?x256xf32> %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index da1dfc7..40bf4d1 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf } // ----- + func.func @pack_mismatch_inner_tile_size_and_output_shape( %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> { // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}} @@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t // ----- +func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> { + %cst = arith.constant 0.0 : f32 + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0] + inner_tiles = [8] into %output + : tensor<9xf32> -> tensor<3x8xf32> + return %0 : tensor<3x8xf32> +} + +// ----- + // The outer dims in the output tensor are incorrectly/unexpectedly transposed. // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose). func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}} + // expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}} %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32> return %0 : tensor<4x16x32x16xf32> } // ----- -func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> - return %0 : tensor<8x8x32x16xf32> +func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> { + // expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}} + %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32> + return %0 : tensor<8x7x16x32xf32> +} + +// ----- + +func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { + // expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output + : tensor<3x8xf32> -> tensor<9xf32> + return %0 : tensor<9xf32> } // ----- -func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> { - // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}} - %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32> +func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> { + // expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}} + %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> } diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 81fd7a8..9e7681d 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_with_pad( -func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>) - -> tensor<265x16x16x1xf32> { +func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>) + -> tensor<265x12x16x1xf32> { // CHECK: tensor.pad {{.*}} low[0, 0] - // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32> + // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32> // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]] - // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32> + // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32> // CHECK: linalg.transpose - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>) + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>) // CHECK-SAME: permutation = [0, 2, 1, 3] %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %dest - : tensor<4225x12xf32> -> tensor<265x16x16x1xf32> - return %0 : tensor<265x16x16x1xf32> + : tensor<4225x12xf32> -> tensor<265x12x16x1xf32> + return %0 : tensor<265x12x16x1xf32> } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 78619b6..981f5dc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -52,22 +52,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0] // CHECK: : tensor<7x5xf32> to tensor<9x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] { - // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -83,7 +83,7 @@ module { // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: pad_conv +func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> + +// CHECK-LABEL: pad_conv_dynamic +func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> { + + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32> + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12] + // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0] + // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> + return %0 : tensor<1x14x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_strided +func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12] + // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_dilated +func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 26c03ed..f741876 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -69,22 +69,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0] // CHECK: : tensor<7x5xf32> to tensor<8x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] { - // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -102,7 +102,7 @@ module { // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -127,13 +127,13 @@ module { // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32> // CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]] // CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] { - // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32> // // CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32> // CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]] - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) { - // CHECK: } -> tensor<8x14x13xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32> + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) { + // CHECK: } -> tensor<8x14x12xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32> // %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index c3ee892..d7722ea 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[PV:.*]] = ub.poison : i32 -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> // CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> -// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex> +// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> -// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> -// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // ----- @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) // CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( // CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { -// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> -// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> -// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // CHECK: return %[[RES]] : tensor<8x1xf32> @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16 // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex> // CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex> // CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_14]] : tensor<1x4xf32> @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex> -// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_10]] : tensor<1x4xf32> // CHECK: } @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> // CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> -// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> -// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index +// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex> // CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 98e8f50..d41d861 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -941,20 +941,17 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack // CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x16x2xf32> func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> { // CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32> -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32> -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 // CHECK: %[[C01:.*]] = arith.constant 0 // CHECK: %[[C02:.*]] = arith.constant 0 -// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32> -// CHECK: %[[CNST14:.*]] = arith.constant 1 -// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32> +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32> // CHECK: %[[CNST16:.*]] = arith.constant 16 : index // CHECK: %[[CNST2:.*]] = arith.constant 2 : index -// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> +// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4c50ed3..8c846cd 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1406,7 +1406,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, // CHECK-NEXT: (%[[XVAL:.*]]: i1): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.icmp "eq" %[[XVAL]], %[[EXPRBOOL]] : i1 // CHECK-NEXT: omp.yield(%[[NEWVAL]] : i1) - // } + // CHECK-NEXT: } omp.atomic.update %xBool : memref<i1> { ^bb0(%xval: i1): %newval = llvm.icmp "eq" %xval, %exprBool : i1 @@ -1562,6 +1562,14 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, omp.yield(%newval : i32) } + // CHECK: omp.atomic.update %[[X]] : memref<i32> { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%{{.+}} : i32) + // CHECK-NEXT: } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} + omp.atomic.update %x : memref<i32> { + ^bb0(%xval:i32): + omp.yield(%const:i32) + } {atomic_control = #omp.atomic_control<ignore_denormal_mode = true, fine_grained_memory = true, remote_memory = true>} return } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index bd51a07..f3a3218 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]] // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 } // end spirv.module + +// ----- + +module { + spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} { + // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + // CHECK: spirv.func @read_image + spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} { + // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + %cst0_i32 = spirv.Constant 0 : i32 + // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]] + %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32> + %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32> + %cst0_i32_0 = spirv.Constant 0 : i32 + %cst0_i32_1 = spirv.Constant 0 : i32 + %cst1_i32 = spirv.Constant 1 : i32 + %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + spirv.Store "StorageBuffer" %4, %3 : f32 + spirv.Return + } + } +} diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 0176fc2..6398161 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,7 +645,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir new file mode 100644 index 0000000..06312c7 --- /dev/null +++ b/mlir/test/Dialect/Tosa/controlflow.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// ----- + +func.func @condif_cond_type_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] : tensor<i1> -> tensor<f32> { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @condif_block_args_check(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // CHECK: tosa.cond_if %[[ARG2:.*]] (%[[ARG3:.*]] = %[[ARG0:.*]], %[[ARG4:.*]] = %[[ARG1:.*]]) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + // CHECK: } else { + // CHECK-NEXT: ^bb0(%[[ARG3]]: tensor<f32>, %[[ARG4]]: tensor<f32>): + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index eb25011..fad1bec 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -259,7 +259,7 @@ func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>) { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<f32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 716362e..b90d6f5 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1125,7 +1125,7 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: // CHECK-LABEL: test_mul_non_scalar_shift_2d func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -1134,7 +1134,7 @@ func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tenso // CHECK-LABEL: test_mul_non_scalar_shift_1d func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %shift = "tosa.const"() <{values = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> - // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant unranked tensor of 8-bit signless integer values or tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 5630c33..3154f54 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,7 +337,7 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 0dddf26..cbe0056 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1506,13 +1506,13 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_cond_if_max_nested_depth(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>, %arg3: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { - %1 = tosa.cond_if %arg3 -> (tensor<f32>) { - %2 = tosa.cond_if %arg2 -> (tensor<f32>) { - %3 = tosa.cond_if %arg3 -> (tensor<f32>) { - %4 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %1 = tosa.cond_if %arg3 : tensor<i1>-> tensor<f32> { + %2 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { + %3 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { + %4 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor<f32>) { + %5 = tosa.cond_if %arg3 : tensor<i1> -> tensor<f32> { %res = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %res : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ef51197e..30361a8 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -839,7 +839,7 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir index 38ac8d8..e957bdd 100644 --- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -54,7 +54,7 @@ func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { // CHECK-LABEL: test_regions // CHECK: %arg0: tensor<i8>, %arg1: tensor<i8> func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> { - // CHECK: tosa.cond_if %arg2 -> (tensor<i8>) + // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<i8>, tensor<i8>) -> tensor<i8> %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>): // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 9d43f89..7b8fc24 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: // ----- +// CHECK-LABEL: @test_unranked_scalar_i8_tensor +func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> { + // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8> + %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8> + // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> @@ -1166,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32> // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { tosa.yield %a : tensor<f32> } else { tosa.yield %b : tensor<f32> @@ -1180,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<?xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) { + // CHECK: -> tensor<?xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1194,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1208,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index b305236..2a937b0 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -500,9 +500,39 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor<f32>, %ar // ----- +func.func @test_cond_if_input_list_mismatch_else_block_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (1) and 'input_list' (2)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<f32>, tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_mismatch_else_block_simple_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' arguments (2) and 'input_list' (1)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0) : tensor<i1> (tensor<f32>) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>): + tosa.yield %arg3 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.add %1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -517,7 +547,7 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -531,7 +561,7 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor<f32>, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { @@ -546,7 +576,7 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor<f32>, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor<f32>, tensor<f32>) { + %0, %2 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<f32>, tensor<f32>) { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> %2 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1, %2 : tensor<f32>, tensor<f32> @@ -574,6 +604,53 @@ func.func @test_cond_if_cond_input_not_size_one(%arg0: tensor<f32>, %arg1: tenso // ----- +// CHECK-LABEL: cond_if_cond_type +func.func @test_cond_if_cond_type(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected ':'}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected type for condition operand}} + %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + tosa.yield %arg0 : tensor<f32> + } else { + tosa.yield %arg1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_input_list_type_mismatch_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+1 {{custom op 'tosa.cond_if' expected as many input types as operands (expected 2 got 0)}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> () -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + +func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{custom op 'tosa.cond_if' expected list of types for block arguments followed by arrow type and list of return types}} + %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (%arg3) -> tensor<f32> { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } else { + ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): + %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32> + tosa.yield %1 : tensor<f32> + } + return %0 : tensor<f32> +} + +// ----- + func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) { %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32> // expected-error@+1 {{'tosa.while_loop' op require same number of values in 'body_graph' arguments (3) and 'input_list' (2)}} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30..9cfebd5 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2562,118 +2562,6 @@ func.func @insert_2d_splat_constant() // ----- -// CHECK-LABEL: func @insert_element_fold -// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> -// CHECK: return %[[V]] -func.func @insert_element_fold() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @insert_element_invalid_fold -func.func @insert_element_invalid_fold() -> vector<1xf32> { - // Out-of-bound index here. - %c26 = arith.constant 26 : index - %cst_2 = arith.constant 1.60215309E+9 : f32 - %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32> -// CHECK: vector.insertelement - %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32> - return %46 : vector<1xf32> -} - - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold1 -// CHECK: vector.insertelement -func.func @insert_poison_fold1() -> vector<4xi32> { - %v = ub.poison : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold2 -// CHECK: vector.insertelement -func.func @insert_poison_fold2() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = ub.poison : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold3 -// CHECK: vector.insertelement -func.func @insert_poison_fold3() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = ub.poison : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_element_fold -// CHECK: %[[C:.+]] = arith.constant 5 : i32 -// CHECK: return %[[C]] -func.func @extract_element_fold() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold1 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold1() -> i32 { - %v = ub.poison : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold2 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold2() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = ub.poison : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: func @reduce_one_element_vector_extract // CHECK-SAME: (%[[V:.+]]: vector<1xf32>) // CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32> @@ -2933,18 +2821,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // ----- -// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( -// CHECK-SAME: %[[f:.*]]: f32 -// CHECK: return %[[f]] -func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { - %0 = vector.broadcast %f : f32 to vector<15xf32> - %c5 = arith.constant 5 : index - %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 { // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32> diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 0263193..2563b48 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -60,16 +60,6 @@ func.func @vector_extract() -> index { func.return %2 : index } -// CHECK-LABEL: func @vector_extractelement -// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} -func.func @vector_extractelement() -> index { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> - %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> - %2 = test.reflect_bounds %1 : index - func.return %2 : index -} - // CHECK-LABEL: func @vector_add // CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} func.func @vector_add() -> vector<4xindex> { @@ -90,17 +80,6 @@ func.func @vector_insert() -> vector<4xindex> { func.return %3 : vector<4xindex> } -// CHECK-LABEL: func @vector_insertelement -// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} -func.func @vector_insertelement() -> vector<4xindex> { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> - %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index - %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> - %3 = test.reflect_bounds %2 : vector<4xindex> - func.return %3 : vector<4xindex> -} - // CHECK-LABEL: func @test_loaded_vector_extract // No bounds // CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ca837d3..c21de56 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- -func.func @extract_element(%arg0: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %1 = vector.extractelement %arg0[%c : i32] : vector<f32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %1 = vector.extractelement %arg0[] : vector<4xf32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> -} - -// ----- - func.func @extract_vector_type(%arg0: index) { // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}} %1 = vector.extract %arg0[] : index from index @@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @insert_element(%arg0: f32, %arg1: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> -} - -// ----- - -func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}} - %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>) -} - -// ----- - func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 6a56116..625ffc1 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4 return %1 : vector<4xf32> } -// CHECK-LABEL: @extract_element_0d -func.func @extract_element_0d(%a: vector<f32>) -> f32 { - // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32> - %1 = vector.extractelement %a[] : vector<f32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_element -func.func @extract_element(%a: vector<16xf32>) -> f32 { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.extractelement %a[%c : i32] : vector<16xf32> - return %1 : f32 -} - // CHECK-LABEL: @extract_const_idx func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { @@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { return %0 : f32 } -// CHECK-LABEL: @insert_element_0d -func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> { - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32> - %1 = vector.insertelement %a, %b[] : vector<f32> - return %1 : vector<f32> -} - -// CHECK-LABEL: @insert_element -func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32> - return %1 : vector<16xf32> -} - // CHECK-LABEL: @insert_const_idx func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index b826cdc..f8638ab 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> { return %r : vector<2x[4]xi32> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.subi %cst, %0 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> + %2 = arith.mulf %0, %cst : vector<3x4xf32> + return %2 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> + +func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderElementwiseOpsOnTranspose] //===----------------------------------------------------------------------===// diff --git a/mlir/test/Examples/transform/Ch3/ops.mlir b/mlir/test/Examples/transform/Ch3/ops.mlir index b2d47cc..707a09f 100644 --- a/mlir/test/Examples/transform/Ch3/ops.mlir +++ b/mlir/test/Examples/transform/Ch3/ops.mlir @@ -30,9 +30,30 @@ module attributes {transform.with_named_sequence} { // ----- func.func private @orig() +func.func private @updated() // CHECK-LABEL: func @test2 func.func @test2() { + // CHECK: call @updated + call @orig() : () -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %call = transform.structured.match ops{["func.call"]} in %arg0 : (!transform.any_op) -> !transform.my.call_op_interface + // CHECK: transform.my.change_call_target %{{.*}}, "updated" : !transform.my.call_op_interface + transform.my.change_call_target %call, "updated" : !transform.my.call_op_interface + transform.yield + } +} + +// ----- + +func.func private @orig() + +// CHECK-LABEL: func @test3 +func.func @test3() { // CHECK: "my.mm4" call @orig() : () -> () return diff --git a/mlir/test/Examples/transform/Ch3/sequence.mlir b/mlir/test/Examples/transform/Ch3/sequence.mlir index 4d28518..877b006 100644 --- a/mlir/test/Examples/transform/Ch3/sequence.mlir +++ b/mlir/test/Examples/transform/Ch3/sequence.mlir @@ -101,11 +101,12 @@ module attributes {transform.with_named_sequence} { %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) %func, %call = transform.loop.outline %outline_target {func_name = "outlined"} - : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">) - - // Rewrite the call target. - transform.my.change_call_target %call, "microkernel" : !transform.op<"func.call"> - + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Cast to our new type. + %casted = transform.cast %call : !transform.any_op to !transform.my.call_op_interface + // Using our new operation. + transform.my.change_call_target %casted, "microkernel" : !transform.my.call_op_interface + transform.yield } } diff --git a/mlir/test/IR/diagnostic-nosplit.mlir b/mlir/test/IR/diagnostic-nosplit.mlir new file mode 100644 index 0000000..ecfb9c6 --- /dev/null +++ b/mlir/test/IR/diagnostic-nosplit.mlir @@ -0,0 +1,13 @@ +// RUN: not mlir-opt %s -o - --split-input-file 2>&1 | FileCheck %s +// This test verifies that diagnostic handler doesn't emit splits. + + +// ----- + + + +func.func @constant_out_of_range() { + // CHECK: mlir:11:8: error: 'arith.constant' + %x = "arith.constant"() {value = 100} : () -> i1 + return +} diff --git a/mlir/test/IR/test-pattern-logging-listener.mlir b/mlir/test/IR/test-pattern-logging-listener.mlir index c521110..d3d42e3 100644 --- a/mlir/test/IR/test-pattern-logging-listener.mlir +++ b/mlir/test/IR/test-pattern-logging-listener.mlir @@ -8,15 +8,15 @@ // {anonymous_namespace} vs `anonymous_namespace` (and maybe others?) on the // various platforms. -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationInserted | test.new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationReplaced (with values) | test.replace_with_new_op -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationModified | arith.addi -// CHECK: [pattern-logging-listener] +// CHECK: [pattern-logging-listener:1] // CHECK-SAME: ::ReplaceWithNewOp | notifyOperationErased | test.replace_with_new_op func.func @replace_with_new_op() -> i32 { %a = "test.replace_with_new_op"() : () -> (i32) diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir index b571d94..e0adb4d82 100644 --- a/mlir/test/IR/top-level.mlir +++ b/mlir/test/IR/top-level.mlir @@ -6,10 +6,10 @@ func.func private @foo() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +// expected-error@-9 {{source must contain a single top-level operation, found: 2}} func.func private @bar() func.func private @baz() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 0}} +// expected-error@-15 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir index 05e6782..a7bb039 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir @@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> { %zero = arith.constant 0 : i32 - %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> + %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32> %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32> - %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32> + %C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32> // Pack matrices - %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32> + %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32> %B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32> - %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32> + %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32> // MMT4D - %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32> + %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32> // Unpack output %C_out_empty = tensor.empty() : tensor<7x13xi32> - %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32> + %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32> return %C_out_unpack : tensor<7x13xi32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir index 6e2a82b..6ec1031 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir @@ -4,14 +4,14 @@ // RUN: FileCheck %s func.func @extract_element_0d(%a: vector<f32>) { - %1 = vector.extractelement %a[] : vector<f32> + %1 = vector.extract %a[] : f32 from vector<f32> // CHECK: 42 vector.print %1: f32 return } func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) { - %1 = vector.insertelement %a, %b[] : vector<f32> + %1 = vector.insert %a, %b[] : f32 into vector<f32> return %1: vector<f32> } @@ -58,9 +58,9 @@ func.func @broadcast_0d(%a: f32) { func.func @bitcast_0d() { %0 = arith.constant 42 : i32 %1 = arith.constant dense<0> : vector<i32> - %2 = vector.insertelement %0, %1[] : vector<i32> + %2 = vector.insert %0, %1[] : i32 into vector<i32> %3 = vector.bitcast %2 : vector<i32> to vector<f32> - %4 = vector.extractelement %3[] : vector<f32> + %4 = vector.extract %3[] : f32 from vector<f32> %5 = arith.bitcast %4 : f32 to i32 // CHECK: 42 vector.print %5: i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index b69a200..eb99886 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -72,7 +72,7 @@ func.func @za0_d_f64() -> i32 { %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64> %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { - %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %t = vector.extract %row[%offset] : f64 from vector<[2]xf64> %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 scf.yield %inner_add_reduce_next : f64 } @@ -102,7 +102,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -125,7 +125,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index 697fb90..ad8e321 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -36,7 +36,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -64,7 +64,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir index 53a7282..aff272c2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir @@ -11,8 +11,8 @@ func.func @entry() -> i32 { %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %r = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32> + %1 = vector.extract %r[%i0] : f32 from vector<8xf32> + %2 = vector.extract %r[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir index bf1caaa..1c56990 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir @@ -196,13 +196,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { %v_A = vector.transfer_read %m_A[%a], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { %v_C = vector.transfer_read %m_C[%b], %index_padding : memref<?xi64>, vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { @@ -273,10 +273,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>, %v_C = vector.transfer_read %m_C[%b1], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { @@ -370,8 +370,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64 -> f64 %r2 = arith.addf %r1, %subresult : f64 - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 %cond_a_i64 = arith.extui %cond_a : i1 to i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir index e9a66cc..1683fa5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -49,7 +48,7 @@ func.func @entry() { memref.store %z, %A[%i] : memref<?xf32> %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir index 2dc00df..826da53 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -53,7 +52,7 @@ func.func @entry() { iter_args(%v_iter = %v) -> (vector<16xf32>) { %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir index 54b6e69..22b5eef 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir @@ -21,8 +21,7 @@ func.func @printmem8(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c8 step %c1 iter_args(%m_iter = %m) -> (vector<8xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<8xf32> scf.yield %m_new : vector<8xf32> } vector.print %mem : vector<8xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir index 2393bd1..639eed4 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir @@ -200,7 +200,7 @@ func.func @entry() { // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. + // Generates a loop with vector.insert. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> () // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir index e665653..731bd5a 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %interleave = vector.interleave %lhs1, %rhs1 : vector<2xi32> -> vector<4xi32> - %res0 = vector.extractelement %interleave[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %interleave[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %interleave[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %interleave[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %interleave[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %interleave[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %interleave[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %interleave[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir index dc53fe3..c1b7dba 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %shuffle = vector.shuffle %lhs1, %rhs1[2, 1, 3, 3] : vector<2xi32>, vector<2xi32> - %res0 = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %shuffle[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %shuffle[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %shuffle[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %shuffle[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %shuffle[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %shuffle[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index cdbca72..7888462 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -595,16 +595,17 @@ module attributes {transform.with_named_sequence} { // ----- -// It is valid to fuse the pack op with padding semantics if the tiled -// dimensions do not need padding. +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> } } %1 = tensor.empty() : tensor<22x2x3x16xf32> @@ -621,109 +622,39 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] -// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] -// CHECK-SAME: into %[[TILED_PACK_DEST]] -// CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// It is valid to fuse the pack if the dimension is not tiled even when it needs -// extra padding. - -func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<33x2x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32> - return %pack : tensor<33x2x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> -// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32> -// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM:.*]] = linalg.exp -// CHECK-SAME: ins(%[[ELEM_SRC]] -// CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1] - -// ----- - -// If the dimension is tiled and it needs extra padding, do not fuse the pack -// op. - -func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> - scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> - } - } - %1 = tensor.empty() : tensor<23x32x3x16xf32> - %cst = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32> - return %pack : tensor<23x32x3x16xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // ----- diff --git a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir b/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir deleted file mode 100644 index d889ef4..0000000 --- a/mlir/test/Target/LLVMIR/omptarget-debug-reduc-fn-loc.mlir +++ /dev/null @@ -1,121 +0,0 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s - -module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { - omp.private {type = private} @_QFEi_private_i32 : i32 loc(#loc1) - omp.declare_reduction @add_reduction_i32 : i32 init { - ^bb0(%arg0: i32 loc("test.f90":8:7)): - %0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } combiner { - ^bb0(%arg0: i32 loc("test.f90":8:7), %arg1: i32 loc("test.f90":8:7)): - %0 = llvm.add %arg0, %arg1 : i32 loc(#loc2) - omp.yield(%0 : i32) loc(#loc2) - } loc(#loc2) - llvm.func @_QQmain() { - %0 = llvm.mlir.constant(1 : i64) : i64 loc(#loc4) - %1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr<5> loc(#loc4) - %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr loc(#loc4) - %3 = llvm.mlir.constant(1 : i64) : i64 loc(#loc1) - %4 = llvm.alloca %3 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr<5> loc(#loc1) - %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr loc(#loc1) - %6 = llvm.mlir.constant(8191 : index) : i64 loc(#loc5) - %7 = llvm.mlir.constant(0 : index) : i64 loc(#loc5) - %8 = llvm.mlir.constant(1 : index) : i64 loc(#loc5) - %9 = llvm.mlir.constant(0 : i32) : i32 loc(#loc5) - %10 = llvm.mlir.constant(8192 : index) : i64 loc(#loc5) - %11 = llvm.mlir.addressof @_QFEarr : !llvm.ptr<1> loc(#loc6) - %12 = llvm.addrspacecast %11 : !llvm.ptr<1> to !llvm.ptr loc(#loc6) - llvm.store %9, %2 : i32, !llvm.ptr loc(#loc7) - %15 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "x"} loc(#loc4) - %16 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"} loc(#loc7) - %17 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) extent(%10 : i64) stride(%8 : i64) start_idx(%8 : i64) loc(#loc7) - %18 = omp.map.info var_ptr(%12 : !llvm.ptr, !llvm.array<8192 x i32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%17) -> !llvm.ptr {name = "arr"} loc(#loc7) - omp.target map_entries(%15 -> %arg0, %16 -> %arg1, %18 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) { - %19 = llvm.mlir.constant(8192 : i32) : i32 loc(#loc5) - %20 = llvm.mlir.constant(1 : i32) : i32 loc(#loc5) - %21 = llvm.mlir.constant(8192 : index) : i64 loc(#loc6) - omp.teams reduction(@add_reduction_i32 %arg0 -> %arg3 : !llvm.ptr) { - omp.parallel private(@_QFEi_private_i32 %arg1 -> %arg4 : !llvm.ptr) { - omp.distribute { - omp.wsloop reduction(@add_reduction_i32 %arg3 -> %arg5 : !llvm.ptr) { - omp.loop_nest (%arg6) : i32 = (%20) to (%19) inclusive step (%20) { - llvm.store %arg6, %arg4 : i32, !llvm.ptr loc(#loc2) - %22 = llvm.load %arg5 : !llvm.ptr -> i32 loc(#loc8) - %23 = llvm.load %arg4 : !llvm.ptr -> i32 loc(#loc8) - %34 = llvm.add %22, %23 : i32 loc(#loc8) - llvm.store %34, %arg5 : i32, !llvm.ptr loc(#loc8) - omp.yield loc(#loc2) - } loc(#loc2) - } {omp.composite} loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } {omp.composite} loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc2) - omp.terminator loc(#loc2) - } loc(#loc13) - llvm.return loc(#loc9) - } loc(#loc12) - llvm.mlir.global internal @_QFEarr() {addr_space = 1 : i32} : !llvm.array<8192 x i32> { - %0 = llvm.mlir.zero : !llvm.array<8192 x i32> loc(#loc6) - llvm.return %0 : !llvm.array<8192 x i32> loc(#loc6) - } loc(#loc6) -} loc(#loc) - -#loc = loc("test.f90":4:18) -#loc1 = loc("test.f90":4:18) -#loc2 = loc("test.f90":8:7) -#loc3 = loc("test.f90":1:7) -#loc4 = loc("test.f90":3:18) -#loc5 = loc(unknown) -#loc6 = loc("test.f90":5:18) -#loc7 = loc("test.f90":6:7) -#loc8 = loc("test.f90":10:7) -#loc9 = loc("test.f90":16:7) - -#di_file = #llvm.di_file<"target7.f90" in ""> -#di_null_type = #llvm.di_null_type -#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, - sourceLanguage = DW_LANG_Fortran95, file = #di_file, producer = "flang", - isOptimized = false, emissionKind = LineTablesOnly> -#di_subroutine_type = #llvm.di_subroutine_type< - callingConvention = DW_CC_program, types = #di_null_type> -#di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, - compileUnit = #di_compile_unit, scope = #di_file, name = "main", - file = #di_file, subprogramFlags = "Definition|MainSubprogram", - type = #di_subroutine_type> -#di_subprogram1 = #llvm.di_subprogram<compileUnit = #di_compile_unit, - name = "target", file = #di_file, subprogramFlags = "Definition", - type = #di_subroutine_type> - - -#loc12 = loc(fused<#di_subprogram>[#loc3]) -#loc13 = loc(fused<#di_subprogram1>[#loc2]) - -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @"__omp_offloading_{{.*}}__QQmain_l8_omp$reduction$reduction_func.1" -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_shuffle_and_reduce_func.2 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_inter_warp_copy_func.3 -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_list_to_global_reduce_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_copy_func -// CHECK-NOT: !dbg -// CHECK: } -// CHECK-DAG: define internal void @_omp_reduction_global_to_list_reduce_func -// CHECK-NOT: !dbg -// CHECK: } diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir new file mode 100644 index 0000000..a3dd0b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) + llvm.func @prefetch(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(1 : i64) : i64 + // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0) + {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, + no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64, + xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} + : (!llvm.ptr<1>, i64) -> () + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 76d34c2..6aca11e 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -1,6 +1,7 @@ // RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s +// RUN: %if spirv-tools %{ mlir-translate -no-implicit-module --split-input-file -serialize-spirv %s | spirv-val %} -spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int64, Int16, Int8, Float64, Float16, CooperativeMatrixKHR], [SPV_KHR_vulkan_memory_model, SPV_KHR_cooperative_matrix]> { // CHECK-LABEL: @bool_const spirv.func @bool_const() -> () "None" { // CHECK: spirv.Constant true @@ -305,6 +306,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> } + + spirv.EntryPoint "GLCompute" @bool_const } // ----- diff --git a/mlir/test/Target/SPIRV/lit.local.cfg b/mlir/test/Target/SPIRV/lit.local.cfg new file mode 100644 index 0000000..6d44394 --- /dev/null +++ b/mlir/test/Target/SPIRV/lit.local.cfg @@ -0,0 +1,4 @@ +if config.spirv_tools_tests: + config.available_features.add("spirv-tools") + config.substitutions.append(("spirv-as", os.path.join(config.llvm_tools_dir, "spirv-as"))) + config.substitutions.append(("spirv-val", os.path.join(config.llvm_tools_dir, "spirv-val"))) diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 53fbb8a..d6fa442 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>> %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>> @@ -12,9 +12,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> + // CHECK: {{.*}} = memref.subview %[[input]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>> %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>> %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>> @@ -24,12 +24,12 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -38,13 +38,13 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_128 = arith.constant 128 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C3]], %[[C384]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>> %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>> return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>> @@ -53,9 +53,9 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, stri // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { +// CHECK-SAME: %[[input:.*]]: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> + // CHECK: {{.*}} = memref.subview %[[input]][4, 384] [1, 64] [4, 4] : memref<8x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<8x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>> %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>> @@ -64,9 +64,9 @@ func.func @subview_strided(%input: memref<8x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { +// CHECK-SAME: %[[input:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> { - // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> + // CHECK: {{.*}} = memref.subview %[[input]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>> %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>> %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>> %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>> @@ -76,13 +76,13 @@ func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index + // CHECK: %[[C384:.*]] = arith.constant 384 : index %cst_64 = arith.constant 64 : index - // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], %[[C384]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> @@ -91,13 +91,39 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid // ----- // CHECK-LABEL: func.func @subview_strided( -// CHECK-SAME: %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { +// CHECK-SAME: %[[input:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> { - // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index %cst_1 = arith.constant 1 : index %cst_2 = arith.constant 2 : index - // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> + // CHECK: {{.*}} = memref.subview %[[input]]{{\[}}%[[C4]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>> %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>> %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> } + +// ----- + +// CHECK-LABEL: func.func @single_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE_1:.*]]: index) -> memref<8x?xf32> { +func.func @single_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<8x?xf32>{ + %subview = memref.subview %input[0, 0][8, %size0][1, 1] : memref<256x?xf32> to memref<8x?xf32> + %subview_1 = memref.subview %subview[0, 0][8, %size1][1, 1] : memref<8x?xf32> to memref<8x?xf32> + // CHECK: %{{.*}} = memref.subview %[[input]][0, 0] [8, %[[SIZE_1]]] [1, 1] : memref<256x?xf32> to memref<8x?xf32> + return %subview_1 : memref<8x?xf32> +} + +// ----- + +// CHECK-LABEL: func.func @all_dynamic_size_subview( +// CHECK-SAME: %[[input:.*]]: memref<256x?xf32>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE1:.*]]: index) -> memref<?x?xf32> { +func.func @all_dynamic_size_subview(%input: memref<256x?xf32>, %size0 : index, %size1 : index) -> memref<?x?xf32>{ + %subview = memref.subview %input[0, 0][%size0, %size0][1, 1] : memref<256x?xf32> to memref<?x?xf32> + %subview_1 = memref.subview %subview[0, 0][%size1, %size1][1, 1] : memref<?x?xf32> to memref<?x?xf32> + // CHECK: {{.*}} = memref.subview %[[input]][0, 0] {{\[}}%[[SIZE1]], %[[SIZE1]]] [1, 1] : memref<256x?xf32> to memref<?x?xf32> + return %subview_1 : memref<?x?xf32> +} diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index db8bd0f..9bffe92 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() { "test.signature_conversion_no_converter"() ({ // expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}} ^bb0(%arg0: f32): - "test.type_consumer"(%arg0) : (f32) -> () // expected-note@below{{see existing live user here}} + "test.type_consumer"(%arg0) : (f32) -> () "test.return"(%arg0) : (f32) -> () }) : () -> () return diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir index 19a1310..5b07055 100644 --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s // expected-remark@-2 {{op 'builtin.module' is legalizable}} // expected-remark@+1 {{op 'func.func' is legalizable}} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 5f1148c..dcd0172 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="test-legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: func @multi_level_mapping func.func @multi_level_mapping() { diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index 8a01a0a..016052c 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -69,25 +69,25 @@ struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, - immAttr, rvl); + res = vcix::BinaryImmOp::create(rewriter, loc, legalType, opcodeAttr, vec, + immAttr, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, - extracted, immAttr, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryImmOp::create( + rewriter, loc, legalType, opcodeAttr, extracted, immAttr, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -112,25 +112,25 @@ struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - vec, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + vec, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, extracted, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, extracted, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -152,28 +152,28 @@ struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { Location loc = op.getLoc(); Value vec = op.getOperand(); Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); Value rvl = nullptr; if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - zero, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + zero, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, zero, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, zero, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); @@ -195,30 +195,30 @@ struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { Value vec = op.getOperand(); Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); Value rvl = nullptr; - Value zeroInt = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value zeroInt = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); if (legalType.isScalable()) // Use arbitrary runtime vector length when vector type is scalable. // Proper conversion pass should take it from the IR. - rvl = rewriter.create<arith::ConstantOp>(loc, - rewriter.getI64IntegerAttr(9)); + rvl = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(9)); Value res; if (n == 1) { - res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, - zeroInt, rvl); + res = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, vec, + zeroInt, rvl); } else { const unsigned eltCount = legalType.getShape()[0]; Type eltTy = legalType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, eltTy, rewriter.getZeroAttr(eltTy)); - res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); + Value zero = arith::ConstantOp::create(rewriter, loc, eltTy, + rewriter.getZeroAttr(eltTy)); + res = vector::BroadcastOp::create(rewriter, loc, opType, zero /*dummy*/); for (unsigned i = 0; i < n; ++i) { - Value extracted = rewriter.create<vector::ScalableExtractOp>( - loc, legalType, vec, i * eltCount); - Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, - extracted, zeroInt, rvl); - res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, - i * eltCount); + Value extracted = vector::ScalableExtractOp::create( + rewriter, loc, legalType, vec, i * eltCount); + Value v = vcix::BinaryOp::create(rewriter, loc, legalType, opcodeAttr, + extracted, zeroInt, rvl); + res = vector::ScalableInsertOp::create(rewriter, loc, v, res, + i * eltCount); } } rewriter.replaceOp(op, res); diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index ed5d06d..3569a73 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -145,7 +145,7 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, if (reifiedScalable->map.getNumInputs() == 1) { // The only possible input to the bound is vscale. vscaleOperand.push_back(std::make_pair( - rewriter.create<vector::VectorScaleOp>(loc), std::nullopt)); + vector::VectorScaleOp::create(rewriter, loc), std::nullopt)); } reified = affine::materializeComputedBound( rewriter, loc, reifiedScalable->map, vscaleOperand); @@ -169,8 +169,9 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp, rewriter.replaceOp(op, val); return WalkResult::skip(); } - Value constOp = rewriter.create<arith::ConstantIndexOp>( - op->getLoc(), cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); + Value constOp = arith::ConstantIndexOp::create( + rewriter, op->getLoc(), + cast<IntegerAttr>(cast<Attribute>(*reified)).getInt()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); }); diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp index 738d4ee59..a792d08 100644 --- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp +++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp @@ -60,7 +60,7 @@ struct TestEmulateWideIntPass // casts (and vice versa) and using it insted of `llvm.bitcast`. auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs); + auto cast = LLVM::BitcastOp::create(builder, loc, type, inputs); return cast->getResult(0); }; typeConverter.addSourceMaterialization(addBitcast); diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index 226e0bb..2ee3222 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestOneShotModuleBufferize.cpp TestTensorCopyInsertion.cpp TestTensorLikeAndBufferLike.cpp diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp new file mode 100644 index 0000000..1e2d4a7 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -0,0 +1,57 @@ +//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneShotModuleBufferizePass + : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) + + TestOneShotModuleBufferizePass() = default; + TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<bufferization::BufferizationDialect>(); + } + StringRef getArgument() const final { + return "test-one-shot-module-bufferize"; + } + StringRef getDescription() const final { + return "Pass to test One Shot Module Bufferization"; + } + + void runOnOperation() override { + + llvm::errs() << "Running TestOneShotModuleBufferize on: " + << getOperation()->getName() << "\n"; + bufferization::OneShotBufferizationOptions opt; + + opt.bufferizeFunctionBoundaries = true; + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, + bufferizationState))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestOneShotModuleBufferizePass() { + PassRegistration<TestOneShotModuleBufferizePass>(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index d0b62e7..c67bcd9 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -48,8 +48,8 @@ static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder, } for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { Type elementType = tupleType.getType(i); - Value element = builder.create<test::GetTupleElementOp>( - loc, elementType, tuple, builder.getI32IntegerAttr(i)); + Value element = test::GetTupleElementOp::create( + builder, loc, elementType, tuple, builder.getI32IntegerAttr(i)); decompose(element); } }; @@ -94,7 +94,7 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, } // Assemble the tuple from the elements. - return builder.create<test::MakeTupleOp>(loc, resultType, elements); + return test::MakeTupleOp::create(builder, loc, resultType, elements); } /// A pass for testing call graph type decomposition. diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 9eade75..9a394d2 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -56,7 +56,7 @@ struct TestSCFForUtilsPass SmallVector<Value> newYieldValues; for (auto yieldVal : oldYieldValues) { newYieldValues.push_back( - b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); + arith::AddFOp::create(b, loc, yieldVal, yieldVal)); } return newYieldValues; }; @@ -160,13 +160,13 @@ struct TestSCFPipeliningPass Value pred) { Location loc = op->getLoc(); auto ifOp = - rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); + scf::IfOp::create(rewriter, loc, op->getResultTypes(), pred, true); // True branch. rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); if (op->getNumResults() > 0) - rewriter.create<scf::YieldOp>(loc, op->getResults()); + scf::YieldOp::create(rewriter, loc, op->getResults()); // False branch. rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); SmallVector<Value> elseYieldOperands; @@ -181,12 +181,12 @@ struct TestSCFPipeliningPass } else { // Default to assuming constant numeric values. for (Type type : op->getResultTypes()) { - elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(type))); + elseYieldOperands.push_back(arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(type))); } } if (op->getNumResults() > 0) - rewriter.create<scf::YieldOp>(loc, elseYieldOperands); + scf::YieldOp::create(rewriter, loc, elseYieldOperands); return ifOp.getOperation(); } diff --git a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp index d3113c0..d3f7f0e6 100644 --- a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp +++ b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp @@ -50,23 +50,23 @@ struct TestSCFWhileOpBuilderPass // Create a WhileOp with the same operands and result types. TypeRange resultTypes = whileOp->getResultTypes(); ValueRange operands = whileOp->getOperands(); - builder.create<WhileOp>( - loc, resultTypes, operands, /*beforeBuilder=*/ + WhileOp::create( + builder, loc, resultTypes, operands, /*beforeBuilder=*/ [&](OpBuilder &b, Location loc, ValueRange args) { // Just cast the before args into the right types for condition. ImplicitLocOpBuilder builder(loc, b); auto castOp = - builder.create<UnrealizedConversionCastOp>(resultTypes, args); - auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1); - builder.create<ConditionOp>(cmp, castOp->getResults()); + UnrealizedConversionCastOp::create(builder, resultTypes, args); + auto cmp = ConstantIntOp::create(builder, /*value=*/1, /*width=*/1); + ConditionOp::create(builder, cmp, castOp->getResults()); }, /*afterBuilder=*/ [&](OpBuilder &b, Location loc, ValueRange args) { // Just cast the after args into the right types for yield. ImplicitLocOpBuilder builder(loc, b); - auto castOp = builder.create<UnrealizedConversionCastOp>( - operands.getTypes(), args); - builder.create<YieldOp>(castOp->getResults()); + auto castOp = UnrealizedConversionCastOp::create( + builder, operands.getTypes(), args); + YieldOp::create(builder, castOp->getResults()); }); }); } diff --git a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp index ac71ff6..23fdad1 100644 --- a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp +++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp @@ -72,15 +72,14 @@ struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> { ShapedType sourceShardShape = shardShapedType(op.getResult().getType(), grid, op.getSharding()); TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>( - builder - .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc()) + UnrealizedConversionCastOp::create(builder, sourceShardShape, + op.getSrc()) ->getResult(0)); TypedValue<ShapedType> targetShard = reshard(builder, grid, op, targetShardOp, sourceShard); Value newTargetUnsharded = - builder - .create<UnrealizedConversionCastOp>( - targetShardOp.getResult().getType(), targetShard) + UnrealizedConversionCastOp::create( + builder, targetShardOp.getResult().getType(), targetShard) ->getResult(0); rewriter.replaceAllUsesWith(targetShardOp.getResult(), newTargetUnsharded); diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 0e191c3..687473e 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -192,8 +192,8 @@ struct RewriteExtractSliceFromCollapseShapeBase // Create the destination tensor using the above values. Type elementType = op.getSourceType().getElementType(); SmallVector<OpFoldResult> outputShape = reifiedShapes[0]; - Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape, - elementType); + Value dest = tensor::EmptyOp::create(rewriter, op->getLoc(), outputShape, + elementType); // Calculate the parameters for the tile loop nest. FailureOr<tensor::ExtractSliceFromCollapseHelper> params = @@ -215,8 +215,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor PatternRewriter &rewriter) const override { Location loc = op.getLoc(); const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); SmallVector<Value> lbs(numTiledDims, zero); SmallVector<Value> steps(numTiledDims, one); @@ -228,8 +228,8 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); // Insert the slice into the destination. - return {nestedBuilder.create<tensor::InsertSliceOp>( - loc, tile, iterArgs[0], insertParams)}; + return {tensor::InsertSliceOp::create(nestedBuilder, loc, tile, + iterArgs[0], insertParams)}; }); rewriter.replaceOp(op, nest.results); @@ -245,8 +245,9 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach tensor::ExtractSliceFromCollapseHelper &helper, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto forallOp = rewriter.create<scf::ForallOp>( - loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), + auto forallOp = scf::ForallOp::create( + rewriter, loc, + /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), /*outputs=*/dest, /*mapping=*/std::nullopt, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { @@ -261,10 +262,10 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach auto [tile, insertParams] = helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); // Insert the slice into the destination. - auto term = nestedBuilder.create<scf::InParallelOp>(loc); + auto term = scf::InParallelOp::create(nestedBuilder, loc); nestedBuilder.setInsertionPointToStart(term.getBody()); - nestedBuilder.create<tensor::ParallelInsertSliceOp>( - loc, tile, outputArgs[0], insertParams); + tensor::ParallelInsertSliceOp::create(nestedBuilder, loc, tile, + outputArgs[0], insertParams); }); rewriter.replaceOp(op, forallOp->getResult(0)); return success(); @@ -355,8 +356,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { MLIRContext *context = rootOp->getContext(); OpBuilder builder(context); OwningOpRef<transform::NamedSequenceOp> transformOp = - builder.create<transform::NamedSequenceOp>( - rootOp->getLoc(), + transform::NamedSequenceOp::create( + builder, rootOp->getLoc(), /*sym_name=*/"test_sequence", /*function_type=*/ TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 382da59..5685004 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> { let mnemonic = "copy_count"; let parameters = (ins TestParamCopyCount:$copy_count); let assemblyFormat = "`<` $copy_count `>`"; + let genVerifyDecl = 1; } def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> { diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b31e90f..5890913 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -214,6 +214,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { } //===----------------------------------------------------------------------===// +// TestCopyCountAttr Implementation +//===----------------------------------------------------------------------===// + +LogicalResult TestCopyCountAttr::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/, + CopyCount /*copy_count*/) { + return success(); +} + +//===----------------------------------------------------------------------===// // CopyCountAttr Implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 1bbf2cc..a4c615b 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -346,7 +346,7 @@ TestDialect::~TestDialect() { Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create<TestOpConstant>(loc, type, value); + return TestOpConstant::create(builder, loc, type, value); } void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 01ae245..1235a5f 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -354,7 +354,7 @@ struct TestInlinerInterface : public DialectInlinerInterface { !(input.getType().isSignlessInteger(16) || input.getType().isSignlessInteger(32))) return nullptr; - return builder.create<TestCastOp>(conversionLoc, resultType, input); + return TestCastOp::create(builder, conversionLoc, resultType, input); } Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, @@ -362,16 +362,16 @@ struct TestInlinerInterface : public DialectInlinerInterface { DictionaryAttr argumentAttrs) const final { if (!argumentAttrs.contains("test.handle_argument")) return argument; - return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(), - argument); + return TestTypeChangerOp::create(builder, call->getLoc(), + argument.getType(), argument); } Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, DictionaryAttr resultAttrs) const final { if (!resultAttrs.contains("test.handle_result")) return result; - return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(), - result); + return TestTypeChangerOp::create(builder, call->getLoc(), result.getType(), + result); } void processInlinedCallBlocks( diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp index dc6413b..b98f6ce 100644 --- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp @@ -43,11 +43,11 @@ static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst, if (failed(addr)) return failure(); // Create the LoadOp - Value loadOp = builder.create<LLVM::LoadOp>( - moduleImport.translateLoc(inst->getDebugLoc()), + Value loadOp = LLVM::LoadOp::create( + builder, moduleImport.translateLoc(inst->getDebugLoc()), moduleImport.convertType(inst->getType()), *addr); - moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>( - loadOp.getLoc(), loadOp.getType(), loadOp, loadOp); + moduleImport.mapValue(inst) = SameOperandElementTypeOp::create( + builder, loadOp.getLoc(), loadOp.getType(), loadOp, loadOp); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 3ab4ef2..53055fe 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -18,6 +18,32 @@ using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// +// OverridenSymbolVisibilityOp +//===----------------------------------------------------------------------===// + +SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() { + return SymbolTable::Visibility::Private; +} + +static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) { + switch (visibility) { + case SymbolTable::Visibility::Private: + return "private"; + case SymbolTable::Visibility::Nested: + return "nested"; + case SymbolTable::Visibility::Public: + return "public"; + } +} + +void OverriddenSymbolVisibilityOp::setVisibility( + SymbolTable::Visibility visibility) { + + emitOpError("cannot change visibility of symbol to ") + << getVisibilityString(visibility); +} + +//===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// @@ -286,9 +312,9 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold<tensor::DimOp>(loc, operand, dim); })); - shapes.push_back(builder.create<tensor::FromElementsOp>( - getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), - currShape)); + shapes.push_back(tensor::FromElementsOp::create( + builder, getLoc(), + RankedTensorType::get({rank}, builder.getIndexType()), currShape)); } return success(); } @@ -1302,8 +1328,8 @@ llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create<TestOpConstant>(getLoc(), slot.elemType, - builder.getI32IntegerAttr(42)); + return TestOpConstant::create(builder, getLoc(), slot.elemType, + builder.getI32IntegerAttr(42)); } void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, @@ -1335,7 +1361,7 @@ createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(oldOp); auto replacement = - builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); + TestMultiSlotAlloca::create(builder, oldOp->getLoc(), newTypes); for (auto [oldResult, newResult] : llvm::zip_equal(remainingValues, replacement.getResults())) oldResult.replaceAllUsesWith(newResult); @@ -1384,7 +1410,7 @@ DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( for (Attribute usedIndex : usedIndices) { Type elemType = slot.subelementTypes.lookup(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); - auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); + auto subAlloca = TestMultiSlotAlloca::create(builder, getLoc(), elemPtr); newAllocators.push_back(subAlloca); slotMap.try_emplace<MemorySlot>(usedIndex, {subAlloca.getResult(0), elemType}); @@ -1412,8 +1438,8 @@ TestMultiSlotAlloca::handleDestructuringComplete( const auto bufferizedOutType = test::TestMemrefType::get( getContext(), outType.getShape(), outType.getElementType(), nullptr); // replace op with memref analogy - auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>( - getLoc(), bufferizedOutType, *buffer); + auto dummyMemrefOp = test::TestDummyMemrefOp::create( + rewriter, getLoc(), bufferizedOutType, *buffer); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), dummyMemrefOp.getResult()); @@ -1434,7 +1460,7 @@ TestMultiSlotAlloca::handleDestructuringComplete( // replace op with memref analogy auto createMemrefOp = - rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType); + test::TestCreateMemrefOp::create(rewriter, getLoc(), *bufferizedOutType); mlir::bufferization::replaceOpWithBufferizedValues( rewriter, getOperation(), createMemrefOp.getResult()); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ab3f847..2eaad55 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -119,12 +119,28 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ + DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, +]> { + let summary = "operation overridden symbol visibility accessors"; + let arguments = (ins StrAttr:$sym_name); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; let regions = (region SizedRegion<1>:$region); } +def SymbolScopeIsolatedOp + : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable, + SingleBlockImplicitTerminator< + "TerminatorOp">]> { + let summary = + "operation which defines a new symbol table that is IsolatedFromAbove"; + let regions = (region SizedRegion<1>:$region); +} + def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; @@ -2035,7 +2051,7 @@ def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> { OpBuilder::InsertionGuard g($_builder); Block *body = $_builder.createBlock(bodyRegion); $_builder.setInsertionPointToEnd(body); - $_builder.create<IllegalOpTerminator>($_state.location); + IllegalOpTerminator::create($_builder,$_state.location); }]>]; } def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">; @@ -2738,7 +2754,7 @@ def TestLinalgConvOp : static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, mlir::ArrayRef<mlir::NamedAttribute> attrs, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) { - b.create<mlir::linalg::YieldOp>(block.getArguments().back()); + mlir::linalg::YieldOp::create(b,block.getArguments().back()); } static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, @@ -2801,7 +2817,7 @@ def TestLinalgFillOp : static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, mlir::ArrayRef<mlir::NamedAttribute> attrs, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) { - b.create<mlir::linalg::YieldOp>(block.getArguments().back()); + mlir::linalg::YieldOp::create(b,block.getArguments().back()); } static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &, diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp index 6d4e5e3..cc131ad 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -313,7 +313,7 @@ ParseResult WrappingRegionOp::parse(OpAsmParser &parser, SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); - builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); + TestReturnOp::create(builder, wrappedOp->getLoc(), returnOperands); // Get the results type for the wrapping op from the terminator operands. Operation &returnOp = body.back().back(); @@ -397,7 +397,7 @@ ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); // Insert a return statement in the block returning the inner-op's result. - builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); + TestReturnOp::create(builder, innerOp->getLoc(), innerOp->getResults()); // Populate the op operation-state with result-type and location. result.addTypes(opFntype.getResults()); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 1fff57e..eda618f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -33,14 +33,14 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { } static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { - rewriter.create<OpI>(loc, input); + OpI::create(rewriter, loc, input); } static void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) { // Turn the no result op to a one-result op. - rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), - op.getOperand()); + OpSymbolBindingB::create(rewriter, op.getLoc(), op.getOperand().getType(), + op.getOperand()); } static bool getFirstI32Result(Operation *op, Value &value) { @@ -120,8 +120,8 @@ public: return failure(); rewriter.setInsertionPointToStart(op->getBlock()); - auto constOp = rewriter.create<arith::ConstantOp>( - op.getLoc(), rewriter.getBoolAttr(true)); + auto constOp = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getBoolAttr(true)); rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), Value(constOp)); return success(); @@ -844,8 +844,8 @@ struct TestRegionRewriteUndo : public RewritePattern { rewriter.getUnknownLoc()); // Add an explicitly illegal operation to ensure the conversion fails. - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); - rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getIntegerType(32)); + TestValidOp::create(rewriter, op->getLoc(), ArrayRef<Value>()); // Drop this operation. rewriter.eraseOp(op); @@ -864,7 +864,7 @@ struct TestCreateBlock : public RewritePattern { Type i32Type = rewriter.getIntegerType(32); Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); - rewriter.create<TerminatorOp>(loc); + TerminatorOp::create(rewriter, loc); rewriter.eraseOp(op); return success(); } @@ -883,8 +883,8 @@ struct TestCreateIllegalBlock : public RewritePattern { Location loc = op->getLoc(); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); // Create an illegal op to ensure the conversion fails. - rewriter.create<ILLegalOpF>(loc, i32Type); - rewriter.create<TerminatorOp>(loc); + ILLegalOpF::create(rewriter, loc, i32Type); + TerminatorOp::create(rewriter, loc); rewriter.eraseOp(op); return success(); } @@ -939,7 +939,7 @@ struct TestUndoBlockErase : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { Block *secondBlock = &*std::next(op->getRegion(0).begin()); rewriter.setInsertionPointToStart(secondBlock); - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type()); rewriter.eraseBlock(secondBlock); rewriter.modifyOpInPlace(op, [] {}); return success(); @@ -1007,9 +1007,8 @@ struct TestPassthroughInvalidOp : public ConversionPattern { // This is a 1:N replacement. Insert a test.cast op. (That's what the // argument materialization used to do.) flattened.push_back( - rewriter - .create<TestCastOp>(op->getLoc(), - op->getOperand(it.index()).getType(), range) + TestCastOp::create(rewriter, op->getLoc(), + op->getOperand(it.index()).getType(), range) .getResult()); } rewriter.replaceOpWithNewOp<TestValidOp>(op, TypeRange(), flattened, @@ -1114,8 +1113,8 @@ struct TestNonRootReplacement : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { auto resultType = *op->result_type_begin(); - auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); - auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); + auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType); + auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType); rewriter.replaceOp(illegalOp, legalOp); rewriter.replaceOp(op, illegalOp); @@ -1181,7 +1180,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { LogicalResult matchAndRewrite(ILLegalOpG op, PatternRewriter &rewriter) const final { IntegerAttr attr = rewriter.getI32IntegerAttr(0); - Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); + Value val = arith::ConstantOp::create(rewriter, op->getLoc(), attr); rewriter.replaceOpWithNewOp<LegalOpC>(op, val); return success(); }; @@ -1354,7 +1353,7 @@ struct TestTypeConverter : public TypeConverter { /// 1->N type mappings. static Value materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); } }; @@ -1362,6 +1361,10 @@ struct TestLegalizePatternDriver : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) + TestLegalizePatternDriver() = default; + TestLegalizePatternDriver(const TestLegalizePatternDriver &other) + : PassWrapper(other) {} + StringRef getArgument() const final { return "test-legalize-patterns"; } StringRef getDescription() const final { return "Run test dialect legalization patterns"; @@ -1369,8 +1372,6 @@ struct TestLegalizePatternDriver /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; - TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<func::FuncDialect, test::TestDialect>(); } @@ -1499,24 +1500,19 @@ struct TestLegalizePatternDriver op->emitRemark() << "op '" << op->getName() << "' is legalizable"; } - /// The mode of conversion to use. - ConversionMode mode; + Option<ConversionMode> mode{ + *this, "test-legalize-mode", + llvm::cl::desc("The legalization mode to use with the test driver"), + llvm::cl::init(ConversionMode::Partial), + llvm::cl::values( + clEnumValN(ConversionMode::Analysis, "analysis", + "Perform an analysis conversion"), + clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), + clEnumValN(ConversionMode::Partial, "partial", + "Perform a partial conversion"))}; }; } // namespace -static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> - legalizerConversionMode( - "test-legalize-mode", - llvm::cl::desc("The legalization mode to use with the test driver"), - llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), - llvm::cl::values( - clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, - "analysis", "Perform an analysis conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", - "Perform a full conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, - "partial", "Perform a partial conversion"))); - //===----------------------------------------------------------------------===// // ConversionPatternRewriter::getRemappedValue testing. This method is used // to get the remapped value of an original value that was replaced using @@ -1916,15 +1912,15 @@ struct TestTypeConversionDriver // Allow casting from F64 back to F32. if (!resultType.isF16() && inputs.size() == 1 && inputs[0].getType().isF64()) - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); // Allow producing an i32 or i64 from nothing. if ((resultType.isInteger(32) || resultType.isInteger(64)) && inputs.empty()) - return builder.create<TestTypeProducerOp>(loc, resultType); + return TestTypeProducerOp::create(builder, loc, resultType); // Allow producing an i64 from an integer. if (isa<IntegerType>(resultType) && inputs.size() == 1 && isa<IntegerType>(inputs[0].getType())) - return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); + return TestCastOp::create(builder, loc, resultType, inputs).getResult(); // Otherwise, fail. return nullptr; }); @@ -2007,7 +2003,7 @@ struct TestTargetMaterializationWithNoUses }); converter.addTargetMaterialization( [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - return builder.create<TestCastOp>(loc, type, inputs).getResult(); + return TestCastOp::create(builder, loc, type, inputs).getResult(); }); ConversionTarget target(getContext()); @@ -2058,7 +2054,7 @@ struct TestUndoBlocksMerge : public ConversionPattern { Operation *branchOp = firstBlock.getTerminator(); Block *secondBlock = &*(std::next(op->getRegion(0).begin())); rewriter.setInsertionPointToStart(secondBlock); - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); + ILLegalOpF::create(rewriter, op->getLoc(), rewriter.getF32Type()); auto succOperands = branchOp->getOperands(); SmallVector<Value, 2> replacements(succOperands); rewriter.eraseOp(branchOp); @@ -2202,9 +2198,7 @@ void registerPatternsTestPass() { PassRegistration<TestStrictPatternDriver>(); PassRegistration<TestWalkPatternDriver>(); - PassRegistration<TestLegalizePatternDriver>([] { - return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); - }); + PassRegistration<TestLegalizePatternDriver>(); PassRegistration<TestRemappedValue>(); diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp index 103817d..7831b27 100644 --- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp @@ -68,8 +68,8 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation( if (createSymbol) { OpBuilder builder(op->getRegion(0)); - builder.create<test::SymbolOp>( - op->getLoc(), + test::SymbolOp::create( + builder, op->getLoc(), StringAttr::get(op->getContext(), "sym_from_attr"), /*sym_visibility=*/nullptr); } diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp index bda614a..9550e4c 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp @@ -47,9 +47,9 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> { op, op->getResultTypes().front()); rewriter.setInsertionPointAfter(bar); - rewriter.create<test_irdl_to_cpp::HashOp>( - bar.getLoc(), rewriter.getIntegerType(32), adaptor.getLhs(), - adaptor.getRhs()); + test_irdl_to_cpp::HashOp::create(rewriter, bar.getLoc(), + rewriter.getIntegerType(32), + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index 3389a1c..6457487 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -87,9 +87,9 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op, return failure(); auto newConstOp = - rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems); - auto newNegateOp = rewriter.create<tosa::NegateOp>( - op->getLoc(), dstQConstType, newConstOp.getResult()); + tosa::ConstOp::create(rewriter, op->getLoc(), dstQConstType, inputElems); + auto newNegateOp = tosa::NegateOp::create( + rewriter, op->getLoc(), dstQConstType, newConstOp.getResult()); rewriter.replaceOp(op, {newNegateOp.getResult()}); return success(); @@ -145,8 +145,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, auto newTosaConv2DOpType = RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); - auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>( - op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), + auto newTosaConv2DOp = tosa::Conv2DOp::create( + rewriter, op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); @@ -178,8 +178,8 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, newTosaConv2DOp.getResult().getType().isUnsignedInteger(); bool outputUnsigned = outputType.isUnsignedInteger(); - auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>( - op->getLoc(), outputType, newTosaConv2DOp.getResult(), + auto newTosaRescaleOp = tosa::RescaleOp::create( + rewriter, op->getLoc(), outputType, newTosaConv2DOp.getResult(), getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}), getConstTensorInt<int8_t>(rewriter, op->getLoc(), {static_cast<int8_t>(shift)}), diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index cdf44c2..97fc699 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -796,8 +796,8 @@ DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( transform::TransformState &state) { // Provide some IR that does not verify. rewriter.setInsertionPointToStart(&target->getRegion(0).front()); - rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), - ValueRange(), /*failToVerify=*/true); + TestDummyPayloadOp::create(rewriter, target->getLoc(), TypeRange(), + ValueRange(), /*failToVerify=*/true); return DiagnosedSilenceableFailure::success(); } @@ -877,7 +877,8 @@ public: Location loc) -> Value { if (inputs.size() != 1) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }; addSourceMaterialization(unrealizedCastConverter); diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a7285ab..f89c944 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -546,8 +546,8 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, auto ip = builder.saveInsertionPoint(); builder.setInsertionPoint(moduleOp); - auto global = builder.create<memref::GlobalOp>( - loc, + auto global = memref::GlobalOp::create( + builder, loc, /*sym_name=*/symbolName, /*sym_visibility=*/builder.getStringAttr("private"), /*type=*/memrefType, @@ -560,19 +560,18 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, global->moveBefore(&moduleOp.front()); builder.restoreInsertionPoint(ip); - return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); + return memref::GetGlobalOp::create(builder, loc, memrefType, symbolName); } static Value warpReduction(Location loc, OpBuilder &builder, Value input, CombiningKind kind, uint32_t size) { // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { - Value shuffled = builder - .create<gpu::ShuffleOp>(loc, laneVal, i, - /*width=*/size, - /*mode=*/gpu::ShuffleMode::XOR) + Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } @@ -647,12 +646,11 @@ struct TestVectorDistribution "unsupported shuffle type"); Type i32Type = builder.getIntegerType(32); Value srcIdxI32 = - builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); - Value warpSzI32 = builder.create<arith::ConstantOp>( - loc, builder.getIntegerAttr(i32Type, warpSz)); - Value result = builder - .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, - gpu::ShuffleMode::IDX) + arith::IndexCastOp::create(builder, loc, i32Type, srcIdx); + Value warpSzI32 = arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(i32Type, warpSz)); + Value result = gpu::ShuffleOp::create(builder, loc, val, srcIdxI32, + warpSzI32, gpu::ShuffleMode::IDX) .getResult(0); return result; }; @@ -680,7 +678,7 @@ struct TestVectorDistribution options.warpAllocationFn = allocateGlobalSharedMemory; options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, gpu::WarpExecuteOnLane0Op warpOp) { - builder.create<gpu::BarrierOp>(loc); + gpu::BarrierOp::create(builder, loc); }; // Test on one pattern in isolation. if (warpOpToSCF) { diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index f71fcf7..c6245b6 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -20,8 +20,6 @@ using namespace mlir::xegpu; namespace { #define DEBUG_TYPE "test-xegpu-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") struct TestXeGPUUnrollingPatterns : public PassWrapper<TestXeGPUUnrollingPatterns, diff --git a/mlir/test/lib/IR/TestPrintInvalid.cpp b/mlir/test/lib/IR/TestPrintInvalid.cpp index 8697918..25d1b19 100644 --- a/mlir/test/lib/IR/TestPrintInvalid.cpp +++ b/mlir/test/lib/IR/TestPrintInvalid.cpp @@ -34,13 +34,14 @@ struct TestPrintInvalidPass void runOnOperation() override { Location loc = getOperation().getLoc(); OpBuilder builder(getOperation().getBodyRegion()); - auto funcOp = builder.create<func::FuncOp>( - loc, "test", FunctionType::get(getOperation().getContext(), {}, {})); + auto funcOp = func::FuncOp::create( + builder, loc, "test", + FunctionType::get(getOperation().getContext(), {}, {})); funcOp.addEntryBlock(); // The created function is invalid because there is no return op. llvm::outs() << "Invalid operation:\n" << funcOp << "\n"; builder.setInsertionPointToEnd(&funcOp.getBody().front()); - builder.create<func::ReturnOp>(loc); + func::ReturnOp::create(builder, loc); // Now this function is valid. llvm::outs() << "Valid operation:\n" << funcOp << "\n"; funcOp.erase(); diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp index 92fd6de..5a5ac45 100644 --- a/mlir/test/lib/IR/TestSlicing.cpp +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -30,8 +30,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op, OpBuilder builder(parentFuncOp); Location loc = op->getLoc(); std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); - func::FuncOp clonedFuncOp = builder.create<func::FuncOp>( - loc, clonedFuncOpName, parentFuncOp.getFunctionType()); + func::FuncOp clonedFuncOp = func::FuncOp::create( + builder, loc, clonedFuncOpName, parentFuncOp.getFunctionType()); IRMapping mapper; builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); for (const auto &arg : enumerate(parentFuncOp.getArguments())) @@ -46,7 +46,7 @@ static LogicalResult createBackwardSliceFunction(Operation *op, (void)result; for (Operation *slicedOp : slice) builder.clone(*slicedOp, mapper); - builder.create<func::ReturnOp>(loc); + func::ReturnOp::create(builder, loc); return success(); } diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 7afe210..25c8e53 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -217,8 +217,8 @@ struct TestInvalidParentPass void runOnOperation() final { FunctionOpInterface op = getOperation(); OpBuilder b(op.getFunctionBody()); - b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func", - ValueRange()); + test::TestCallOp::create(b, op.getLoc(), TypeRange(), "some_unknown_func", + ValueRange()); } }; diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp index 8278937..dc0538e 100644 --- a/mlir/test/lib/Transforms/TestDialectConversion.cpp +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -45,7 +45,7 @@ struct PDLLTypeConverter : public TypeConverter { /// Hook for materializing a conversion. static Value materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); } }; diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp index c518f3f..2888c3c 100644 --- a/mlir/test/lib/Transforms/TestInliningCallback.cpp +++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp @@ -53,8 +53,8 @@ struct InlinerCallback mlir::Operation &call = inlineBlock->back(); builder.setInsertionPointAfter(&call); - auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>( - call.getLoc(), call.getResultTypes()); + auto executeRegionOp = mlir::scf::ExecuteRegionOp::create( + builder, call.getLoc(), call.getResultTypes()); mlir::Region ®ion = executeRegionOp.getRegion(); // Move the inlined blocks into the region @@ -70,8 +70,8 @@ struct InlinerCallback if (test::TestReturnOp returnOp = llvm::dyn_cast<test::TestReturnOp>(&op)) { mlir::OpBuilder returnBuilder(returnOp); - returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(), - returnOp.getOperands()); + mlir::scf::YieldOp::create(returnBuilder, returnOp.getLoc(), + returnOp.getOperands()); returnOp.erase(); } } @@ -79,8 +79,8 @@ struct InlinerCallback // Add test.return after scf.execute_region builder.setInsertionPointAfter(executeRegionOp); - builder.create<test::TestReturnOp>(executeRegionOp.getLoc(), - executeRegionOp.getResults()); + test::TestReturnOp::create(builder, executeRegionOp.getLoc(), + executeRegionOp.getResults()); } void runOnOperation() override { diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp index 4e0213c..c1fb706 100644 --- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -28,7 +28,7 @@ makeIsolatedFromAboveImpl(RewriterBase &rewriter, SmallVector<Value> operands = regionOp.getOperands(); operands.append(capturedValues); auto isolatedRegionOp = - rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands); + test::IsolatedOneRegionOp::create(rewriter, regionOp.getLoc(), operands); rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), isolatedRegionOp.getRegion().begin()); rewriter.eraseOp(regionOp); diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index 9a5632b..ff5838d 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -74,8 +74,8 @@ transform::TestMakeComposedFoldedAffineApply::applyToOne( if (auto v = dyn_cast<Value>(ofr)) { result = v; } else { - result = rewriter.create<arith::ConstantIndexOp>( - loc, getConstantIntValue(ofr).value()); + result = arith::ConstantIndexOp::create(rewriter, loc, + getConstantIntValue(ofr).value()); } results.push_back(result.getDefiningOp()); rewriter.replaceOp(affineApplyOp, result); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 233fef8..feaf5fb 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -343,7 +343,6 @@ if config.enable_assertions: else: config.available_features.add("noasserts") - def have_host_jit_feature_support(feature_name): mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 132aabe..b1185e1 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -5,6 +5,7 @@ import sys config.target_triple = "@LLVM_TARGET_TRIPLE@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@") +config.spirv_tools_tests = @LLVM_INCLUDE_SPIRV_TOOLS_TESTS@ config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@")) config.python_executable = "@Python3_EXECUTABLE@" @@ -41,7 +42,7 @@ config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@ # This is a workaround for the fact that LIT's: # %if <cond> -# requires <cond> to be in the set of available features. +# requires <cond> to be in the set of available features. # TODO: Update LIT's TestRunner so that this is not required. if config.mlir_run_arm_sve_tests: config.available_features.add("mlir_arm_sve_tests") diff --git a/mlir/test/mlir-runner/simple.mlir b/mlir/test/mlir-runner/simple.mlir index 1a03b99..21dabdd 100644 --- a/mlir/test/mlir-runner/simple.mlir +++ b/mlir/test/mlir-runner/simple.mlir @@ -15,10 +15,10 @@ // RUN: ls %t.o // RUN: rm %t.o -// RUN: mlir-runner %s -dump-object-file -object-filename=%T/test.o \ +// RUN: mlir-runner %s -dump-object-file -object-filename=%t.o \ // RUN: %if target={{s390x-.*}} %{ -argext-abi-check=false %} | FileCheck %s -// RUN: ls %T/test.o -// RUN: rm %T/test.o +// RUN: ls %t.o +// RUN: rm %t.o // Declarations of C library functions. llvm.func @logbf(f32) -> f32 diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index d47411d..a809611 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate<CompoundAAttrStorage>()) // DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); +// DEF: CompoundAAttr CompoundAAttr::getChecked( +// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner +// DEF-SAME: ) +// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); + // DEF: ::mlir::Type CompoundAAttr::getInner() const { // DEF-NEXT: return getImpl()->inner; } diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td index 40af548..23ab24e 100644 --- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -44,7 +44,7 @@ def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; // CHECK: test::AOp::Properties tblgen_props; // CHECK: tblgen_values.push_back((*x.getODSResults(0).begin())); // CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y); -// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props); +// CHECK: tblgen_AOp_0 = test::AOp::create(rewriter, odsLoc, tblgen_types, tblgen_values, tblgen_props); // Note: These use strings to pick up a non-trivial storage/interface type // difference. diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td index 0a94746..9bb6103 100644 --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -55,7 +55,7 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)), // We expect ODSOperand 0 here, the attribute before the operand in BOp // definition shouldn't shift the counter. // CHECK: op1 = (*castedOp0.getODSOperands(0).begin()).getDefiningOp(); -// CHECK: rewriter.create<test::BOp>((*a.getODSResults(0).begin()).getLoc() +// CHECK: test::BOp::create(rewriter, (*a.getODSResults(0).begin()).getLoc() def test3 : Pat<(BOp $attr, (AOp:$a $input)), (BOp $attr, (AOp $input), (location $a))>; diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt index 6932e0f..0518620 100644 --- a/mlir/tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt @@ -2,8 +2,6 @@ set(LLVM_OPTIONAL_SOURCES null.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LLVM_LINK_COMPONENTS Core Support @@ -35,22 +33,11 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${conversion_libs} - ${dialect_libs} - ${extension_libs} - - MLIRAffineAnalysis - MLIRAnalysis - MLIRDialect - MLIRFuncAllExtensions MLIRLspServerLib - MLIRParser - MLIRPass - MLIRTensorAllExtensions - MLIRTransforms - MLIRTransformUtils - MLIRSupport - MLIRIR + + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) add_mlir_tool(mlir-lsp-server diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index 6a759d9..10d602f 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 6958fe3..7cc6e78 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -2,9 +2,6 @@ set(LLVM_OPTIONAL_SOURCES null.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(LLVM_LINK_COMPONENTS Core Support @@ -65,21 +62,11 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${dialect_libs} - ${conversion_libs} - ${extension_libs} - MLIRAffineAnalysis - MLIRAnalysis - MLIRCastInterfaces - MLIRDialect MLIROptLib - MLIRParser - MLIRPass - MLIRTransforms - MLIRTransformUtils - MLIRSupport - MLIRIR + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses # TODO: Remove when registerAllGPUToLLVMIRTranslations is no longer # registered directly in mlir-opt.cpp. diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2c09753..14714c45 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -135,6 +135,7 @@ void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); +void registerTestOneShotModuleBufferizePass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); @@ -281,6 +282,7 @@ void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); + mlir::test::registerTestOneShotModuleBufferizePass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion(); diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp index 88a5f36..f99dcdb 100644 --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -201,6 +201,12 @@ int main(int argc, char **argv) { llvm::raw_string_ostream outputStrOS(outputStr); auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os) { + // Split does not guarantee null-termination. Make a copy of the buffer to + // ensure null-termination. + if (!chunkBuffer->getBuffer().ends_with('\0')) { + chunkBuffer = llvm::MemoryBuffer::getMemBufferCopy( + chunkBuffer->getBuffer(), chunkBuffer->getBufferIdentifier()); + } return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs, dumpODS, includedFiles); }; diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt index 1826397..1668bba 100644 --- a/mlir/tools/mlir-query/CMakeLists.txt +++ b/mlir/tools/mlir-query/CMakeLists.txt @@ -1,5 +1,3 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) - if(MLIR_INCLUDE_TESTS) set(test_libs MLIRTestDialect @@ -12,8 +10,8 @@ add_mlir_tool(mlir-query llvm_update_compile_flags(mlir-query) mlir_target_link_libraries(mlir-query PRIVATE - ${dialect_libs} MLIRQueryLib + MLIRRegisterAllDialects ) target_link_libraries(mlir-query PRIVATE ${test_libs}) diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt index d71ac86..349d75b 100644 --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -1,6 +1,3 @@ -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) - if(MLIR_INCLUDE_TESTS) set(test_libs MLIRTestDialect @@ -8,12 +5,9 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${conversion_libs} - ${dialect_libs} - MLIRDialect - MLIRIR - MLIRPass MLIRReduceLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses ) add_mlir_tool(mlir-reduce diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt index 216491e..4120b175 100644 --- a/mlir/tools/mlir-rewrite/CMakeLists.txt +++ b/mlir/tools/mlir-rewrite/CMakeLists.txt @@ -1,21 +1,19 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) set(LLVM_LINK_COMPONENTS Support ) set(LIBS - ${dialect_libs} - MLIRAffineAnalysis MLIRAnalysis MLIRCastInterfaces MLIRDialect + MLIRIR MLIRParser MLIRPass - MLIRTransforms - MLIRTransformUtils + MLIRRegisterAllDialects MLIRSupport - MLIRIR + MLIRTransformUtils + MLIRTransforms ) include_directories(../../../clang/include) diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp index 87df9e1..fd8ae7e 100644 --- a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp +++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp @@ -24,6 +24,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LineIterator.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index dbae2143..3140f12 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -495,7 +495,7 @@ void DefGen::emitCheckedBuilder() { MethodBody &body = m->body().indent(); auto scope = body.scope("return Base::getChecked(emitError, context", ");"); for (const auto ¶m : params) - body << ", " << param.getName(); + body << ", std::move(" << param.getName() << ")"; } static SmallVector<MethodParameter> diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 975a524..605033d 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -632,7 +632,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { ++opArgIdx; continue; } - if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { + if (auto *operand = + llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { if (argTree.isVariadic()) { if (!operand->isVariadic()) { auto error = formatv("variadic DAG construct can't match op {0}'s " @@ -1695,7 +1696,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // Then create the op. os.scope("", "\n}\n").os - << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});", + << formatv("{0} = {1}::create(rewriter, {2}, tblgen_values, {3});", valuePackName, resultOp.getQualCppClassName(), locToUse, useProperties ? "tblgen_props" : "tblgen_attrs"); return resultValue; @@ -1714,7 +1715,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // aggregate-parameter builders. createSeparateLocalVarsForOpArgs(tree, childNodeNames); - os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, + os.scope().os << formatv("{0} = {1}::create(rewriter, {2}", valuePackName, resultOp.getQualCppClassName(), locToUse); supplyValuesForOpArgs(tree, childNodeNames, depth); os << "\n );\n}\n"; @@ -1753,7 +1754,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, resultIndex + i); } } - os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " + os << formatv("{0} = {1}::create(rewriter, {2}, tblgen_types, " "tblgen_values, {3});\n", valuePackName, resultOp.getQualCppClassName(), locToUse, useProperties ? "tblgen_props" : "tblgen_attrs"); @@ -1772,8 +1773,8 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs( int valueIndex = 0; // An index for uniquing local variable names. for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { - const auto *operand = - llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex)); + const auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>( + resultOp.getArg(argIndex)); // We do not need special handling for attributes or properties. if (!operand) continue; @@ -1828,7 +1829,8 @@ void PatternEmitter::supplyValuesForOpArgs( Argument opArg = resultOp.getArg(argIndex); // Handle the case of operand first. - if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { + if (auto *operand = + llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { if (!operand->name.empty()) os << "/*" << operand->name << "=*/"; os << childNodeNames.lookup(argIndex); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 7256705..41ffdfc 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -397,10 +397,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, avail.getMergeInstanceType(), avail.getQueryFnName(), enumName); - os << formatv( - " assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" - " && \"cannot have more than one bit set\");\n", - underlyingType); + os << formatv(" assert(::llvm::popcount(static_cast<{0}>(value)) <= 1" + " && \"cannot have more than one bit set\");\n", + underlyingType); os << " switch (value) {\n"; for (const auto &caseSpecPair : classCasePair.getValue()) { @@ -933,7 +932,8 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc, // Process operands/attributes for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); - if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) { + if (auto *valueArg = + llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) { if (valueArg->isVariableLength()) { if (i != e - 1) { PrintFatalError( @@ -1044,7 +1044,7 @@ static void emitDeserializationFunction(const Record *attrClass, emitDecorationDeserialization(op, " ", valueID, attributes, os); os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n"); - os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " + os << formatv(" auto {1} = {0}::create(opBuilder, loc, {2}, {3}, {4}); " "(void){1};\n", op.getQualCppClassName(), opVar, resultTypes, operands, attributes); diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp index c2ad09f..4343f2d 100644 --- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp +++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp @@ -52,33 +52,33 @@ Value createPredicate(OpBuilder &builder, tblgen::Pred pred) { } if (combiner == "PredCombinerAnd") { auto op = - builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); + irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } auto op = - builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); + irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } } std::string condition = pred.getCondition(); // Build a CPredOp to match the C constraint built. - irdl::CPredOp op = builder.create<irdl::CPredOp>( - UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); + irdl::CPredOp op = irdl::CPredOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, condition)); return op; } Value typeToConstraint(OpBuilder &builder, Type type) { MLIRContext *ctx = builder.getContext(); auto op = - builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type)); + irdl::IsOp::create(builder, UnknownLoc::get(ctx), TypeAttr::get(type)); return op.getOutput(); } Value baseToConstraint(OpBuilder &builder, StringRef baseClass) { MLIRContext *ctx = builder.getContext(); - auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), - StringAttr::get(ctx, baseClass)); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, baseClass)); return op.getOutput(); } @@ -179,7 +179,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { return createTypeConstraint(builder, predRec.getValueAsDef("baseType")); if (predRec.getName() == "AnyType") { - auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); + auto op = irdl::AnyOp::create(builder, UnknownLoc::get(ctx)); return op.getOutput(); } @@ -190,12 +190,12 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { SmallVector<FlatSymbolRefAttr> nested = { SymbolRefAttr::get(ctx, combined)}; auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); - auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), typeSymbol); return op.getOutput(); } std::string typeName = ("!" + predRec.getValueAsString("typeName")).str(); - auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), - StringAttr::get(ctx, typeName)); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, typeName)); return op.getOutput(); } @@ -205,7 +205,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { constraints.push_back( createTypeConstraint(builder, tblgen::Constraint(child))); } - auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); + auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } @@ -215,14 +215,14 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { constraints.push_back( createTypeConstraint(builder, tblgen::Constraint(child))); } - auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); + auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } // Integer types if (predRec.getName() == "AnyInteger") { - auto op = builder.create<irdl::BaseOp>( - UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer")); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, "!builtin.integer")); return op.getOutput(); } @@ -235,7 +235,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { IntegerType::get(ctx, width, IntegerType::Signed)), typeToConstraint(builder, IntegerType::get(ctx, width, IntegerType::Unsigned))}; - auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types); + auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), types); return op.getOutput(); } @@ -253,7 +253,7 @@ Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { for (const Record *child : predRec.getValueAsListOfDefs("predicateList")) { constraints.push_back(createPredicate(builder, tblgen::Pred(child))); } - auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); + auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } @@ -279,7 +279,7 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { constraints.push_back(createPredicate( builder, tblgen::Pred(child->getValueAsDef("predicate")))); } - auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); + auto op = irdl::AllOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } @@ -290,12 +290,12 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { constraints.push_back( createAttrConstraint(builder, tblgen::Constraint(child))); } - auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); + auto op = irdl::AnyOfOp::create(builder, UnknownLoc::get(ctx), constraints); return op.getOutput(); } if (predRec.getName() == "AnyAttr") { - auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); + auto op = irdl::AnyOp::create(builder, UnknownLoc::get(ctx)); return op.getOutput(); } @@ -317,7 +317,7 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { if (predRec.getName() == "UnitAttr") { auto op = - builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx)); + irdl::IsOp::create(builder, UnknownLoc::get(ctx), UnitAttr::get(ctx)); return op.getOutput(); } @@ -329,12 +329,12 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { }; auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); - auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), typeSymbol); return op.getOutput(); } std::string typeName = ("#" + predRec.getValueAsString("attrName")).str(); - auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), - StringAttr::get(ctx, typeName)); + auto op = irdl::BaseOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, typeName)); return op.getOutput(); } @@ -348,15 +348,15 @@ Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) { if (predRec.getName() == "AnyRegion") { ValueRange entryBlockArgs = {}; auto op = - builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs); + irdl::RegionOp::create(builder, UnknownLoc::get(ctx), entryBlockArgs); return op.getResult(); } if (predRec.isSubClassOf("SizedRegion")) { ValueRange entryBlockArgs = {}; auto ty = IntegerType::get(ctx, 32); - auto op = builder.create<irdl::RegionOp>( - UnknownLoc::get(ctx), entryBlockArgs, + auto op = irdl::RegionOp::create( + builder, UnknownLoc::get(ctx), entryBlockArgs, IntegerAttr::get(ty, predRec.getValueAsInt("blocks"))); return op.getResult(); } @@ -388,8 +388,8 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, MLIRContext *ctx = builder.getContext(); StringRef opName = getOperatorName(tblgenOp); - irdl::OperationOp op = builder.create<irdl::OperationOp>( - UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); + irdl::OperationOp op = irdl::OperationOp::create( + builder, UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); // Add the block in the region. Block &opBlock = op.getBody().emplaceBlock(); @@ -471,19 +471,19 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, // Create the operands and results operations. if (!operands.empty()) - consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands, - ArrayAttr::get(ctx, operandNames), - operandVariadicity); + irdl::OperandsOp::create(consBuilder, UnknownLoc::get(ctx), operands, + ArrayAttr::get(ctx, operandNames), + operandVariadicity); if (!results.empty()) - consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results, - ArrayAttr::get(ctx, resultNames), - resultVariadicity); + irdl::ResultsOp::create(consBuilder, UnknownLoc::get(ctx), results, + ArrayAttr::get(ctx, resultNames), + resultVariadicity); if (!attributes.empty()) - consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes, - ArrayAttr::get(ctx, attrNames)); + irdl::AttributesOp::create(consBuilder, UnknownLoc::get(ctx), attributes, + ArrayAttr::get(ctx, attrNames)); if (!regions.empty()) - consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions, - ArrayAttr::get(ctx, regionNames)); + irdl::RegionsOp::create(consBuilder, UnknownLoc::get(ctx), regions, + ArrayAttr::get(ctx, regionNames)); return op; } @@ -493,8 +493,8 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) { StringRef typeName = getTypeName(tblgenType); std::string combined = ("!" + typeName).str(); - irdl::TypeOp op = builder.create<irdl::TypeOp>( - UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); + irdl::TypeOp op = irdl::TypeOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, combined)); op.getBody().emplaceBlock(); @@ -507,8 +507,8 @@ irdl::AttributeOp createIRDLAttr(OpBuilder &builder, StringRef attrName = getAttrName(tblgenAttr); std::string combined = ("#" + attrName).str(); - irdl::AttributeOp op = builder.create<irdl::AttributeOp>( - UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); + irdl::AttributeOp op = irdl::AttributeOp::create( + builder, UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); op.getBody().emplaceBlock(); @@ -517,8 +517,8 @@ irdl::AttributeOp createIRDLAttr(OpBuilder &builder, static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { MLIRContext *ctx = builder.getContext(); - return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx), - StringAttr::get(ctx, selectedDialect)); + return irdl::DialectOp::create(builder, UnknownLoc::get(ctx), + StringAttr::get(ctx, selectedDialect)); } static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) { @@ -529,7 +529,7 @@ static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) { // Create a module op and set it as the insertion point. OwningOpRef<ModuleOp> module = - builder.create<ModuleOp>(UnknownLoc::get(&ctx)); + ModuleOp::create(builder, UnknownLoc::get(&ctx)); builder = builder.atBlockBegin(module->getBody()); // Create the dialect and insert it. irdl::DialectOp dialect = createIRDLDialect(builder); diff --git a/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp b/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp index f82ece0..020c0fe 100644 --- a/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp +++ b/mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp @@ -41,7 +41,7 @@ protected: builder.setInsertionPointToStart(&block); for (int i = 0; i < 4; ++i) // Ops will be deleted when `block` is destroyed. - v[i] = builder.create<ConstantIntOp>(builder.getUnknownLoc(), i, 32); + v[i] = ConstantIntOp::create(builder, builder.getUnknownLoc(), i, 32); } /// Checks that optimal branching on graph has the given cost and diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index 836efdb..6ac9a87 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -45,7 +45,7 @@ protected: template <typename Op> void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_FALSE(op->hasAsyncOnly()); for (auto d : dtypes) EXPECT_FALSE(op->hasAsyncOnly(d)); @@ -82,12 +82,12 @@ void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes) { auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); TypedValue<PointerLikeType> varPtr = cast<TypedValue<PointerLikeType>>(varPtrOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, varPtr, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, varPtr, + /*structured=*/true, /*implicit=*/true); EXPECT_FALSE(op->hasAsyncOnly()); for (auto d : dtypes) @@ -128,7 +128,7 @@ TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) { template <typename Op> void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); mlir::Value empty; EXPECT_EQ(op->getAsyncValue(), empty); @@ -136,7 +136,7 @@ void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, EXPECT_EQ(op->getAsyncValue(d), empty); OwningOpRef<arith::ConstantIndexOp> val = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); op->getAsyncOperandsMutable().assign(val->getResult()); @@ -158,12 +158,12 @@ void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes) { auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); TypedValue<PointerLikeType> varPtr = cast<TypedValue<PointerLikeType>>(varPtrOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, varPtr, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, varPtr, + /*structured=*/true, /*implicit=*/true); mlir::Value empty; EXPECT_EQ(op->getAsyncValue(), empty); @@ -171,7 +171,7 @@ void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc, EXPECT_EQ(op->getAsyncValue(d), empty); OwningOpRef<arith::ConstantIndexOp> val = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); op->getAsyncOperandsMutable().assign(val->getResult()); @@ -197,13 +197,13 @@ template <typename Op> void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes, llvm::SmallVector<DeviceType> &dtypesWithoutNone) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_EQ(op->getNumGangsValues().begin(), op->getNumGangsValues().end()); OwningOpRef<arith::ConstantIndexOp> val1 = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<arith::ConstantIndexOp> val2 = - b.create<arith::ConstantIndexOp>(loc, 4); + arith::ConstantIndexOp::create(b, loc, 4); auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); op->getNumGangsMutable().assign(val1->getResult()); op->setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); @@ -264,7 +264,7 @@ TEST_F(OpenACCOpsTest, numGangsValuesTest) { template <typename Op> void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); mlir::Value empty; EXPECT_EQ(op->getVectorLengthValue(), empty); @@ -272,7 +272,7 @@ void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc, EXPECT_EQ(op->getVectorLengthValue(d), empty); OwningOpRef<arith::ConstantIndexOp> val = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); op->setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); op->getVectorLengthMutable().assign(val->getResult()); @@ -292,7 +292,7 @@ template <typename Op> void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes, llvm::SmallVector<DeviceType> &dtypesWithoutNone) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_FALSE(op->hasWaitOnly()); for (auto d : dtypes) EXPECT_FALSE(op->hasWaitOnly(d)); @@ -332,15 +332,15 @@ template <typename Op> void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, llvm::SmallVector<DeviceType> &dtypes, llvm::SmallVector<DeviceType> &dtypesWithoutNone) { - OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<Op> op = Op::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_EQ(op->getWaitValues().begin(), op->getWaitValues().end()); OwningOpRef<arith::ConstantIndexOp> val1 = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<arith::ConstantIndexOp> val2 = - b.create<arith::ConstantIndexOp>(loc, 4); + arith::ConstantIndexOp::create(b, loc, 4); OwningOpRef<arith::ConstantIndexOp> val3 = - b.create<arith::ConstantIndexOp>(loc, 5); + arith::ConstantIndexOp::create(b, loc, 5); auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); op->getWaitOperandsMutable().assign(val1->getResult()); op->setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); @@ -426,7 +426,7 @@ TEST_F(OpenACCOpsTest, waitValuesTest) { } TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) { - OwningOpRef<LoopOp> op = b.create<LoopOp>(loc, TypeRange{}, ValueRange{}); + OwningOpRef<LoopOp> op = LoopOp::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_FALSE(op->hasGang()); EXPECT_FALSE(op->hasVector()); EXPECT_FALSE(op->hasWorker()); @@ -473,7 +473,7 @@ TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) { TEST_F(OpenACCOpsTest, routineOpTest) { OwningOpRef<RoutineOp> op = - b.create<RoutineOp>(loc, TypeRange{}, ValueRange{}); + RoutineOp::create(b, loc, TypeRange{}, ValueRange{}); EXPECT_FALSE(op->hasSeq()); EXPECT_FALSE(op->hasVector()); @@ -564,12 +564,12 @@ void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, Location loc, DataClause dataClause) { auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); TypedValue<PointerLikeType> varPtr = cast<TypedValue<PointerLikeType>>(varPtrOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, varPtr, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, varPtr, + /*structured=*/true, /*implicit=*/true); EXPECT_EQ(op->getVarPtr(), varPtr); EXPECT_EQ(op->getType(), memrefTy); @@ -579,24 +579,24 @@ void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context, EXPECT_TRUE(op->getBounds().empty()); EXPECT_FALSE(op->getVarPtrPtr()); - OwningOpRef<Op> op2 = b.create<Op>(loc, varPtr, - /*structured=*/false, /*implicit=*/false); + OwningOpRef<Op> op2 = Op::create(b, loc, varPtr, + /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); OwningOpRef<arith::ConstantIndexOp> extent = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<DataBoundsOp> bounds = - b.create<DataBoundsOp>(loc, extent->getResult()); + DataBoundsOp::create(b, loc, extent->getResult()); OwningOpRef<Op> opWithBounds = - b.create<Op>(loc, varPtr, - /*structured=*/true, /*implicit=*/true, bounds->getResult()); + Op::create(b, loc, varPtr, + /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef<Op> opWithName = - b.create<Op>(loc, varPtr, - /*structured=*/true, /*implicit=*/true, "varName"); + Op::create(b, loc, varPtr, + /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -637,17 +637,17 @@ void testShortDataExitOpBuilders(OpBuilder &b, MLIRContext &context, Location loc, DataClause dataClause) { auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); TypedValue<PointerLikeType> varPtr = cast<TypedValue<PointerLikeType>>(varPtrOp->getResult()); - OwningOpRef<GetDevicePtrOp> accPtrOp = b.create<GetDevicePtrOp>( - loc, varPtr, /*structured=*/true, /*implicit=*/true); + OwningOpRef<GetDevicePtrOp> accPtrOp = GetDevicePtrOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/true); TypedValue<PointerLikeType> accPtr = cast<TypedValue<PointerLikeType>>(accPtrOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, accPtr, varPtr, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, accPtr, varPtr, + /*structured=*/true, /*implicit=*/true); EXPECT_EQ(op->getVarPtr(), varPtr); EXPECT_EQ(op->getAccPtr(), accPtr); @@ -656,24 +656,24 @@ void testShortDataExitOpBuilders(OpBuilder &b, MLIRContext &context, EXPECT_TRUE(op->getStructured()); EXPECT_TRUE(op->getBounds().empty()); - OwningOpRef<Op> op2 = b.create<Op>(loc, accPtr, varPtr, - /*structured=*/false, /*implicit=*/false); + OwningOpRef<Op> op2 = Op::create(b, loc, accPtr, varPtr, + /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); OwningOpRef<arith::ConstantIndexOp> extent = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<DataBoundsOp> bounds = - b.create<DataBoundsOp>(loc, extent->getResult()); + DataBoundsOp::create(b, loc, extent->getResult()); OwningOpRef<Op> opWithBounds = - b.create<Op>(loc, accPtr, varPtr, - /*structured=*/true, /*implicit=*/true, bounds->getResult()); + Op::create(b, loc, accPtr, varPtr, + /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef<Op> opWithName = - b.create<Op>(loc, accPtr, varPtr, - /*structured=*/true, /*implicit=*/true, "varName"); + Op::create(b, loc, accPtr, varPtr, + /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -689,17 +689,17 @@ void testShortDataExitNoVarPtrOpBuilders(OpBuilder &b, MLIRContext &context, Location loc, DataClause dataClause) { auto memrefTy = MemRefType::get({}, b.getI32Type()); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); TypedValue<PointerLikeType> varPtr = cast<TypedValue<PointerLikeType>>(varPtrOp->getResult()); - OwningOpRef<GetDevicePtrOp> accPtrOp = b.create<GetDevicePtrOp>( - loc, varPtr, /*structured=*/true, /*implicit=*/true); + OwningOpRef<GetDevicePtrOp> accPtrOp = GetDevicePtrOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/true); TypedValue<PointerLikeType> accPtr = cast<TypedValue<PointerLikeType>>(accPtrOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, accPtr, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, accPtr, + /*structured=*/true, /*implicit=*/true); EXPECT_EQ(op->getAccPtr(), accPtr); EXPECT_EQ(op->getDataClause(), dataClause); @@ -707,24 +707,24 @@ void testShortDataExitNoVarPtrOpBuilders(OpBuilder &b, MLIRContext &context, EXPECT_TRUE(op->getStructured()); EXPECT_TRUE(op->getBounds().empty()); - OwningOpRef<Op> op2 = b.create<Op>(loc, accPtr, - /*structured=*/false, /*implicit=*/false); + OwningOpRef<Op> op2 = Op::create(b, loc, accPtr, + /*structured=*/false, /*implicit=*/false); EXPECT_FALSE(op2->getImplicit()); EXPECT_FALSE(op2->getStructured()); OwningOpRef<arith::ConstantIndexOp> extent = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<DataBoundsOp> bounds = - b.create<DataBoundsOp>(loc, extent->getResult()); + DataBoundsOp::create(b, loc, extent->getResult()); OwningOpRef<Op> opWithBounds = - b.create<Op>(loc, accPtr, - /*structured=*/true, /*implicit=*/true, bounds->getResult()); + Op::create(b, loc, accPtr, + /*structured=*/true, /*implicit=*/true, bounds->getResult()); EXPECT_FALSE(opWithBounds->getBounds().empty()); EXPECT_EQ(opWithBounds->getBounds().back(), bounds->getResult()); OwningOpRef<Op> opWithName = - b.create<Op>(loc, accPtr, - /*structured=*/true, /*implicit=*/true, "varName"); + Op::create(b, loc, accPtr, + /*structured=*/true, /*implicit=*/true, "varName"); EXPECT_EQ(opWithName->getNameAttr().str(), "varName"); } @@ -742,16 +742,16 @@ void testShortDataEntryOpBuildersMappableVar(OpBuilder &b, MLIRContext &context, auto int64Ty = b.getI64Type(); auto memrefTy = MemRefType::get({}, int64Ty); OwningOpRef<memref::AllocaOp> varPtrOp = - b.create<memref::AllocaOp>(loc, memrefTy); + memref::AllocaOp::create(b, loc, memrefTy); SmallVector<Value> indices; OwningOpRef<memref::LoadOp> loadVarOp = - b.create<memref::LoadOp>(loc, int64Ty, varPtrOp->getResult(), indices); + memref::LoadOp::create(b, loc, int64Ty, varPtrOp->getResult(), indices); EXPECT_TRUE(isMappableType(loadVarOp->getResult().getType())); TypedValue<MappableType> var = cast<TypedValue<MappableType>>(loadVarOp->getResult()); - OwningOpRef<Op> op = b.create<Op>(loc, var, - /*structured=*/true, /*implicit=*/true); + OwningOpRef<Op> op = Op::create(b, loc, var, + /*structured=*/true, /*implicit=*/true); EXPECT_EQ(op->getVar(), var); EXPECT_EQ(op->getVarPtr(), nullptr); diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index fecd960..ef23123 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -119,45 +119,45 @@ protected: TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) { OwningOpRef<arith::ConstantIndexOp> lb = - b.create<arith::ConstantIndexOp>(loc, 0); + arith::ConstantIndexOp::create(b, loc, 0); OwningOpRef<arith::ConstantIndexOp> ub = - b.create<arith::ConstantIndexOp>(loc, 10); + arith::ConstantIndexOp::create(b, loc, 10); OwningOpRef<arith::ConstantIndexOp> step = - b.create<arith::ConstantIndexOp>(loc, 2); + arith::ConstantIndexOp::create(b, loc, 2); OwningOpRef<scf::ForOp> forOp = - b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get()); + scf::ForOp::create(b, loc, lb.get(), ub.get(), step.get()); checkUnidimensional(forOp.get()); - OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( - loc, ArrayRef<OpFoldResult>(lb->getResult()), + OwningOpRef<scf::ForallOp> forallOp = scf::ForallOp::create( + b, loc, ArrayRef<OpFoldResult>(lb->getResult()), ArrayRef<OpFoldResult>(ub->getResult()), ArrayRef<OpFoldResult>(step->getResult()), ValueRange(), std::nullopt); checkUnidimensional(forallOp.get()); - OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( - loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()), + OwningOpRef<scf::ParallelOp> parallelOp = scf::ParallelOp::create( + b, loc, ValueRange(lb->getResult()), ValueRange(ub->getResult()), ValueRange(step->getResult()), ValueRange()); checkUnidimensional(parallelOp.get()); } TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { OwningOpRef<arith::ConstantIndexOp> lb = - b.create<arith::ConstantIndexOp>(loc, 0); + arith::ConstantIndexOp::create(b, loc, 0); OwningOpRef<arith::ConstantIndexOp> ub = - b.create<arith::ConstantIndexOp>(loc, 10); + arith::ConstantIndexOp::create(b, loc, 10); OwningOpRef<arith::ConstantIndexOp> step = - b.create<arith::ConstantIndexOp>(loc, 2); + arith::ConstantIndexOp::create(b, loc, 2); - OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( - loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), + OwningOpRef<scf::ForallOp> forallOp = scf::ForallOp::create( + b, loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}), ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}), ValueRange(), std::nullopt); checkMultidimensional(forallOp.get()); - OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( - loc, ValueRange({lb->getResult(), lb->getResult()}), + OwningOpRef<scf::ParallelOp> parallelOp = scf::ParallelOp::create( + b, loc, ValueRange({lb->getResult(), lb->getResult()}), ValueRange({ub->getResult(), ub->getResult()}), ValueRange({step->getResult(), step->getResult()}), ValueRange()); checkMultidimensional(parallelOp.get()); @@ -165,22 +165,22 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { TEST_F(SCFLoopLikeTest, testForallNormalize) { OwningOpRef<arith::ConstantIndexOp> lb = - b.create<arith::ConstantIndexOp>(loc, 1); + arith::ConstantIndexOp::create(b, loc, 1); OwningOpRef<arith::ConstantIndexOp> ub = - b.create<arith::ConstantIndexOp>(loc, 10); + arith::ConstantIndexOp::create(b, loc, 10); OwningOpRef<arith::ConstantIndexOp> step = - b.create<arith::ConstantIndexOp>(loc, 3); + arith::ConstantIndexOp::create(b, loc, 3); - scf::ForallOp forallOp = b.create<scf::ForallOp>( - loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), + scf::ForallOp forallOp = scf::ForallOp::create( + b, loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}), ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}), ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}), ValueRange(), std::nullopt); // Create a user of the induction variable. Bitcast is chosen for simplicity // since it is unary. b.setInsertionPointToStart(forallOp.getBody()); - b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(), - forallOp.getInductionVar(0)); + arith::BitcastOp::create(b, UnknownLoc::get(&context), b.getF64Type(), + forallOp.getInductionVar(0)); IRRewriter rewriter(b); FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp); diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp index d7c57f0..5cbc019 100644 --- a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp +++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp @@ -26,10 +26,10 @@ TEST(QuantifierTest, ExistsBuilderWithPattern) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ExistsOp> existsOp = ExistsOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, std::nullopt, [](OpBuilder &builder, Location loc, ValueRange boundVars) { @@ -57,10 +57,10 @@ TEST(QuantifierTest, ExistsBuilderNoPattern) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ExistsOp> existsOp = ExistsOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); @@ -82,10 +82,10 @@ TEST(QuantifierTest, ExistsBuilderDefault) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ExistsOp> existsOp = ExistsOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, ArrayRef<StringRef>{"a", "b"}); @@ -111,10 +111,10 @@ TEST(QuantifierTest, ForallBuilderWithPattern) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ForallOp> forallOp = ForallOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, ArrayRef<StringRef>{"a", "b"}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { @@ -142,10 +142,10 @@ TEST(QuantifierTest, ForallBuilderNoPattern) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ForallOp> forallOp = ForallOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); @@ -167,10 +167,10 @@ TEST(QuantifierTest, ForallBuilderDefault) { OpBuilder builder(&context); auto boolTy = BoolType::get(&context); - OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>( - loc, TypeRange{boolTy, boolTy}, + OwningOpRef<ForallOp> forallOp = ForallOp::create( + builder, loc, TypeRange{boolTy, boolTy}, [](OpBuilder &builder, Location loc, ValueRange boundVars) { - return builder.create<AndOp>(loc, boundVars); + return AndOp::create(builder, loc, boundVars); }, std::nullopt); diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index ef89c16..af55296 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -71,8 +71,8 @@ protected: spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { OpBuilder builder(module->getRegion()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); - return builder.create<spirv::GlobalVariableOp>( - UnknownLoc::get(&context), TypeAttr::get(ptrType), + return spirv::GlobalVariableOp::create( + builder, UnknownLoc::get(&context), TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr); } @@ -82,14 +82,14 @@ protected: auto loc = UnknownLoc::get(&context); if (auto intType = dyn_cast<IntegerType>(type)) { - return builder.create<spirv::ConstantOp>( - loc, type, builder.getIntegerAttr(type, val)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, val)); } if (auto vectorType = dyn_cast<VectorType>(type)) { Type elemType = vectorType.getElementType(); if (auto intType = dyn_cast<IntegerType>(elemType)) { - return builder.create<spirv::ConstantOp>( - loc, type, + return spirv::ConstantOp::create( + builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, val).getValue())); } diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt index 4ef69a8..b83163e 100644 --- a/mlir/unittests/ExecutionEngine/CMakeLists.txt +++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -10,14 +10,13 @@ add_mlir_unittest(MLIRExecutionEngineTests StridedMemRef.cpp Invoke.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) mlir_target_link_libraries(MLIRExecutionEngineTests PRIVATE MLIRArithToLLVM MLIRMemRefToLLVM MLIRReconcileUnrealizedCasts - ${dialect_libs} + MLIRRegisterAllDialects ) target_link_libraries(MLIRExecutionEngineTests PRIVATE diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index a55592d..fd40404 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -477,8 +477,9 @@ TEST(SubElementTest, Nested) { {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); } -// Test how many times we call copy-ctor when building an attribute. -TEST(CopyCountAttr, CopyCount) { +// Test how many times we call copy-ctor when building an attribute with the +// 'get' method. +TEST(CopyCountAttr, CopyCountGet) { MLIRContext context; context.loadDialect<test::TestDialect>(); @@ -489,15 +490,35 @@ TEST(CopyCountAttr, CopyCount) { test::CopyCount::counter = 0; test::TestCopyCountAttr::get(&context, std::move(copyCount)); #ifndef NDEBUG - // One verification enabled only in assert-mode requires a copy. - EXPECT_EQ(counter1, 1); - EXPECT_EQ(test::CopyCount::counter, 1); + // One verification enabled only in assert-mode requires two copies: one for + // calling 'verifyInvariants' and one for calling 'verify' inside + // 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); #else EXPECT_EQ(counter1, 0); EXPECT_EQ(test::CopyCount::counter, 0); #endif } +// Test how many times we call copy-ctor when building an attribute with the +// 'getChecked' method. +TEST(CopyCountAttr, CopyCountGetChecked) { + MLIRContext context; + context.loadDialect<test::TestDialect>(); + test::CopyCount::counter = 0; + test::CopyCount copyCount("hello"); + auto loc = UnknownLoc::get(&context); + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + int counter1 = test::CopyCount::counter; + test::CopyCount::counter = 0; + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + // The verifiers require two copies: one for calling 'verifyInvariants' and + // one for calling 'verify' inside 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); +} + // Test stripped printing using test dialect attribute. TEST(CopyCountAttr, PrintStripped) { MLIRContext context; diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp index b88009d..983c41a 100644 --- a/mlir/unittests/IR/IRMapping.cpp +++ b/mlir/unittests/IR/IRMapping.cpp @@ -26,10 +26,10 @@ TEST(IRMapping, TypedValue) { Block block; builder.setInsertionPointToEnd(&block); - Value i64Val = builder.create<test::TestOpConstant>( - loc, builder.getI64Type(), builder.getI64IntegerAttr(0)); - Value f64Val = builder.create<test::TestOpConstant>( - loc, builder.getF64Type(), builder.getF64FloatAttr(0.0)); + Value i64Val = test::TestOpConstant::create( + builder, loc, builder.getI64Type(), builder.getI64IntegerAttr(0)); + Value f64Val = test::TestOpConstant::create( + builder, loc, builder.getF64Type(), builder.getF64FloatAttr(0.0)); IRMapping mapping; mapping.map(i64Val, f64Val); diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index 1b5d3b8..e1e65da 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -303,7 +303,7 @@ TEST(InterfaceAttachment, Operation) { // Initially, the operation doesn't have the interface. OwningOpRef<ModuleOp> moduleOp = - builder.create<ModuleOp>(UnknownLoc::get(&context)); + ModuleOp::create(builder, UnknownLoc::get(&context)); ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); // We can attach an external interface and now the operaiton has it. @@ -317,8 +317,8 @@ TEST(InterfaceAttachment, Operation) { // Default implementation can be overridden. OwningOpRef<UnrealizedConversionCastOp> castOp = - builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), - TypeRange(), ValueRange()); + UnrealizedConversionCastOp::create(builder, UnknownLoc::get(&context), + TypeRange(), ValueRange()); ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( context); @@ -368,11 +368,11 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) { OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module->getBody(), module->getBody()->begin()); auto opJ = - builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); + test::OpJ::create(builder, builder.getUnknownLoc(), builder.getI32Type()); auto opH = - builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); + test::OpH::create(builder, builder.getUnknownLoc(), opJ.getResult()); auto opI = - builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); + test::OpI::create(builder, builder.getUnknownLoc(), opJ.getResult()); EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); @@ -399,11 +399,11 @@ TEST(InterfaceAttachment, OperationDelayedContextAppend) { OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module->getBody(), module->getBody()->begin()); auto opJ = - builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); + test::OpJ::create(builder, builder.getUnknownLoc(), builder.getI32Type()); auto opH = - builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); + test::OpH::create(builder, builder.getUnknownLoc(), opJ.getResult()); auto opI = - builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); + test::OpI::create(builder, builder.getUnknownLoc(), opJ.getResult()); EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 42196b0..235163c 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -27,12 +27,12 @@ TEST(InterfaceTest, OpInterfaceDenseMapKey) { OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module->getBody(), module->getBody()->begin()); - auto op1 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(), - builder.getI32Type()); - auto op2 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(), - builder.getI32Type()); - auto op3 = builder.create<test::SideEffectOp>(builder.getUnknownLoc(), - builder.getI32Type()); + auto op1 = test::SideEffectOp::create(builder, builder.getUnknownLoc(), + builder.getI32Type()); + auto op2 = test::SideEffectOp::create(builder, builder.getUnknownLoc(), + builder.getI32Type()); + auto op3 = test::SideEffectOp::create(builder, builder.getUnknownLoc(), + builder.getI32Type()); DenseSet<MemoryEffectOpInterface> opSet; opSet.insert(op1); opSet.insert(op2); @@ -64,8 +64,8 @@ TEST(InterfaceTest, TestCustomClassOf) { context.loadDialect<test::TestDialect>(); OpBuilder builder(&context); - auto op = builder.create<TestOpOptionallyImplementingInterface>( - builder.getUnknownLoc(), /*implementsInterface=*/true); + auto op = TestOpOptionallyImplementingInterface::create( + builder, builder.getUnknownLoc(), /*implementsInterface=*/true); EXPECT_TRUE(isa<TestOptionallyImplementedOpInterface>(*op)); op.setImplementsInterface(false); EXPECT_FALSE(isa<TestOptionallyImplementedOpInterface>(*op)); diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 7bc1a04..9f3e7ed 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -302,8 +302,8 @@ TEST(OperandStorageTest, PopulateDefaultAttrs) { auto req1 = b.getI32IntegerAttr(10); auto req2 = b.getI32IntegerAttr(60); // Verify default attributes populated post op creation. - Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr, - nullptr, req2); + Operation *op = test::OpAttrMatch1::create(b, b.getUnknownLoc(), req1, + nullptr, nullptr, req2); auto opt = op->getInherentAttr("default_valued_attr"); EXPECT_NE(opt, nullptr) << *op; @@ -343,11 +343,11 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) { // Check ignore properties. auto req1 = b.getI32IntegerAttr(10); - Operation *opWithProperty1 = b.create<test::OpAttrMatch1>( - b.getUnknownLoc(), req1, nullptr, nullptr, req1); + Operation *opWithProperty1 = test::OpAttrMatch1::create( + b, b.getUnknownLoc(), req1, nullptr, nullptr, req1); auto req2 = b.getI32IntegerAttr(60); - Operation *opWithProperty2 = b.create<test::OpAttrMatch1>( - b.getUnknownLoc(), req2, nullptr, nullptr, req2); + Operation *opWithProperty2 = test::OpAttrMatch1::create( + b, b.getUnknownLoc(), req2, nullptr, nullptr, req2); EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties), getHash(opWithProperty2, OperationEquivalence::IgnoreProperties)); EXPECT_NE(getHash(opWithProperty1, OperationEquivalence::None), diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index cfc3fe0..4b3545b 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { }); } +TEST(SymbolOpInterface, Visibility) { + DialectRegistry registry; + ::test::registerTestDialect(registry); + MLIRContext context(registry); + + constexpr static StringLiteral kInput = R"MLIR( + "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> () + )MLIR"; + OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(kInput, &context); + auto symOp = cast<SymbolOpInterface>(module->getBody()->front()); + + ASSERT_TRUE(symOp.isPrivate()); + ASSERT_FALSE(symOp.isPublic()); + ASSERT_FALSE(symOp.isNested()); + ASSERT_TRUE(symOp.canDiscardOnUseEmpty()); + + std::string diagStr; + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { diagStr += diag.str(); }); + + std::string expectedDiag; + symOp.setPublic(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to public"; + symOp.setNested(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to nested"; + symOp.setPrivate(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to private"; + + ASSERT_EQ(diagStr, expectedDiag); +} + } // namespace diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp index 74e62aa..0943033 100644 --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -36,12 +36,11 @@ protected: OpBuildGenTest() : ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()), i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()), - cstI32(builder.create<test::TableGenConstant>(loc, i32Ty)), - cstF32(builder.create<test::TableGenConstant>(loc, f32Ty)), - noAttrs(), attrStorage{builder.getNamedAttr("attr0", - builder.getBoolAttr(true)), - builder.getNamedAttr( - "attr1", builder.getI32IntegerAttr(33))}, + cstI32(test::TableGenConstant::create(builder, loc, i32Ty)), + cstF32(test::TableGenConstant::create(builder, loc, f32Ty)), noAttrs(), + attrStorage{ + builder.getNamedAttr("attr0", builder.getBoolAttr(true)), + builder.getNamedAttr("attr1", builder.getI32IntegerAttr(33))}, attrs(attrStorage) {} // Verify that `op` has the given set of result types, operands, and @@ -123,21 +122,21 @@ protected: /// Test basic build methods. TEST_F(OpBuildGenTest, BasicBuildMethods) { // Test separate args, separate results build method. - auto op = builder.create<test::TableGenBuildOp0>(loc, i32Ty, *cstI32); + auto op = test::TableGenBuildOp0::create(builder, loc, i32Ty, *cstI32); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test separate args, collective results build method. - op = builder.create<test::TableGenBuildOp0>(loc, TypeRange{i32Ty}, *cstI32); + op = test::TableGenBuildOp0::create(builder, loc, TypeRange{i32Ty}, *cstI32); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test collective args, collective params build method. - op = builder.create<test::TableGenBuildOp0>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32}); + op = test::TableGenBuildOp0::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32}); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test collective args, collective results, non-empty attributes - op = builder.create<test::TableGenBuildOp0>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32}, attrs); + op = test::TableGenBuildOp0::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32}, attrs); verifyOp(op, {i32Ty}, {*cstI32}, attrs); } @@ -154,25 +153,25 @@ TEST_F(OpBuildGenTest, BasicBuildMethods) { /// variadic result. TEST_F(OpBuildGenTest, BuildMethodsSingleVariadicArgAndResult) { // Test collective args, collective results method, building a unary op. - auto op = builder.create<test::TableGenBuildOp1>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32}); + auto op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32}); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test collective args, collective results method, building a unary op with // named attributes. - op = builder.create<test::TableGenBuildOp1>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32}, attrs); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32}, attrs); verifyOp(op, {i32Ty}, {*cstI32}, attrs); // Test collective args, collective results method, building a binary op. - op = builder.create<test::TableGenBuildOp1>(loc, TypeRange{i32Ty, f32Ty}, - ValueRange{*cstI32, *cstF32}); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty, f32Ty}, + ValueRange{*cstI32, *cstF32}); verifyOp(op, {i32Ty, f32Ty}, {*cstI32, *cstF32}, noAttrs); // Test collective args, collective results method, building a binary op with // named attributes. - op = builder.create<test::TableGenBuildOp1>( - loc, TypeRange{i32Ty, f32Ty}, ValueRange{*cstI32, *cstF32}, attrs); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty, f32Ty}, + ValueRange{*cstI32, *cstF32}, attrs); verifyOp(op, {i32Ty, f32Ty}, {*cstI32, *cstF32}, attrs); } @@ -181,22 +180,22 @@ TEST_F(OpBuildGenTest, BuildMethodsSingleVariadicArgAndResult) { TEST_F(OpBuildGenTest, BuildMethodsSingleVariadicArgNonVariadicResults) { // Test separate arg, separate param build method. auto op = - builder.create<test::TableGenBuildOp1>(loc, i32Ty, ValueRange{*cstI32}); + test::TableGenBuildOp1::create(builder, loc, i32Ty, ValueRange{*cstI32}); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test collective params build method, no attributes. - op = builder.create<test::TableGenBuildOp1>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32}); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32}); verifyOp(op, {i32Ty}, {*cstI32}, noAttrs); // Test collective params build method no attributes, 2 inputs. - op = builder.create<test::TableGenBuildOp1>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32, *cstF32}); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32, *cstF32}); verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs); // Test collective params build method, non-empty attributes. - op = builder.create<test::TableGenBuildOp1>( - loc, TypeRange{i32Ty}, ValueRange{*cstI32, *cstF32}, attrs); + op = test::TableGenBuildOp1::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32, *cstF32}, attrs); verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, attrs); } @@ -205,18 +204,18 @@ TEST_F(OpBuildGenTest, BuildMethodsSingleVariadicArgNonVariadicResults) { TEST_F(OpBuildGenTest, BuildMethodsSingleVariadicArgAndMultipleVariadicResults) { // Test separate arg, separate param build method. - auto op = builder.create<test::TableGenBuildOp3>( - loc, TypeRange{i32Ty}, TypeRange{f32Ty}, ValueRange{*cstI32}); + auto op = test::TableGenBuildOp3::create( + builder, loc, TypeRange{i32Ty}, TypeRange{f32Ty}, ValueRange{*cstI32}); verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, noAttrs); // Test collective params build method, no attributes. - op = builder.create<test::TableGenBuildOp3>(loc, TypeRange{i32Ty, f32Ty}, - ValueRange{*cstI32}); + op = test::TableGenBuildOp3::create(builder, loc, TypeRange{i32Ty, f32Ty}, + ValueRange{*cstI32}); verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, noAttrs); // Test collective params build method, with attributes. - op = builder.create<test::TableGenBuildOp3>(loc, TypeRange{i32Ty, f32Ty}, - ValueRange{*cstI32}, attrs); + op = test::TableGenBuildOp3::create(builder, loc, TypeRange{i32Ty, f32Ty}, + ValueRange{*cstI32}, attrs); verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs); } @@ -227,29 +226,29 @@ TEST_F(OpBuildGenTest, // build methods with no result types as they are inferred from the input types. TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) { // Test separate arg, separate param build method. - auto op = builder.create<test::TableGenBuildOp4>( - loc, i32Ty, ValueRange{*cstI32, *cstI32}); + auto op = test::TableGenBuildOp4::create(builder, loc, i32Ty, + ValueRange{*cstI32, *cstI32}); verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs); // Test collective params build method. - op = builder.create<test::TableGenBuildOp4>(loc, TypeRange{i32Ty}, - ValueRange{*cstI32, *cstI32}); + op = test::TableGenBuildOp4::create(builder, loc, TypeRange{i32Ty}, + ValueRange{*cstI32, *cstI32}); verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs); // Test build method with no result types, default value of attributes. - op = - builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32}); + op = test::TableGenBuildOp4::create(builder, loc, + ValueRange{*cstI32, *cstI32}); verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs); // Test build method with no result types and supplied attributes. - op = builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32}, - attrs); + op = test::TableGenBuildOp4::create(builder, loc, + ValueRange{*cstI32, *cstI32}, attrs); verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs); } TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) { - auto op = builder.create<test::TableGenBuildOp5>( - loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs); + auto op = test::TableGenBuildOp5::create( + builder, loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs); ASSERT_EQ(op->getNumRegions(), 1u); verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs); } @@ -266,28 +265,28 @@ TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) { ArrayRef<NamedAttribute> attrs(attrsStorage); // Test separate arg, separate param build method. - auto op = builder.create<test::TableGenBuildOp6>( - loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32}); + auto op = test::TableGenBuildOp6::create( + builder, loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32}); verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs); // Test build method with no result types, default value of attributes. - op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32}, - ValueRange{*cstI32}); + op = test::TableGenBuildOp6::create(builder, loc, ValueRange{*cstI32}, + ValueRange{*cstI32}); verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs); // Test collective params build method. - op = builder.create<test::TableGenBuildOp6>( - loc, TypeRange{f32Ty}, ValueRange{*cstI32}, ValueRange{*cstI32}); + op = test::TableGenBuildOp6::create(builder, loc, TypeRange{f32Ty}, + ValueRange{*cstI32}, ValueRange{*cstI32}); verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs); // Test build method with result types, supplied attributes. - op = builder.create<test::TableGenBuildOp6>( - loc, TypeRange{f32Ty}, ValueRange{*cstI32, *cstI32}, attrs); + op = test::TableGenBuildOp6::create(builder, loc, TypeRange{f32Ty}, + ValueRange{*cstI32, *cstI32}, attrs); verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs); // Test build method with no result types and supplied attributes. - op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32, *cstI32}, - attrs); + op = test::TableGenBuildOp6::create(builder, loc, + ValueRange{*cstI32, *cstI32}, attrs); verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs); } @@ -295,14 +294,14 @@ TEST_F(OpBuildGenTest, BuildMethodsInherentDiscardableAttrs) { test::TableGenBuildOp7::Properties props; props.attr0 = cast<BoolAttr>(attrs[0].getValue()); ArrayRef<NamedAttribute> discardableAttrs = attrs.drop_front(); - auto op7 = builder.create<test::TableGenBuildOp7>( - loc, TypeRange{}, ValueRange{}, props, discardableAttrs); + auto op7 = test::TableGenBuildOp7::create( + builder, loc, TypeRange{}, ValueRange{}, props, discardableAttrs); verifyOp(op7, {}, {}, attrs); // Check that the old-style builder where all the attributes go in the same // place works. - auto op7b = builder.create<test::TableGenBuildOp7>(loc, TypeRange{}, - ValueRange{}, attrs); + auto op7b = test::TableGenBuildOp7::create(builder, loc, TypeRange{}, + ValueRange{}, attrs); // Note: this goes before verifyOp() because verifyOp() calls erase(), causing // use-after-free. ASSERT_EQ(op7b.getProperties().getAttr0(), attrs[0].getValue()); diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt index 15835b9..0a77deb 100644 --- a/mlir/unittests/Target/LLVM/CMakeLists.txt +++ b/mlir/unittests/Target/LLVM/CMakeLists.txt @@ -1,13 +1,11 @@ -set(LLVM_LINK_COMPONENTS nativecodegen) - -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +set(LLVM_LINK_COMPONENTS nativecodegen BitReader) add_mlir_unittest(MLIRTargetLLVMTests SerializeNVVMTarget.cpp SerializeROCDLTarget.cpp SerializeToLLVMBitcode.cpp DEPENDS - ${dialect_libs} + MLIRRegisterAllDialects ) mlir_target_link_libraries(MLIRTargetLLVMTests |